/*
 * Decompiled with CFR 0.152.
 */
package edu.duke.cs.osprey.control;

import edu.duke.cs.osprey.confspace.SearchProblem;
import edu.duke.cs.osprey.control.ConfigFileParser;
import edu.duke.cs.osprey.control.ParamSet;
import edu.duke.cs.osprey.energy.forcefield.BigForcefieldEnergy;
import edu.duke.cs.osprey.energy.forcefield.ForcefieldParams;
import edu.duke.cs.osprey.gmec.MinimizingConfEnergyCalculator;
import edu.duke.cs.osprey.multistatekstar.InputValidation;
import edu.duke.cs.osprey.multistatekstar.KStarScore;
import edu.duke.cs.osprey.multistatekstar.LMB;
import edu.duke.cs.osprey.multistatekstar.MSConfigFileParser;
import edu.duke.cs.osprey.multistatekstar.MSKStarFactory;
import edu.duke.cs.osprey.multistatekstar.MSKStarTree;
import edu.duke.cs.osprey.multistatekstar.MSSearchProblem;
import edu.duke.cs.osprey.multistatekstar.MSSearchSettings;
import edu.duke.cs.osprey.parallelism.Parallelism;
import edu.duke.cs.osprey.pruning.PruningControl;
import edu.duke.cs.osprey.tools.ObjectIO;
import edu.duke.cs.osprey.tools.Stopwatch;
import edu.duke.cs.osprey.tools.StringParsing;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.StringTokenizer;

public class MSKStarDoer {
    MSKStarTree tree;
    int numSeqsWanted;
    int numMaxMut;
    LMB objFcn;
    LMB[] msConstr;
    LMB[][] sConstr;
    int numStates;
    int numMutRes;
    ArrayList<String[]> wtSeqs;
    ArrayList<ArrayList<ArrayList<ArrayList<String>>>> AATypeOptions;
    ArrayList<ArrayList<ArrayList<Integer>>> state2MutableResNums;
    String[][] stateArgs;
    ParamSet msParams;
    MSConfigFileParser[] cfps;
    SearchProblem[][] searchDisc;
    SearchProblem[][] searchCont;
    MinimizingConfEnergyCalculator[][] ecalcsCont;
    MinimizingConfEnergyCalculator[][] ecalcsDisc;

    public MSKStarDoer(ConfigFileParser cfp) {
        int state;
        BigForcefieldEnergy.ParamInfo.printWarnings = false;
        ForcefieldParams.printWarnings = false;
        System.out.println();
        System.out.println("Performing multistate K*");
        System.out.println();
        ParamSet msParams = cfp.params;
        msParams.setVerbosity(false);
        this.numSeqsWanted = msParams.getInt("NUMSEQS");
        this.numStates = msParams.getInt("NUMSTATES");
        this.numMutRes = msParams.getInt("NUMMUTRES");
        this.numMaxMut = msParams.getInt("NUMMAXMUT");
        int numConstr = msParams.getInt("NUMSTATECONSTR");
        this.stateArgs = new String[this.numStates][];
        this.objFcn = new LMB(msParams.getValue("OBJFCN"), this.numStates);
        this.msConstr = new LMB[numConstr];
        for (int constr = 0; constr < numConstr; ++constr) {
            this.msConstr[constr] = new LMB(msParams.getValue("STATECONSTR" + constr), this.numStates);
        }
        this.sConstr = new LMB[this.numStates][];
        this.cfps = new MSConfigFileParser[this.numStates];
        this.searchCont = new SearchProblem[this.numStates][];
        this.searchDisc = new SearchProblem[this.numStates][];
        this.ecalcsCont = new MinimizingConfEnergyCalculator[this.numStates][];
        this.ecalcsDisc = new MinimizingConfEnergyCalculator[this.numStates][];
        System.out.println();
        System.out.println("Checking multistate K* parameters for consistency");
        System.out.println();
        this.state2MutableResNums = new ArrayList();
        this.AATypeOptions = new ArrayList();
        this.wtSeqs = new ArrayList();
        InputValidation inputValidation = new InputValidation(this.AATypeOptions, this.state2MutableResNums);
        inputValidation.handleObjFcn(msParams, this.objFcn);
        inputValidation.handleConstraints(msParams, this.msConstr);
        for (state = 0; state < this.numStates; ++state) {
            System.out.println();
            System.out.println("Checking state " + state + " parameters");
            System.out.println();
            this.cfps[state] = this.makeStateCfp(state);
            ParamSet sParams = this.cfps[state].params;
            inputValidation.handleStateParams(state, sParams, msParams);
            this.state2MutableResNums.add(this.stateMutableRes(state, this.cfps[state], this.numMutRes));
            for (int subState = 0; subState < this.state2MutableResNums.get(state).size(); ++subState) {
                inputValidation.handleAATypeOptions(state, subState, this.cfps[state]);
                if (subState != this.state2MutableResNums.get(state).size() - 1) continue;
                this.wtSeqs.add(this.cfps[state].getWtSeq(this.state2MutableResNums.get(state).get(subState)));
            }
            int numUbConstr = sParams.getInt("NUMUBCONSTR");
            int numPartFuncs = sParams.getInt("NUMUBSTATES") + 1;
            this.sConstr[state] = new LMB[numUbConstr];
            for (int constr = 0; constr < numUbConstr; ++constr) {
                this.sConstr[state][constr] = new LMB(sParams.getValue("UBCONSTR" + constr), numPartFuncs);
            }
            System.out.println();
            System.out.println("State " + state + " parameters checked");
            System.out.println();
        }
        this.state2MutableResNums.trimToSize();
        this.wtSeqs.trimToSize();
        System.out.println();
        System.out.println("Preparing search problems and matrices for multistate K*");
        System.out.println();
        for (state = 0; state < this.numStates; ++state) {
            this.searchCont[state] = this.makeStateSearchProblems(state, true, this.cfps[state]);
            this.searchDisc[state] = this.makeStateSearchProblems(state, false, this.cfps[state]);
            System.out.println();
            System.out.println("State " + state + " matrices ready");
            System.out.println();
        }
    }

