/*
 * This file is part of RDC-ANALYTIC.
 *
 * RDC-ANALYTIC Protein Backbone Structure Determination Software Version 1.0
 * Copyright (C) 2001-2012 Bruce Donald Lab, Duke University
 *
 * RDC-ANALYTIC is free software; you can redistribute it and/or modify it under
 * the terms of the GNU Lesser General Public License as published by the Free
 * Software Foundation, either version 3 of the License, or (at your option) any
 * later version.
 *
 * RDC-ANALYTIC is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 * FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
 * details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with this library; if not, see:
 *     <http://www.gnu.org/licenses/>.
 *
 * There are additional restrictions imposed on the use and distribution of this
 * open-source code, including: (A) this header must be included in any
 * modification or extension of the code; (B) you are required to cite our
 * papers in any publications that use this code. The citation for the various
 * different modules of our software, together with a complete list of
 * requirements and restrictions are found in the document license.pdf enclosed
 * with this distribution.
 *
 * Contact Info:
 *     Bruce R. Donald
 *     Duke University
 *     Department of Computer Science
 *     Levine Science Research Center (LSRC)
 *     Durham, NC 27708-0129
 *     USA
 *     email: www.cs.duke.edu/brd/
 *
 * <signature of Bruce Donald>, August 04, 2012
 * Bruce R. Donald, Professor of Computer Science and Biochemistry
 */

/**
 * @version       1.0.1, August 04, 2012
 * @author        Chittaranjan Tripathy (2007-2012)
 * @email         chittu@cs.duke.edu
 * @organization  Duke University
 */

/**
 * Package specification
 */
package utilities;

/**
 * Import statement(s)
 */
import java.util.*;

import analytic.myProtein;
import analytic.myAtom;
import analytic.myResidue;
import analytic.myVector3D;
import analytic.myPoint;
import analytic.myPair;
import analytic.myPdbParser;
import analytic.Matrix;
import analytic.myMatrix;
import analytic.myMiscUtilities;
import analytic.SingularValueDecomposition;

/**
 * Description of the class
 */
public class StructureAligner {
    private myProtein __p1 = null;
    private myProtein __p2 = null;
    private myProtein __p2Aligned = null;
    private int[] __range = null;
    private String[] __atomTypes = null;    
        
    private Matrix __rotation = null;
    private myVector3D __translation = null;
    private myPoint __centroid1 = null;
    private myPoint __centroid2 = null;
    private double __rmsd;

    public StructureAligner() {
    }
    
    public StructureAligner(myProtein p1, myProtein p2, int[] range, String[] atomTypes) {
        __p1 = new myProtein(p1);
        __p2 = new myProtein(p2);
        __range = range;
        __atomTypes = atomTypes;        
        
        alignAndComputeRmsd(__p1, __p2, __range, __atomTypes);        
    }       
    
    private Vector<myAtom> getAtomsForSuperposition(myProtein p, int [] range, String [] atomTypes) {
        Vector<String> atomTypesChosen = new Vector<>();
        for (String s : atomTypes) {
            atomTypesChosen.add(s.trim());
        }

        List<myPair<Integer, Integer>> ranges = new ArrayList<>();
        for (int i = 0; i < range.length; i += 2) {
            int lb = range[i];
            int ub = range[i + 1];
            ranges.add(new myPair<Integer, Integer>(lb, ub));            
        }      
        
        // Check if the intervals overlap. If so then report to the user.
        if (myMiscUtilities.existsOverlap(ranges)) {
            System.out.println("Error: there exists overlap in the intervals");
            System.exit(1);
        }
        
        // Get all the atoms
        Vector<myAtom> atomSet = new Vector<>();
        for (myPair<Integer, Integer> thisRange : ranges) {
            int lb = thisRange.first();
            int ub = thisRange.second();
            for (int i = lb; i <= ub; i++) {
                myResidue r = p.residueAt(i);
                if (r != null) {
                    for (String thisAtomType : atomTypesChosen) {                        
                        myAtom a = r.getAtom(thisAtomType);
                        if (a != null) {
                            atomSet.add(new myAtom(a));
                        }
                    }
                }
            }
        }
        
        return atomSet;
    }
           
