/*
 * Decompiled with CFR 0.152.
 */
package edu.duke.cs.osprey.voxq;

import cern.colt.matrix.DoubleMatrix1D;
import edu.duke.cs.osprey.minimization.MoleculeModifierAndScorer;
import org.apache.commons.math3.special.Erf;

public class QuadraticQFunction {
    double a;
    double b;
    double c;
    double xLo;
    double xHi;
    boolean useLinearPrior;

    public QuadraticQFunction(MoleculeModifierAndScorer mms, int dof, double origVal) {
        double slope;
        DoubleMatrix1D[] constr = mms.getConstraints();
        this.xLo = constr[0].get(dof);
        this.xHi = constr[1].get(dof);
        if (this.xHi < this.xLo + 1.0E-14) {
            throw new RuntimeException("ERROR: Trying to sample a rigid DOF!");
        }
        double origValE = mms.getValForDOF(dof, origVal);
        double xLoE = mms.getValForDOF(dof, this.xLo);
        double xHiE = mms.getValForDOF(dof, this.xHi);
        double x1 = QuadraticQFunction.chooseReasonableOuterPt(origVal, origValE, this.xLo, xLoE, mms, dof);
        double x3 = QuadraticQFunction.chooseReasonableOuterPt(origVal, origValE, this.xHi, xHiE, mms, dof);
        double x2 = Math.max(origVal, 0.9 * x1 + 0.1 * x3);
        x2 = Math.min(x2, 0.1 * x1 + 0.9 * x3);
        double E1 = this.getEnergyIfNeeded(x1, mms, dof, this.xLo, xLoE);
        double E3 = this.getEnergyIfNeeded(x3, mms, dof, this.xHi, xHiE);
        double E2 = this.getEnergyIfNeeded(x2, mms, dof, origVal, origValE);
        if (E2 - E1 < (slope = (E3 - E1) / (x3 - x1)) * (x2 - x1)) {
            boolean success = this.setupQuadratic(x1, x2, x3, E1, E2, E3);
            if (this.a != 0.0 && !this.erfcInvNumericsOK() || !success) {
                this.setupLinear(x1, x3, E1, E3);
            }
        } else {
            this.setupLinear(x1, x3, E1, E3);
        }
        for (double q : new double[]{this.a, this.b, this.c}) {
            if (!Double.isInfinite(q) && !Double.isNaN(q)) continue;
            throw new RuntimeException("ERROR: Infinite or NaN coefficient in sampling prior");
        }
    }

    private double getEnergyIfNeeded(double x, MoleculeModifierAndScorer mms, int dof, double xDone, double Edone) {
        if (x != xDone) {
            return mms.getValForDOF(dof, x);
        }
        return Edone;
    }

    private static double chooseReasonableOuterPt(double origVal, double origValE, double outerPt, double outerPtE, MoleculeModifierAndScorer mms, int dof) {
        double top = 20.0;
        double bottom = 10.0;
        boolean neededAdjustment = false;
        while (outerPtE > origValE + top) {
            neededAdjustment = true;
            outerPt = 0.5 * (origVal + outerPt);
            outerPtE = mms.getValForDOF(dof, outerPt);
        }
        if (neededAdjustment) {
            double prevOuterPt = 2.0 * outerPt - origVal;
            while (outerPtE <= origValE + bottom) {
                outerPt = 0.8 * outerPt + 0.2 * prevOuterPt;
                outerPtE = mms.getValForDOF(dof, outerPt);
            }
        }
        return outerPt;
    }

    boolean setupQuadratic(double x1, double x2, double x3, double E1, double E2, double E3) {
        this.a = ((E2 - E1) / (x2 - x1) - (E3 - E1) / (x3 - x1)) / (x2 - x3);
        this.b = (E3 - E1 - this.a * (x3 * x3 - x1 * x1)) / (x3 - x1);
        this.c = E3 - x3 * (this.b + this.a * x3);
        this.a /= -0.593050165;
        this.b /= -0.593050165;
        this.c /= -0.593050165;
        return this.normalizeDistr();
    }

    void setupLinear(double x1, double x3, double E1, double E3) {
        this.a = 0.0;
        double slope = (E3 - E1) / (x3 - x1);
        this.b = -slope / 0.593050165;
        this.c = -(E3 - x3 * slope) / 0.593050165;
        this.normalizeDistr();
    }

    double drawDOFValue() {
        return this.cumulDistrInv(Math.random());
    }

    double evalQ(double x) {
        return Math.exp(this.c + x * (this.b + this.a * x));
    }

    private double cumulDistr(double x) {
        if (Math.abs(this.a) < 1.0E-14) {
            return (Math.exp(this.b * x + this.c) - Math.exp(this.b * this.xLo + this.c)) / this.b;
        }
        double C = 0.5 * Math.exp(this.c - 0.25 * this.b * this.b / this.a) * Math.sqrt(-Math.PI / this.a);
        return C * Erf.erf((double)this.cdErfArg(this.xLo), (double)this.cdErfArg(x));
    }

