/*
 * Decompiled with CFR 0.152.
 */
package edu.duke.cs.osprey.lute;

import edu.duke.cs.osprey.astar.conf.RCs;
import edu.duke.cs.osprey.confspace.Conf;
import edu.duke.cs.osprey.confspace.ConfDB;
import edu.duke.cs.osprey.confspace.PosMatrixGeneric;
import edu.duke.cs.osprey.confspace.RCTuple;
import edu.duke.cs.osprey.confspace.SimpleConfSpace;
import edu.duke.cs.osprey.confspace.TupleTree;
import edu.duke.cs.osprey.confspace.TuplesIndex;
import edu.duke.cs.osprey.ematrix.EnergyMatrix;
import edu.duke.cs.osprey.energy.ConfEnergyCalculator;
import edu.duke.cs.osprey.lute.ConfSampler;
import edu.duke.cs.osprey.lute.LUTEConfEnergyCalculator;
import edu.duke.cs.osprey.lute.LUTEIO;
import edu.duke.cs.osprey.lute.LUTEState;
import edu.duke.cs.osprey.parallelism.TaskExecutor;
import edu.duke.cs.osprey.parallelism.ThreadPoolTaskExecutor;
import edu.duke.cs.osprey.pruning.PruningMatrix;
import edu.duke.cs.osprey.tools.Log;
import edu.duke.cs.osprey.tools.Progress;
import edu.duke.cs.osprey.tools.Stopwatch;
import java.io.File;
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Function;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.ConjugateGradient;
import org.apache.commons.math3.linear.RealLinearOperator;
import org.apache.commons.math3.linear.RealVector;
import smile.data.SparseDataset;
import smile.math.matrix.Matrix;
import smile.math.matrix.SparseMatrix;
import smile.regression.LASSO;

public class LUTE {
    public final SimpleConfSpace confSpace;
    public final TuplesIndex tuplesIndex;
    public final ConfSampler.Samples trainingSet;
    public final ConfSampler.Samples testSet;
    private LinearSystem trainingSystem = null;
    private LinearSystem testSystem = null;
    private Map<int[], Double> energies = null;

    public LUTE(SimpleConfSpace confSpace) {
        this.confSpace = confSpace;
        this.tuplesIndex = new TuplesIndex(confSpace);
        this.trainingSet = new ConfSampler.Samples(this.tuplesIndex);
        this.testSet = new ConfSampler.Samples(this.tuplesIndex);
    }

    public Set<RCTuple> getUnprunedSingleTuples(PruningMatrix pmat) {
        LinkedHashSet<RCTuple> tuples = new LinkedHashSet<RCTuple>();
        pmat.forEachUnprunedSingle((pos1, rc1) -> {
            tuples.add(new RCTuple(pos1, rc1));
            return PruningMatrix.IteratorCommand.Continue;
        });
        return tuples;
    }

    public Set<RCTuple> getUnprunedPairTuples(PruningMatrix pmat) {
        LinkedHashSet<RCTuple> tuples = new LinkedHashSet<RCTuple>();
        pmat.forEachUnprunedPair((pos1, rc1, pos2, rc2) -> {
            tuples.add(new RCTuple(pos2, rc2, pos1, rc1));
            return PruningMatrix.IteratorCommand.Continue;
        });
        return tuples;
    }

    public Set<RCTuple> getUnprunedTripleTuples(PruningMatrix pmat) {
        LinkedHashSet<RCTuple> tuples = new LinkedHashSet<RCTuple>();
        pmat.forEachUnprunedTriple((pos1, rc1, pos2, rc2, pos3, rc3) -> {
            tuples.add(new RCTuple(pos3, rc3, pos2, rc2, pos1, rc1));
            return PruningMatrix.IteratorCommand.Continue;
        });
        return tuples;
    }

