/*
 * Decompiled with CFR 0.152.
 */
package edu.duke.cs.osprey.ematrix;

import edu.duke.cs.osprey.confspace.RCTuple;
import edu.duke.cs.osprey.confspace.SimpleConfSpace;
import edu.duke.cs.osprey.confspace.TupE;
import edu.duke.cs.osprey.ematrix.EnergyMatrix;
import edu.duke.cs.osprey.ematrix.ProxyEnergyMatrix;
import edu.duke.cs.osprey.energy.ConfEnergyCalculator;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

public class UpdatingEnergyMatrix
extends ProxyEnergyMatrix {
    private static final boolean debug = false;
    private TupleTrie corrections;
    private int numPos;
    public final ConfEnergyCalculator sourceECalc;

    public UpdatingEnergyMatrix(SimpleConfSpace confSpace, EnergyMatrix target, ConfEnergyCalculator confECalc) {
        super(confSpace, target);
        this.corrections = new TupleTrie(confSpace.positions);
        this.numPos = confSpace.getNumPos();
        this.sourceECalc = confECalc;
    }

    public UpdatingEnergyMatrix(SimpleConfSpace confSpace, EnergyMatrix target) {
        super(confSpace, target);
        this.numPos = confSpace.getNumPos();
        this.sourceECalc = null;
        this.corrections = new TupleTrie(confSpace.positions);
    }

    @Override
    public boolean hasHigherOrderTerms() {
        return this.corrections.size() > 0;
    }

    @Override
    public boolean hasHigherOrderTermFor(RCTuple query) {
        return this.corrections.contains(query);
    }

    public String formatCorrections(List<TupE> corrections) {
        Object out = "";
        for (TupE correction : corrections) {
            out = (String)out + correction.tup.stringListing() + ":" + correction.E + "\n";
        }
        return out;
    }

    @Override
    public double getInternalEnergy(RCTuple tup) {
        int RCNum;
        int posNum;
        int indexInTuple;
        boolean useHigherOrderTerms = this.hasHigherOrderTerms();
        ArrayList<Integer> tuppos = tup.pos;
        ArrayList<Integer> tupRCs = tup.RCs;
        int numPosInTuple = tup.pos.size();
        double energy = 0.0;
        for (indexInTuple = 0; indexInTuple < numPosInTuple; ++indexInTuple) {
            posNum = tuppos.get(indexInTuple);
            RCNum = tupRCs.get(indexInTuple);
            energy += this.getOneBody(posNum, RCNum).doubleValue();
        }
        for (indexInTuple = 0; indexInTuple < numPosInTuple; ++indexInTuple) {
            posNum = tuppos.get(indexInTuple);
            RCNum = tupRCs.get(indexInTuple);
            for (int index2 = 0; index2 < indexInTuple; ++index2) {
                int pos2 = tuppos.get(index2);
                int rc2 = tupRCs.get(index2);
                energy += this.getPairwise(posNum, RCNum, pos2, rc2).doubleValue();
            }
        }
        if (useHigherOrderTerms) {
            energy += this.internalEHigherOrder(tup);
        }
        return energy;
    }

    double internalEHigherOrder(RCTuple tup) {
        double E = 0.0;
        List<TupE> confCorrections = this.corrections.getCorrections(tup);
        if (confCorrections.size() > 0) {
            double corr = this.processCorrections(confCorrections);
            E += corr;
        }
        return E;
    }

    private double processCorrections(List<TupE> confCorrections) {
        Collections.sort(confCorrections, (a, b) -> -Double.compare(a.E, b.E));
        double sum = 0.0;
        HashSet<Integer> usedPositions = new HashSet<Integer>();
        ArrayList<TupE> usedCorrections = new ArrayList<TupE>();
        int numApplied = 0;
        for (TupE correction : confCorrections) {
            if (usedPositions.size() >= this.numPos) break;
            ArrayList<Integer> positions = correction.tup.pos;
            boolean noIntersections = true;
            Iterator iterator2 = positions.iterator();
            while (iterator2.hasNext()) {
                int position = (Integer)iterator2.next();
                if (!usedPositions.contains(position)) continue;
                noIntersections = false;
                break;
            }
            if (!noIntersections) continue;
            usedPositions.addAll(correction.tup.pos);
            usedCorrections.add(correction);
            ++numApplied;
            sum += correction.E;
        }
        return sum;
    }

    @Override
    public void setHigherOrder(RCTuple tup, Double val) {
        if (tup.size() < 3) {
            System.err.println("Should not be trying to submit correction of lower-order term.");
            return;
        }
        RCTuple orderedTup = tup.sorted();
        this.corrections.insert(new TupE(orderedTup, val));
    }

    public static class TupleTrie {
        public static final int WILDCARD_RC = -123;
        TupleTrieNode root;
        List<SimpleConfSpace.Position> positions;
        private int numCorrections;

        public TupleTrie(List<SimpleConfSpace.Position> positions) {
            this.positions = positions;
            this.root = this.createTrie(positions);
        }

        private TupleTrieNode createTrie(List<SimpleConfSpace.Position> positions) {
            this.root = new TupleTrieNode(positions, -1);
            return this.root;
        }

        public void insert(TupE correction) {
            this.root.insert(correction, 0);
            ++this.numCorrections;
        }

        private void checkRCTuple(RCTuple tup) {
            int lastIndex = 0;
            for (int i = 0; i < tup.size(); ++i) {
                int index = this.positions.indexOf(tup.pos.get(i));
                if (index > -1 && lastIndex > index) {
                    System.err.println("Tuple and confspace are not ordered the same way.");
                }
                lastIndex = index;
            }
        }

        public List<TupE> getCorrections(RCTuple query) {
            ArrayList<TupE> corrections = new ArrayList<TupE>();
            this.root.populateCorrections(query.sorted(), corrections);
            return corrections;
        }

        public boolean contains(RCTuple query) {
            return this.root.contains(query.sorted(), 0);
        }

        public int size() {
            return this.numCorrections;
        }

        private class TupleTrieNode {
            int rc = -123;
            int positionIndex = -1;
            int position = -1;
            List<SimpleConfSpace.Position> positions;
            List<TupE> corrections = new ArrayList<TupE>();
            Map<Integer, TupleTrieNode> children = new HashMap<Integer, TupleTrieNode>();

            private TupleTrieNode(List<SimpleConfSpace.Position> positions, int positionIndex) {
                this.positions = positions;
                this.positionIndex = positionIndex;
                if (positionIndex >= 0) {
                    this.position = positions.get((int)positionIndex).index;
                }
                if (positionIndex + 1 < positions.size()) {
                    this.children.put(-123, new TupleTrieNode(positions, positionIndex + 1));
                }
            }

            public boolean contains(RCTuple query, int tupleIndex) {
                this.debugPrint("Currently at " + String.valueOf(this));
                if (tupleIndex >= query.size()) {
                    return true;
                }
                int currentRC = query.RCs.get(tupleIndex);
                int currentPos = query.pos.get(tupleIndex);
                int indexedPos = -1;
                int indexedRC = -123;
                if (tupleIndex > 0) {
                    indexedRC = query.RCs.get(tupleIndex - 1);
                    indexedPos = query.pos.get(tupleIndex - 1);
                }
                if (tupleIndex + 1 == this.positions.size()) {
                    return true;
                }
                int nextIndex = tupleIndex + 1;
                if (this.position + 1 < currentPos) {
                    if (this.children.get(-123) == null) {
                        return false;
                    }
                    return this.children.get(-123).contains(query, tupleIndex);
                }
                if (!this.children.containsKey(currentRC)) {
                    return false;
                }
                if (this.children.get(currentRC) == null) {
                    return false;
                }
                return this.children.get(currentRC).contains(query, nextIndex);
            }

            public String toString() {
                Object rcString = "*";
                if (this.rc > -1) {
                    rcString = "" + this.rc;
                }
                return this.position + ":" + (String)rcString;
            }

            private void debugPrint(String s) {
            }

            public void insert(TupE correction, int tupIndex) {
                for (TupleTrieNode tupleTrieNode : this.children.values()) {
                    this.debugPrint(String.valueOf(this) + "->" + String.valueOf(tupleTrieNode));
                }
                for (TupE tupE : this.corrections) {
                    this.debugPrint(tupE.tup.stringListing() + ":" + tupE.E);
                }
                RCTuple tup = correction.tup;
                if (tupIndex >= tup.size()) {
                    this.debugPrint("Reached end of tuple, inserting correction at " + String.valueOf(this) + ".");
                    this.corrections.add(correction);
                    for (TupE corr : this.corrections) {
                        this.debugPrint(corr.tup.stringListing() + ":" + corr.E);
                    }
                    return;
                }
                int n = -1;
                int nodeIndex = this.position;
                int currentRC = -123;
                if (tupIndex > 0) {
                    int n2 = tup.pos.get(tupIndex - 1);
                    currentRC = tup.pos.get(tupIndex - 1);
                }
                int childIndex = tup.pos.get(tupIndex);
                int childRC = tup.RCs.get(tupIndex);
                if (nodeIndex + 1 != childIndex) {
                    this.debugPrint(nodeIndex + 1 + "!=" + childIndex + ", continuing...");
                    this.children.get(-123).insert(correction, tupIndex);
                } else {
                    if (!this.children.containsKey(childRC)) {
                        TupleTrieNode newChild = new TupleTrieNode(this.positions, this.positionIndex + 1);
                        newChild.rc = childRC;
                        this.children.put(childRC, newChild);
                        this.debugPrint("Added child " + String.valueOf(newChild) + " to " + String.valueOf(this));
                    }
                    this.children.get(childRC).insert(correction, tupIndex + 1);
                }
            }

            public void populateCorrections(RCTuple query, List<TupE> output) {
                this.debugPrint("Matching corrections for " + query.stringListing());
                this.populateCorrections(query, output, 0);
            }

            private void populateCorrections(RCTuple query, List<TupE> output, int tupleIndex) {
                this.debugPrint("Currently at " + String.valueOf(this));
                if (this.corrections.size() > 0) {
                    output.addAll(this.corrections);
                    this.debugPrint("Adding corrections from " + String.valueOf(this));
                }
                if (tupleIndex >= query.size()) {
                    return;
                }
                int currentRC = query.RCs.get(tupleIndex);
                int currentPos = query.pos.get(tupleIndex);
                int indexedPos = -1;
                int indexedRC = -123;
                if (tupleIndex > 0) {
                    indexedRC = query.RCs.get(tupleIndex - 1);
                    indexedPos = query.pos.get(tupleIndex - 1);
                }
                if (indexedPos > this.position || indexedPos == this.position && indexedRC != this.rc && this.rc != -123) {
                    System.err.println("Error in trie traversal.");
                }
                if (tupleIndex + 1 > this.positions.size()) {
                    return;
                }
                int nextIndex = tupleIndex + 1;
                if (this.position + 1 < currentPos) {
                    nextIndex = tupleIndex;
                }
                if (this.position + 1 == currentPos && this.children.containsKey(currentRC)) {
                    this.children.get(currentRC).populateCorrections(query, output, nextIndex);
                }
                if (!this.children.containsKey(-123)) {
                    this.children.put(-123, new TupleTrieNode(this.positions, this.positionIndex + 1));
                }
                this.children.get(-123).populateCorrections(query, output, nextIndex);
            }
        }
    }
}

