/*
 * Decompiled with CFR 0.152.
 */
package edu.duke.cs.osprey.markstar.visualizer;

import edu.duke.cs.osprey.markstar.visualizer.KStarTreeNode;
import edu.duke.cs.osprey.tools.JvmMem;
import edu.duke.cs.osprey.tools.Log;
import edu.duke.cs.osprey.tools.Stopwatch;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.math.BigDecimal;
import java.math.MathContext;
import java.math.RoundingMode;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.function.Consumer;

public class KStarTreeSplitter {
    public static void main(String[] args) {
        List<String> resNums;
        File outFile;
        File inFile;
        try {
            inFile = new File(args[0]);
            outFile = new File(args[1]);
            resNums = Arrays.asList(args[2].split(","));
        }
        catch (Throwable t) {
            Log.log("Invalid arguments, caused an exception:", new Object[0]);
            t.printStackTrace(System.out);
            Log.log("expected arguments: inFile outFile resNums", new Object[0]);
            Log.log("\twhere resNums is a comma-separated list of residue numbers", new Object[0]);
            return;
        }
        Stopwatch readsw = new Stopwatch().start();
        Log.log("reading tree from %s ...", inFile.getAbsolutePath());
        KStarTreeNode root = KStarTreeNode.parseTree(inFile, false, null);
        Log.log("\tdone in %s, used %s", readsw.stop().getTime(2), JvmMem.getOldPool());
        Log.log("target residue order:  %s", resNums);
        List<String> currentResNums = KStarTreeSplitter.getResNums(root);
        Log.log("current residue order: %s", currentResNums);
        Log.log("sorting ...", new Object[0]);
        Stopwatch sortsw = new Stopwatch().start();
        for (int i = 0; i < resNums.size(); ++i) {
            for (int j = currentResNums.indexOf(resNums.get(i)); j > i; --j) {
                int jm1 = j - 1;
                String temp = currentResNums.get(j);
                currentResNums.set(j, currentResNums.get(jm1));
                currentResNums.set(jm1, temp);
                KStarTreeSplitter.pushUpTreeLayer(root, j);
                Log.log("current residue order: %s, used %s", KStarTreeSplitter.getResNums(root), JvmMem.getOldPool());
            }
        }
        Log.log("sort finished in %s", sortsw.stop().getTime(2));
        Stopwatch writesw = new Stopwatch().start();
        Log.log("writing tree to %s ...", outFile.getAbsolutePath());
        try (FileWriter writer = new FileWriter(outFile);){
            root.printTreeLikeMARKStar(writer);
        }
        catch (IOException ex) {
            throw new Error(ex);
        }
        Log.log("\tfinished in %s", writesw.stop().getTime(2));
    }

    private static List<String> getResNums(KStarTreeNode node) {
        LinkedHashSet<String> assignedResNums = new LinkedHashSet<String>();
        while (true) {
            Integer i;
            if ((i = node.getAssignmentIndex()) != null) {
                assignedResNums.add(node.getAssignments()[i].split(":")[0]);
            }
            if (node.children.isEmpty()) break;
            node = node.children.get(0);
        }
        return new ArrayList<String>(assignedResNums);
    }

    private static void pushUpTreeLayer(KStarTreeNode root, int i) {
        int depth = i + 1;
        if (depth > root.getConfAssignments().length) {
            throw new IllegalArgumentException("depth too deep");
        }
        if (depth < 2) {
            throw new IllegalArgumentException("depth too shallow");
        }
        MathContext mathContext = new MathContext(128, RoundingMode.HALF_UP);
        for (KStarTreeNode topNode : KStarTreeSplitter.collectNodesAt(root, depth - 2)) {
            class Edge {
                final KStarTreeNode nodeim1;
                final int posim1;
                final KStarTreeNode nodei;
                final int posi;

                public Edge(KStarTreeNode nodeim1, KStarTreeNode nodei) {
                    this.nodeim1 = nodeim1;
                    this.posim1 = nodeim1.getAssignmentIndex();
                    this.nodei = nodei;
                    this.posi = nodei.getAssignmentIndex();
                }
            }
            ArrayList<Edge> edges = new ArrayList<Edge>();
            for (KStarTreeNode nodeim1 : topNode.children) {
                for (KStarTreeNode nodei : nodeim1.children) {
                    edges.add(new Edge(nodeim1, nodei));
                }
            }
            for (Edge edge : edges) {
                edge.nodeim1.removeFromParent();
            }
            if (!topNode.children.isEmpty()) {
                throw new Error("tree levels are not fully expanded, can't sort");
            }
            ArrayList<KStarTreeNode> nodesip1 = new ArrayList<KStarTreeNode>();
            for (Edge edge : edges) {
                nodesip1.addAll(edge.nodei.children);
            }
            for (KStarTreeNode nodeip1 : nodesip1) {
                nodeip1.removeFromParent();
            }
            for (Edge edge : edges) {
                int posi = edge.posi;
                int rci = edge.nodei.getConfAssignments()[posi];
                String assignmenti = edge.nodei.getAssignments()[posi];
                KStarTreeNode newNodeim1 = topNode.children.stream().filter(node -> node.getConfAssignments()[posi] == rci).findFirst().orElseGet(() -> topNode.assign(posi, rci, assignmenti, BigDecimal.ZERO, BigDecimal.ZERO, Double.NaN, Double.NaN));
                int posim1 = edge.posim1;
                int rcim1 = edge.nodeim1.getConfAssignments()[posim1];
                String assignmentim1 = edge.nodeim1.getAssignments()[posim1];
                KStarTreeNode newNodei = newNodeim1.assign(posim1, rcim1, assignmentim1, edge.nodei.getLowerBound(), edge.nodei.getUpperBound(), edge.nodei.getConfLowerBound(), edge.nodei.getConfUpperBound());
                Iterator iter = nodesip1.iterator();
                while (iter.hasNext()) {
                    KStarTreeNode nodeip1 = (KStarTreeNode)iter.next();
                    int[] conf = nodeip1.getConfAssignments();
                    if (conf[posim1] != rcim1 || conf[posi] != rci) continue;
                    iter.remove();
                    newNodei.addChild(nodeip1);
                }
            }
            for (KStarTreeNode nodeim1 : topNode.children) {
                nodeim1.updateBoundsFromChildren(mathContext);
            }
            if (nodesip1.isEmpty()) continue;
            throw new Error("tree layer swap orphaned some nodes. this is a bug.");
        }
    }

    private static List<KStarTreeNode> collectNodesAt(KStarTreeNode root, int depth) {
        ArrayList<KStarTreeNode> nodes = new ArrayList<KStarTreeNode>();
        class Box<T> {
            public T f;

            Box() {
            }
        }
        Box box = new Box();
        box.f = node -> {
            if (node.level == depth) {
                nodes.add((KStarTreeNode)node);
            } else {
                for (KStarTreeNode child : node.children) {
                    ((Consumer)box.f).accept(child);
                }
            }
        };
        ((Consumer)box.f).accept(root);
        return nodes;
    }
}

