#-------------------------------------------------------------------------------
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
# 
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
# 
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
# USA
# 
# Contact Info:
# 	Bruce Donald
# 	Duke University
# 	Department of Computer Science
# 	Levine Science Research Center (LSRC)
# 	Durham
# 	NC 27708-0129 
# 	USA
# 	brd@cs.duke.edu
# 
# Copyright (C) 2011 Jeffrey W. Martin and Bruce R. Donald
# 
# <signature of Bruce Donald>, April 2011
# Bruce Donald, Professor of Computer Science
#-------------------------------------------------------------------------------


import jvm, share, jdshot, util, math
from Struct import Struct
from figure import Figure

jdshot.getJvm().start( "-Xmx2g" )

SubunitOrder = share.f.protein.SubunitOrder
Vector3 = share.f.geom.Vector3
CnGridPoint = jdshot.f.grid.cn.CnGridPoint


def loadInputs( path ):
	
	print "reading inputs from\n\t%s" % path
	
	# read the in file
	axisPositions = []
	frameRotation = []
	frameTranslation = []
	file = open( path )
	mode = 0
	for line in file:
		
		# skip comments
		line = line.strip()
		if line.startswith( "#" ):
			continue
		
		# if blank line, read next section of file
		if line == "":
			mode += 1
			continue
		
		# convert the line to a list of floats
		floats = [float( part ) for part in line.split( '\t' )]
		
		# handle the section
		if mode == 0:
			axisPositions.append( Vector3( floats[0], floats[1], 0.0 ) )
		elif mode == 1:
			frameRotation.append( Vector3( floats[0], floats[1], floats[2] ) )
		elif mode == 2:
			frameTranslation.append( Vector3( floats[0], floats[1], floats[2] ) )
	
	print "\tread %i axis positions" % len( axisPositions )
	
	rotationMat = share.f.math.Matrix3()
	rotationMat.setRows( frameRotation[0], frameRotation[1], frameRotation[2] ),
	return Struct(
		axisPositions = axisPositions,
		frameRotation = rotationMat,
		frameTranslation = frameTranslation[0]
	)


def analyzeDistance( referenceAxisPosition, computedPositions ):
	
	minDistSq = min( referenceAxisPosition.getSquaredDistance( position ) for position in computedPositions )
	maxDistSq = max( referenceAxisPosition.getSquaredDistance( position ) for position in computedPositions )
	print "Min computed axis distance to reference Axis: min=%.2f, max=%.2f\n" % ( math.sqrt( minDistSq ), math.sqrt( maxDistSq ) )
	

def analyzeRmsd( pathOut, referenceStructure, referenceAxisStructure, computedStructures ):
	
	exactBounds = False
	
	print "Reference Axis Backbone RMSD: %f" % share.getRmsd( referenceStructure, referenceAxisStructure )
	rmsds = [share.getRmsd( referenceStructure, structure ) for structure in computedStructures]
	
	# print out individual RMSDs
	for rmsd in rmsds:
		print "Computed Backbone RMSD: %f" % rmsd
	print "Min: %f" % min( rmsds )
	print "Max: %f" % max( rmsds )
	
	# render the plot
	fig = Figure()
	( _, bins, _ ) = fig.plot.hist( rmsds, bins=10, facecolor="#004586" )
	fig.plot.set_xlabel( u"RMSD to reference structure in \u212B" )
	fig.plot.set_ylabel( "Number of structures" )
#	if exactBounds:
#		fig.plot.set_xticks( bins )
#		fig.plot.set_xticklabels( ["%.2f" % s for s in bins], rotation=-45, horizontalalignment="left" )
#	else:
#		fig.plot.set_xbound(
#			lower=math.floor( max( rmsds ) * 10.0 ) / 10.0,
#			upper=math.ceil( max( rmsds ) * 10.0 ) / 10.0
#		)
	
	fig.save( pathOut )
	print "Write RMSDs figure to:\n\t%s" % pathOut
	
	
def analyzeVariance( referenceStructures, computedStructures ):
	
	print "avg backbone reference variance: %f" % share.getAverageVariance( jvm.toArrayList( referenceStructures ) )
	print "avg backbone computed variance: %f" % share.getAverageVariance( jvm.toArrayList( computedStructures ) )


