/*******************************************************************************
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2.1 of the License, or (at your option) any later version.
 * 
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU Lesser General Public License for more details.
 * 
 * You should have received a copy of the GNU Lesser General Public
 * License along with this library; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
 * USA
 * 
 * Contact Info:
 * 	Bruce Donald
 * 	Duke University
 * 	Department of Computer Science
 * 	Levine Science Research Center (LSRC)
 * 	Durham
 * 	NC 27708-0129 
 * 	USA
 * 	brd@cs.duke.edu
 * 
 * Copyright (C) 2011 Jeffrey W. Martin and Bruce R. Donald
 * 
 * <signature of Bruce Donald>, April 2011
 * Bruce Donald, Professor of Computer Science
 ******************************************************************************/


package edu.duke.donaldLab.share.analysis;

import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.TreeSet;

import edu.duke.donaldLab.share.geom.Vector3;
import edu.duke.donaldLab.share.nmr.DistanceRestraint;
import edu.duke.donaldLab.share.nmr.Assignment;
import edu.duke.donaldLab.share.protein.Atom;
import edu.duke.donaldLab.share.protein.AtomAddressInternal;
import edu.duke.donaldLab.share.protein.Element;
import edu.duke.donaldLab.share.protein.Protein;
import edu.duke.donaldLab.share.protein.Residue;
import edu.duke.donaldLab.share.protein.Subunit;

public class RestraintCalculator
{
	/**************************
	 *   Definitions
	 **************************/
	
	private static final boolean DefaultAddAmbiguityCorrections = false;
	
	
	/**************************
	 *   Data Members
	 **************************/
	
	private boolean m_addAmbiguityCorrections;
	
	
	/**************************
	 *   Constructors
	 **************************/
	
	public RestraintCalculator( )
	{
		this( DefaultAddAmbiguityCorrections );
	}
	
	public RestraintCalculator( boolean addAmbiguityCorrections )
	{
		m_addAmbiguityCorrections = addAmbiguityCorrections;
	}
	

	/**************************
	 *   Methods
	 **************************/
	
	public int getNumSatisfied( Protein protein, List<DistanceRestraint<AtomAddressInternal>> restraints )
	{
		int count = 0;
		
		for( DistanceRestraint<AtomAddressInternal> restraint : restraints )
		{
			for( Assignment<AtomAddressInternal> assignment : restraint )
			{
				if( getViolation( protein, restraint, assignment ) == 0.0 )
				{
					count++;
					break;
				}
			}
		}
		
		// just in case...
		assert( count <= restraints.size() );
		
		return count;
	}
	
	public double getRmsd( Protein protein, List<DistanceRestraint<AtomAddressInternal>> restraints )
	{
		double totalDistOverSq = 0.0;
		int count = 0;
		
		for( DistanceRestraint<AtomAddressInternal> restraint : restraints )
		{
			double violation = getViolation( protein, restraint );
			
			// update the rmsd
			totalDistOverSq += violation * violation;
			count++;
		}
		
		return Math.sqrt( totalDistOverSq / (double)count ); 
	}
	
	public double getViolation( Protein protein, DistanceRestraint<AtomAddressInternal> restraint )
	{
		return getViolation( protein, restraint, (AssignmentFilter<AtomAddressInternal>)null );
	}
	
	public double getViolation( Protein protein, DistanceRestraint<AtomAddressInternal> restraint, AssignmentFilter<AtomAddressInternal> filter )
	{
		double minViolation = Double.POSITIVE_INFINITY;
		for( Assignment<AtomAddressInternal> assignment : restraint )
		{
			// check the assignment filter if needed
			if( filter != null && filter.filter( restraint, assignment ) == AssignmentFilter.Result.Block )
			{
				continue;
			}
			
			double violation = getViolation( protein, restraint, assignment );
			if( violation < minViolation )
			{
				minViolation = violation;
			}
		}
		assert( minViolation < Double.POSITIVE_INFINITY );
		return minViolation;
	}
	