    private MinimizingConfEnergyCalculator[][] makeEnergyCalculators(boolean cont) {
        MinimizingConfEnergyCalculator[][] ans = new MinimizingConfEnergyCalculator[this.numStates][];
        for (int state = 0; state < this.numStates; ++state) {
            ans[state] = this.makeEnergyCalculators(state, cont);
        }
        return ans;
    }

    private MinimizingConfEnergyCalculator[] makeEnergyCalculators(int state, boolean cont) {
        SearchProblem[] search2 = cont ? this.searchCont[state] : this.searchDisc[state];
        Parallelism parallelism = cont ? Parallelism.makeFromConfig(this.cfps[state]) : Parallelism.makeCpu(1);
        MinimizingConfEnergyCalculator[] ans = new MinimizingConfEnergyCalculator[search2.length];
        for (int substate = 0; substate < search2.length; ++substate) {
            ans[substate] = MSKStarFactory.makeEnergyCalculator(this.cfps[state], search2[substate], parallelism);
        }
        return ans;
    }

    private void cleanupEnergyCalculators(MinimizingConfEnergyCalculator[][] ecalcs, int state) {
        if (ecalcs[state] == null) {
            return;
        }
        for (int substate = 0; substate < ecalcs[state].length; ++substate) {
            MinimizingConfEnergyCalculator ecalc = ecalcs[state][substate];
            if (ecalc == null) continue;
            ecalcs[state][substate].clean();
            ecalcs[state][substate] = null;
        }
        ecalcs[state] = null;
    }

    private void cleanupEnergyCalculators(MinimizingConfEnergyCalculator[][] ecalcs) {
        if (ecalcs == null) {
            return;
        }
        for (int state = 0; state < ecalcs.length; ++state) {
            this.cleanupEnergyCalculators(ecalcs, state);
        }
        ecalcs = null;
    }

    private void cleanup() {
        this.cleanupEnergyCalculators(this.ecalcsCont);
        this.cleanupEnergyCalculators(this.ecalcsDisc);
    }

