/*
 * Decompiled with CFR 0.152.
 */
package edu.duke.cs.osprey.kstar.pfunc;

import edu.duke.cs.osprey.astar.conf.RCs;
import edu.duke.cs.osprey.confspace.ConfDB;
import edu.duke.cs.osprey.confspace.ConfSearch;
import edu.duke.cs.osprey.confspace.SimpleConfSpace;
import edu.duke.cs.osprey.confspace.compiled.ConfSpace;
import edu.duke.cs.osprey.confspace.compiled.PosInter;
import edu.duke.cs.osprey.energy.compiled.ConfEnergyCalculator;
import edu.duke.cs.osprey.energy.compiled.PosInterGen;
import edu.duke.cs.osprey.kstar.pfunc.BoltzmannCalculator;
import edu.duke.cs.osprey.kstar.pfunc.PartitionFunction;
import edu.duke.cs.osprey.kstar.pfunc.PfuncSurface;
import edu.duke.cs.osprey.parallelism.TaskExecutor;
import edu.duke.cs.osprey.tools.BigMath;
import edu.duke.cs.osprey.tools.JvmMem;
import edu.duke.cs.osprey.tools.Log;
import edu.duke.cs.osprey.tools.MathTools;
import edu.duke.cs.osprey.tools.Stopwatch;
import edu.duke.cs.osprey.tools.TimeTools;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