    /**
     * Compute the geometric center of a collection of points in 3D space.
     *
     * @param coll set of points
     * @return the geometric center of the set of points
     */
    // Taken from myNewPacker
    public static myPoint computeGeometricCenter(Collection<? extends myPoint> coll) {
        if (coll == null) {
            return null;
        }
        if (coll.isEmpty()) {
            return null;
        }
        myPoint geomCenter = new myPoint(0, 0, 0);
        Iterator<? extends myPoint> iter = coll.iterator();
        while (iter.hasNext()) {
            myPoint p = iter.next();
            geomCenter.setX(geomCenter.getX() + p.getX());
            geomCenter.setY(geomCenter.getY() + p.getY());
            geomCenter.setZ(geomCenter.getZ() + p.getZ());
        }
        geomCenter.setX(geomCenter.getX() / coll.size());
        geomCenter.setY(geomCenter.getY() / coll.size());
        geomCenter.setZ(geomCenter.getZ() / coll.size());

        return geomCenter;
    }
           
    // Taken from myNewPacker
    private void computeTranslation(Collection<? extends myPoint> coll, myVector3D v) {
        if (coll == null || coll.isEmpty() || v == null) {
            System.out.println("Error: computeTranslation failed as one of the input parameters is null or empty");
            System.exit(1);
        }        
        Iterator<? extends myPoint> iter = coll.iterator();
        while (iter.hasNext()) {
            myPoint p = iter.next();
            p.translate(v);
        }
    }            
    
    private boolean checkCorrespondence(Vector<myAtom> atomSet1, Vector<myAtom> atomSet2) {
        // This method assumes that the atom sets are sorted so that a correspondence can be made        
        if (atomSet1.size() != atomSet2.size()) {
            //System.out.println("Error: different number of atoms in the two atom sets chosen");
            //System.exit(1);
            return false;
        }
        
        for (int i = 0; i < atomSet1.size(); i++) {
            myAtom atom1 = atomSet1.elementAt(i);
            myAtom atom2 = atomSet2.elementAt(i);
            if (atom1.getResidueNumber() != atom2.getResidueNumber() && !atom1.getAtomName().equalsIgnoreCase(atom2.getAtomName())) {
                //System.out.println("Error: the points in two structures do not have same corresponding atoms and residue numbers");
                //System.exit(1);
                return false;
            }
        }
        
        return true;        
    }
    
    private Vector<myPoint> getPointSet(Vector<myAtom> atomSet) {
        Vector<myPoint> vp = new Vector<myPoint>();
        for (myAtom a : atomSet) {
            vp.add(new myPoint(a.getCoordinates()));
        }
        return vp;
    }   
    