    public Set<RCTuple> sampleTripleTuplesByFitError(PruningMatrix pmat, final ConfDB.ConfTable confTable, double fractionSqErrorCovered) {
        final LUTEConfEnergyCalculator luteConfEcalc = new LUTEConfEnergyCalculator(this.confSpace, new LUTEState(this.trainingSystem));
        class ConfInfo {
            public final int[] conf;
            public final double ffEnergy;
            public final double luteEnergy;
            public final double err;

            public ConfInfo(int[] conf) {
                this.conf = conf;
                this.ffEnergy = confTable.get((int[])conf).upper.energy;
                this.luteEnergy = luteConfEcalc.calcEnergy(conf);
                this.err = Math.abs(this.ffEnergy - this.luteEnergy);
            }
        }
        ArrayList<ConfInfo> confInfos = new ArrayList<ConfInfo>();
        for (int[] conf : this.trainingSystem.confs) {
            confInfos.add(new ConfInfo(conf));
        }
        confInfos.sort(Comparator.comparing(info2 -> info2.err).reversed());
        ArrayList triples = new ArrayList();
        pmat.forEachUnprunedTriple((pos1, rc1, pos2, rc2, pos3, rc3) -> {
            triples.add(new RCTuple(pos1, rc1, pos2, rc2, pos3, rc3).sorted());
            return PruningMatrix.IteratorCommand.Continue;
        });
        TupleTree<Double> triplesSqError = new TupleTree<Double>();
        for (ConfInfo info3 : confInfos) {
            for (Object triple2 : triples) {
                if (!Conf.containsTuple(info3.conf, (RCTuple)triple2)) continue;
                Double errorSqSum = (Double)triplesSqError.get((RCTuple)triple2);
                if (errorSqSum == null) {
                    errorSqSum = 0.0;
                }
                errorSqSum = errorSqSum + info3.err * info3.err;
                triplesSqError.put((RCTuple)triple2, errorSqSum);
            }
        }
        Function<RCTuple, Double> getSqError = triple -> {
            Double sqError = (Double)triplesSqError.get((RCTuple)triple);
            if (sqError == null) {
                sqError = 0.0;
            }
            return sqError;
        };
        double totalSqError = 0.0;
        for (RCTuple triple3 : triples) {
            totalSqError += getSqError.apply(triple3).doubleValue();
        }
        HashSet<RCTuple> chosenTriples = new HashSet<RCTuple>();
        double sqError = 0.0;
        for (RCTuple triple4 : triples) {
            chosenTriples.add(triple4);
            if (!((sqError += getSqError.apply(triple4).doubleValue()) / totalSqError >= fractionSqErrorCovered)) continue;
            break;
        }
        return chosenTriples;
    }

    public Set<RCTuple> sampleTripleTuplesByStrongInteractions(EnergyMatrix emat, PruningMatrix pmat, int maxNumPairsPerPosition) {
        int n = emat.getNumPos();
        int numPairsPerPosition = Math.min(n - 1, maxNumPairsPerPosition);
        PosMatrixGeneric<Double> strongestInteractions = new PosMatrixGeneric<Double>(this.confSpace);
        strongestInteractions.fill(0.0);
        pmat.forEachUnprunedPair((pos1, rc1, pos2, rc2) -> {
            double interaction = Math.abs(emat.getPairwise(pos1, rc1, pos2, rc2));
            if (interaction > (Double)strongestInteractions.get(pos1, pos2)) {
                strongestInteractions.set(pos1, pos2, interaction);
            }
            return PruningMatrix.IteratorCommand.Continue;
        });
        PosMatrixGeneric<Boolean> topPositionInteractions = new PosMatrixGeneric<Boolean>(this.confSpace);
        topPositionInteractions.fill(false);
        for (int pos12 = 0; pos12 < n; ++pos12) {
            ArrayList<Integer> positions = new ArrayList<Integer>(n);
            for (int pos22 = 0; pos22 < n; ++pos22) {
                if (pos22 == pos12) continue;
                positions.add(pos22);
            }
            int fpos1 = pos12;
            positions.sort(Comparator.comparing(pos2 -> (Boolean)topPositionInteractions.get(fpos1, (int)pos2)).reversed());
            for (int i = 0; i < numPairsPerPosition; ++i) {
                int pos23 = (Integer)positions.get(i);
                topPositionInteractions.set(pos12, pos23, true);
            }
        }
        LinkedHashSet<RCTuple> triples = new LinkedHashSet<RCTuple>();
        pmat.forEachUnprunedTriple((pos1, rc1, pos2, rc2, pos3, rc3) -> {
            int numStrongInteractions = ((Boolean)topPositionInteractions.get(pos1, pos2) != false ? 1 : 0) + ((Boolean)topPositionInteractions.get(pos1, pos3) != false ? 1 : 0) + ((Boolean)topPositionInteractions.get(pos2, pos3) != false ? 1 : 0);
            if (numStrongInteractions >= 2) {
                triples.add(new RCTuple(pos1, rc1, pos2, rc2, pos3, rc3).sorted());
            }
            return PruningMatrix.IteratorCommand.Continue;
        });
        return triples;
    }

