/*
 * Decompiled with CFR 0.152.
 */
package edu.duke.cs.osprey.astar.conf;

import edu.duke.cs.osprey.Queue;
import edu.duke.cs.osprey.astar.AStarProgress;
import edu.duke.cs.osprey.astar.conf.ConfAStarFactory;
import edu.duke.cs.osprey.astar.conf.ConfAStarNode;
import edu.duke.cs.osprey.astar.conf.ConfIndex;
import edu.duke.cs.osprey.astar.conf.RCs;
import edu.duke.cs.osprey.astar.conf.linked.LinkedConfAStarFactory;
import edu.duke.cs.osprey.astar.conf.order.AStarOrder;
import edu.duke.cs.osprey.astar.conf.order.DynamicHMeanAStarOrder;
import edu.duke.cs.osprey.astar.conf.order.StaticScoreHMeanAStarOrder;
import edu.duke.cs.osprey.astar.conf.pruning.AStarPruner;
import edu.duke.cs.osprey.astar.conf.scoring.AStarScorer;
import edu.duke.cs.osprey.astar.conf.scoring.MPLPPairwiseHScorer;
import edu.duke.cs.osprey.astar.conf.scoring.PairwiseGScorer;
import edu.duke.cs.osprey.astar.conf.scoring.TraditionalPairwiseHScorer;
import edu.duke.cs.osprey.astar.conf.scoring.mplp.MPLPUpdater;
import edu.duke.cs.osprey.astar.conf.scoring.mplp.NodeUpdater;
import edu.duke.cs.osprey.astar.conf.smastar.ConfSMAStarNode;
import edu.duke.cs.osprey.astar.conf.smastar.ConfSMAStarQueue;
import edu.duke.cs.osprey.confspace.ConfSearch;
import edu.duke.cs.osprey.confspace.RCTuple;
import edu.duke.cs.osprey.confspace.SimpleConfSpace;
import edu.duke.cs.osprey.confspace.compiled.ConfSpace;
import edu.duke.cs.osprey.ematrix.EnergyMatrix;
import edu.duke.cs.osprey.lute.LUTEConfEnergyCalculator;
import edu.duke.cs.osprey.lute.LUTEGScorer;
import edu.duke.cs.osprey.lute.LUTEHScorer;
import edu.duke.cs.osprey.parallelism.Parallelism;
import edu.duke.cs.osprey.parallelism.TaskExecutor;
import edu.duke.cs.osprey.pruning.PruningMatrix;
import edu.duke.cs.osprey.tools.MathTools;
import edu.duke.cs.osprey.tools.ObjectPool;
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.List;

