/*
 * Decompiled with CFR 0.152.
 */
package edu.duke.cs.osprey.dof.deeper;

import edu.duke.cs.osprey.confspace.Strand;
import edu.duke.cs.osprey.dof.ResidueTypeDOF;
import edu.duke.cs.osprey.dof.deeper.PertSet;
import edu.duke.cs.osprey.dof.deeper.RamachandranChecker;
import edu.duke.cs.osprey.dof.deeper.perts.Perturbation;
import edu.duke.cs.osprey.multistatekstar.ResidueTermini;
import edu.duke.cs.osprey.restypes.HardCodedResidueInfo;
import edu.duke.cs.osprey.restypes.ResidueTemplateLibrary;
import edu.duke.cs.osprey.structure.Atom;
import edu.duke.cs.osprey.structure.PDBIO;
import edu.duke.cs.osprey.structure.Residue;
import edu.duke.cs.osprey.tools.FileTools;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.TreeSet;

public class PerturbationSelector {
    String startingPertFile;
    boolean onlyStarting;
    double maxShearParam;
    double maxBackrubParam;
    boolean selectLCAs;
    boolean doRamaCheck;
    ArrayList<String> flexibleRes;
    Strand strand;
    ArrayList<Perturbation> perts;
    PertSet ps;
    ArrayList<TreeSet<String>> resMovedByPert;
    ArrayList<double[]> defaultLCAIntervals;
    ArrayList<double[]> defaultShearIntervals;
    ArrayList<double[]> defaultBackrubIntervals;
    int failingPertIndex = -1;

    public PerturbationSelector(String startingPertFile, boolean onlyStarting, double maxShearParam, double maxBackrubParam, boolean selectLCAs, ArrayList<String> flexibleRes, String PDBFile, ResidueTermini termini, boolean doRamaCheck, ResidueTemplateLibrary templateLib) {
        this.startingPertFile = startingPertFile;
        this.onlyStarting = onlyStarting;
        this.maxShearParam = maxShearParam;
        this.maxBackrubParam = maxBackrubParam;
        this.selectLCAs = selectLCAs;
        this.flexibleRes = flexibleRes;
        this.strand = new Strand.Builder(PDBIO.read(FileTools.readFile(PDBFile))).setTemplateLibrary(templateLib).setResidues(termini).build();
    }

    public PertSet selectPerturbations(ResidueTermini termini) {
        this.ps = new PertSet();
        if (!this.startingPertFile.equalsIgnoreCase("none") && !this.ps.loadPertFile(this.startingPertFile, false, termini)) {
            throw new RuntimeException("ERROR: Can't find starting perturbation file " + this.startingPertFile);
        }
        if (!this.onlyStarting) {
            this.autogeneratePerturbations();
        }
        this.calcResMovedByPert();
        this.mutateFlexResToGly();
        this.perts = this.ps.makePerturbations(this.strand.mol);
        this.ps.pertStates = new ArrayList();
        for (int pos = 0; pos < this.flexibleRes.size(); ++pos) {
            ArrayList<ArrayList<int[]>> resPertStates = new ArrayList<ArrayList<int[]>>();
            ArrayList<Integer> pertIndices = this.pertIndicesForPos(pos);
            ArrayList<int[]> state = new ArrayList<int[]>();
            for (int pertInd : pertIndices) {
                state.add(new int[]{pertInd, 0});
            }
            while (state != null) {
                if (this.isStateReasonable(state, pos)) {
                    resPertStates.add(state);
                }
                state = this.nextPossibleState(state);
            }
            this.ps.pertStates.add(resPertStates);
        }
        this.removeIncompatiblePertStates(this.ps.pertStates);
        return this.ps;
    }

    void mutateFlexResToGly() {
        for (String resNum : this.flexibleRes) {
            Residue res = this.strand.mol.getResByPDBResNumber(resNum);
            if (!HardCodedResidueInfo.hasAminoAcidBB(res) || res.fullName.startsWith("FOL")) continue;
            new ResidueTypeDOF(this.strand.templateLib, res).mutateTo("GLY");
        }
    }

