/*
 * Decompiled with CFR 0.152.
 */
package edu.duke.cs.osprey.energy.approximation;

import cern.colt.matrix.DoubleFactory1D;
import cern.colt.matrix.DoubleFactory2D;
import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;
import cern.colt.matrix.linalg.QRDecomposition;
import edu.duke.cs.osprey.energy.approximation.ApproximatedObjectiveFunction;
import edu.duke.cs.osprey.minimization.Minimizer;
import edu.duke.cs.osprey.tools.IOable;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.List;

public class QuadraticApproximator
implements ApproximatedObjectiveFunction.Approximator.Addable,
IOable {
    public final List<Integer> dofBlockIds;
    public final List<Integer> dofCounts;
    public final int numDofs;
    public final DoubleMatrix1D coefficients;
    private double maxe;
    private final int[] blockIndicesByDof;
    private final int[] dofOffsetsByBlock;

    public QuadraticApproximator(List<Integer> dofBlockIds, List<Integer> dofCounts) {
        this.dofBlockIds = dofBlockIds;
        this.dofCounts = dofCounts;
        this.numDofs = dofCounts.stream().mapToInt(i -> i).sum();
        this.coefficients = DoubleFactory1D.dense.make(1 + this.numDofs + this.numDofs * (this.numDofs + 1) / 2);
        this.maxe = 0.0;
        this.blockIndicesByDof = new int[this.numDofs];
        this.dofOffsetsByBlock = new int[dofBlockIds.size()];
        int n = 0;
        for (int i2 = 0; i2 < dofBlockIds.size(); ++i2) {
            this.dofOffsetsByBlock[i2] = n;
            for (int j = 0; j < dofCounts.get(i2); ++j) {
                this.blockIndicesByDof[n++] = i2;
            }
        }
    }

    @Override
    public int numDofs() {
        return this.numDofs;
    }

    @Override
    public List<Integer> dofBlockIds() {
        return this.dofBlockIds;
    }

    @Override
    public List<Integer> dofCounts() {
        return this.dofCounts;
    }

    @Override
    public int numParams() {
        return this.coefficients.size();
    }

    @Override
    public double train(List<Minimizer.Result> trainingSet, List<Minimizer.Result> testSet) {
        for (Minimizer.Result sample : trainingSet) {
            if (sample.dofValues.size() == this.numDofs) continue;
            throw new IllegalArgumentException("samples have wrong number of dimensions");
        }
        LinearSystem trainingSystem = new LinearSystem(this, trainingSet);
        LinearSystem testSystem = new LinearSystem(this, testSet);
        this.coefficients.assign(new QRDecomposition(trainingSystem.A).solve(trainingSystem.b).viewColumn(0));
        DoubleMatrix1D residual = new Algebra().mult(testSystem.A, this.coefficients).assign(testSystem.b.viewColumn(0), (ri, bi) -> Math.abs(ri - bi));
        this.maxe = 0.0;
        for (int i = 0; i < residual.size(); ++i) {
            assert (residual.get(i) >= 0.0);
            this.maxe = Math.max(this.maxe, residual.get(i));
        }
        return this.maxe;
    }

    @Override
    public void train(double energy) {
        this.coefficients.set(0, energy);
        for (int i = 1; i < this.coefficients.size(); ++i) {
            this.coefficients.set(i, 0.0);
        }
    }

    private int index1(int d1) {
        return 1 + d1;
    }

    private int index2(int d1, int d2) {
        if (d2 > d1) {
            int swap = d1;
            d1 = d2;
            d2 = swap;
        }
        return 1 + this.numDofs + d1 * (d1 + 1) / 2 + d2;
    }

    @Override
    public double getValue(DoubleMatrix1D x) {
        if (x.size() != this.numDofs) {
            throw new IllegalArgumentException(String.format("x is wrong size (%d), expected %d", x.size(), this.numDofs));
        }
        double v = this.coefficients.get(0);
        for (int d1 = 0; d1 < this.numDofs; ++d1) {
            double v1 = this.coefficients.get(this.index1(d1));
            for (int d2 = 0; d2 <= d1; ++d2) {
                double c = this.coefficients.get(this.index2(d1, d2));
                v1 += c * x.get(d2);
            }
            v += v1 * x.get(d1);
        }
        return v;
    }

    @Override
    public double getValForDOF(int d1, double val, DoubleMatrix1D x) {
        double v = this.coefficients.get(this.index1(d1));
        for (int d2 = 0; d2 < this.numDofs; ++d2) {
            double x2 = x.get(d2);
            double c = this.coefficients.get(this.index2(d1, d2));
            v += c * x2;
        }
        v *= x.get(d1);
        return v += this.coefficients.get(0);
    }

    @Override
    public double error() {
        return this.maxe;
    }

    @Override
    public QuadraticApproximator makeIdentity(List<Integer> dofBlockIds, List<Integer> dofCounts) {
        return new QuadraticApproximator(dofBlockIds, dofCounts);
    }

    @Override
    public void add(ApproximatedObjectiveFunction.Approximator.Addable src, double weight, double offset) {
        if (!(src instanceof QuadraticApproximator)) {
            throw new IllegalArgumentException("can't add different approximator types together:\n\t" + this.getClass().getName() + "\n\t" + src.getClass().getName());
        }
        QuadraticApproximator.add((QuadraticApproximator)src, this, weight, offset);
    }

    public static void add(QuadraticApproximator src, QuadraticApproximator dst, double weight, double offset) {
        int dstd1;
        int srcd1;
        int[] dstBlockIndices = new int[src.dofBlockIds.size()];
        for (int srci = 0; srci < src.dofBlockIds.size(); ++srci) {
            int blockId = src.dofBlockIds.get(srci);
            int dsti = dst.dofBlockIds.indexOf(blockId);
            if (dsti < 0) {
                throw new IllegalArgumentException("destination approximator doesn't have dof block " + blockId);
            }
            if (!dst.dofCounts.get(dsti).equals(src.dofCounts.get(srci))) {
                throw new IllegalArgumentException("block " + blockId + " has different sizes in different approximators");
            }
            dstBlockIndices[srci] = dsti;
        }
        int[] dstDofs = new int[src.numDofs];
        for (int srcd = 0; srcd < src.numDofs; ++srcd) {
            int dstd;
            int srcb = src.blockIndicesByDof[srcd];
            int dofOffset = srcd - src.dofOffsetsByBlock[srcb];
            int dstb = dstBlockIndices[srcb];
            dstDofs[srcd] = dstd = dst.dofOffsetsByBlock[dstb] + dofOffset;
        }
        dst.coefficients.set(0, dst.coefficients.get(0) + (src.coefficients.get(0) + offset) * weight);
        for (srcd1 = 0; srcd1 < src.numDofs; ++srcd1) {
            dstd1 = dstDofs[srcd1];
            int srci = src.index1(srcd1);
            int dsti = dst.index1(dstd1);
            dst.coefficients.set(dsti, dst.coefficients.get(dsti) + src.coefficients.get(srci) * weight);
        }
        for (srcd1 = 0; srcd1 < src.numDofs; ++srcd1) {
            dstd1 = dstDofs[srcd1];
            for (int srcd2 = 0; srcd2 <= srcd1; ++srcd2) {
                int dstd2 = dstDofs[srcd2];
                int srci = src.index2(srcd1, srcd2);
                int dsti = dst.index2(dstd1, dstd2);
                dst.coefficients.set(dsti, dst.coefficients.get(dsti) + src.coefficients.get(srci) * weight);
            }
        }
        dst.maxe += src.maxe;
    }

    @Override
    public void writeTo(DataOutput out) throws IOException {
        for (int i = 0; i < this.coefficients.size(); ++i) {
            out.writeDouble(this.coefficients.get(i));
        }
        out.writeDouble(this.maxe);
    }

    @Override
    public void readFrom(DataInput in) throws IOException {
        for (int i = 0; i < this.coefficients.size(); ++i) {
            this.coefficients.set(i, in.readDouble());
        }
        this.maxe = in.readDouble();
    }

    public boolean equals(Object other) {
        return other instanceof QuadraticApproximator && this.equals((QuadraticApproximator)other);
    }

    public boolean equals(QuadraticApproximator other) {
        return this.dofBlockIds.equals(other.dofBlockIds) && this.dofCounts.equals(other.dofCounts) && this.coefficients.equals((Object)other.coefficients) && this.maxe == other.maxe;
    }

    private class LinearSystem {
        public final List<Minimizer.Result> samples;
        public final DoubleMatrix2D A;
        public final DoubleMatrix2D b;

        public LinearSystem(QuadraticApproximator quadraticApproximator, List<Minimizer.Result> samples) {
            this.samples = samples;
            this.A = DoubleFactory2D.dense.make(samples.size(), quadraticApproximator.coefficients.size());
            this.b = DoubleFactory2D.dense.make(samples.size(), 1);
            for (int i = 0; i < samples.size(); ++i) {
                int d1;
                DoubleMatrix1D x = samples.get((int)i).dofValues;
                double energy = samples.get((int)i).energy;
                this.b.set(i, 0, energy);
                this.A.set(i, 0, 1.0);
                for (d1 = 0; d1 < quadraticApproximator.numDofs; ++d1) {
                    this.A.set(i, quadraticApproximator.index1(d1), x.get(d1));
                }
                for (d1 = 0; d1 < quadraticApproximator.numDofs; ++d1) {
                    for (int d2 = 0; d2 <= d1; ++d2) {
                        this.A.set(i, quadraticApproximator.index2(d1, d2), x.get(d1) * x.get(d2));
                    }
                }
            }
        }
    }
}