    private boolean normalizeDistr() {
        if (Math.abs(this.a) < 1.0E-14) {
            if (Math.abs(this.b) < 1.0E-7) {
                this.c = -this.b * (this.xHi + this.xLo) / 2.0 - Math.log(this.xHi - this.xLo);
            } else if (this.b < 0.0) {
                expl = this.b / (Math.exp(this.b * (this.xHi - this.xLo)) - 1.0);
                this.c = Math.log(expl) - this.b * this.xLo;
            } else {
                expl = this.b / (1.0 - Math.exp(this.b * (this.xLo - this.xHi)));
                this.c = Math.log(expl) - this.b * this.xHi;
            }
        } else {
            double erfDiff = Erf.erf((double)this.cdErfArg(this.xLo), (double)this.cdErfArg(this.xHi));
            if (erfDiff != 0.0) {
                double exponent = -Math.log(0.5 * erfDiff * Math.sqrt(-Math.PI / this.a));
                this.c = exponent + 0.25 * this.b * this.b / this.a;
            }
        }
        double newNorm = this.cumulDistr(this.xHi);
        if (Double.isNaN(newNorm) || Math.abs(newNorm - 1.0) > 1.0E-5) {
            if (Math.abs(this.a) >= 1.0E-14) {
                return false;
            }
            throw new RuntimeException("ERROR: Unsuccessful normalization.  a: " + this.a + " b: " + this.b + " c: " + this.c + " newNorm (should be 1): " + newNorm);
        }
        return true;
    }

    private double cdErfArg(double x) {
        double denom = 2.0 * Math.sqrt(-this.a);
        double num = -this.b - 2.0 * this.a * x;
        return num / denom;
    }

    private boolean erfcInvNumericsOK() {
        double erfArgRange = Math.abs(this.cdErfArg(this.xHi) - this.cdErfArg(this.xLo));
        for (double x : new double[]{this.xLo, this.xHi}) {
            if (this.erfcInvNumericsOK(this.cdErfArg(x), erfArgRange)) continue;
            return false;
        }
        return true;
    }

    private boolean erfcInvNumericsOK(double x, double refRange) {
        double invResult = QuadraticQFunction.myErfcInv(Erf.erfc((double)(x = Math.abs(x))));
        if (Double.isInfinite(invResult)) {
            return false;
        }
        return !(Math.abs(x - invResult) > 0.01 * refRange);
    }

    private double cumulDistrInv(double F) {
        double erfArg2;
        if (Math.abs(this.a) < 1.0E-14) {
            double ans = (Math.log(this.b * F + Math.exp(this.b * this.xLo + this.c)) - this.c) / this.b;
            if (ans < this.xLo - 1.0E-6 || ans > this.xHi + 1.0E-6) {
                System.out.println("Out of range QuadraticQFunction draw...");
            }
            return ans;
        }
        double C = 0.5 * Math.exp(this.c - 0.25 * this.b * this.b / this.a) * Math.sqrt(-Math.PI / this.a);
        double erfArg1 = this.cdErfArg(this.xLo);
        if (erfArg1 > 1.0) {
            double erfcVal = -F / C + Erf.erfc((double)erfArg1);
            erfArg2 = QuadraticQFunction.myErfcInv(erfcVal);
        } else if (erfArg1 < -1.0) {
            double oerfcVal = F / C + Erf.erfc((double)(-erfArg1));
            erfArg2 = -QuadraticQFunction.myErfcInv(oerfcVal);
        } else {
            double erfVal = F / C + Erf.erf((double)erfArg1);
            erfArg2 = Erf.erfInv((double)erfVal);
        }
        double denom = 2.0 * Math.sqrt(-this.a);
        double ans = (-erfArg2 * denom - this.b) / (2.0 * this.a);
        if (ans < this.xLo - 1.0 || ans > this.xHi + 1.0) {
            System.out.println("Out of range QuadraticQFunction draw...");
        }
        return ans;
    }

    private static double myErfcInv(double z) {
        if (z > 1.0) {
            return -QuadraticQFunction.myErfcInv(2.0 - z);
        }
        if (z > 1.0E-15) {
            return Erf.erfcInv((double)z);
        }
        return QuadraticQFunction.quadApproxErfcInv(z);
    }

    private static double quadApproxErfcInv(double z) {
        double discr = 0.027566902962268564 - 3.988084472 * (1.47400792 + Math.log(z));
        if (discr < 0.0) {
            throw new RuntimeException("ERROR: erfc quad approx out of range!");
        }
        return (Math.sqrt(discr) - 0.166032837) / 1.994042236;
    }
}