    public void addTuples(Iterable<RCTuple> tuples) {
        for (RCTuple tuple : tuples) {
            this.addTuple(tuple);
        }
    }

    public void addUniqueTuples(Iterable<RCTuple> tuples) {
        for (RCTuple tuple : tuples) {
            if (this.tuplesIndex.contains(tuple)) continue;
            this.addTuple(tuple);
        }
    }

    public void addTuple(RCTuple tuple) {
        tuple.checkSortedPositions();
        this.tuplesIndex.appendTuple(tuple);
        this.trainingSet.addTuple(tuple);
        this.testSet.addTuple(tuple);
    }

    public boolean sampleTuplesAndFit(ConfEnergyCalculator confEcalc, EnergyMatrix emat, PruningMatrix pmat, ConfDB.ConfTable confTable, ConfSampler sampler, Fitter fitter, double maxOverfittingScore, double maxRMSE) {
        if (this.confSpace.positions.size() == 1) {
            Log.logf("Sampling all pair tuples...", new Object[0]);
            Stopwatch singlesStopwatch = new Stopwatch().start();
            this.addUniqueTuples(this.getUnprunedSingleTuples(pmat));
            Log.log(" done in " + singlesStopwatch.stop().getTime(2), new Object[0]);
            this.fit(confEcalc, confTable, sampler, fitter, maxRMSE, maxOverfittingScore);
            if (this.trainingSystem.errors.rms <= maxRMSE) {
                Log.log("training set RMS error %f meets goal of %f", this.trainingSystem.errors.rms, maxRMSE);
                return true;
            }
            Log.log("training set RMS error %f does not meet goal of %f", this.trainingSystem.errors.rms, maxRMSE);
            return false;
        }
        Log.logf("Sampling all pair tuples...", new Object[0]);
        Stopwatch pairsStopwatch = new Stopwatch().start();
        this.addUniqueTuples(this.getUnprunedPairTuples(pmat));
        Log.log(" done in " + pairsStopwatch.stop().getTime(2), new Object[0]);
        this.fit(confEcalc, confTable, sampler, fitter, maxRMSE, maxOverfittingScore);
        if (this.trainingSystem.errors.rms <= maxRMSE) {
            Log.log("training set RMS error %f meets goal of %f", this.trainingSystem.errors.rms, maxRMSE);
            return true;
        }
        Log.log("training set RMS error %f does not meet goal of %f", this.trainingSystem.errors.rms, maxRMSE);
        int maxNumPairsPerPosition = this.confSpace.positions.size() - 1;
        for (int numPairsPerPosition = 1; numPairsPerPosition < maxNumPairsPerPosition; ++numPairsPerPosition) {
            Log.logf("Sampling triple tuples (top %d strongly-interacting pairs at each position) to try to reduce error...", numPairsPerPosition);
            Stopwatch triplesStopwatch = new Stopwatch().start();
            this.addUniqueTuples(this.sampleTripleTuplesByStrongInteractions(emat, pmat, numPairsPerPosition));
            Log.log(" done in " + triplesStopwatch.stop().getTime(2), new Object[0]);
            this.fit(confEcalc, confTable, sampler, fitter, maxRMSE, maxOverfittingScore);
            if (this.trainingSystem.errors.rms <= maxRMSE) {
                Log.log("training set RMS error %f meets goal of %f", this.trainingSystem.errors.rms, maxRMSE);
                return true;
            }
            Log.log("training set RMS error %f does not meet goal of %f", this.trainingSystem.errors.rms, maxRMSE);
        }
        Log.log("all triples exhausted. Nothing more to try to improve the fit.", new Object[0]);
        return false;
    }