    private void calcResMovedByPert() {
        this.resMovedByPert = new ArrayList();
        for (int pertIndex = 0; pertIndex < this.ps.pertTypes.size(); ++pertIndex) {
            TreeSet<String> curPertRes = new TreeSet<String>();
            curPertRes.addAll((Collection)this.ps.resNums.get(pertIndex));
            block1: for (int pertIndex2 = pertIndex + 1; pertIndex2 < this.ps.pertTypes.size(); ++pertIndex2) {
                ArrayList<String> pert2Res = this.ps.resNums.get(pertIndex2);
                for (String resNum : pert2Res) {
                    if (!curPertRes.contains(resNum)) continue;
                    curPertRes.addAll(pert2Res);
                    continue block1;
                }
            }
            this.resMovedByPert.add(curPertRes);
        }
    }

    ArrayList<Integer> pertIndicesForPos(int pos) {
        ArrayList<Integer> ans = new ArrayList<Integer>();
        String resNum = this.flexibleRes.get(pos);
        for (int pertIndex = 0; pertIndex < this.perts.size(); ++pertIndex) {
            if (!this.resMovedByPert.get(pertIndex).contains(resNum)) continue;
            ans.add(pertIndex);
        }
        return ans;
    }

    private void removeIncompatiblePertStates(ArrayList<ArrayList<ArrayList<int[]>>> pertStates) {
        int numPos = pertStates.size();
        boolean[][] prunedStates = this.getIncompatiblePertStates(pertStates);
        for (int pos = 0; pos < numPos; ++pos) {
            int resNumStates = pertStates.get(pos).size();
            for (int state = resNumStates - 1; state >= 0; --state) {
                if (!prunedStates[pos][state]) continue;
                pertStates.get(pos).remove(state);
            }
        }
    }

    private boolean[][] getIncompatiblePertStates(ArrayList<ArrayList<ArrayList<int[]>>> pertStates) {
        int numPos = pertStates.size();
        boolean[][] prunedStates = new boolean[numPos][];
        for (int pos = 0; pos < numPos; ++pos) {
            prunedStates[pos] = new boolean[pertStates.get(pos).size()];
        }
        boolean done = false;
        while (!done) {
            done = true;
            for (int curPos = 0; curPos < numPos; ++curPos) {
                int resNumStates = pertStates.get(curPos).size();
                for (int curState = 0; curState < resNumStates; ++curState) {
                    if (prunedStates[curPos][curState]) continue;
                    ArrayList<int[]> state1 = pertStates.get(curPos).get(curState);
                    for (int altPos = 0; altPos < numPos; ++altPos) {
                        if (altPos == curPos) continue;
                        boolean prune = true;
                        int altResNumStates = pertStates.get(altPos).size();
                        for (int altState = 0; altState < altResNumStates; ++altState) {
                            ArrayList<int[]> state2;
                            if (prunedStates[altPos][altState] || this.arePertStatesIncompatible(state1, state2 = pertStates.get(altPos).get(altState))) continue;
                            prune = false;
                            break;
                        }
                        prunedStates[curPos][curState] = prunedStates[curPos][curState] || prune;
                    }
                    if (!prunedStates[curPos][curState]) continue;
                    done = false;
                }
            }
        }
        return prunedStates;
    }

    public boolean arePertStatesIncompatible(ArrayList<int[]> state1, ArrayList<int[]> state2) {
        for (int[] p1 : state1) {
            for (int[] p2 : state2) {
                if (p1[0] != p2[0] || p1[1] == p2[1]) continue;
                return true;
            }
        }
        return false;
    }

    private ArrayList<int[]> nextPossibleState(ArrayList<int[]> state) {
        if (state.isEmpty()) {
            return null;
        }
        ArrayList<int[]> ans = new ArrayList<int[]>();
        for (int[] p : state) {
            ans.add((int[])p.clone());
        }
        if (this.failingPertIndex == -1) {
            return this.incrementPertState(ans, ans.size() - 1);
        }
        return this.incrementPertState(ans, this.failingPertIndex);
    }