def analyzeRestraintsForStructure( name, structure, noes ):
	
	util.setSilence( True )
	restraints = share.mapDistanceRestraintsToProtein( noes, structure )
	util.setSilence( False );
	
	rmsd = share.getDistanceRestraintRmsd( structure, restraints )
	numSatisfied = share.getDistanceRestraintNumSatisfied( structure, restraints )
	numRestraints = len( restraints )
	
	print "%s NOE RMSD: %f\tNum Satisfied: %d/%d %.2f" % (
		name,
		rmsd,
		numSatisfied,
		numRestraints,
		numSatisfied * 100.0 / numRestraints
	)
	
	
def analyzeRestraints( referenceStructure, referenceAxisStructure, computedStructures, noes ):
	
	analyzeRestraintsForStructure( "Reference", referenceStructure, noes )
	analyzeRestraintsForStructure( "Reference Axis", referenceAxisStructure, noes )
	for computedStructure in computedStructures:
		analyzeRestraintsForStructure( "Computed", computedStructure, noes )

	
def analyze( pathOutMonomer, pathOutPoints, pathOutRmsdsPlot, pathOutViolationOligomer, pathIn, pathMonomer, pathOligomer, pathEnsemble, pathNoes, subunitOrder, referenceAxisPosition, violationAxisPosition, padPercent ):
	
	# read inputs
	inputs = loadInputs( pathIn )
	monomerStructure = share.loadProtein( pathMonomer ).getSubunit( 0 )
	oligomerStructure = share.loadProtein( pathOligomer )
	ensembleStructures = share.loadProteins( pathEnsemble )
	
	# transform the proteins to the new frame
	ProteinGeometry = share.f.protein.tools.ProteinGeometry
	ProteinGeometry.translate( monomerStructure, inputs.frameTranslation )
	ProteinGeometry.transform( monomerStructure, inputs.frameRotation )
	ProteinGeometry.translate( oligomerStructure, inputs.frameTranslation )
	ProteinGeometry.transform( oligomerStructure, inputs.frameRotation )
	
	# generate reference axis structure
	gridPoint = CnGridPoint( referenceAxisPosition.x, referenceAxisPosition.y, 0.0, 0.0 )
	referenceAxisStructure = jdshot.getOligomer(
		monomerStructure,
		gridPoint,
		subunitOrder.getNumSubunits(),
		jdshot.Cn
	)
	subunitOrder.convertComputedToReference( referenceAxisStructure )
	
	# generate computed structures
	computedStructures = []
	gridPoints = []
	for axisPosition in inputs.axisPositions:
		gridPoint = CnGridPoint( axisPosition.x, axisPosition.y, 0.0, 0.0 )
		gridPoints.append( gridPoint )
		computedStructure = jdshot.getOligomer(
			monomerStructure,
			gridPoint,
			subunitOrder.getNumSubunits(),
			jdshot.Cn
		)
		subunitOrder.convertComputedToReference( computedStructure )
		computedStructures.append( computedStructure )
	print "Generated %d oligomer structures." % len( computedStructures )
		
	# get backbones
	referenceAxisBackbone = referenceAxisStructure.getBackbone()
	oligomerBackbone = oligomerStructure.getBackbone()
	ensembleBackbones = [protein.getBackbone() for protein in ensembleStructures]
	computedBackbones = [protein.getBackbone() for protein in computedStructures]
	
	# load the restraints
	noes = share.loadDistanceRestraints( pathNoes )
	share.interpretDistanceRestraintsWithPseudoatoms( noes, oligomerStructure )
	restraints = share.mapDistanceRestraintsToProtein( noes, oligomerStructure )
	
	# pad the distances if needed
	if padPercent > 0:
		share.padRestraints( restraints, padPercent )
	
	share.printRestraintsStats( restraints )
	restraints = share.filterRestraintsIntersubunit( restraints )	
	share.printRestraintsStats( restraints )
	noes = share.unmapDistanceRestraintsFromProtein( restraints, oligomerStructure )

	# perform all the backbone alignments
	share.alignOptimally( oligomerBackbone, referenceAxisBackbone )
	share.alignEnsembleOptimally( oligomerBackbone, computedBackbones )
	
	# TEMP
	#share.saveEnsembleKinemage( "/home/jeff/ensemble.kin", computedBackbones, oligomerBackbone )
	
	# do the analysis
	analyzeDistance( referenceAxisPosition, inputs.axisPositions )
	analyzeRmsd( pathOutRmsdsPlot, oligomerBackbone, referenceAxisBackbone, computedBackbones )
	analyzeVariance( ensembleBackbones, computedBackbones )
	analyzeRestraints( oligomerStructure, referenceAxisStructure, computedStructures, noes )
	
	# output the transformed version of the monomer and grid points
	share.writeProtein( pathOutMonomer, share.f.protein.Protein( monomerStructure ) )
	jdshot.writePoints( pathOutPoints, gridPoints )
	
	# if we have a violation position, generate and save the structure
	if violationAxisPosition is not None:
		gridPoint = CnGridPoint( violationAxisPosition.x, violationAxisPosition.y, 0.0, 0.0 )
		computedStructure = jdshot.getOligomer(
			monomerStructure,
			gridPoint,
			subunitOrder.getNumSubunits(),
			jdshot.Cn
		)
		subunitOrder.convertComputedToReference( computedStructure )
		share.writeProtein( pathOutViolationOligomer, computedStructure )
	
	print "Done!"