    public boolean sampleAllPairsAndFit(ConfEnergyCalculator confEcalc, EnergyMatrix emat, PruningMatrix pmat, ConfDB.ConfTable confTable, ConfSampler sampler, Fitter fitter, double maxOverfittingScore, double maxRMSE) {
        if (this.confSpace.positions.size() < 3) {
            throw new IllegalArgumentException("conf space must have at least three positions");
        }
        Log.logf("Sampling all pairs...", new Object[0]);
        Stopwatch tupleStopwatch = new Stopwatch().start();
        this.addUniqueTuples(this.getUnprunedPairTuples(pmat));
        Log.log(" done in " + tupleStopwatch.stop().getTime(2), new Object[0]);
        this.fit(confEcalc, confTable, sampler, fitter, maxRMSE, maxOverfittingScore);
        if (this.trainingSystem.errors.rms <= maxRMSE) {
            Log.log("training set RMS error %f meets goal of %f", this.trainingSystem.errors.rms, maxRMSE);
            return true;
        }
        Log.log("training set RMS error %f does not meet goal of %f", this.trainingSystem.errors.rms, maxRMSE);
        return false;
    }

    public boolean sampleAllPairsTriplesAndFit(ConfEnergyCalculator confEcalc, EnergyMatrix emat, PruningMatrix pmat, ConfDB.ConfTable confTable, ConfSampler sampler, Fitter fitter, double maxOverfittingScore, double maxRMSE) {
        if (this.confSpace.positions.size() < 3) {
            throw new IllegalArgumentException("conf space must have at least three positions");
        }
        Log.logf("Sampling all pairs and triples...", new Object[0]);
        Stopwatch tupleStopwatch = new Stopwatch().start();
        this.addUniqueTuples(this.getUnprunedPairTuples(pmat));
        this.addUniqueTuples(this.getUnprunedTripleTuples(pmat));
        Log.log(" done in " + tupleStopwatch.stop().getTime(2), new Object[0]);
        this.fit(confEcalc, confTable, sampler, fitter, maxRMSE, maxOverfittingScore);
        if (this.trainingSystem.errors.rms <= maxRMSE) {
            Log.log("training set RMS error %f meets goal of %f", this.trainingSystem.errors.rms, maxRMSE);
            return true;
        }
        Log.log("training set RMS error %f does not meet goal of %f", this.trainingSystem.errors.rms, maxRMSE);
        return false;
    }