    private String getStateKStarScore(int state, ArrayList<String> boundStateAATypes) {
        ArrayList subStateAATypes = new ArrayList();
        for (int subState = 0; subState < this.state2MutableResNums.get(state).size(); ++subState) {
            subStateAATypes.add(new ArrayList());
            ArrayList<Integer> subStateResNums = this.state2MutableResNums.get(state).get(subState);
            ArrayList<Integer> boundStateResNums = this.state2MutableResNums.get(state).get(this.state2MutableResNums.get(state).size() - 1);
            for (int resNum : subStateResNums) {
                int index = boundStateResNums.indexOf(resNum);
                ArrayList<String> aa = new ArrayList<String>();
                aa.add(boundStateAATypes.get(index));
                ((ArrayList)subStateAATypes.get(subState)).add(aa);
            }
        }
        ParamSet sParams = this.cfps[state].params;
        int numPartFuncs = sParams.getInt("NUMUBSTATES") + 1;
        boolean doMinimize = sParams.getBool("DOMINIMIZE");
        MSSearchProblem[] singleSeqSearchCont = doMinimize ? new MSSearchProblem[numPartFuncs] : null;
        MSSearchProblem[] singleSeqSearchDisc = !doMinimize ? new MSSearchProblem[numPartFuncs] : null;
        for (int subState = 0; subState < numPartFuncs; ++subState) {
            MSSearchSettings spSet = new MSSearchSettings();
            spSet.AATypeOptions = (ArrayList)subStateAATypes.get(subState);
            ArrayList<String> mutRes = new ArrayList<String>();
            for (int i : this.state2MutableResNums.get(state).get(subState)) {
                mutRes.add(String.valueOf(i));
            }
            spSet.mutRes = mutRes;
            spSet.stericThreshold = sParams.getDouble("STERICTHRESH");
            spSet.pruningWindow = sParams.getDouble("IVAL") + sParams.getDouble("EW");
            if (doMinimize) {
                singleSeqSearchCont[subState] = new MSSearchProblem(this.searchCont[state][subState], spSet);
                singleSeqSearchCont[subState].setPruningMatrix();
                continue;
            }
            singleSeqSearchDisc[subState] = new MSSearchProblem(this.searchDisc[state][subState], spSet);
            singleSeqSearchDisc[subState].setPruningMatrix();
        }
        KStarScore.KStarScoreType scoreType = doMinimize ? KStarScore.KStarScoreType.Minimized : KStarScore.KStarScoreType.Discrete;
        KStarScore score = MSKStarFactory.makeKStarScore(this.msParams, state, this.cfps[state], this.sConstr[state], singleSeqSearchCont, singleSeqSearchDisc, this.ecalcsCont[state], this.ecalcsDisc[state], scoreType);
        score.compute(Integer.MAX_VALUE);
        return score.toString();
    }

    private SearchProblem[] makeStateSearchProblems(int state, boolean cont, MSConfigFileParser stateCfp) {
        ParamSet sParams = stateCfp.params;
        int numUbStates = sParams.getInt("NUMUBSTATES");
        String flexibility = cont ? "continuous" : "discrete";
        SearchProblem[] subStateSps = new SearchProblem[numUbStates + 1];
        for (int subState = 0; subState < subStateSps.length; ++subState) {
            subStateSps[subState] = stateCfp.getSearchProblem(state, subState, this.state2MutableResNums.get(state).get(subState), cont);
            subStateSps[subState].loadEnergyMatrix();
            if (!sParams.getBool("UsePoissonBoltzmann")) {
                PruningControl pc = stateCfp.setupPruning(subStateSps[subState], sParams.getDouble("Ival") + sParams.getDouble("Ew"), sParams.getBool("UseEpic"), sParams.getBool("UseTupExp"));
                pc.setReportMode(null);
                pc.prune();
            }
            System.out.println();
            System.out.println("State " + state + "." + subState + " " + flexibility + " matrix ready");
            System.out.println();
        }
        return subStateSps;
    }

    private ArrayList<ArrayList<Integer>> stateMutableRes(int state, MSConfigFileParser stateCfp, int numTreeLevels) {
        int ubState;
        ParamSet sParams = stateCfp.params;
        int numUbStates = sParams.getInt("NUMUBSTATES");
        ArrayList<ArrayList<Integer>> m2s = new ArrayList<ArrayList<Integer>>();
        for (ubState = 0; ubState <= numUbStates; ++ubState) {
            m2s.add(new ArrayList());
        }
        for (ubState = 0; ubState < numUbStates; ++ubState) {
            StringTokenizer st = new StringTokenizer(sParams.getValue("UBSTATEMUT" + ubState));
            while (st.hasMoreTokens()) {
                m2s.get(ubState).add(Integer.valueOf(st.nextToken()));
            }
            m2s.get(numUbStates).addAll((Collection<Integer>)m2s.get(ubState));
            m2s.get(numUbStates).trimToSize();
        }
        if (((ArrayList)m2s.get(numUbStates)).size() != numTreeLevels) {
            throw new RuntimeException("ERROR: SeqTree has " + numTreeLevels + " mutable positions  but " + m2s.size() + " are listed for state " + state);
        }
        m2s.trimToSize();
        return m2s;
    }