# set up paths
dirData = "/home/jeff/duke/donaldLab/dshot/input"
dirResults = "/home/jeff/duke/donaldLab/noe+rdc/results"

# GB1
#dirResultsGB1 = dirResults + "/GB1/reassigned"
#analyze(
#	dirResultsGB1 + "/postProcessing/transformedMonomer.protein",
#	dirResultsGB1 + "/postProcessing/out.points",
#	dirResultsGB1 + "/postProcessing/backboneRmsds.pdf",
#	dirResultsGB1 + "/postProcessing/violation.protein",
#	dirResultsGB1 + "/matlabOut.txt",
#	dirData + "/1Q10.monomer.protein",
#	dirData + "/1Q10.oligomer.protein",
#	dirData + "/raw/1Q10.pdb",
#	dirData + "/1Q10.experimental.fixed.noe",
#	SubunitOrder( "AB" ),
#	Vector3( 1.7392, 2.2211, 0.0 ),
#	Vector3( 1.7552, 2.0590, 0.0 ),
#	0
#)

# DAGK
#dirResultsDAGK = dirResults + "/DAGK/annuli/padding=3%"
#analyze(
#	dirResultsDAGK + "/postProcessing/transformedMonomer.protein",
#	dirResultsDAGK + "/postProcessing/out.points",
#	dirResultsDAGK + "/postProcessing/backboneRmsds.pdf",
#	dirResultsDAGK + "/postProcessing/violation.protein",
#	dirResultsDAGK + "/stage2/matlabOut.txt",
#	dirData + "/2KDC.monomer.protein",
#	dirData + "/2KDC.oligomer.protein",
#	dirData + "/raw/2KDC.pdb",
#	dirData + "/2KDC.experimental.noe",
#	SubunitOrder( "ACB" ),
#	Vector3( -10.1733, -8.3759, 0.0 ),
#	Vector3( -10.0840, -8.2633, 0.0 ),
#	0.03
#)

dirResultsDAGK = dirResults + "/DAGK/annuli/padding=3%"
analyze(
	dirResultsDAGK + "/stage1/transformedMonomer.protein",
	dirResultsDAGK + "/stage1/out.points",
	dirResultsDAGK + "/stage1/backboneRmsds.pdf",
	dirResultsDAGK + "/stage1/violation.protein",
	dirResultsDAGK + "/stage1/matlabOut.txt",
	dirData + "/2KDC.monomer.protein",
	dirData + "/2KDC.oligomer.protein",
	dirData + "/raw/2KDC.pdb",
	dirData + "/2KDC.experimental.onlyDisulfideBonds.noe",
	SubunitOrder( "ACB" ),
	Vector3( -10.1733, -8.3759, 0.0 ),
	Vector3( -10.0840, -8.2633, 0.0 ),
	0.03
)

# DAGK with PREs
#analyze(
#	dirResultsDAGK + "/transformedMonomer.withPREs.protein",
#	dirResultsDAGK + "/out.withPREs.points",
#	dirResultsDAGK + "/backboneRmsds.withPREs.pdf",
#	dirResultsDAGK + "/matlabOut.withPREs.txt",
#	dirData + "/2KDC.monomer.protein",
#	dirData + "/2KDC.oligomer.protein",
#	dirData + "/raw/2KDC.pdb",
#	dirData + "/2KDC.experimental.noe",
#	SubunitOrder( "ACB" ),
#	Vector( -6.0843, -1.9681, 0.0 )
#)