    private ArrayList<int[]> incrementPertState(ArrayList<int[]> state, int pertIndex) {
        int maxInterval;
        int curInterval = state.get(pertIndex)[1];
        if (curInterval < (maxInterval = this.ps.pertIntervals.get(state.get(pertIndex)[0]).size() - 1)) {
            int[] nArray = state.get(pertIndex);
            nArray[1] = nArray[1] + 1;
            return state;
        }
        if (pertIndex == 0) {
            return null;
        }
        for (int index2 = pertIndex; index2 < state.size(); ++index2) {
            state.get((int)pertIndex)[1] = 0;
        }
        return this.incrementPertState(state, pertIndex - 1);
    }

    private void initDefaultIntervals() {
        this.defaultLCAIntervals = new ArrayList();
        for (double param = 0.0; param < 16.0; param += 1.0) {
            this.defaultLCAIntervals.add(new double[]{param, param});
        }
        this.defaultShearIntervals = new ArrayList();
        this.defaultShearIntervals.add(new double[]{-this.maxShearParam, this.maxShearParam});
        this.defaultBackrubIntervals = new ArrayList();
        this.defaultBackrubIntervals.add(new double[]{-this.maxBackrubParam, this.maxBackrubParam});
    }

    private void autogeneratePerturbations() {
        this.initDefaultIntervals();
        ArrayList<ArrayList<String>> consecTriplesBR = this.consecutiveFlexibleRes(3, "BACKRUB");
        ArrayList<ArrayList<String>> consecTriplesLCA = this.consecutiveFlexibleRes(3, "LOOP CLOSURE ADJUSTMENT");
        ArrayList<ArrayList<String>> consecQuads = this.consecutiveFlexibleRes(4, "SHEAR");
        if (this.selectLCAs) {
            for (ArrayList<String> lcaRes : consecTriplesLCA) {
                this.ps.pertTypes.add("LOOP CLOSURE ADJUSTMENT");
                this.ps.resNums.add(lcaRes);
                this.ps.pertIntervals.add(this.defaultLCAIntervals);
                this.ps.additionalInfo.add(null);
            }
        }
        for (ArrayList<String> shearRes : consecQuads) {
            this.ps.pertTypes.add("SHEAR");
            this.ps.resNums.add(shearRes);
            this.ps.pertIntervals.add(this.defaultShearIntervals);
            this.ps.additionalInfo.add(null);
        }
        for (ArrayList<String> brRes : consecTriplesBR) {
            this.ps.pertTypes.add("BACKRUB");
            this.ps.resNums.add(brRes);
            this.ps.pertIntervals.add(this.defaultBackrubIntervals);
            this.ps.additionalInfo.add(null);
        }
    }

    ArrayList<ArrayList<String>> consecutiveFlexibleRes(int resCount, String pertType) {
        ArrayList<Integer> startPos = new ArrayList<Integer>();
        for (int pos = 0; pos < this.flexibleRes.size() - resCount + 1; ++pos) {
            boolean correctBonding = true;
            for (int offset = 1; offset < resCount; ++offset) {
                Residue res1 = this.strand.mol.getResByPDBResNumber(this.flexibleRes.get(pos + offset - 1));
                Residue res2 = this.strand.mol.getResByPDBResNumber(this.flexibleRes.get(pos + offset));
                int Cindex = res1.getAtomIndexByName("C");
                int Nindex = res2.getAtomIndexByName("N");
                if (Cindex == -1 || Nindex == -1) {
                    correctBonding = false;
                    break;
                }
                Atom C = res1.atoms.get(Cindex);
                Atom N = res2.atoms.get(Nindex);
                if (C.bonds.contains(N)) continue;
                correctBonding = false;
                break;
            }
            if (!correctBonding) continue;
            ArrayList<Residue> candidateRes = new ArrayList<Residue>();
            for (int offset = 0; offset < resCount; ++offset) {
                candidateRes.add(this.strand.mol.getResByPDBResNumber(this.flexibleRes.get(pos + offset)));
            }
            if (!this.secondaryStructureCorrect(candidateRes, pertType)) continue;
            startPos.add(pos);
        }
        ArrayList<ArrayList<String>> ans = new ArrayList<ArrayList<String>>();
        for (int shift = 0; shift < resCount; ++shift) {
            Iterator iterator2 = startPos.iterator();
            while (iterator2.hasNext()) {
                int pos = (Integer)iterator2.next();
                if (pos % resCount != shift) continue;
                ArrayList<String> pertResNums = new ArrayList<String>();
                for (int offset = 0; offset < resCount; ++offset) {
                    pertResNums.add(this.flexibleRes.get(pos + offset));
                }
                ans.add(pertResNums);
            }
        }
        return ans;
    }

