/*
 * Decompiled with CFR 0.152.
 */
package edu.duke.cs.osprey.ematrix.epic;

import cern.colt.matrix.DoubleFactory1D;
import cern.colt.matrix.DoubleMatrix1D;
import cern.jet.math.Functions;
import edu.duke.cs.osprey.minimization.ObjectiveFunction;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Random;

public class SubThreshSampler {
    double thresh;
    ObjectiveFunction of;
    DoubleMatrix1D DOFmin;
    DoubleMatrix1D DOFmax;
    int numDOFs;
    DoubleMatrix1D samplingScale;
    static int numCandScaleTuning = 25;
    static double[] acceptRatioTarget = new double[]{0.2, 0.4};
    int useFrequency = 1;
    DoubleMatrix1D x;
    Random random = new Random();
    static boolean adaptiveScale = true;

    public SubThreshSampler(double thresh, ObjectiveFunction of, DoubleMatrix1D DOFmin, DoubleMatrix1D DOFmax) {
        this.thresh = thresh;
        this.of = of;
        this.DOFmin = DOFmin;
        this.DOFmax = DOFmax;
        this.numDOFs = DOFmin.size();
    }

    void initScale() {
        this.samplingScale = this.getScaleAdaptive(this.x);
    }

    DoubleMatrix1D getScaleAdaptive(DoubleMatrix1D sp) {
        DoubleMatrix1D ans = DoubleFactory1D.dense.make(this.numDOFs);
        for (int dim = 0; dim < this.numDOFs; ++dim) {
            DoubleMatrix1D y = sp.copy();
            double upDist = 1.0E-6;
            double downDist = 1.0E-6;
            do {
                y.set(dim, sp.get(dim) + (upDist *= 2.0));
            } while (this.checkValidPt(y));
            do {
                y.set(dim, sp.get(dim) - (downDist *= 2.0));
            } while (this.checkValidPt(y));
            ans.set(dim, (upDist + downDist) / 6.0);
        }
        return ans;
    }

    void burnIn(DoubleMatrix1D startingPoint) {
        int dof;
        System.out.println("Starting burn-in for SubThreshSampler.  x=" + String.valueOf(startingPoint));
        this.x = startingPoint;
        this.initScale();
        if (!adaptiveScale) {
            this.tuneScale();
        }
        ArrayList<DoubleMatrix1D> burnInSamp = new ArrayList<DoubleMatrix1D>();
        DoubleMatrix1D firstHalfSum = DoubleFactory1D.dense.make(this.numDOFs);
        DoubleMatrix1D secondHalfSum = DoubleFactory1D.dense.make(this.numDOFs);
        DoubleMatrix1D firstHalfSumSq = DoubleFactory1D.dense.make(this.numDOFs);
        DoubleMatrix1D secondHalfSumSq = DoubleFactory1D.dense.make(this.numDOFs);
        boolean done = false;
        int nhalf = 0;
        int b = 0;
        while (!done) {
            DoubleMatrix1D y;
            while (!this.checkCandidate(y = this.nextCandidate())) {
            }
            burnInSamp.add(this.x);
            secondHalfSum.assign(this.x, Functions.plus);
            DoubleMatrix1D xsq = this.x.copy().assign(Functions.square);
            secondHalfSumSq.assign(xsq, Functions.plus);
            if (b % 2 == 1) {
                DoubleMatrix1D y2 = (DoubleMatrix1D)burnInSamp.get(b / 2);
                firstHalfSum.assign(y2, Functions.plus);
                secondHalfSum.assign(y2, Functions.minus);
                DoubleMatrix1D ysq = y2.copy().assign(Functions.square);
                firstHalfSumSq.assign(ysq, Functions.plus);
                secondHalfSumSq.assign(ysq, Functions.minus);
                if (b > 10) {
                    nhalf = b / 2 + 1;
                    DoubleMatrix1D meanDiff = secondHalfSum.copy().assign(firstHalfSum, Functions.minus);
                    meanDiff.assign(Functions.mult((double)(1.0 / (double)nhalf)));
                    DoubleMatrix1D std1 = this.getStDVec(firstHalfSum, firstHalfSumSq, nhalf);
                    DoubleMatrix1D std2 = this.getStDVec(secondHalfSum, secondHalfSumSq, nhalf);
                    done = true;
                    for (dof = 0; dof < this.numDOFs; ++dof) {
                        if (!(meanDiff.get(dof) > 0.5 * Math.min(std1.get(dof), std2.get(dof)))) continue;
                        done = false;
                        break;
                    }
                }
            }
            if (done) {
                System.out.println("Burn-in complete at sample " + b);
            } else if ((b + 1) % 1000000 == 0) {
                System.out.println("Burn-in sample " + b + " done.  x: " + String.valueOf(this.x));
            }
            ++b;
        }
        DoubleMatrix1D var = this.getStDVec(secondHalfSum, secondHalfSumSq, nhalf).assign(Functions.square);
        DoubleMatrix1D mean = secondHalfSum.copy().assign(Functions.mult((double)(1.0 / (double)nhalf)));
        BitSet dimGood = new BitSet();
        double[] bestAutorat = new double[this.numDOFs];
        for (int uf = 1; uf < nhalf; ++uf) {
            DoubleMatrix1D autocorr = DoubleFactory1D.dense.make(this.numDOFs);
            for (int s = nhalf; s < 2 * nhalf - uf; ++s) {
                DoubleMatrix1D crossTerm = ((DoubleMatrix1D)burnInSamp.get(s)).copy().assign(mean, Functions.minus);
                DoubleMatrix1D relSamp2 = ((DoubleMatrix1D)burnInSamp.get(s + uf)).copy().assign(mean, Functions.minus);
                crossTerm.assign(relSamp2, Functions.mult);
                autocorr.assign(crossTerm, Functions.plus);
            }
            autocorr.assign(Functions.mult((double)(1.0 / (double)(nhalf - uf))));
            DoubleMatrix1D autorat = autocorr.copy().assign(var, Functions.div);
            for (int dim = 0; dim < this.numDOFs; ++dim) {
                if (autorat.get(dim) < 0.1) {
                    dimGood.set(dim);
                }
                bestAutorat[dim] = Math.min(bestAutorat[dim], autorat.get(dim));
            }
            if (dimGood.cardinality() == this.numDOFs) {
                this.useFrequency = uf;
                System.out.println("Setting useFrequency=" + this.useFrequency + ".  Variance: " + String.valueOf(var) + " Autocorr: " + String.valueOf(autocorr));
                break;
            }
            if (uf == nhalf - 1) {
                this.useFrequency = uf;
                System.out.println("Warning: high autocorrelation detected at all useFrequencies!  Setting useFrequency=" + uf);
                continue;
            }
            if ((uf + 1) % 1000000 != 0) continue;
            System.out.print("Trying useFrequency " + uf + ".  Best autocorrelation/variance ratios so far: ");
            for (dof = 0; dof < this.numDOFs; ++dof) {
                System.out.print(bestAutorat[dof] + " ");
            }
            System.out.println();
        }
    }