    public void fit(ConfEnergyCalculator confEcalc, ConfDB.ConfTable confTable, ConfSampler sampler, Fitter fitter, double maxTrainingRMSE, double maxOverfittingScore) {
        double overfittingScore;
        this.energies = new Conf.Map<Double>();
        Stopwatch sw = new Stopwatch().start();
        for (int[] conf : this.trainingSet.getAllConfs()) {
            this.energies.put(conf, null);
        }
        for (int[] conf : this.testSet.getAllConfs()) {
            this.energies.put(conf, null);
        }
        int numSamples = this.energies.size();
        int samplesPerTuple = 1;
        while (true) {
            Log.logf("\nsampling at least %d confs per tuple for %d tuples...", samplesPerTuple, this.tuplesIndex.size());
            Stopwatch samplingSw = new Stopwatch().start();
            sampler.sampleConfsForTuples(this.trainingSet, samplesPerTuple);
            sampler.sampleConfsForTuples(this.testSet, samplesPerTuple);
            Log.log(" done in %s", samplingSw.stop().getTime(2));
            for (int[] conf : this.trainingSet.getAllConfs()) {
                this.energies.put(conf, null);
            }
            for (int[] conf : this.testSet.getAllConfs()) {
                this.energies.put(conf, null);
            }
            int numAdditionalSamples = this.energies.size() - numSamples;
            numSamples = this.energies.size();
            Progress progress2 = new Progress(numSamples);
            progress2.setReportMemory(true);
            Log.log("calculating energies for %d more samples...", numAdditionalSamples);
            for (Map.Entry<int[], Double> entry : this.energies.entrySet()) {
                int[] conf = entry.getKey();
                confEcalc.calcEnergyAsync(new RCTuple(conf), confTable, energy -> {
                    entry.setValue((Double)energy);
                    progress2.incrementProgress();
                });
            }
            confEcalc.tasks.waitForFinish();
            try (ThreadPoolTaskExecutor tasks = new ThreadPoolTaskExecutor();){
                tasks.start(confEcalc.ecalc.parallelism.numThreads);
                double[] oldTupleEnergies = null;
                if (this.trainingSystem != null) {
                    oldTupleEnergies = this.trainingSystem.tupleEnergies;
                }
                Log.logf("fitting %d confs to %d tuples ...", numSamples, this.tuplesIndex.size());
                Stopwatch trainingSw = new Stopwatch().start();
                this.trainingSystem = new LinearSystem(this.tuplesIndex, this.trainingSet, this.energies);
                this.trainingSystem.fit(fitter, oldTupleEnergies, tasks);
                Log.logf(" done in %s", trainingSw.stop().getTime(2));
                this.testSystem = new LinearSystem(this.tuplesIndex, this.testSet, this.energies);
                this.testSystem.setTupleEnergies(this.trainingSystem.tupleEnergies, this.trainingSystem.tupleEnergyOffset, tasks);
            }
            overfittingScore = this.calcOverfittingScore();
            Log.log("    RMS errors:  train %.4f    test %.4f    overfitting score: %.4f", this.trainingSystem.errors.rms, this.testSystem.errors.rms, overfittingScore);
            if (this.trainingSystem.errors.rms > maxTrainingRMSE || overfittingScore <= maxOverfittingScore) break;
            ++samplesPerTuple;
        }
        Log.log("\nLUTE fitting finished in %s:\n", sw.stop().getTime(2));
        Log.log("sampled at least %d confs per tuple for %d tuples", samplesPerTuple, this.tuplesIndex.size());
        Log.log("sampled %d training confs, %d test confs, %d confs total (%.2f%% overlap)", this.trainingSet.size(), this.testSet.size(), this.energies.size(), 100.0 * (double)(this.trainingSet.size() + this.testSet.size() - this.energies.size()) / (double)this.energies.size());
        Log.log("training errors: %s", this.trainingSystem.errors);
        Log.log("    test errors: %s", this.testSystem.errors);
        double[] bucketTops = new double[]{0.001, 0.01, 0.1, 0.2, 0.4, 1.0, 2.0, 4.0, 10.0};
        int[] counts = new int[bucketTops.length];
        Arrays.fill(counts, 0);
        for (Object error : (Progress)this.testSystem.errors.residual) {
            for (int i = 0; i < counts.length; ++i) {
                if (!(Math.abs((double)error) <= bucketTops[i])) continue;
                int n = i;
                counts[n] = counts[n] + 1;
            }
        }
        for (int i = 0; i < counts.length; ++i) {
            Log.log("               : %5.1f%% <= %s", 100.0 * (double)counts[i] / (double)this.testSystem.confs.size(), Double.toString(bucketTops[i]));
        }
        Log.log("total energy calculations: %d    overfitting score: %.4f <= %.4f", this.energies.size(), overfittingScore, maxOverfittingScore);
        Log.log("sample energies:    training set: [%.3f,%.3f]   test set: [%.3f,%.3f]", Arrays.stream(this.trainingSystem.confEnergies).min().orElse(Double.NaN), Arrays.stream(this.trainingSystem.confEnergies).max().orElse(Double.NaN), Arrays.stream(this.testSystem.confEnergies).min().orElse(Double.NaN), Arrays.stream(this.testSystem.confEnergies).max().orElse(Double.NaN));
        Log.log("", new Object[0]);
    }