public class NewGradientDescentPfunc
implements PartitionFunction.WithConfDB,
PartitionFunction.WithExternalMemory {
    public final ConfEnergyCalculator ecalc;
    public final BigInteger numConfsBeforePruning;
    public final PosInterGen posInterGen;
    private double targetEpsilon = Double.NaN;
    private BigDecimal stabilityThreshold = BigDecimal.ZERO;
    private PartitionFunction.ConfListener confListener = null;
    private boolean isReportingProgress = false;
    private Stopwatch stopwatch = new Stopwatch().start();
    private ConfSearch scoreConfs;
    private ConfSearch energyConfs;
    private static BoltzmannCalculator bcalc = new BoltzmannCalculator(PartitionFunction.decimalPrecision);
    private boolean usePreciseBcalc = true;
    private PartitionFunction.Status status = null;
    private PartitionFunction.Values values = null;
    private State state = null;
    private boolean hasEnergyConfs = true;
    private boolean hasScoreConfs = true;
    private long numEnergyConfsEnumerated = 0L;
    private long numScoreConfsEnumerated = 0L;
    private ConfDB confDB = null;
    private ConfDB.Key confDBKey = null;
    private boolean useExternalMemory = false;
    private RCs rcs = null;
    private PfuncSurface surf = null;
    private PfuncSurface.Trace trace = null;
    private final TaskExecutor te;
    private Integer instanceId = null;

    private static BigMath bigMath() {
        return new BigMath(PartitionFunction.decimalPrecision);
    }

    public NewGradientDescentPfunc(ConfEnergyCalculator ecalc, ConfSearch upperBoundConfs, ConfSearch lowerBoundConfs, BigInteger numConfsBeforePruning, PosInterGen posInterGen, TaskExecutor te) {
        this.ecalc = ecalc;
        this.scoreConfs = upperBoundConfs;
        this.energyConfs = lowerBoundConfs;
        this.numConfsBeforePruning = numConfsBeforePruning;
        this.posInterGen = posInterGen;
        this.te = te;
    }

    @Override
    public void setInstanceId(int instanceId) {
        this.instanceId = instanceId;
    }

    public int instanceIdOrThrow() {
        if (this.instanceId == null) {
            throw new IllegalStateException("no instance ID set, task doesn't know what context to use");
        }
        return this.instanceId;
    }

    public NewGradientDescentPfunc setPreciseBcalc(boolean val) {
        this.usePreciseBcalc = val;
        return this;
    }

    ConfDB.ConfTable getConfTable() {
        if (this.confDB == null) {
            return null;
        }
        return this.confDB.get(this.confDBKey);
    }

    BigDecimal bcalc(double val) {
        if (this.usePreciseBcalc) {
            return bcalc.calcPrecise(val);
        }
        return bcalc.calc(val);
    }

    private void saveToConfDb(ConfSearch.EnergiedConf econf) {
        ConfDB.ConfTable confTable = this.getConfTable();
        if (confTable != null) {
            confTable.setBounds(econf, TimeTools.getTimestampNs());
        }
    }

    @Override
    public void setReportProgress(boolean val) {
        this.isReportingProgress = val;
    }

    @Override
    public void setConfListener(PartitionFunction.ConfListener val) {
        this.confListener = val;
    }

    @Override
    public PartitionFunction.Status getStatus() {
        return this.status;
    }

    @Override
    public PartitionFunction.Values getValues() {
        return this.values;
    }

    @Override
    public int getNumConfsEvaluated() {
        return (int)this.state.numEnergiedConfs;
    }

    @Override
    public int getParallelism() {
        return 1;
    }

    @Override
    public void setConfDB(ConfDB confDB, ConfDB.Key key) {
        this.confDB = confDB;
        this.confDBKey = key;
    }

    @Override
    public void setUseExternalMemory(boolean val, RCs rcs) {
        this.useExternalMemory = val;
        this.rcs = rcs;
    }

    public void traceTo(PfuncSurface val) {
        this.surf = val;
    }

    @Override
    public void init(double targetEpsilon) {
        if (targetEpsilon < 0.0) {
            throw new IllegalArgumentException("target epsilon must at least zero");
        }
        this.targetEpsilon = targetEpsilon;
        this.status = PartitionFunction.Status.Estimating;
        this.state = new State(this.numConfsBeforePruning);
        this.values = PartitionFunction.Values.makeFullRange();
        this.values.pstar = BigDecimal.ZERO;
        this.hasEnergyConfs = true;
        this.hasScoreConfs = true;
        this.numEnergyConfsEnumerated = 0L;
        this.numScoreConfsEnumerated = 0L;
        if (this.energyConfs == null) {
            ConfSearch.Splitter confsSplitter = new ConfSearch.Splitter(this.scoreConfs, this.useExternalMemory, this.rcs);
            this.scoreConfs = confsSplitter.first;
            this.energyConfs = confsSplitter.second;
        }
    }

    @Override
    public void setStabilityThreshold(BigDecimal val) {
        this.stabilityThreshold = val;
    }

    private EnergyResult computeEnergyResult(ConfSearch.ScoredConf conf) {
        EnergyResult result = new EnergyResult();
        result.stopwatch.start();
        ConfSpace confSpace = this.ecalc.confSpace();
        int[] assignments = conf.getAssignments();
        List<PosInter> inters = this.posInterGen.all(confSpace, assignments);
        result.econf = this.ecalc.minimizeEnergy(conf, inters);
        result.scoreWeight = this.bcalc(result.econf.getScore());
        result.energyWeight = this.bcalc(result.econf.getEnergy());
        this.saveToConfDb(result.econf);
        result.stopwatch.stop();
        return result;
    }

    private ScoreResult computeScoreConfs(Iterable<ConfSearch.ScoredConf> confs) {
        ScoreResult result = new ScoreResult();
        result.stopwatch.start();
        for (ConfSearch.ScoredConf conf : confs) {
            result.scoreWeights.add(this.bcalc(conf.getScore()));
            result.scores.add(conf.getScore());
        }
        result.stopwatch.stop();
        return result;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void compute(int maxNumConfs) {
        if (this.status == null) {
            throw new IllegalStateException("pfunc was not initialized. Call init() before compute()");
        }
        if (!this.status.canContinue()) {
            return;
        }
        if (this.surf != null) {
            this.trace = new PfuncSurface.Trace(this.surf);
        }
        boolean keepStepping = true;
        int numConfsEnergied = 0;
        while (numConfsEnergied < maxNumConfs) {
            Step step = Step.None;
            int numScores = 0;
            NewGradientDescentPfunc newGradientDescentPfunc = this;
            synchronized (newGradientDescentPfunc) {
                boolean energySteeperThanScore;
                boolean bl = keepStepping = keepStepping && !this.state.epsilonReached(this.targetEpsilon) && this.state.isStable(this.stabilityThreshold) && this.state.hasLowEnergies();
                if (!keepStepping) {
                    break;
                }
                if (Double.isNaN(this.state.dEnergy) || Double.isNaN(this.state.dScore)) {
                    throw new Error("Can't determine gradient of delta surface. This is a bug.");
                }
                boolean scoreAheadOfEnergy = this.numEnergyConfsEnumerated < this.numScoreConfsEnumerated;
                boolean bl2 = energySteeperThanScore = this.state.dEnergy <= this.state.dScore;
                if (this.hasEnergyConfs && (scoreAheadOfEnergy && energySteeperThanScore || !this.hasScoreConfs)) {
                    step = Step.Energy;
                } else if (this.hasScoreConfs) {
                    step = Step.Score;
                    double scoringSeconds = Math.max(0.1 / this.state.energyOps, 0.01);
                    numScores = Math.max((int)(scoringSeconds * this.state.scoreOps), 10);
                }
            }
            switch (step) {
                case Energy: {
                    ConfSearch.ScoredConf conf = this.energyConfs.nextConf();
                    if (conf != null) {
                        ++this.numEnergyConfsEnumerated;
                    }
                    if (conf == null || conf.getScore() == Double.POSITIVE_INFINITY) {
                        this.hasEnergyConfs = false;
                        keepStepping = false;
                        break;
                    }
                    ++numConfsEnergied;
                    this.te.submit(() -> this.computeEnergyResult(conf), result -> this.onEnergy(result.econf, result.scoreWeight, result.energyWeight, result.stopwatch.getTimeS()));
                    break;
                }
                case Score: {
                    ArrayList<ConfSearch.ScoredConf> confs = new ArrayList<ConfSearch.ScoredConf>();
                    for (int i = 0; i < numScores; ++i) {
                        ConfSearch.ScoredConf conf = this.scoreConfs.nextConf();
                        if (conf != null) {
                            ++this.numScoreConfsEnumerated;
                        }
                        if (conf == null || conf.getScore() == Double.POSITIVE_INFINITY) {
                            this.hasScoreConfs = false;
                            break;
                        }
                        confs.add(conf);
                    }
                    if (confs.isEmpty()) break;
                    this.te.submit(() -> this.computeScoreConfs(confs), result -> this.onScores(result.scoreWeights, result.stopwatch.getTimeS()));
                    break;
                }
                case None: {
                    keepStepping = false;
                }
            }
        }
        this.te.waitForFinish();
        this.values.qstar = this.state.getLowerBound();
        this.values.qprime = NewGradientDescentPfunc.bigMath().set(this.state.getUpperBound()).sub(this.state.getLowerBound()).get();
        if (!this.state.hasLowEnergies()) {
            this.status = PartitionFunction.Status.OutOfLowEnergies;
        }
        if (!this.hasEnergyConfs) {
            this.status = PartitionFunction.Status.OutOfConformations;
        }
        if (this.state.epsilonReached(this.targetEpsilon)) {
            this.status = PartitionFunction.Status.Estimated;
            if (this.isReportingProgress) {
                Log.log("Total Z upper bound reduction through minimizations: %12.6e", this.state.cumulativeZReduction);
                Log.log("Average Z upper bound reduction per minimizations: %12.6e", NewGradientDescentPfunc.bigMath().set(this.state.cumulativeZReduction).div(this.state.numEnergiedConfs).get());
            }
        }
        if (!this.state.isStable(this.stabilityThreshold)) {
            this.status = PartitionFunction.Status.Unstable;
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void onEnergy(ConfSearch.EnergiedConf econf, BigDecimal scoreWeight, BigDecimal energyWeight, double seconds) {
        NewGradientDescentPfunc newGradientDescentPfunc = this;
        synchronized (newGradientDescentPfunc) {
            this.state.energyWeightSum = NewGradientDescentPfunc.bigMath().set(this.state.energyWeightSum).add(energyWeight).get();
            this.state.lowerScoreWeightSum = NewGradientDescentPfunc.bigMath().set(this.state.lowerScoreWeightSum).add(scoreWeight).get();
            ++this.state.numEnergiedConfs;
            this.state.energyOps = 1.0 / seconds;
            if (MathTools.isLessThan(scoreWeight, this.state.minLowerScoreWeight)) {
                this.state.minLowerScoreWeight = scoreWeight;
            }
            double delta = this.state.calcDelta();
            this.state.dEnergy = NewGradientDescentPfunc.calcSlope(delta, this.state.prevDelta, this.state.dScore);
            this.state.prevDelta = delta;
            this.state.cumulativeZReduction = NewGradientDescentPfunc.bigMath().set(this.state.cumulativeZReduction).add(scoreWeight).sub(energyWeight).get();
            int minimizationSize = econf.getAssignments().length;
            if (this.state.minList.size() < minimizationSize) {
                this.state.minList.addAll(new ArrayList<Integer>(Collections.nCopies(minimizationSize - this.state.minList.size(), 0)));
            }
            if (minimizationSize > 0) {
                this.state.minList.set(minimizationSize - 1, this.state.minList.get(minimizationSize - 1) + 1);
            }
            this.state.dScore *= 2.0;
            if (this.isReportingProgress) {
                Log.log("[%s] scores:%8d, confs:%4d, score:%12.6f, energy:%12.6f, bounds:[%12f,%12f] (log10p1), delta:%.6f, time:%10s, heapMem:%s", SimpleConfSpace.formatConfRCs(econf), this.state.numScoredConfs, this.state.numEnergiedConfs, econf.getScore(), econf.getEnergy(), MathTools.log10p1(this.state.getLowerBound()), MathTools.log10p1(this.state.getUpperBound()), this.state.calcDelta(), this.stopwatch.getTime(2), JvmMem.getOldPool());
                this.state.lastReportNs = System.nanoTime();
            }
            if (this.trace != null) {
                this.trace.step(this.state.numScoredConfs, this.state.numEnergiedConfs, this.state.calcDelta());
            }
        }
        if (this.confListener != null) {
            this.confListener.onConf(econf);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void onScores(List<BigDecimal> scoreWeights, double seconds) {
        NewGradientDescentPfunc newGradientDescentPfunc = this;
        synchronized (newGradientDescentPfunc) {
            long nowNs;
            if (this.state.numScoredConfs == 0L) {
                this.state.firstScoreWeight = scoreWeights.get(0);
            }
            for (BigDecimal weight : scoreWeights) {
                this.state.upperScoreWeightSum = NewGradientDescentPfunc.bigMath().set(this.state.upperScoreWeightSum).add(weight).get();
                if (!MathTools.isLessThan(weight, this.state.minUpperScoreWeight)) continue;
                this.state.minUpperScoreWeight = weight;
            }
            this.state.numScoredConfs += (long)scoreWeights.size();
            this.state.scoreOps = (double)scoreWeights.size() / seconds;
            double delta = this.state.calcDelta();
            this.state.dScore = NewGradientDescentPfunc.calcSlope(delta, this.state.prevDelta, this.state.dEnergy);
            this.state.prevDelta = delta;
            this.state.dEnergy *= 2.0;
            if (this.isReportingProgress && (nowNs = System.nanoTime()) - this.state.lastReportNs > 1000000000L) {
                Log.log("[%s] scores:%8d, confs:%4d, score:%12s, energy:%12s, bounds:[%12f,%12f] (log10p1), delta:%.6f, time:%10s, heapMem:%s, energyOps:%.6f, scoreOps:%.6f", String.format("%" + (this.ecalc.confSpace().numPos() * 6 - 1) + "s", ""), this.state.numScoredConfs, this.state.numEnergiedConfs, "", "", MathTools.log10p1(this.state.getLowerBound()), MathTools.log10p1(this.state.getUpperBound()), this.state.calcDelta(), this.stopwatch.getTime(2), JvmMem.getOldPool(), this.state.energyOps, this.state.scoreOps);
                this.state.lastReportNs = nowNs;
            }
            if (this.trace != null) {
                this.trace.step(this.state.numScoredConfs, this.state.numEnergiedConfs, this.state.calcDelta());
            }
        }
    }

    private static double calcSlope(double delta, double prevDelta, double otherSlope) {
        double slope = delta - prevDelta;
        if (slope >= 0.0) {
            slope = otherSlope / 10.0;
        }
        return slope;
    }

    @Override
    public PartitionFunction.Result makeResult() {
        BigDecimal startLowerBound = BigDecimal.ZERO;
        BigDecimal startUpperBound = this.state.numConfs.multiply(this.state.firstScoreWeight);
        BigDecimal lowerFullMin = this.state.getLowerBound();
        BigDecimal lowerConfUpperBound = BigDecimal.ZERO;
        BigDecimal upperFullMin = this.state.cumulativeZReduction;
        BigDecimal upperPartialMin = BigDecimal.ZERO;
        BigDecimal finalUpperBoundNoEnergies = this.state.getUpperBoundNoE();
        BigDecimal upperConfLowerBound = startUpperBound.subtract(finalUpperBoundNoEnergies);
        return new PartitionFunction.Result(this.getStatus(), this.getValues(), this.getNumConfsEvaluated());
    }

    private static class State {
        BigDecimal numConfs;
        long numScoredConfs = 0L;
        BigDecimal upperScoreWeightSum = BigDecimal.ZERO;
        BigDecimal minUpperScoreWeight = MathTools.BigPositiveInfinity;
        long numEnergiedConfs = 0L;
        BigDecimal lowerScoreWeightSum = BigDecimal.ZERO;
        BigDecimal energyWeightSum = BigDecimal.ZERO;
        BigDecimal minLowerScoreWeight = MathTools.BigPositiveInfinity;
        BigDecimal cumulativeZReduction = BigDecimal.ZERO;
        ArrayList<Integer> minList = new ArrayList();
        BigDecimal firstScoreWeight = BigDecimal.ZERO;
        double scoreOps = 100.0;
        double energyOps = 1.0;
        double prevDelta = 1.0;
        double dEnergy = -1.0;
        double dScore = -1.0;
        long lastReportNs = 0L;

        State(BigInteger numConfs) {
            this.numConfs = new BigDecimal(numConfs);
        }

        double calcDelta() {
            BigDecimal upperBound = this.getUpperBound();
            if (MathTools.isZero(upperBound) || MathTools.isInf(upperBound)) {
                return 1.0;
            }
            return NewGradientDescentPfunc.bigMath().set(upperBound).sub(this.getLowerBound()).div(upperBound).get().doubleValue();
        }

        public BigDecimal getLowerBound() {
            return this.energyWeightSum;
        }

        public void printBoundStats() {
            System.out.println("Num confs: " + String.format("%12e", this.numConfs));
            System.out.println("Num Scored confs: " + String.format("%4d", this.numScoredConfs));
            String upperScoreString = this.minUpperScoreWeight.toString();
            String upperSumString = this.upperScoreWeightSum.toString();
            if (!MathTools.isInf(this.minUpperScoreWeight)) {
                upperScoreString = String.format("%12e", this.minUpperScoreWeight);
            }
            if (!MathTools.isInf(this.upperScoreWeightSum)) {
                upperSumString = String.format("%12e", this.upperScoreWeightSum);
            }
            System.out.println("Conf bound: " + upperScoreString);
            System.out.println("Scored weight bound:" + upperSumString);
        }

        public BigDecimal getUpperBound() {
            return NewGradientDescentPfunc.bigMath().set(this.numConfs).sub(this.numScoredConfs).mult(this.minUpperScoreWeight).add(this.upperScoreWeightSum).sub(this.lowerScoreWeightSum).add(this.energyWeightSum).get();
        }

        public BigDecimal getUpperBoundNoE() {
            return NewGradientDescentPfunc.bigMath().set(this.numConfs).sub(this.numScoredConfs).mult(this.minUpperScoreWeight).add(this.upperScoreWeightSum).get();
        }

        boolean epsilonReached(double targetEpsilon) {
            return this.calcDelta() <= targetEpsilon;
        }

        boolean isStable(BigDecimal stabilityThreshold) {
            return this.numEnergiedConfs <= 0L || stabilityThreshold == null || MathTools.isGreaterThanOrEqual(this.getUpperBound(), stabilityThreshold);
        }

        boolean hasLowEnergies() {
            return MathTools.isGreaterThan(this.minLowerScoreWeight, BigDecimal.ZERO);
        }

        public String toString() {
            return String.format("upper: count %d  sum %s  min %s     lower: count %d  score sum %s  energy sum %s", this.numScoredConfs, Log.formatBig(this.upperScoreWeightSum), Log.formatBig(this.minUpperScoreWeight), this.numEnergiedConfs, Log.formatBig(this.lowerScoreWeightSum), Log.formatBig(this.energyWeightSum));
        }
    }

    static class EnergyResult {
        public ConfSearch.EnergiedConf econf;
        public BigDecimal scoreWeight;
        public BigDecimal energyWeight;
        public Stopwatch stopwatch = new Stopwatch();

        EnergyResult() {
        }
    }

    static class ScoreResult {
        public List<Double> scores = new ArrayList<Double>();
        public List<BigDecimal> scoreWeights = new ArrayList<BigDecimal>();
        public Stopwatch stopwatch = new Stopwatch();

        ScoreResult() {
        }
    }

    private static enum Step {
        None,
        Score,
        Energy;

    }
}