    private void rotationalFit(Vector<myPoint> pointSet1, Vector<myPoint> pointSet2) {
        if (pointSet1.size() != pointSet2.size()) {
            System.out.println("Error: The two point sets are of different sizes");
            System.exit(1);
        }
        double[][] m1 = new double[pointSet1.size()][3];
        double[][] m2 = new double[pointSet2.size()][3];
        
        for (int i = 0; i < pointSet1.size(); i++) { // both the matrix have same dimensions
            m1[i] = pointSet1.elementAt(i).getXYZ();
            m2[i] = pointSet2.elementAt(i).getXYZ();
        }
        
        // Compute the correlation matrix
        Matrix M1 = new Matrix(m1);
        Matrix M2 = new Matrix(m2);
        Matrix M1T = M1.transpose();
        //Matrix corrMat = M1.times(M2T);
        Matrix corrMat = M1T.times(M2);
        
        //System.out.println("Printing the CorrMat matrix while doing the best fit");
        Matrix A = corrMat;
        //A.print(4, 4);
        //System.out.println("row: " + A.getRowDimension() + "    col: " + A.getColumnDimension());
        SingularValueDecomposition SVD = A.svd();        
        Matrix U = SVD.getU();
        Matrix S = SVD.getS();
        Matrix V = SVD.getV(); // Note that V is already transposed here: corrMat = U S V
        double[] singularValues = SVD.getSingularValues();
        //System.out.print("U = " + " row: " + U.getRowDimension() + "  col: " + U.getColumnDimension());
        //U.print(9, 6);
        //System.out.print("Sigma = " + " row: " + S.getRowDimension() + "  col: " + S.getColumnDimension());
        //S.print(9, 6);
        //System.out.print("V = " + " row: " + V.getRowDimension() + "  col: " + V.getColumnDimension());
        //V.print(9, 6);
        
        try {
            check(A, SVD.getU().times(SVD.getS().times(SVD.getV().transpose())));
            System.out.println("Singular Value Decomposition (SVD): succeeds");
        } catch (java.lang.RuntimeException e) {
            System.out.println("Singular Value Decomposition (SVD): fails");
        }        
        
        Matrix VT = ((Matrix) V.clone());        
        Matrix VTsave = (Matrix) VT.clone();
        Matrix UT = U.transpose();
        Matrix rotOnly = VT.times(UT);
        __rotation = rotOnly.transpose();        
        
        //System.out.print("RotMatrix = " + Matrix.det(__rotation.getArray())); // determinant has to be 1
        //__rotation.print(9, 6); 
        //System.exit(1);        
        boolean bb = (Matrix.det(U.getArray()) * Matrix.det(V.getArray()) < 0.000000000001);
        if (bb) {
            System.out.println("Error: encountered reflection while doing SVD"); System.exit(1);
        }
        
        // Check for reflection
        double det = Matrix.det(__rotation.getArray());

        if (det < 0) {
            System.out.println("here in det < 0 to check reflection"); System.exit(1);            
            VT = VTsave.transpose();
            VT.set(2, 0, (0 - VT.get(2, 0)));
            VT.set(2, 1, (0 - VT.get(2, 1)));
            VT.set(2, 2, (0 - VT.get(2, 2)));

            Matrix nv_transp = VT.transpose();
            rotOnly = nv_transp.times(UT);
            __rotation = rotOnly.transpose();
        }          
        
        // Compute translation
        myVector3D centroid2 = new myVector3D(__centroid2.getXYZ());
        centroid2 = myVector3D.rotate(centroid2, __rotation);
        __translation = new myVector3D(centroid2, __centroid1); //(tail, head)                        
        
        // Apply the superposition to the point set2 which already has the same centroid as point set1
        for (myPoint p : pointSet2) {            
            myPoint pRotTr = new myPoint(p);            
            Matrix rotMatrix = getRotation();
            //myVector3D translation = getTranslation();
            myVector3D v = new myVector3D(pRotTr.getXYZ());
            v = myVector3D.rotate(v, rotMatrix);            
            //v.translate(translation); // do not translate
            p.setX(v.getX());
            p.setY(v.getY());
            p.setZ(v.getZ());
        }
        
        // Compute the rmsd
        __rmsd = computeRmsd(pointSet1, pointSet2);                                   
                
        // Align the second protein wrt the first protein and save it
        myProtein p2Rot = new myProtein(__p2);       
        p2Rot.rotate(new myMatrix(getRotation().getArrayCopy()));
        p2Rot.translate(getTranslation());                            
        __p2Aligned = p2Rot;
    }

    public double computeRmsd(Vector<myPoint> pointSet1, Vector<myPoint> pointSet2) {
        if (pointSet1.size() != pointSet2.size()) {
            System.out.println("Error: The two point sets must have the same size");
            System.exit(1);
        }
        
        double rmsd = 0.0;        
        for (int i = 0; i < pointSet1.size(); i++) {
            myPoint point1 = pointSet1.elementAt(i);
            myPoint point2 = pointSet2.elementAt(i);          
            rmsd += myPoint.squaredDist(point1, point2);
        }
        rmsd = Math.sqrt(rmsd / pointSet1.size());
        return rmsd;        
    }
    
    private void alignAndComputeRmsd(myProtein p1, myProtein p2, int[] range, String[] atomTypes) {                 
        Vector<myAtom> atomSet1 = getAtomsForSuperposition(p1, range, atomTypes);        
        Vector<myAtom> atomSet2 = getAtomsForSuperposition(p2, range, atomTypes);
        
        if (checkCorrespondence(atomSet1, atomSet2)) {
            System.out.println("Good correspondence!");
        } else {
            System.out.println("Error: different number of atoms in the two atom sets chosen or the correspondence between atom sets cannot be established due to mismatch in atom types and/or residue numbers");
            System.exit(1);
        }
        
        Vector<myPoint> pointSet1 = getPointSet(atomSet1);
        Vector<myPoint> pointSet2 = getPointSet(atomSet2);
        
        superposePointSets(pointSet1, pointSet2); // This fixes the translation parts by shifting the centroids of both point sets to the origin
        rotationalFit(pointSet1, pointSet2);
    }
    