public class ConfAStarTree
implements ConfSearch {
    public final AStarOrder order;
    public final AStarScorer gscorer;
    public final AStarScorer hscorer;
    public final MathTools.Optimizer optimizer;
    public final RCs rcs;
    public final ConfAStarFactory factory;
    public final AStarPruner pruner;
    private final AStarImpl impl;
    private final ConfIndex confIndex;
    private AStarProgress progress;
    private TaskExecutor tasks;
    private ObjectPool<ScoreContext> contexts;

    public static MPLPBuilder MPLPBuilder() {
        return new MPLPBuilder();
    }

    private ConfAStarTree(AStarOrder order, AStarScorer gscorer, AStarScorer hscorer, MathTools.Optimizer optimizer, RCs rcs, ConfAStarFactory factory, AStarPruner pruner, Long maxNumNodes) {
        this.order = order;
        this.gscorer = gscorer;
        this.hscorer = hscorer;
        this.optimizer = optimizer;
        this.rcs = rcs;
        this.factory = factory;
        this.pruner = pruner;
        this.impl = maxNumNodes != null ? new SimplifiedBoundedImpl(maxNumNodes) : new UnboundedImpl();
        this.confIndex = new ConfIndex(this.rcs.getNumPos());
        this.progress = null;
        this.order.setScorers(this.gscorer, this.hscorer);
        this.contexts = new ObjectPool<ScoreContext>(ingored -> {
            ScoreContext context = new ScoreContext();
            context.index = new ConfIndex(rcs.getNumPos());
            context.gscorer = gscorer.make();
            context.hscorer = hscorer.make();
            return context;
        });
        this.setParallelism(null);
    }

    public void initProgress() {
        this.progress = new AStarProgress(this.rcs.getNumPos());
    }

    public AStarProgress getProgress() {
        return this.progress;
    }

    public void stopProgress() {
        this.progress = null;
    }

    public void setParallelism(Parallelism val) {
        if (val == null) {
            val = Parallelism.makeCpu(1);
        }
        this.tasks = val.makeTaskExecutor(1000);
        this.contexts.allocate(val.getParallelism());
    }

    @Override
    public BigInteger getNumConformations() {
        return this.rcs.getNumConformations();
    }

    @Override
    public ConfSearch.ScoredConf nextConf() {
        return this.impl.nextConf();
    }

    @Override
    public List<ConfSearch.ScoredConf> nextConfs(double thresholdEnergy) {
        ConfSearch.ScoredConf conf;
        ArrayList<ConfSearch.ScoredConf> confs = new ArrayList<ConfSearch.ScoredConf>();
        if (this.progress != null) {
            this.progress.setGoalScore(thresholdEnergy);
        }
        while ((conf = this.nextConf()) != null) {
            confs.add(conf);
            if (this.optimizer.isBetter(conf.getScore(), thresholdEnergy)) continue;
            break;
        }
        return confs;
    }

    private boolean isPruned(ConfIndex confIndex, int nextPos, int nextRc) {
        PruningMatrix pmat = this.rcs.getPruneMat();
        if (pmat == null) {
            return false;
        }
        for (int i = 0; i < confIndex.numDefined; ++i) {
            int pos = confIndex.definedPos[i];
            int rc = confIndex.definedRCs[i];
            assert (pos != nextPos || rc != nextRc);
            if (!pmat.getPairwise(pos, rc, nextPos, nextRc).booleanValue()) continue;
            return true;
        }
        if (pmat.hasHigherOrderTuples()) {
            RCTuple tuple = new RCTuple(0, 0, 0, 0, 0, 0);
            for (int i1 = 0; i1 < confIndex.numDefined; ++i1) {
                int pos1 = confIndex.definedPos[i1];
                int rc1 = confIndex.definedRCs[i1];
                assert (pos1 != nextPos || rc1 != nextRc);
                for (int i2 = 0; i2 < i1; ++i2) {
                    int pos2 = confIndex.definedPos[i2];
                    int rc2 = confIndex.definedRCs[i2];
                    assert (pos2 != nextPos || rc2 != nextRc);
                    tuple.set(pos1, rc1, pos2, rc2, nextPos, nextRc);
                    tuple.sortPositions();
                    if (!pmat.getTuple(tuple).booleanValue()) continue;
                    return true;
                }
            }
        }
        return false;
    }

    public static class MPLPBuilder {
        private MPLPUpdater updater = new NodeUpdater();
        private int numIterations = 1;
        private double convergenceThreshold = 1.0E-4;

        public MPLPBuilder setUpdater(MPLPUpdater val) {
            this.updater = val;
            return this;
        }

        public MPLPBuilder setNumIterations(int val) {
            this.numIterations = val;
            return this;
        }

        public MPLPBuilder setConvergenceThreshold(double val) {
            this.convergenceThreshold = val;
            return this;
        }
    }

    private class SimplifiedBoundedImpl
    implements AStarImpl {
        private final long maxNumNodes;
        private final ConfSMAStarQueue q;
        private ConfSMAStarNode rootNode = null;
        private long numNodes = 0L;

        SimplifiedBoundedImpl(long maxNumNodes) {
            if (maxNumNodes <= (long)ConfAStarTree.this.rcs.getNumPos()) {
                throw new IllegalArgumentException(String.format("SMA* needs space for at least %d nodes for this problem (i.e., numPos + 1)", ConfAStarTree.this.rcs.getNumPos() + 1));
            }
            if (ConfAStarTree.this.order.isDynamic()) {
                throw new IllegalArgumentException("SMA* can only use static position orders, because of the node forgetting mechanism. If using ConfAStarTree.Builder, call setMaxNumNodes() before setTraditional()/setLUTE() so the builder can choose the correct heuristics.");
            }
            this.maxNumNodes = maxNumNodes;
            this.q = new ConfSMAStarQueue();
            ++this.numNodes;
        }

        @Override
        public ConfSearch.ScoredConf nextConf() {
            int numPos = ConfAStarTree.this.rcs.getNumPos();
            if (this.rootNode == null) {
                this.rootNode = new ConfSMAStarNode();
                this.rootNode.index(ConfAStarTree.this.confIndex);
                this.rootNode.setGScore(ConfAStarTree.this.gscorer.calc(ConfAStarTree.this.confIndex, ConfAStarTree.this.rcs), ConfAStarTree.this.optimizer);
                this.rootNode.setHScore(ConfAStarTree.this.hscorer.calc(ConfAStarTree.this.confIndex, ConfAStarTree.this.rcs), ConfAStarTree.this.optimizer);
                this.rootNode.setScore(this.rootNode.getGScore(ConfAStarTree.this.optimizer) + this.rootNode.getHScore(ConfAStarTree.this.optimizer), ConfAStarTree.this.optimizer);
                this.q.add(this.rootNode);
            }
            while (!this.q.isEmpty()) {
                ConfSMAStarNode node = this.q.getLowestDeepest();
                if (node.depth == numPos) {
                    assert (node.getHScore(ConfAStarTree.this.optimizer) == 0.0);
                    double scoreEquivalance = 1.0E-12;
                    int[] conf = null;
                    if (Math.abs(node.getGScore(ConfAStarTree.this.optimizer) - node.getScore(ConfAStarTree.this.optimizer)) < 1.0E-12) {
                        conf = node.makeConf(numPos);
                    }
                    this.numNodes -= (long)node.parent.finishChild(node, this.q);
                    if (conf == null) continue;
                    return new ConfSearch.ScoredConf(conf, node.getGScore(ConfAStarTree.this.optimizer));
                }
                node.index(ConfAStarTree.this.confIndex);
                int pos = ConfAStarTree.this.order.getNextPos(ConfAStarTree.this.confIndex, ConfAStarTree.this.rcs);
                int index = node.getNextChildIndex(ConfAStarTree.this.rcs.getNum(pos));
                int rc = ConfAStarTree.this.rcs.get(pos)[index];
                ConfSMAStarNode child = node.spawnChild(pos, rc, index);
                child.setGScore(ConfAStarTree.this.gscorer.calcDifferential(ConfAStarTree.this.confIndex, ConfAStarTree.this.rcs, pos, rc), ConfAStarTree.this.optimizer);
                child.setHScore(ConfAStarTree.this.hscorer.calcDifferential(ConfAStarTree.this.confIndex, ConfAStarTree.this.rcs, pos, rc), ConfAStarTree.this.optimizer);
                child.setScore(ConfAStarTree.this.optimizer.reverse().opt(node.getScore(ConfAStarTree.this.optimizer), child.getGScore(ConfAStarTree.this.optimizer) + child.getHScore(ConfAStarTree.this.optimizer)), ConfAStarTree.this.optimizer);
                ++this.numNodes;
                node.backup(this.q);
                if (!node.canSpawnChildren()) {
                    this.q.removeOrAssert(node);
                }
                if (this.numNodes > this.maxNumNodes) {
                    ConfSMAStarNode highest = this.q.removeHighestShallowestLeaf();
                    if (highest.parent != null) {
                        highest.parent.forgetChild(highest);
                        this.q.add(highest.parent);
                    }
                    --this.numNodes;
                }
                this.q.add(child);
            }
            return null;
        }
    }

    private static interface AStarImpl {
        public ConfSearch.ScoredConf nextConf();
    }

    private class UnboundedImpl
    implements AStarImpl {
        private final Queue<ConfAStarNode> queue;
        private ConfAStarNode rootNode = null;

        UnboundedImpl() {
            this.queue = ConfAStarTree.this.factory.makeQueue(ConfAStarTree.this.rcs);
        }

        @Override
        public ConfSearch.ScoredConf nextConf() {
            ConfAStarNode node;
            if (this.rootNode == null) {
                if (!ConfAStarTree.this.rcs.hasConfs()) {
                    return null;
                }
                node = this.rootNode = ConfAStarTree.this.factory.makeRootNode(ConfAStarTree.this.rcs.getNumPos());
                for (int pos = 0; pos < ConfAStarTree.this.rcs.getNumPos(); ++pos) {
                    if (ConfAStarTree.this.rcs.getNum(pos) != 1) continue;
                    node = node.assign(pos, ConfAStarTree.this.rcs.get(pos)[0]);
                }
                assert (node.getLevel() == ConfAStarTree.this.rcs.getNumTrivialPos());
                node.index(ConfAStarTree.this.confIndex);
                node.setGScore(ConfAStarTree.this.gscorer.calc(ConfAStarTree.this.confIndex, ConfAStarTree.this.rcs), ConfAStarTree.this.optimizer);
                node.setHScore(ConfAStarTree.this.hscorer.calc(ConfAStarTree.this.confIndex, ConfAStarTree.this.rcs), ConfAStarTree.this.optimizer);
                this.queue.push(node);
            }
            while (!this.queue.isEmpty()) {
                node = this.queue.poll();
                if (ConfAStarTree.this.pruner != null && ConfAStarTree.this.pruner.isPruned(node)) continue;
                if (node.getLevel() == ConfAStarTree.this.rcs.getNumPos()) {
                    if (ConfAStarTree.this.progress != null) {
                        ConfAStarTree.this.progress.reportLeafNode(node.getGScore(ConfAStarTree.this.optimizer), this.queue.size());
                    }
                    return new ConfSearch.ScoredConf(node.makeConf(ConfAStarTree.this.rcs.getNumPos()), node.getGScore(ConfAStarTree.this.optimizer));
                }
                int numChildren = 0;
                node.index(ConfAStarTree.this.confIndex);
                int nextPos = ConfAStarTree.this.order.getNextPos(ConfAStarTree.this.confIndex, ConfAStarTree.this.rcs);
                assert (!ConfAStarTree.this.confIndex.isDefined(nextPos));
                assert (ConfAStarTree.this.confIndex.isUndefined(nextPos));
                ArrayList children = new ArrayList();
                for (int nextRc : ConfAStarTree.this.rcs.get(nextPos)) {
                    if (ConfAStarTree.this.isPruned(ConfAStarTree.this.confIndex, nextPos, nextRc) || ConfAStarTree.this.pruner != null && ConfAStarTree.this.pruner.isPruned(node, nextPos, nextRc)) continue;
                    ConfAStarTree.this.tasks.submit(() -> {
                        try (ObjectPool.Checkout<ScoreContext> checkout = ConfAStarTree.this.contexts.autoCheckout();){
                            ScoreContext context = checkout.get();
                            node.index(context.index);
                            ConfAStarNode child = node.assign(nextPos, nextRc);
                            child.setGScore(context.gscorer.calcDifferential(context.index, ConfAStarTree.this.rcs, nextPos, nextRc), ConfAStarTree.this.optimizer);
                            child.setHScore(context.hscorer.calcDifferential(context.index, ConfAStarTree.this.rcs, nextPos, nextRc), ConfAStarTree.this.optimizer);
                            ConfAStarNode confAStarNode = child;
                            return confAStarNode;
                        }
                    }, child -> {
                        if (Double.isFinite(child.getScore())) {
                            children.add(child);
                        }
                    });
                }
                ConfAStarTree.this.tasks.waitForFinish();
                numChildren += children.size();
                this.queue.pushAll(children);
                if (ConfAStarTree.this.progress == null) continue;
                ConfAStarTree.this.progress.reportInternalNode(node.getLevel(), node.getGScore(ConfAStarTree.this.optimizer), node.getHScore(ConfAStarTree.this.optimizer), this.queue.size(), numChildren);
            }
            return null;
        }
    }

    private static class ScoreContext {
        public ConfIndex index;
        public AStarScorer gscorer;
        public AStarScorer hscorer;

        private ScoreContext() {
        }
    }

    public static class Builder {
        private EnergyMatrix emat;
        private RCs rcs;
        private AStarOrder order = null;
        private AStarScorer gscorer = null;
        private AStarScorer hscorer = null;
        private MathTools.Optimizer optimizer = MathTools.Optimizer.Minimize;
        private boolean showProgress = false;
        private ConfAStarFactory factory = new LinkedConfAStarFactory();
        private AStarPruner pruner = null;
        private Long maxNumNodes = null;

        public Builder(EnergyMatrix emat, SimpleConfSpace confSpace) {
            this(emat, new RCs(confSpace));
        }

        public Builder(EnergyMatrix emat, ConfSpace confSpace) {
            this(emat, new RCs(confSpace));
        }

        public Builder(EnergyMatrix emat, PruningMatrix pmat) {
            this(emat, new RCs(pmat));
        }

        public Builder(EnergyMatrix emat, RCs rcs) {
            this.emat = emat;
            this.rcs = rcs;
            this.setMPLP();
        }

        public Builder setCustom(AStarOrder order, AStarScorer gscorer, AStarScorer hscorer) {
            this.order = order;
            this.gscorer = gscorer;
            this.hscorer = hscorer;
            return this;
        }

        public Builder setTraditional() {
            return this.setTraditionalOpt(MathTools.Optimizer.Minimize);
        }

        public Builder setTraditionalOpt(MathTools.Optimizer optimizer) {
            this.order = this.maxNumNodes == null ? new DynamicHMeanAStarOrder(optimizer) : new StaticScoreHMeanAStarOrder();
            this.gscorer = new PairwiseGScorer(this.emat, optimizer);
            this.hscorer = new TraditionalPairwiseHScorer(this.emat, this.rcs, optimizer);
            this.optimizer = optimizer;
            return this;
        }

        public Builder setMPLP() {
            this.setMPLP(new MPLPBuilder());
            return this;
        }

        public Builder setMPLP(MPLPBuilder builder) {
            this.order = new StaticScoreHMeanAStarOrder();
            this.gscorer = new PairwiseGScorer(this.emat);
            this.hscorer = new MPLPPairwiseHScorer(builder.updater, this.emat, builder.numIterations, builder.convergenceThreshold);
            return this;
        }

        public Builder setLUTE(LUTEConfEnergyCalculator luteEcalc) {
            this.order = new DynamicHMeanAStarOrder();
            this.gscorer = new LUTEGScorer(luteEcalc);
            this.hscorer = new LUTEHScorer(luteEcalc);
            return this;
        }

        public Builder setShowProgress(boolean val) {
            this.showProgress = val;
            return this;
        }

        public Builder setPruner(AStarPruner val) {
            this.pruner = val;
            return this;
        }

        public Builder setMaxNumNodes(Long val) {
            this.maxNumNodes = val;
            return this;
        }

        public Builder setMaxNumNodes(int val) {
            return this.setMaxNumNodes(Long.valueOf(val));
        }

        public ConfAStarTree build() {
            ConfAStarTree tree = new ConfAStarTree(this.order, this.gscorer, this.hscorer, this.optimizer, this.rcs, this.factory, this.pruner, this.maxNumNodes);
            if (this.showProgress) {
                tree.initProgress();
            }
            return tree;
        }
    }
}