    public double calcOverfittingScore() {
        double num = this.testSystem.errors.rms;
        double denom = this.trainingSystem.errors.rms;
        if (num == 0.0 && denom == 0.0) {
            return 1.0;
        }
        return num / denom;
    }

    public void reportConfSpaceSize() {
        BigInteger size = new RCs(this.confSpace).getNumConformations();
        double percent = 100.0 * (double)this.energies.size() / size.doubleValue();
        Log.log("conf space (no pruning was reported) has exactly %s conformations", Log.formatBig(size));
        Log.log("LUTE sampled %.1f percent of those confs", percent);
    }

    public void reportConfSpaceSize(PruningMatrix pmat) {
        BigInteger sizeLower = pmat.calcUnprunedConfsLowerBound();
        BigInteger sizeUpper = pmat.calcUnprunedConfsUpperBound();
        try {
            sizeLower = BigInteger.valueOf(Math.max(sizeLower.longValueExact(), (long)this.energies.size()));
        }
        catch (ArithmeticException arithmeticException) {
            // empty catch block
        }
        double percentLower = 100.0 * (double)this.energies.size() / sizeUpper.doubleValue();
        double percentUpper = 100.0 * (double)this.energies.size() / sizeLower.doubleValue();
        Log.log("conf space (after singles and pairs pruning) has somewhere between %s and %s conformations", Log.formatBig(sizeLower), Log.formatBig(sizeUpper));
        Log.log("LUTE sampled somewhere between %.1f%% and %.1f%% of those conformations", percentLower, percentUpper);
    }

    public LinearSystem getTrainingSystem() {
        return this.trainingSystem;
    }

    public LinearSystem getTestSystem() {
        return this.testSystem;
    }

    public void save(File file) {
        LUTEIO.write(new LUTEState(this.trainingSystem), file);
    }

    public static class LinearSystem {
        public final TuplesIndex tuples;
        public final List<int[]> confs;
        public final double[] confEnergies;
        public double[] tupleEnergies;
        public double tupleEnergyOffset;
        public Errors errors = null;

        public LinearSystem(TuplesIndex tuples, ConfSampler.Samples samples, Map<int[], Double> confEnergies) {
            this.tuples = tuples;
            this.confs = new ArrayList<int[]>(samples.getAllConfs());
            this.confEnergies = new double[this.confs.size()];
            for (int c = 0; c < this.confs.size(); ++c) {
                this.confEnergies[c] = confEnergies.get(this.confs.get(c));
            }
        }

        private void forEachTupleIn(int c, Consumer<Integer> callback) {
            boolean throwIfMissingSingle = false;
            boolean throwIfMissingPair = true;
            this.tuples.forEachIn(this.confs.get(c), false, true, callback);
        }

