/*
 * Decompiled with CFR 0.152.
 */
package edu.duke.cs.osprey.restypes;

import edu.duke.cs.osprey.energy.forcefield.ForcefieldParams;
import edu.duke.cs.osprey.restypes.ResidueTemplate;
import edu.duke.cs.osprey.structure.Atom;
import edu.duke.cs.osprey.structure.Residue;
import edu.duke.cs.osprey.tools.VectorAlgebra;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;

public class ResTemplateMatching {
    ResidueTemplate template;
    Residue res;
    ForcefieldParams ffParams;
    int[] matching;
    int[] partialMatching;
    public double score;
    int numAtoms;
    ArrayList<ArrayList<Integer>> templateBonds;
    double[][] residueDistanceMatrix;
    double[][] templateDistanceMatrix;
    int[] templateAtomOrdering;
    int[] templateAtomOrderingRev;
    double[] cumulativeScores;
    boolean use13Distances = false;
    int curAssignmentCount = 0;
    static double timeoutableScore = 0.01;
    int timeoutNumAssignments = 100000000;

    public ResTemplateMatching(Residue res, ResidueTemplate template, ForcefieldParams ffParams) {
        this.res = res;
        this.template = template;
        this.ffParams = ffParams;
        this.score = Double.POSITIVE_INFINITY;
        ArrayList<Atom> templateAtoms = template.templateRes.atoms;
        this.numAtoms = templateAtoms.size();
        if (res.atoms.size() != this.numAtoms) {
            return;
        }
        this.matching = new int[templateAtoms.size()];
        this.partialMatching = new int[templateAtoms.size()];
        Arrays.fill(this.partialMatching, -1);
        this.templateBonds = new ArrayList();
        for (Atom atom1 : template.templateRes.atoms) {
            ArrayList<Integer> atom1Bonds = new ArrayList<Integer>();
            block1: for (int atNum2 = 0; atNum2 < this.numAtoms; ++atNum2) {
                Atom atom2 = template.templateRes.atoms.get(atNum2);
                for (Atom bondedAtom : atom1.bonds) {
                    if (bondedAtom != atom2) continue;
                    atom1Bonds.add(atNum2);
                    continue block1;
                }
            }
            this.templateBonds.add(atom1Bonds);
        }
        if (template.templateRes.coords != null) {
            this.templateDistanceMatrix = template.templateRes.atomDistanceMatrix();
        } else {
            this.templateDistanceMatrix = template.templateRes.estBondDistanceMatrix(ffParams);
            this.use13Distances = false;
        }
        this.residueDistanceMatrix = res.atomDistanceMatrix();
        this.chooseTemplateAtomOrdering();
        this.cumulativeScores = new double[this.numAtoms];
        this.curAssignmentCount = 0;
        this.searchForMatching(0);
    }

    private void searchForMatching(int level) {
        if (this.curAssignmentCount > this.timeoutNumAssignments && this.score < timeoutableScore * (double)this.res.atoms.size()) {
            return;
        }
        double curMatchingScore = this.scoreCurPartialMatching(level);
        if (level == this.numAtoms) {
            if (curMatchingScore < this.score && this.checkStereochemistry()) {
                this.score = curMatchingScore;
                System.arraycopy(this.partialMatching, 0, this.matching, 0, this.numAtoms);
            }
        } else if (curMatchingScore < this.score) {
            String templateAtomName = this.template.templateRes.atoms.get((int)this.templateAtomOrdering[level]).name;
            int nameMatchAtom = this.res.getAtomIndexByName(templateAtomName);
            if (nameMatchAtom > -1) {
                this.tryNextAtomMatch(level, nameMatchAtom);
            }
            for (int resAtNum = 0; resAtNum < this.numAtoms; ++resAtNum) {
                if (resAtNum == nameMatchAtom) continue;
                this.tryNextAtomMatch(level, resAtNum);
            }
        }
    }

    private void tryNextAtomMatch(int level, int resAtNum) {
        int resEleType;
        int templateEleType;
        ++this.curAssignmentCount;
        if (!this.atomAlreadyInPartialMatching(resAtNum, level) && (templateEleType = this.template.templateRes.atoms.get((int)this.templateAtomOrdering[level]).elementNumber) == (resEleType = this.res.atoms.get((int)resAtNum).elementNumber)) {
            this.partialMatching[this.templateAtomOrdering[level]] = resAtNum;
            this.searchForMatching(level + 1);
        }
    }

    boolean atomAlreadyInPartialMatching(int resAtNum, int level) {
        for (int prevLevel = 0; prevLevel < level; ++prevLevel) {
            if (this.partialMatching[this.templateAtomOrdering[prevLevel]] != resAtNum) continue;
            return true;
        }
        return false;
    }

    Collection<Integer> getConstrDistAtoms(int atomNum1) {
        if (this.use13Distances) {
            HashSet<Integer> ans = new HashSet<Integer>();
            for (int bondedAtNum : this.templateBonds.get(atomNum1)) {
                ans.add(bondedAtNum);
                for (int atNum13 : this.templateBonds.get(bondedAtNum)) {
                    if (atNum13 == atomNum1) continue;
                    ans.add(atNum13);
                }
            }
            return ans;
        }
        return this.templateBonds.get(atomNum1);
    }