    private void noAlignAndComputeRmsd(myProtein p1, myProtein p2, int[] range, String[] atomTypes) {
        Vector<myAtom> atomSet1 = getAtomsForSuperposition(p1, range, atomTypes);
        Vector<myAtom> atomSet2 = getAtomsForSuperposition(p2, range, atomTypes);

        if (checkCorrespondence(atomSet1, atomSet2)) {
            System.out.println("Good correspondence!");
        } else {
            System.out.println("Error: different number of atoms in the two atom sets chosen or the correspondence between atom sets cannot be established due to mismatch in atom types and/or residue numbers");
            System.exit(1);
        }

        Vector<myPoint> pointSet1 = getPointSet(atomSet1);
        Vector<myPoint> pointSet2 = getPointSet(atomSet2);
        
        __centroid1 = computeGeometricCenter(pointSet1);
        __centroid2 = computeGeometricCenter(pointSet2);
        __translation = new myVector3D(0, 0, 0); //new myVector3D(__centroid2, __centroid1); //(tail, head) 
        __rotation = Matrix.identity(3, 3);
        __p2Aligned = p2;
        __rmsd = computeRmsd(pointSet1, pointSet2);
    }
    
    public void superposePointSets(Vector<myPoint> pointSet1, Vector<myPoint> pointSet2) {
        myPoint geomCenter1 = computeGeometricCenter(pointSet1);
        myPoint geomCenter2 = computeGeometricCenter(pointSet2);
        __centroid1 = geomCenter1;
        __centroid2 = geomCenter2;
        
        // Translate both the fragments to origin
        computeTranslation(pointSet1, myVector3D.reverse(new myVector3D(geomCenter1.getXYZ())));
        computeTranslation(pointSet2, myVector3D.reverse(new myVector3D(geomCenter2.getXYZ())));        
        
        // Testing if the translation to origin worked
        //myPoint geomCenter1x = computeGeometricCenter(pointSet1);
        //myPoint geomCenter2x = computeGeometricCenter(pointSet2);        
        //System.out.println("x1: " + geomCenter1x.toString() + "    x2: " + geomCenter2x.toString());
        //System.exit(1);
    }
        
    // taken from myAlignmentTensorEstimator.java
    /**
     * Check for SVD success.
     * @param X first matrix
     * @param Y second matrix
     */
    private void check(Matrix X, Matrix Y) {
        double eps = Math.pow(2.0, -52.0);
        if (X.norm1() == 0. & Y.norm1() < 10 * eps) {
            return;
        }
        if (Y.norm1() == 0. & X.norm1() < 10 * eps) {
            return;
        }
        if (X.minus(Y).norm1() > 1000 * eps * Math.max(X.norm1(), Y.norm1())) {
            throw new RuntimeException("The norm of (X-Y) is too large: " + Double.toString(X.minus(Y).norm1()));
        }
    }
    
    public myVector3D getTranslation() {
        return new myVector3D(__translation);
    }
    
    public Matrix getRotation() {
        return (Matrix) __rotation.clone();
    }
    
    public myProtein getFirstProtein() {
        return new myProtein(__p1);
    }
    
    public myProtein getSecondProtein() {
        return new myProtein(__p2);
    } 
    
    public myProtein getSuperimposedSecondProtein() {
        return new myProtein(__p2Aligned);
    }
    
    public double getRmsd() {
        return __rmsd;
    }

    /**
     * Check for a valid option.
     *
     * @param s string that specifies an option
     * @return return true if the option is valid
     */
    private static boolean validOption(String s) {
        return s.startsWith("-");
    }    
    