    DoubleMatrix1D getStDVec(DoubleMatrix1D sum, DoubleMatrix1D sumsq, int nsamp) {
        DoubleMatrix1D ans = sumsq.copy().assign(Functions.mult((double)(1.0 / (double)nsamp)));
        ans.assign(sum.copy().assign(Functions.mult((double)(1.0 / (double)nsamp))).assign(Functions.square), Functions.minus);
        ans.assign(Functions.sqrt);
        return ans;
    }

    void tuneScale() {
        int numAccepted = this.tryScale();
        if ((double)numAccepted < (double)numCandScaleTuning * acceptRatioTarget[0]) {
            while ((double)numAccepted < (double)numCandScaleTuning * acceptRatioTarget[0]) {
                this.samplingScale.assign(Functions.mult((double)0.7));
                numAccepted = this.tryScale();
            }
        } else if ((double)numAccepted > (double)numCandScaleTuning * acceptRatioTarget[1]) {
            while ((double)numAccepted > (double)numCandScaleTuning * acceptRatioTarget[1]) {
                this.samplingScale.assign(Functions.mult((double)1.5));
                numAccepted = this.tryScale();
            }
        }
        System.out.println("Tuned scale for SubThreshSampler: " + String.valueOf(this.samplingScale));
    }

    int tryScale() {
        int numAccepted = 0;
        for (int c = 0; c < numCandScaleTuning; ++c) {
            DoubleMatrix1D y = this.nextCandidate();
            if (!this.checkCandidate(y)) continue;
            ++numAccepted;
        }
        return numAccepted;
    }

    DoubleMatrix1D nextSample() {
        for (int u = 0; u < this.useFrequency; ++u) {
            DoubleMatrix1D y;
            while (!this.checkCandidate(y = this.nextCandidate())) {
            }
        }
        return this.x.copy();
    }

    boolean checkCandidate(DoubleMatrix1D y) {
        if (this.checkValidPt(y)) {
            if (adaptiveScale) {
                double qforward = this.Q(this.x, y, this.samplingScale);
                DoubleMatrix1D ySamplingScale = this.getScaleAdaptive(y);
                double qback = this.Q(y, this.x, ySamplingScale);
                if (qback / qforward > Math.random()) {
                    this.samplingScale = ySamplingScale;
                } else {
                    return false;
                }
            }
            this.x = y;
            return true;
        }
        return false;
    }

    double Q(DoubleMatrix1D pt1, DoubleMatrix1D pt2, DoubleMatrix1D scale) {
        DoubleMatrix1D diff = pt2.copy().assign(pt1, Functions.minus);
        double prob = 1.0;
        for (int dof = 0; dof < this.numDOFs; ++dof) {
            double v = diff.get(dof) / scale.get(dof);
            prob *= Math.exp(-v * v / 2.0);
        }
        return prob;
    }

    boolean checkValidPt(DoubleMatrix1D y) {
        for (int dof = 0; dof < this.numDOFs; ++dof) {
            if (!(y.get(dof) < this.DOFmin.get(dof)) && !(y.get(dof) > this.DOFmax.get(dof))) continue;
            return false;
        }
        return !(this.of.getValue(y) > this.thresh);
    }

    DoubleMatrix1D nextCandidate() {
        DoubleMatrix1D y = this.x.copy();
        for (int dof = 0; dof < this.numDOFs; ++dof) {
            y.set(dof, this.x.get(dof) + this.random.nextGaussian() * this.samplingScale.get(dof));
        }
        return y;
    }
}