    double scoreCurPartialMatching(int level) {
        if (level > 0) {
            this.cumulativeScores[level - 1] = 0.0;
            int justAssignedTemplateAtom = this.templateAtomOrdering[level - 1];
            int justAssignedResAtom = this.partialMatching[justAssignedTemplateAtom];
            for (int bondedTemplateAtNum : this.getConstrDistAtoms(justAssignedTemplateAtom)) {
                if (this.templateAtomOrderingRev[bondedTemplateAtNum] >= level) continue;
                int bondedResAtNum = this.partialMatching[bondedTemplateAtNum];
                double resDist = this.residueDistanceMatrix[justAssignedResAtom][bondedResAtNum];
                double templateDist = this.templateDistanceMatrix[justAssignedTemplateAtom][bondedTemplateAtNum];
                if (resDist >= 1.5 * templateDist) {
                    this.cumulativeScores[level - 1] = Double.POSITIVE_INFINITY;
                    break;
                }
                double distDiff = templateDist - resDist;
                int n = level - 1;
                this.cumulativeScores[n] = this.cumulativeScores[n] + distDiff * distDiff;
            }
        }
        double ans = 0.0;
        for (int assignedLevel = 0; assignedLevel < level; ++assignedLevel) {
            ans += this.cumulativeScores[assignedLevel];
        }
        return ans;
    }

    public void assign() {
        this.res.template = this.template;
        ArrayList<Atom> newAtoms = new ArrayList<Atom>();
        for (int atNum = 0; atNum < this.numAtoms; ++atNum) {
            Atom newAtom = this.template.templateRes.atoms.get(atNum).copy();
            newAtom.res = this.res;
            newAtoms.add(newAtom);
        }
        this.res.atoms = newAtoms;
        this.res.markIntraResBondsByTemplate();
        double[] newCoords = new double[3 * this.numAtoms];
        for (int atNum = 0; atNum < this.numAtoms; ++atNum) {
            System.arraycopy(this.res.coords, 3 * this.matching[atNum], newCoords, 3 * atNum, 3);
        }
        this.res.coords = newCoords;
    }

    private void chooseTemplateAtomOrdering() {
        HashMap<Integer, Integer> elementFreqs = new HashMap<Integer, Integer>();
        for (Atom at : this.template.templateRes.atoms) {
            if (elementFreqs.containsKey(at.elementNumber)) {
                elementFreqs.put(at.elementNumber, (Integer)elementFreqs.get(at.elementNumber) + 1);
                continue;
            }
            elementFreqs.put(at.elementNumber, 1);
        }
        ArrayList elementFreqsSorted = new ArrayList(new HashSet(elementFreqs.values()));
        Collections.sort(elementFreqsSorted);
        int count = 0;
        this.templateAtomOrdering = new int[this.numAtoms];
        this.templateAtomOrderingRev = new int[this.numAtoms];
        Iterator iterator2 = elementFreqsSorted.iterator();
        while (iterator2.hasNext()) {
            int elementFreq = (Integer)iterator2.next();
            for (int templateAtNum = 0; templateAtNum < this.numAtoms; ++templateAtNum) {
                int templAtElement = this.template.templateRes.atoms.get((int)templateAtNum).elementNumber;
                if ((Integer)elementFreqs.get(templAtElement) != elementFreq) continue;
                this.templateAtomOrdering[count] = templateAtNum;
                this.templateAtomOrderingRev[templateAtNum] = count++;
            }
        }
        if (count != this.numAtoms) {
            throw new RuntimeException("ERROR: Bug in ResTemplateMatching, lost " + (this.numAtoms - count) + " atoms");
        }
    }

    private boolean checkStereochemistry() {
        if (this.template.templateRes.coords == null) {
            return true;
        }
        for (int atNum = 0; atNum < this.numAtoms; ++atNum) {
            List<Atom> resBondedAtoms;
            boolean resStereoSign;
            Atom templateAtom = this.template.templateRes.atoms.get(atNum);
            if (templateAtom.bonds.size() != 4) continue;
            Atom resAtom = this.res.atoms.get(this.partialMatching[atNum]);
            boolean templateStereoSign = ResTemplateMatching.stereoSign(templateAtom, templateAtom.bonds);
            if (templateStereoSign == (resStereoSign = ResTemplateMatching.stereoSign(resAtom, resBondedAtoms = templateAtom.bonds.stream().map(at -> this.res.atoms.get(this.partialMatching[at.indexInRes])).collect(Collectors.toList())))) continue;
            return false;
        }
        return true;
    }

    private static boolean stereoSign(Atom at, List<Atom> subst) {
        if (subst.size() != 4) {
            throw new RuntimeException("ERROR: don't know how to check stereochemistry with " + subst.size() + " substituents");
        }
        double[] center = at.getCoords();
        List bonds2 = subst.stream().map(bat -> VectorAlgebra.subtract(bat.getCoords(), center)).collect(Collectors.toList());
        double[] normal = VectorAlgebra.cross((double[])bonds2.get(0), (double[])bonds2.get(1));
        return VectorAlgebra.dot(normal, (double[])bonds2.get(2)) > VectorAlgebra.dot(normal, (double[])bonds2.get(3));
    }
}