        public void fit(Fitter fitter, double[] oldTupleEnergies, TaskExecutor tasks) {
            BInfo binfo = new BInfo();
            if (fitter.normalize) {
                int c;
                double min = Double.POSITIVE_INFINITY;
                double max = Double.NEGATIVE_INFINITY;
                for (c = 0; c < this.confs.size(); ++c) {
                    double e = this.confEnergies[c];
                    min = Math.min(min, e);
                    max = Math.max(max, e);
                }
                binfo.offset = min;
                binfo.scale = max - min;
                for (c = 0; c < this.confs.size(); ++c) {
                    binfo.b[c] = (binfo.b[c] - binfo.offset) / binfo.scale;
                }
            }
            double[] x0 = new double[this.tuples.size()];
            if (oldTupleEnergies != null) {
                System.arraycopy(oldTupleEnergies, 0, x0, 0, oldTupleEnergies.length);
                Arrays.fill(x0, oldTupleEnergies.length, x0.length, 0.0);
                if (fitter.normalize) {
                    int t = 0;
                    while (t < oldTupleEnergies.length) {
                        int n = t++;
                        x0[n] = x0[n] / binfo.scale;
                    }
                }
            } else {
                Arrays.fill(x0, 0.0);
            }
            double[] x = fitter.fit(this, binfo, x0, tasks);
            this.calcTupleEnergies(x, binfo, tasks);
        }

        private double[] multA(double[] x) {
            double[] out = new double[this.confs.size()];
            for (int c = 0; c < this.confs.size(); ++c) {
                int fc = c;
                this.forEachTupleIn(c, t -> {
                    int n = fc;
                    out[n] = out[n] + x[t];
                });
            }
            return out;
        }

        private double[] multAt(double[] x) {
            double[] out = new double[this.tuples.size()];
            for (int c = 0; c < this.confs.size(); ++c) {
                double xc = x[c];
                this.forEachTupleIn(c, t -> {
                    int n = t;
                    out[n] = out[n] + xc;
                });
            }
            return out;
        }

        private double[] parallelMultA(double[] x, TaskExecutor tasks) {
            int numConfs = this.confs.size();
            int numThreads = tasks.getParallelism();
            int partitionSize = numConfs / numThreads;
            double[] out = new double[numConfs];
            for (int i = 0; i < numThreads; ++i) {
                int startC = i * partitionSize;
                int stopC = i == numThreads - 1 ? numConfs : (i + 1) * partitionSize;
                tasks.submit(() -> {
                    int c = startC;
                    while (c < stopC) {
                        int fc = c++;
                        this.forEachTupleIn(fc, t -> {
                            int n = fc;
                            out[n] = out[n] + x[t];
                        });
                    }
                    return null;
                }, ignored -> {});
            }
            tasks.waitForFinish();
            return out;
        }

        private double[] parallelMultAt(double[] x, TaskExecutor tasks) {
            int numConfs = this.confs.size();
            int numThreads = tasks.getParallelism();
            int partitionSize = numConfs / numThreads;
            int numTuples = this.tuples.size();
            double[] out = new double[numTuples];
            for (int i = 0; i < numThreads; ++i) {
                int startC = i * partitionSize;
                int stopC = i == numThreads - 1 ? numConfs : (i + 1) * partitionSize;
                tasks.submit(() -> {
                    double[] threadOut = new double[numTuples];
                    for (int c = startC; c < stopC; ++c) {
                        double xc = x[c];
                        this.forEachTupleIn(c, t -> {
                            int n = t;
                            threadOut[n] = threadOut[n] + xc;
                        });
                    }
                    return threadOut;
                }, threadOut -> {
                    for (int t = 0; t < numTuples; ++t) {
                        int n = t;
                        out[n] = out[n] + threadOut[t];
                    }
                });
            }
            tasks.waitForFinish();
            return out;
        }

        private void calcTupleEnergies(double[] x, BInfo binfo, TaskExecutor tasks) {
            double[] energies = new double[this.tuples.size()];
            for (int t = 0; t < this.tuples.size(); ++t) {
                energies[t] = x[t] * binfo.scale;
            }
            this.setTupleEnergies(energies, binfo.offset, tasks);
        }

