/*
 * Decompiled with CFR 0.152.
 */
package edu.duke.cs.osprey.astar;

import edu.duke.cs.osprey.astar.AStarNode;
import edu.duke.cs.osprey.astar.AStarTree;
import edu.duke.cs.osprey.astar.FullAStarNode;
import edu.duke.cs.osprey.astar.GMECMutSpace;
import edu.duke.cs.osprey.confspace.HigherTupleFinder;
import edu.duke.cs.osprey.confspace.RCTuple;
import edu.duke.cs.osprey.confspace.SearchProblem;
import edu.duke.cs.osprey.ematrix.EnergyMatrix;
import edu.duke.cs.osprey.ematrix.epic.EPICMatrix;
import edu.duke.cs.osprey.ematrix.epic.EPICSettings;
import edu.duke.cs.osprey.ematrix.epic.NewEPICMatrix;
import edu.duke.cs.osprey.gmec.PrecomputedMatrices;
import edu.duke.cs.osprey.pruning.PruningMatrix;
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.Arrays;

public class ConfTree<T extends AStarNode>
extends AStarTree<T> {
    private static final long serialVersionUID = -2776047349227703884L;
    private AStarNode.Factory<T> nodeFactory;
    protected int numPos;
    protected EnergyMatrix emat;
    protected int[][] unprunedRCsAtPos;
    protected boolean traditionalScore = true;
    protected boolean useRefinement = false;
    protected boolean useDynamicAStar = true;
    protected EPICMatrix epicMat = null;
    protected NewEPICMatrix newEPICMat = null;
    protected boolean minPartialConfs = false;
    protected RCTuple rcTuple;
    protected int numDefined;
    protected int numUndefined;
    protected int[] definedPos;
    protected int[] definedRCs;
    protected int[] undefinedPos;
    protected int[] childConf;
    private GMECMutSpace mutSpace;

    public static ConfTree<FullAStarNode> makeFull(SearchProblem search2) {
        return ConfTree.makeFull(search2, search2.pruneMat);
    }

    public static ConfTree<FullAStarNode> makeFull(SearchProblem search2, PruningMatrix pmat) {
        return new ConfTree<FullAStarNode>(new FullAStarNode.Factory(search2.confSpace.numPos), search2, pmat);
    }

    public static ConfTree<FullAStarNode> makeFull(SearchProblem search2, PruningMatrix pmat, GMECMutSpace gms) {
        return new ConfTree<FullAStarNode>(new FullAStarNode.Factory(search2.confSpace.numPos), search2, pmat, search2.useEPIC, gms);
    }

    public static ConfTree<FullAStarNode> makeFull(PrecomputedMatrices precompMat, GMECMutSpace gms, boolean useTupExpForSearch, boolean useEPIC, EPICSettings epicSettings, int numPos) {
        return new ConfTree<FullAStarNode>(new FullAStarNode.Factory(numPos), useTupExpForSearch, precompMat.getEmat(), null, precompMat.getEpicMat(), precompMat.getLuteMat(), precompMat.getPruneMat(), useEPIC, epicSettings, gms);
    }

    public ConfTree(AStarNode.Factory<T> nodeFactory, SearchProblem sp) {
        this(nodeFactory, sp, sp.pruneMat, sp.useEPIC, null);
    }

    public ConfTree(AStarNode.Factory<T> nodeFactory, SearchProblem sp, PruningMatrix pruneMat) {
        this(nodeFactory, sp, pruneMat, sp.useEPIC, null);
    }

    public ConfTree(AStarNode.Factory<T> nodeFactory, SearchProblem sp, PruningMatrix pruneMat, boolean useEPIC, GMECMutSpace gms) {
        this(nodeFactory, sp.useTupExpForSearch, sp.emat, sp.epicMat, null, sp.tupExpEMat, pruneMat, sp.useEPIC, sp.epicSettings, null);
    }

    public ConfTree(AStarNode.Factory<T> nodeFactory, boolean useTupExpForSearch, EnergyMatrix emat, EPICMatrix epicMat, NewEPICMatrix newEPICMat, EnergyMatrix tupExpEMat, PruningMatrix pruneMat, boolean useEPIC, EPICSettings epicSettings, GMECMutSpace gms) {
        if (!this.traditionalScore) {
            throw new RuntimeException("Advanced A* scoring methods not implemented yet!");
        }
        this.nodeFactory = nodeFactory;
        this.numPos = pruneMat.getNumPos();
        this.rcTuple = new RCTuple();
        this.numDefined = 0;
        this.numUndefined = 0;
        this.definedPos = new int[this.numPos];
        this.definedRCs = new int[this.numPos];
        this.undefinedPos = new int[this.numPos];
        this.childConf = new int[this.numPos];
        this.unprunedRCsAtPos = new int[this.numPos][];
        for (int pos = 0; pos < this.numPos; ++pos) {
            ArrayList<Integer> srcRCs = pruneMat.unprunedRCsAtPos(pos);
            int[] destRCs = new int[srcRCs.size()];
            for (int i = 0; i < srcRCs.size(); ++i) {
                destRCs[i] = srcRCs.get(i);
            }
            this.unprunedRCsAtPos[pos] = destRCs;
        }
        if (useTupExpForSearch) {
            this.emat = tupExpEMat;
        } else {
            this.emat = emat;
            if (useEPIC) {
                this.useRefinement = true;
                this.epicMat = epicMat;
                this.newEPICMat = newEPICMat;
                if (epicMat == null == (newEPICMat == null)) {
                    throw new RuntimeException("ERROR: to use EPIC in A* need exactly one of old and new EPIC matrices");
                }
                this.minPartialConfs = epicSettings.minPartialConfs;
            }
        }
        this.mutSpace = gms;
        if (gms != null) {
            this.useDynamicAStar = false;
        }
    }

    @Override
    public BigInteger getNumConformations() {
        BigInteger num = BigInteger.valueOf(1L);
        for (int pos = 0; pos < this.numPos; ++pos) {
            num = num.multiply(BigInteger.valueOf(this.unprunedRCsAtPos[pos].length));
        }
        return num;
    }

    @Override
    public ArrayList<T> getChildren(T curNode) {
        if (this.isFullyAssigned(curNode)) {
            throw new RuntimeException("ERROR: Can't expand a fully assigned A* node");
        }
        if (curNode.getScore() == Double.POSITIVE_INFINITY) {
            return new ArrayList();
        }
        ArrayList<T> ans = new ArrayList<T>();
        int nextLevel = this.nextLevelToExpand(curNode);
        this.splitPositions(curNode);
        for (int rc : this.unprunedRCsAtPos[nextLevel]) {
            if (this.mutSpace != null && !this.mutSpace.isNewRCAllowed(curNode.getNodeAssignments(), curNode.getLevel(), rc)) continue;
            T childNode = this.nodeFactory.make(curNode, nextLevel, rc);
            childNode.setScoreNeedsRefinement(this.useRefinement);
            this.scoreNodeDifferential(curNode, childNode, nextLevel, rc);
            ans.add(childNode);
        }
        this.resetSplitPositions();
        return ans;
    }

    @Override
    public T rootNode() {
        int[] conf = new int[this.numPos];
        Arrays.fill(conf, -1);
        T root = this.nodeFactory.makeRoot();
        root.setScoreNeedsRefinement(this.useRefinement);
        this.splitPositions(root);
        this.scoreNode(root);
        this.resetSplitPositions();
        return root;
    }

    @Override
    public boolean isFullyAssigned(T node) {
        return node.isFullyDefined();
    }

    public int nextLevelToExpand(T parentNode) {
        int bestLevel = -1;
        this.splitPositions(parentNode);
        if (this.useDynamicAStar) {
            double bestLevelScore = Double.NEGATIVE_INFINITY;
            for (int i = 0; i < this.numUndefined; ++i) {
                int level = this.undefinedPos[i];
                double levelScore = this.scoreExpansionLevel(parentNode, level);
                if (!(levelScore > bestLevelScore)) continue;
                bestLevelScore = levelScore;
                bestLevel = level;
            }
        } else if (this.numUndefined > 0) {
            bestLevel = this.undefinedPos[0];
        }
        this.resetSplitPositions();
        if (bestLevel == -1) {
            throw new RuntimeException("ERROR: No next expansion level found for dynamic A*");
        }
        return bestLevel;
    }

    double scoreExpansionLevel(T parentNode, int level) {
        double parentScore = parentNode.getScore();
        double reciprocalSum = 0.0;
        for (int rc : this.unprunedRCsAtPos[level]) {
            double childScore = this.scoreConfDifferential(parentNode, level, rc);
            reciprocalSum += 1.0 / (childScore - parentScore);
        }
        return 1.0 / reciprocalSum;
    }

    protected void splitPositions(T node) {
        assert (this.numDefined == 0 && this.numUndefined == 0);
        int[] conf = node.getNodeAssignments();
        this.numDefined = 0;
        this.numUndefined = 0;
        for (int pos = 0; pos < this.numPos; ++pos) {
            int rc = conf[pos];
            if (rc >= 0) {
                this.definedPos[this.numDefined] = pos;
                this.definedRCs[this.numDefined] = rc;
                ++this.numDefined;
                continue;
            }
            this.undefinedPos[this.numUndefined] = pos;
            ++this.numUndefined;
        }
        assert (this.numDefined + this.numUndefined == this.numPos);
    }

    private void resetSplitPositions() {
        this.numDefined = 0;
        this.numUndefined = 0;
    }

    protected void assertSplitPositions() {
        assert (this.numDefined + this.numUndefined == this.numPos) : "call splitPostions(node) before calling this function!";
    }

    protected double scoreNodeDifferential(T parent, T child, int childPos, int childRc) {
        this.assertSplitPositions();
        double score = this.scoreConfDifferential(parent, childPos, childRc);
        child.setScore(score);
        return score;
    }

    protected double scoreNode(T node) {
        this.assertSplitPositions();
        double score = this.scoreConfDifferential(node, -1, -1);
        node.setScore(score);
        return score;
    }

    protected double scoreConfDifferential(T parentNode, int childPos, int childRc) {
        this.assertSplitPositions();
        int[] conf = parentNode.getNodeAssignments();
        if (childPos >= 0) {
            assert (conf[childPos] < 0);
            System.arraycopy(conf, 0, this.childConf, 0, this.numPos);
            this.childConf[childPos] = childRc;
            conf = this.childConf;
        }
        this.rcTuple.set(conf);
        double gscore = this.emat.getConstTerm() + this.emat.getInternalEnergy(this.rcTuple);
        double hscore = 0.0;
        for (int i = 0; i < this.numUndefined; ++i) {
            int pos1 = this.undefinedPos[i];
            if (pos1 == childPos) continue;
            double resContribLB = Double.POSITIVE_INFINITY;
            int[] rc1s = this.unprunedRCsAtPos[pos1];
            int n1 = rc1s.length;
            for (int j = 0; j < n1; ++j) {
                int rc1 = rc1s[j];
                double rcContrib = this.getUndefinedRCEnergy(conf, pos1, rc1, j, childPos, childRc);
                resContribLB = Math.min(resContribLB, rcContrib);
            }
            hscore += resContribLB;
        }
        return gscore + hscore;
    }

    private double getUndefinedRCEnergy(int[] conf, int pos1, int rc1, int rc1i, int childPos, int childRc) {
        int pos2;
        int i;
        this.assertSplitPositions();
        EnergyMatrix emat = this.emat;
        int numDefined = this.numDefined;
        int numUndefined = this.numUndefined;
        int[] definedPos = this.definedPos;
        int[] definedRCs = this.definedRCs;
        int[] undefinedPos = this.undefinedPos;
        double rcContrib = emat.getOneBody(pos1, rc1);
        for (i = 0; i < numDefined; ++i) {
            pos2 = definedPos[i];
            int rc2 = definedRCs[i];
            assert (pos2 != childPos);
            rcContrib += emat.getPairwise(pos1, rc1, pos2, rc2).doubleValue();
            rcContrib += this.higherOrderContribLB(conf, pos1, rc1, pos2, rc2);
        }
        if (childPos >= 0) {
            rcContrib += emat.getPairwise(pos1, rc1, childPos, childRc).doubleValue();
            rcContrib += this.higherOrderContribLB(conf, pos1, rc1, childPos, childRc);
        }
        for (i = 0; i < numUndefined && (pos2 = undefinedPos[i]) < pos1; ++i) {
            if (pos2 == childPos) continue;
            double minEnergy = Double.POSITIVE_INFINITY;
            for (int rc2 : this.unprunedRCsAtPos[pos2]) {
                double pairwiseEnergy = emat.getPairwise(pos1, rc1, pos2, rc2);
                minEnergy = Math.min(minEnergy, pairwiseEnergy += this.higherOrderContribLB(conf, pos1, rc1, pos2, rc2));
            }
            rcContrib += minEnergy;
        }
        return rcContrib;
    }

    ArrayList<Integer> allowedRCsAtLevel(int level, int[] partialConf) {
        ArrayList<Integer> allowedRCs = new ArrayList<Integer>();
        if (partialConf[level] == -1) {
            for (int rc : this.unprunedRCsAtPos[level]) {
                allowedRCs.add(rc);
            }
        } else if (partialConf[level] >= 0) {
            allowedRCs.add(partialConf[level]);
        } else {
            throw new UnsupportedOperationException("ERROR: Partially assigned position not yet supported in A*");
        }
        return allowedRCs;
    }

    double higherOrderContribLB(int[] partialConf, int pos1, int rc1, int pos2, int rc2) {
        HigherTupleFinder<Double> htf = this.emat.getHigherOrderTerms(pos1, rc1, pos2, rc2);
        if (htf == null) {
            return 0.0;
        }
        return this.higherOrderContribLB(partialConf, htf, pos2);
    }

    double higherOrderContribLB(int[] partialConf, HigherTupleFinder<Double> htf, int level2) {
        double contrib = 0.0;
        for (int iPos : htf.getInteractingPos()) {
            if (!this.posComesBefore(iPos, level2, partialConf)) continue;
            double levelBestE = Double.POSITIVE_INFINITY;
            ArrayList<Integer> allowedRCs = this.allowedRCsAtLevel(iPos, partialConf);
            for (int rc : allowedRCs) {
                double interactionE = htf.getInteraction(iPos, rc);
                HigherTupleFinder<Double> htf2 = htf.getHigherInteractions(iPos, rc);
                if (htf2 != null) {
                    interactionE += this.higherOrderContribLB(partialConf, htf2, iPos);
                }
                levelBestE = Math.min(levelBestE, interactionE);
            }
            contrib += levelBestE;
        }
        return contrib;
    }

    private boolean posComesBefore(int pos1, int pos2, int[] partialConf) {
        if (partialConf[pos2] >= 0) {
            return pos1 < pos2 && partialConf[pos1] >= 0;
        }
        return pos1 < pos2 || partialConf[pos1] >= 0;
    }

    @Override
    public void refineScore(T node) {
        if (this.minPartialConfs || this.isFullyAssigned(node)) {
            if (this.epicMat == null) {
                node.setScore(node.getScore() + this.newEPICMat.minContE(node.getNodeAssignments()));
            } else {
                node.setScore(this.epicMat.minimizeEnergy(new RCTuple(node.getNodeAssignments()), true));
            }
            node.setScoreNeedsRefinement(false);
        }
    }

    double exhaustiveScore(int[] partialConf) {
        for (int pos = 0; pos < partialConf.length; ++pos) {
            if (partialConf[pos] != -1) continue;
            double score = Double.POSITIVE_INFINITY;
            for (int rc : this.allowedRCsAtLevel(pos, partialConf)) {
                int[] partialConf2 = (int[])partialConf.clone();
                partialConf2[pos] = rc;
                score = Math.min(score, this.exhaustiveScore(partialConf2));
            }
            return score;
        }
        return this.emat.getInternalEnergy(new RCTuple(partialConf));
    }
}