	public double getViolation( Protein protein, DistanceRestraint<AtomAddressInternal> restraint, Assignment<AtomAddressInternal> assignment )
	{
		// get the two atom positions
		Vector3 leftPos = protein.getAtom( assignment.getLeft() ).getPosition();
		Vector3 rightPos = protein.getAtom( assignment.getRight() ).getPosition();
		
		double violation = leftPos.getDistance( rightPos ) - restraint.getMaxDistance();
		
		// account for atom distances from the center here if needed
		if( m_addAmbiguityCorrections )
		{
			violation -= getCorrection( protein, restraint.getLefts(), assignment.getLeft() );
			violation -= getCorrection( protein, restraint.getRights(), assignment.getRight() );
		}
		
		if( violation > 0.0 )
		{
			return violation;
		}
		
		return 0.0;
	}
	
	public double getCorrection( Protein protein, Set<AtomAddressInternal> addresses, AtomAddressInternal target )
	{
		// shortcut
		if( addresses.size() == 0 )
		{
			return 0.0;
		}
		
		// filter the list to only the atoms on the same subunit as the target
		TreeSet<AtomAddressInternal> addressesInSameSubunit = new TreeSet<AtomAddressInternal>();
		for( AtomAddressInternal address : addresses )
		{
			if( address.getSubunitId() == target.getSubunitId() )
			{
				addressesInSameSubunit.add( address );
			}
		}
		
		// find the centroid of the atoms in the endpoints
		Vector3 centroid = new Vector3();
		for( AtomAddressInternal address : addressesInSameSubunit )
		{
			centroid.add( protein.getAtom( address ).getPosition() );
		}
		centroid.scale( 1.0 / addressesInSameSubunit.size() );
		
		// find the max distance to the centroid
		double maxDistSq = 0.0;
		for( AtomAddressInternal address : addressesInSameSubunit )
		{
			double distSq = centroid.getSquaredDistance( protein.getAtom( address ).getPosition() );
			
			if( distSq > maxDistSq )
			{
				maxDistSq = distSq;
			}
		}
		
		return Math.sqrt( maxDistSq );
	}
	
	public ArrayList<DistanceRestraint<AtomAddressInternal>> getSimulatedRestraints( int leftSubunitId, int rightSubunitId, Protein protein, double maxDistance )
	{
		ArrayList<DistanceRestraint<AtomAddressInternal>> restraints = new ArrayList<DistanceRestraint<AtomAddressInternal>>();
		
		double thresholdDistanceSq = maxDistance * maxDistance;
		
		// get our subunits
		Subunit leftSubunit = protein.getSubunit( leftSubunitId );
		Subunit rightSubunit = protein.getSubunit( rightSubunitId );
		
		// for every pair of atoms, create a restraint if their distance is less than some distance
		// this is a simple brute force implementation. Proteins aren't that big
		for( Residue leftResidue : leftSubunit.getResidues() )
		{
			for( Atom leftAtom : leftResidue.getAtoms() )
			{
				// only use hydrogens
				if( leftAtom.getElement() != Element.Hydrogen )
				{
					continue;
				}
				
				for( Residue rightResidue : rightSubunit.getResidues() )
				{
					for( Atom rightAtom : rightResidue.getAtoms() )
					{
						// only use hydrogens
						if( rightAtom.getElement() != Element.Hydrogen )
						{
							continue;
						}
						
						// finally, we have a pair of atoms, but skip duplicates
						if( rightAtom == leftAtom ) // yes, compare references
						{
							continue;
						}
						
						// add a new distance restraint if needed
						double distSq = rightAtom.getPosition().getSquaredDistance( leftAtom.getPosition() );
						if( distSq < thresholdDistanceSq )
						{
							DistanceRestraint<AtomAddressInternal> restraint = new DistanceRestraint<AtomAddressInternal>();
							restraint.setLefts( new AtomAddressInternal(
								leftSubunit.getId(),
								leftResidue.getId(),
								leftAtom.getId()
							) );
							restraint.setRights( new AtomAddressInternal(
								rightSubunit.getId(),
								rightResidue.getId(),
								rightAtom.getId()
							) );
							restraint.setMinDistance( 1.8 );
							restraint.setMaxDistance( Math.ceil( Math.sqrt( distSq ) ) );
							
							restraints.add( restraint );
						}
					}
				}
			}
		}
		
		return restraints;
	}
}