    private MSConfigFileParser makeStateCfp(int state) {
        String stateConfigFiles = this.msParams.getValue("STATECFGFILES" + state);
        String stateKStFile = StringParsing.getToken(stateConfigFiles, 1);
        String stateSysFile = StringParsing.getToken(stateConfigFiles, 2);
        String stateDEEFile = StringParsing.getToken(stateConfigFiles, 3);
        this.stateArgs[state] = new String[]{"-c", stateKStFile, "n/a", stateSysFile, stateDEEFile};
        MSConfigFileParser stateCfp = new MSConfigFileParser(this.stateArgs[state], false);
        stateCfp.loadData();
        return stateCfp;
    }

    protected void printAllSeqs() {
        ArrayList<ArrayList<ArrayList<String>>> stateSeqLists = this.listAllSeqs();
        for (int state = 0; state < stateSeqLists.size(); ++state) {
            int numSeqs = stateSeqLists.get(state).size();
            System.out.println();
            System.out.println("State" + state + ": " + numSeqs + " sequences with <= " + this.numMaxMut + " mutation(s) from wild-type");
            System.out.println();
            for (ArrayList<String> seq : stateSeqLists.get(state)) {
                for (String aa : seq) {
                    System.out.print(aa + " ");
                }
                System.out.println();
            }
        }
    }

    private ArrayList<ArrayList<ArrayList<String>>> listAllSeqs() {
        ArrayList<ArrayList<ArrayList<String>>> ans = new ArrayList<ArrayList<ArrayList<String>>>();
        String[] buf = new String[this.numMutRes];
        for (int state = 0; state < this.numStates; ++state) {
            int subState = this.AATypeOptions.get(state).size() - 1;
            ArrayList<ArrayList<String>> subStateAATypeOptions = this.AATypeOptions.get(state).get(subState);
            ArrayList<ArrayList<String>> stateOutput = new ArrayList<ArrayList<String>>();
            this.listAllSeqsHelper(subStateAATypeOptions, stateOutput, this.wtSeqs.get(state), buf, 0, 0);
            stateOutput.trimToSize();
            ans.add(stateOutput);
        }
        ans.trimToSize();
        return ans;
    }

    private void listAllSeqsHelper(ArrayList<ArrayList<String>> subStateAATypeOptions, ArrayList<ArrayList<String>> stateOutput, String[] wt, String[] buf, int depth, int dist) {
        if (depth == this.numMutRes) {
            ArrayList<String> seq = new ArrayList<String>(Arrays.asList(buf));
            seq.trimToSize();
            stateOutput.add(seq);
            return;
        }
        for (int aaIndex = 0; aaIndex < subStateAATypeOptions.get(depth).size(); ++aaIndex) {
            int nDist;
            buf[depth] = subStateAATypeOptions.get(depth).get(aaIndex);
            int n = nDist = buf[depth].equalsIgnoreCase(wt[depth]) ? dist : dist + 1;
            if (nDist > this.numMaxMut) continue;
            this.listAllSeqsHelper(subStateAATypeOptions, stateOutput, wt, buf, depth + 1, nDist);
        }
    }

    public void calcBestSequences() {
        String algOption = this.msParams.getValue("MultStateAlgOption");
        switch (algOption.toLowerCase()) {
            case "exhaustive": {
                this.exhaustiveMultistateSearch();
                return;
            }
            case "sublinear": {
                this.subLinearMultiStateSearch();
                return;
            }
        }
        throw new UnsupportedOperationException("ERROR: " + algOption + " is not supported for MULTISTATEALGOPTION");
    }

    public void subLinearMultiStateSearch() {
        String seq;
        System.out.println();
        System.out.println("Performing sub-linear multistate K*");
        System.out.println();
        for (int state = 0; state < this.numStates; ++state) {
            boolean doMinimize = this.cfps[state].params.getBool("DOMINIMIZE");
            if (doMinimize) {
                this.ecalcsCont[state] = this.makeEnergyCalculators(state, true);
            }
            this.ecalcsDisc[state] = this.makeEnergyCalculators(state, false);
        }
        this.tree = new MSKStarTree(this.numMutRes, this.numStates, this.numMaxMut, this.numSeqsWanted, this.objFcn, this.msConstr, this.sConstr, this.state2MutableResNums, this.AATypeOptions, this.wtSeqs, this.searchCont, this.searchDisc, this.ecalcsCont, this.ecalcsDisc, this.msParams, this.cfps);
        Stopwatch stopwatch = new Stopwatch().start();
        ArrayList<String> bestSequences = new ArrayList<String>();
        for (int seqNum = 0; seqNum < this.numSeqsWanted && (seq = this.tree.nextSeq()) != null; ++seqNum) {
            bestSequences.add(seq);
        }
        this.cleanup();
        System.out.println();
        System.out.println("Finished sub-linear multistate K* in " + stopwatch.getTime(2));
        System.out.println();
    }

