/*
 * Decompiled with CFR 0.152.
 */
package edu.duke.cs.osprey.multistatekstar;

import edu.duke.cs.osprey.confspace.ConfSearch;
import edu.duke.cs.osprey.gmec.ConfSearchFactory;
import edu.duke.cs.osprey.kstar.pfunc.PartitionFunction;
import edu.duke.cs.osprey.multistatekstar.ConfComparator;
import edu.duke.cs.osprey.multistatekstar.KStarScore;
import edu.duke.cs.osprey.multistatekstar.LMB;
import edu.duke.cs.osprey.multistatekstar.MSKStarFactory;
import edu.duke.cs.osprey.multistatekstar.MSKStarSettings;
import edu.duke.cs.osprey.multistatekstar.PartitionFunctionMinimized;
import edu.duke.cs.osprey.multistatekstar.PruningMatrixInverted;
import edu.duke.cs.osprey.multistatekstar.PruningMatrixNull;
import edu.duke.cs.osprey.pruning.PruningMatrix;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.PriorityQueue;

public class KStarScoreMinimized
implements KStarScore {
    public MSKStarSettings settings;
    public PartitionFunctionMinimized[] partitionFunctions;
    protected boolean[] initialized;
    public int numStates;
    protected boolean constrSatisfied;

    public KStarScoreMinimized(MSKStarSettings settings) {
        this.settings = settings;
        this.numStates = settings.search.length;
        this.partitionFunctions = new PartitionFunctionMinimized[this.numStates];
        this.initialized = new boolean[this.numStates];
        Arrays.fill(this.partitionFunctions, null);
        Arrays.fill(this.initialized, false);
        this.constrSatisfied = true;
    }

    @Override
    public MSKStarSettings getSettings() {
        return this.settings;
    }

    protected BigDecimal getDenom() {
        BigDecimal ans = BigDecimal.ONE.setScale(64, RoundingMode.HALF_UP);
        for (int state = 0; state < this.numStates - 1; ++state) {
            PartitionFunctionMinimized pf = this.partitionFunctions[state];
            if (pf == null || pf.getValues().qstar.compareTo(BigDecimal.ZERO) == 0) {
                return BigDecimal.ZERO;
            }
            ans = ans.multiply(pf.getValues().qstar);
        }
        return ans;
    }

    @Override
    public BigDecimal getScore() {
        BigDecimal den = this.getDenom();
        if (den.compareTo(BigDecimal.ZERO) == 0) {
            return BigDecimal.ZERO;
        }
        PartitionFunctionMinimized pf = this.partitionFunctions[this.numStates - 1];
        return pf == null ? BigDecimal.ZERO : pf.getValues().qstar.divide(den, RoundingMode.HALF_UP);
    }

    @Override
    public BigDecimal getLowerBoundScore() {
        return this.getScore();
    }

    @Override
    public BigDecimal getUpperBoundScore() {
        BigDecimal den = this.getDenom();
        if (den.compareTo(BigDecimal.ZERO) == 0) {
            return BigDecimal.ZERO;
        }
        PartitionFunctionMinimized pf = this.partitionFunctions[this.numStates - 1];
        if (pf == null) {
            return BigDecimal.ZERO;
        }
        BigDecimal num = pf.getValues().qstar;
        if (pf.getStatus() != PartitionFunction.Status.Estimated) {
            num = num.add(pf.getValues().qprime).add(pf.getValues().pstar);
        }
        return num.divide(den, RoundingMode.HALF_UP);
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("Seq: " + this.settings.search[this.numStates - 1].settings.getFormattedSequence() + ", ");
        sb.append(String.format("score: %12e, ", this.getScore()));
        for (int state = 0; state < this.numStates; ++state) {
            BigDecimal qstar = this.partitionFunctions[state] == null ? BigDecimal.ZERO : this.partitionFunctions[state].getValues().qstar;
            sb.append(String.format("pf: %2d, q*: %12e, ", state, qstar));
        }
        String ans = sb.toString().trim();
        return ans.substring(0, ans.length() - 1);
    }

    protected boolean init(int state) {
        this.settings.search[state].prunePmat();
        ConfSearchFactory confSearchFactory = MSKStarFactory.makeConfSearchFactory(this.settings.search[state], this.settings.cfp);
        this.partitionFunctions[state] = (PartitionFunctionMinimized)MSKStarFactory.makePartitionFunction(this.settings.pfTypes[state], this.settings.search[state].emat, this.settings.search[state].pruneMat, new PruningMatrixInverted(this.settings.search[state], this.settings.search[state].pruneMat), confSearchFactory, this.settings.ecalcs[state]);
        this.partitionFunctions[state].setReportProgress(this.settings.isReportingProgress);
        this.partitionFunctions[state].init(this.settings.targetEpsilon);
        if (this.settings.search[state].isFullyAssigned() && this.settings.numTopConfsToSave > 0) {
            this.partitionFunctions[state].topConfs = new PriorityQueue<ConfSearch.ScoredConf>(this.settings.numTopConfsToSave, new ConfComparator());
            this.partitionFunctions[state].maxNumTopConfs = this.settings.numTopConfsToSave;
            int pfState = state;
            this.partitionFunctions[state].setConfListener(conf -> this.partitionFunctions[pfState].saveConf(conf));
        }
        return true;
    }

    @Override
    public void compute(int maxNumConfs) {
        for (int state = 0; state < this.numStates; ++state) {
            if (!this.constrSatisfied) {
                return;
            }
            if (!this.initialized[state]) {
                this.initialized[state] = this.init(state);
            }
            this.compute(state, maxNumConfs);
        }
        if (this.settings.isFinal && this.constrSatisfied) {
            this.constrSatisfied = this.checkConstraints();
        }
        this.cleanup();
    }

    private void cleanup() {
        for (PartitionFunctionMinimized pf : this.partitionFunctions) {
            if (pf == null) continue;
            pf.cleanup();
        }
    }

    private PartitionFunction phase2(int state) {
        BigDecimal targetScoreWeights;
        PartitionFunctionMinimized pf = this.partitionFunctions[state];
        double epsilon = pf.getValues().getEffectiveEpsilon();
        double targetEpsilon = this.settings.targetEpsilon;
        BigDecimal qstar = pf.getValues().qstar;
        BigDecimal qprime = pf.getValues().qprime;
        BigDecimal pstar = pf.getValues().pstar;
        if (epsilon == 1.0) {
            targetScoreWeights = pstar;
        } else {
            targetScoreWeights = BigDecimal.valueOf(targetEpsilon / (1.0 - targetEpsilon));
            targetScoreWeights = targetScoreWeights.multiply(qstar);
            targetScoreWeights = pstar.add(qprime).subtract(targetScoreWeights);
        }
        ConfSearchFactory confSearchFactory = MSKStarFactory.makeConfSearchFactory(this.settings.search[state], this.settings.cfp);
        PruningMatrix invmat = pf.invmat;
        PartitionFunctionMinimized p2pf = (PartitionFunctionMinimized)MSKStarFactory.makePartitionFunction(this.settings.pfTypes[state], this.settings.search[state].emat, invmat, new PruningMatrixNull(invmat), confSearchFactory, this.settings.ecalcs[state]);
        p2pf.init(targetEpsilon);
        p2pf.getValues().qstar = qstar;
        p2pf.compute(targetScoreWeights);
        return p2pf;
    }

    protected void compute(int state, int maxNumConfs) {
        if (this.settings.isReportingProgress) {
            System.out.println("state" + state + ": " + this.settings.search[state].settings.getFormattedSequence());
        }
        PartitionFunctionMinimized pf = this.partitionFunctions[state];
        pf.compute(maxNumConfs);
        double effectiveEpsilon = pf.getValues().getEffectiveEpsilon();
        if (!Double.isNaN(effectiveEpsilon) && effectiveEpsilon > this.settings.targetEpsilon) {
            PartitionFunctionMinimized p2pf = (PartitionFunctionMinimized)this.phase2(state);
            pf.getValues().qstar = p2pf.getValues().qstar;
            if (this.settings.search[state].isFullyAssigned() && this.settings.numTopConfsToSave > 0) {
                pf.saveEConfs(p2pf.topConfs);
            }
        }
        pf.setStatus(PartitionFunction.Status.Estimated);
        if (this.settings.isFinal) {
            if (this.constrSatisfied) {
                this.constrSatisfied = this.checkConstraints(state);
            }
            if (this.settings.numTopConfsToSave > 0) {
                pf.writeTopConfs(this.settings.state, this.settings.search[state]);
            }
        }
    }

    protected ArrayList<LMB> getLMBsForState(int state, boolean negCoeff) {
        ArrayList<LMB> ans = new ArrayList<LMB>();
        for (LMB constr : this.getLMBsForState(state)) {
            if (negCoeff && constr.getCoeffs()[state].compareTo(BigDecimal.ZERO) < 0) {
                ans.add(constr);
                continue;
            }
            if (negCoeff || constr.getCoeffs()[state].compareTo(BigDecimal.ZERO) <= 0) continue;
            ans.add(constr);
        }
        ans.trimToSize();
        return ans;
    }

    protected ArrayList<LMB> getLMBsForState(int state) {
        ArrayList<LMB> ans = new ArrayList<LMB>();
        if (this.settings.constraints == null) {
            return ans;
        }
        for (int l = 0; l < this.settings.constraints.length; ++l) {
            BigDecimal[] coeffs = this.settings.constraints[l].coeffs;
            if (coeffs[state].compareTo(BigDecimal.ZERO) == 0) continue;
            boolean addConstr = true;
            for (int c = 0; c < coeffs.length; ++c) {
                if (c == state || coeffs[c].compareTo(BigDecimal.ZERO) == 0) continue;
                addConstr = false;
                break;
            }
            if (!addConstr) continue;
            ans.add(this.settings.constraints[l]);
        }
        ans.trimToSize();
        return ans;
    }

    private boolean checkConstraints(ArrayList<LMB> constraints) {
        for (LMB constr : constraints) {
            BigDecimal[] stateVals = new BigDecimal[this.numStates];
            for (int s = 0; s < this.numStates; ++s) {
                PartitionFunctionMinimized pf = this.partitionFunctions[s];
                stateVals[s] = pf == null ? BigDecimal.ZERO : pf.getValues().qstar;
            }
            if (constr.eval(stateVals).compareTo(BigDecimal.ZERO) <= 0) continue;
            return false;
        }
        return true;
    }

    protected boolean checkConstraints(int state, boolean negCoeff) {
        return this.checkConstraints(this.getLMBsForState(state, negCoeff));
    }

    protected boolean checkConstraints(int state) {
        return this.checkConstraints(this.getLMBsForState(state));
    }

    private boolean checkConstraints() {
        if (!this.constrSatisfied) {
            return this.constrSatisfied;
        }
        if (this.settings.constraints == null) {
            return true;
        }
        BigDecimal[] stateVals = new BigDecimal[this.numStates];
        for (int c = 0; c < this.settings.constraints.length; ++c) {
            LMB constr = this.settings.constraints[c];
            for (int s = 0; s < this.numStates; ++s) {
                PartitionFunctionMinimized pf = this.partitionFunctions[s];
                stateVals[s] = pf == null ? BigDecimal.ZERO : pf.getValues().qstar;
            }
            if (constr.eval(stateVals).compareTo(BigDecimal.ZERO) <= 0) continue;
            return false;
        }
        return true;
    }

    @Override
    public boolean constrSatisfied() {
        return this.constrSatisfied;
    }

    @Override
    public boolean isFullyProcessed() {
        if (!this.settings.isFinal) {
            return false;
        }
        int nulls = 0;
        for (PartitionFunctionMinimized pf : this.partitionFunctions) {
            if (pf == null) {
                ++nulls;
                continue;
            }
            if (pf.getStatus() == PartitionFunction.Status.Estimated) continue;
            return false;
        }
        if (nulls > 0 && !this.constrSatisfied) {
            return true;
        }
        throw new RuntimeException("ERROR: illegally skipped a partition function computation");
    }

    @Override
    public boolean isFinal() {
        return this.settings.isFinal;
    }

    @Override
    public boolean isFullyAssigned() {
        return this.settings.search[this.numStates - 1].isFullyAssigned();
    }
}