    // java utilities/StrictureAligner -pdbfile file1.pdb -pdbfile file2.pdb -atomtypes N CA C -ranges 2 7 12 17
    public void parseCommandLineArgumentsAndInvokeStructureAligner(String[] args) {
        List<String> pdbFiles = new ArrayList<String>();
        List<String> atomTypesForSuperposition = new ArrayList<String>();
        List<Integer> ranges = new ArrayList<Integer>();
        boolean align = true;

        if (args.length == 0) {
            System.out.println("To learn more about this command, please type: " + new Exception().getStackTrace()[0].getClassName() + " -help");
        }
        
        // Parse the command line arguments 
        for (int i = 0; i < args.length && validOption(args[i]);) {
            String thisArgument = args[i++];

            if (thisArgument.equals("-help") || thisArgument.equals("-verbose")) {
                String helpString = "Usage: " + new Exception().getStackTrace()[0].getClassName() + " <options> \nwhere possible options include: \n"                        
                        + "-pdbfile <file name>                   Specify the name of the input PDB file (one model and must end with TER \\n END" + '\n'
                        + "-atomtypes                             Specify the atom types to be used for superposition" + '\n'
                        + "-ranges                                Specify the range of residues to be considered" + '\n'                                             
                        + "-noalign                               Specify that the proteins are not to be aligned" + '\n'
                        + "-verbose                               Print a synopsis of standard options and return" + '\n'
                        + "-help                                  Print a synopsis of standard options and return" + '\n';
                System.out.println(helpString);
                System.exit(0);
                return;
            } else if (thisArgument.equals("-pdbfile")) {
                if (i < args.length) {
                    String thisPdbFile = args[i++];
                    pdbFiles.add(thisPdbFile);
                    if (validOption(thisPdbFile)) {
                        System.out.println("Error: incorrect pdb file name or missing argument(s)");
                        System.exit(1);
                    }
                } else {
                    System.out.println("Error: pdb file name is not supplied");
                    System.exit(1);
                }
            } else if (thisArgument.equals("-noalign")) {
                align = false;
            } else if (thisArgument.equals("-atomtypes")) {                
                while (true) {
                    //System.out.println("asdf");
                    if (i < args.length) {
                        String thisType = args[i++];
                        if (validOption(thisType)) {
                            if (atomTypesForSuperposition.isEmpty()) {
                                System.out.println("Error: incorrect atom types or missing argument(s)");
                                System.exit(1);
                            } else {
                                i--;
                                break;
                            }
                        } else {
                            atomTypesForSuperposition.add(thisType);
                        }
                    }
                    else {
                        if (!atomTypesForSuperposition.isEmpty()) break;
                        System.out.println("Error: atom types are not supplied");
                        System.exit(1);
                    }
                }
            } else if (thisArgument.equals("-ranges")) {                
                while (true) {
                    //System.out.println("asdf");
                    if (i < args.length) {
                        String thisLbOrUb = args[i++];
                        if (validOption(thisLbOrUb)) {
                            if (ranges.isEmpty()) {
                                System.out.println("Error: incorrect ranges or missing argument(s)");
                                System.exit(1);
                            } else {
                                i--;
                                break;
                            }
                        } else {
                            ranges.add(Integer.parseInt(thisLbOrUb));
                        }
                    }
                    else {
                        if (!ranges.isEmpty()) break;
                        System.out.println("Error: ranges are not supplied");
                        System.exit(1);
                    }
                }
            } else {
                System.out.println("Error: incorrect argument specification: " + thisArgument);
                System.exit(1);
            }
        }
        
       if (pdbFiles.size() != 2) {
            System.out.println("Error: exactly two pdb files must be supplied");
            System.exit(1);
        }
        
        if (ranges.isEmpty()) {
            System.out.println("Error: ranges are not supplied");
            System.exit(1);
        } else {
            if (ranges.size() % 2 != 0) {
                System.out.println("Error: incorrect ranges or missing argument(s)");
            }
        }
        
        if (atomTypesForSuperposition.isEmpty()) {
            System.out.println("Error: atom types not supplied");
            System.exit(1);
        }
        
        myProtein referenceStructure = null;
        myProtein otherStructure = null;
        
        // The command line arguments are parsed. Now do the alignment.        
        ReadProteins:
        {
            Vector<myProtein> vp = new Vector<myProtein>();
            myPdbParser pParser = new myPdbParser(pdbFiles.get(0));
            while (pParser.hasNextProtein()) {
                vp.add(pParser.nextProtein());
            }
            referenceStructure = vp.elementAt(0);
            
            Vector<myProtein> vp2 = new Vector<myProtein>();
            myPdbParser pParser2 = new myPdbParser(pdbFiles.get(1));
            while (pParser2.hasNextProtein()) {
                vp2.add(pParser2.nextProtein());
            }
            otherStructure = vp2.elementAt(0);
        }
        
        // Assign all to the current object
        __p1 = new myProtein(referenceStructure);
        __p2 = new myProtein(otherStructure);

        __range = new int[ranges.size()];
        for (int i = 0; i < ranges.size(); i++) {
            __range[i] = ranges.get(i);
        }

        __atomTypes = new String[atomTypesForSuperposition.size()];
        for (int i = 0; i < atomTypesForSuperposition.size(); i++) {
            __atomTypes[i] = atomTypesForSuperposition.get(i);
        }
        
        if (__range == null) {
            System.out.println("Error: range(s) provided is null");
            System.exit(1);
        }
        if (__atomTypes == null) {
            System.out.println("Error: atom type(s) is null");
            System.exit(1);
        }
                
        if (align) {
            alignAndComputeRmsd(__p1, __p2, __range, __atomTypes);
        } else {
            noAlignAndComputeRmsd(__p1, __p2, __range, __atomTypes);
        }
    }