        public void setTupleEnergies(double[] energies, double offset, TaskExecutor tasks) {
            this.tupleEnergies = energies;
            this.tupleEnergyOffset = offset;
            double[] residual = this.parallelMultA(this.tupleEnergies, tasks);
            for (int c = 0; c < this.confs.size(); ++c) {
                residual[c] = residual[c] + this.tupleEnergyOffset - this.confEnergies[c];
            }
            this.errors = new Errors(residual);
        }

        private class BInfo {
            double[] b;
            double offset;
            double scale;

            private BInfo() {
                this.b = (double[])LinearSystem.this.confEnergies.clone();
                this.offset = 0.0;
                this.scale = 1.0;
            }
        }
    }

    public static enum Fitter {
        OLSCG(true){

            @Override
            public double[] fit(final LinearSystem system, LinearSystem.BInfo binfo, double[] x0, final TaskExecutor tasks) {
                RealLinearOperator AtA = new RealLinearOperator(this){

                    public int getRowDimension() {
                        return system.tuples.size();
                    }

                    public int getColumnDimension() {
                        return system.tuples.size();
                    }

                    public RealVector operate(RealVector vx) throws DimensionMismatchException {
                        double[] x = ((ArrayRealVector)vx).getDataRef();
                        double[] AtAx = system.parallelMultAt(system.parallelMultA(x, tasks), tasks);
                        return new ArrayRealVector(AtAx, false);
                    }
                };
                ArrayRealVector Atb = new ArrayRealVector(system.parallelMultAt(binfo.b, tasks), false);
                ArrayRealVector rx0 = new ArrayRealVector(x0, false);
                ConjugateGradient cg = new ConjugateGradient(100000, 1.0E-6, false);
                return ((ArrayRealVector)cg.solve(AtA, (RealVector)Atb, (RealVector)rx0)).getDataRef();
            }
        }
        ,
        LASSO(false){

            @Override
            public double[] fit(LinearSystem system, LinearSystem.BInfo binfo, double[] x0, TaskExecutor tasks) {
                SparseDataset data = new SparseDataset(system.tuples.size());
                for (int c = 0; c < system.confs.size(); ++c) {
                    int fc = c;
                    system.forEachTupleIn(c, t -> data.set(fc, t.intValue(), 1.0));
                }
                SparseMatrix A = data.toSparseMatrix();
                double lambda = 0.1;
                double tolerance = 0.1;
                int maxIterations = 200;
                LASSO lasso = new LASSO((Matrix)A, binfo.b, lambda, tolerance, maxIterations);
                binfo.offset += lasso.intercept();
                return lasso.coefficients();
            }
        };

        public final boolean normalize;

        private Fitter(boolean normalize) {
            this.normalize = normalize;
        }

        public abstract double[] fit(LinearSystem var1, LinearSystem.BInfo var2, double[] var3, TaskExecutor var4);
    }

    public static class Errors {
        public final double[] residual;
        public final double min;
        public final double max;
        public final double avg;
        public final double rms;

        public Errors(double[] residual) {
            this.residual = residual;
            double sum = 0.0;
            double sumsq = 0.0;
            double min = Double.POSITIVE_INFINITY;
            double max = Double.NEGATIVE_INFINITY;
            int n = residual.length;
            for (int row = 0; row < n; ++row) {
                double val = Math.abs(residual[row]);
                sum += val;
                sumsq += val * val;
                min = Math.min(min, val);
                max = Math.max(max, val);
            }
            this.min = min;
            this.max = max;
            this.avg = sum / (double)n;
            this.rms = Math.sqrt(sumsq / (double)n);
        }

        public String toString() {
            return String.format("range [%8.4f,%8.4f]   avg %8.4f   rms %8.4f", this.min, this.max, this.avg, this.rms);
        }
    }
}