    private void exhaustiveMultistateSearch() {
        int state;
        System.out.println();
        System.out.println("Checking MultiStateKStar by exhaustive search");
        System.out.println();
        Stopwatch stopwatch = new Stopwatch().start();
        ArrayList<ArrayList<ArrayList<String>>> seqList = this.listAllSeqs();
        String fname = "sequences-exhaustive.txt";
        boolean resume = this.msParams.getBool("RESUME");
        if (resume) {
            ArrayList<ArrayList<ArrayList<String>>> completed = this.getCompletedSeqs(fname);
            for (state = 0; state < this.numStates; ++state) {
                for (ArrayList<String> seq : completed.get(state)) {
                    if (!seqList.get(state).contains(seq)) continue;
                    seqList.get(state).remove(seq);
                }
            }
        }
        String[][] stateKSS = new String[this.numStates][];
        for (state = 0; state < this.numStates; ++state) {
            stateKSS[state] = new String[seqList.get(state).size()];
        }
        try {
            if (!resume) {
                ObjectIO.delete(fname);
            }
            PrintStream fout = new PrintStream(new FileOutputStream(new File(fname), true));
            for (int state2 = 0; state2 < this.numStates; ++state2) {
                boolean doMinimize = this.cfps[state2].params.getBool("DOMINIMIZE");
                if (stateKSS[state2].length > 0) {
                    fout.println();
                    fout.println("State" + state2 + ": ");
                    fout.println();
                }
                if (doMinimize) {
                    this.ecalcsCont[state2] = this.makeEnergyCalculators(state2, true);
                } else {
                    this.ecalcsDisc[state2] = this.makeEnergyCalculators(state2, false);
                }
                for (int seqNum = 0; seqNum < stateKSS[state2].length; ++seqNum) {
                    stateKSS[state2][seqNum] = this.getStateKStarScore(state2, seqList.get(state2).get(seqNum));
                    fout.println(stateKSS[state2][seqNum]);
                }
                this.searchCont[state2] = null;
                this.searchDisc[state2] = null;
                this.cleanupEnergyCalculators(this.ecalcsCont, state2);
                this.cleanupEnergyCalculators(this.ecalcsDisc, state2);
            }
            fout.flush();
            fout.close();
        }
        catch (FileNotFoundException e) {
            e.printStackTrace();
        }
        this.cleanup();
        this.printAllKStarScores(stateKSS);
        System.out.println();
        System.out.println("Finished checking MultiStateKStar by exhaustive search in " + stopwatch.getTime(2));
        System.out.println();
    }

    private void printAllKStarScores(String[][] stateKSS) {
        for (int state = 0; state < this.numStates; ++state) {
            if (stateKSS[state].length > 0) {
                System.out.println();
                System.out.println("State" + state + ": ");
                System.out.println();
            }
            String[] kss = stateKSS[state];
            for (int subState = 0; subState < kss.length; ++subState) {
                System.out.println(kss[subState]);
            }
        }
    }

    private ArrayList<ArrayList<ArrayList<String>>> getCompletedSeqs(String fname) {
        int state;
        ArrayList<ArrayList<ArrayList<String>>> ans = new ArrayList<ArrayList<ArrayList<String>>>();
        for (state = 0; state < this.numStates; ++state) {
            ans.add(new ArrayList());
        }
        if (!new File(fname).exists()) {
            return ans;
        }
        for (state = 0; state < this.numStates; ++state) {
            ans.add(new ArrayList());
        }
        try (BufferedReader br = new BufferedReader(new FileReader(fname));){
            String line;
            int state2 = 0;
            while ((line = br.readLine()) != null) {
                if (line.length() == 0) continue;
                if (line.startsWith("State")) {
                    line = line.toLowerCase();
                    line = line.replace("state", "");
                    line = line.replace(":", "").trim();
                    state2 = Integer.valueOf(line);
                    continue;
                }
                if (!line.startsWith("Seq")) continue;
                StringTokenizer st = new StringTokenizer(line, ":");
                while (st.hasMoreTokens()) {
                    String token = st.nextToken();
                    if (!token.contains("score")) continue;
                    token = token.replace("score", "");
                    token = token.replace(",", "");
                    token = token.trim();
                    StringTokenizer st1 = new StringTokenizer(token);
                    ArrayList<String> val1 = new ArrayList<String>();
                    while (st1.hasMoreTokens()) {
                        val1.add(st1.nextToken().split("-")[0]);
                    }
                    val1.trimToSize();
                    ans.get(state2).add(val1);
                }
            }
            br.close();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
        return ans;
    }
}