    public static double getRmsd(myProtein p1, myProtein p2, int beginResidueNumber, int endResidueNumber, String[] atomTypes) {
        double rmsd = 0.0;
        int numberOfResidues = endResidueNumber - beginResidueNumber + 1;

        for (int i = beginResidueNumber; i <= endResidueNumber; i++) {
            myResidue r1 = p1.residueAt(i);
            myResidue r2 = p2.residueAt(i);

            for (String thisAtomType : atomTypes) {
                myPoint point1 = r1.getAtom(thisAtomType).getCoordinates();
                myPoint point2 = r2.getAtom(thisAtomType).getCoordinates();
                rmsd += myPoint.squaredDist(point1, point2);
            }
        }
        rmsd = Math.sqrt(rmsd / (atomTypes.length * numberOfResidues));
        return rmsd;
    }
    
    public static void main(String... args) {
        long startTime = System.currentTimeMillis(); // Log the start time
        
//        String arguments = "-pdbfile  EXPERIMENTS_TO_TEST_UTILITIES/1ghh/1GHHModel1.pdb -pdbfile  EXPERIMENTS_TO_TEST_UTILITIES/1ghh/1GHHModel2.pdb -atomtypes N CA C -ranges 2 12 72 80"; // "-noalign", 
//        args = arguments.split("\\s+"); // Give a set of arguments for testing
//        
//        for (String s : args) {
//            System.out.println(s);
//        }
        
        
        StructureAligner thisAligner = new StructureAligner();
        
        thisAligner.parseCommandLineArgumentsAndInvokeStructureAligner(args); 
        
        System.out.println("Printing the first (reference) protein");
        thisAligner.getFirstProtein().print();
        
        System.out.println("Printing the second (to be aligned) protein before alignment");
        thisAligner.getSecondProtein().print();
        
        System.out.println("Printing the second protein after alignment");
        thisAligner.getSuperimposedSecondProtein().print();
        
        System.out.println("Printing the rotation matrix of alignment");
        Matrix rotMatrix = thisAligner.getRotation();        
        rotMatrix.print(8, 5);
        
        System.out.println("Printing the magnitude of translation during the alignment");       
        myVector3D translation = thisAligner.getTranslation();
        System.out.println("Magnitude of translation: " + translation.norm() + "    and the translation vector is : " + translation.toString());
        
        System.out.println("The rmsd is: " + thisAligner.getRmsd());
        
//        myProtein p2Rot = new myProtein(thisAligner.getSecondProtein());
//        p2Rot.rotate(new myMatrix(rotMatrix.getArrayCopy()));
//        p2Rot.translate(translation);
//        System.out.println("Printing the second protein after alignment*");
//        p2Rot.print();
        
        long endTime = System.currentTimeMillis(); // Log the end time
        double totalTime = (double) ((endTime - startTime) / 60000.0); //in minutes
        System.out.println("Time elapsed: " + totalTime + " minutes");                
    }       

}