    private boolean secondaryStructureCorrect(ArrayList<Residue> resList, String pertType) {
        boolean good = true;
        if (pertType.equalsIgnoreCase("SHEAR")) {
            int count = 0;
            for (Residue res : resList) {
                if (res.secondaryStruct != Residue.SecondaryStructure.HELIX) continue;
                ++count;
            }
            if (count < 3) {
                good = false;
            }
        } else if (pertType.equalsIgnoreCase("BACKRUB")) {
            for (Residue res : resList) {
                if (res.secondaryStruct != Residue.SecondaryStructure.HELIX) continue;
                good = false;
            }
        } else if (pertType.equalsIgnoreCase("LOOP CLOSURE ADJUSTMENT")) {
            int count = 0;
            for (Residue res : resList) {
                if (res.secondaryStruct != Residue.SecondaryStructure.LOOP) continue;
                ++count;
            }
            if (count < 2) {
                good = false;
            }
        } else {
            throw new RuntimeException("ERROR: Unrecognized perturbation: " + pertType);
        }
        return good;
    }

    private boolean isStateReasonable(ArrayList<int[]> pertState, int pos) {
        boolean allUnperturbed = true;
        for (int[] p : pertState) {
            if (p[1] == 0) continue;
            allUnperturbed = false;
        }
        if (allUnperturbed) {
            this.failingPertIndex = -1;
            return true;
        }
        double[][] backupCoords = this.backupFlexResCoords();
        for (int pertInd = 0; pertInd < pertState.size(); ++pertInd) {
            int[] p = pertState.get(pertInd);
            double[] interval = this.ps.pertIntervals.get(p[0]).get(p[1]);
            double midVal = 0.5 * (interval[0] + interval[1]);
            if (midVal == 0.0 || this.perts.get(p[0]).doPerturbationMotion(midVal)) continue;
            this.failingPertIndex = pertInd;
            return false;
        }
        this.failingPertIndex = -1;
        boolean ok = this.ramaCheck(this.strand.mol.getResByPDBResNumber(this.flexibleRes.get(pos)));
        this.restoreFlexResCoords(backupCoords);
        return ok;
    }

    private double[][] backupFlexResCoords() {
        int numPos = this.flexibleRes.size();
        double[][] backup = new double[numPos][];
        for (int flexRes = 0; flexRes < numPos; ++flexRes) {
            Residue res = this.strand.mol.getResByPDBResNumber(this.flexibleRes.get(flexRes));
            backup[flexRes] = (double[])res.coords.clone();
        }
        return backup;
    }

    private void restoreFlexResCoords(double[][] backup) {
        int numPos = this.flexibleRes.size();
        for (int flexRes = 0; flexRes < numPos; ++flexRes) {
            Residue res = this.strand.mol.getResByPDBResNumber(this.flexibleRes.get(flexRes));
            res.coords = backup[flexRes];
        }
    }

    boolean ramaCheck(Residue res) {
        if (!this.doRamaCheck) {
            return true;
        }
        boolean[] allowed = RamachandranChecker.getInstance().checkByAAType(res);
        return allowed[0];
    }
}

