/*******************************************************************************
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2.1 of the License, or (at your option) any later version.
 * 
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU Lesser General Public License for more details.
 * 
 * You should have received a copy of the GNU Lesser General Public
 * License along with this library; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
 * USA
 * 
 * Contact Info:
 * 	Bruce Donald
 * 	Duke University
 * 	Department of Computer Science
 * 	Levine Science Research Center (LSRC)
 * 	Durham
 * 	NC 27708-0129 
 * 	USA
 * 	brd@cs.duke.edu
 * 
 * Copyright (C) 2011 Jeffrey W. Martin and Bruce R. Donald
 * 
 * <signature of Bruce Donald>, April 2011
 * Bruce Donald, Professor of Computer Science
 ******************************************************************************/


package edu.duke.donaldLab.jdshot;

import java.io.File;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;

import edu.duke.donaldLab.jdshot.cluster.DistanceBuilder;
import edu.duke.donaldLab.jdshot.cluster.DuplicateFilter;
import edu.duke.donaldLab.jdshot.grid.GridPoint;
import edu.duke.donaldLab.jdshot.grid.PointIterator;
import edu.duke.donaldLab.jdshot.grid.PointWriter;
import edu.duke.donaldLab.jdshot.grid.Symmetry;
import edu.duke.donaldLab.jdshot.search.SearchContext;
import edu.duke.donaldLab.share.clustering.Cluster;
import edu.duke.donaldLab.share.clustering.Clusterer;
import edu.duke.donaldLab.share.clustering.distance.DistanceCluster;
import edu.duke.donaldLab.share.clustering.distance.DistanceClusterer;
import edu.duke.donaldLab.share.clustering.distance.DistanceMatrix;
import edu.duke.donaldLab.share.clustering.distance.DistanceMatrixReader;
import edu.duke.donaldLab.share.clustering.distance.DistanceMatrixWriter;
import edu.duke.donaldLab.share.clustering.medianStrategy.MedianStrategyLongestDimension;
import edu.duke.donaldLab.share.clustering.stopCondition.StopConditionNumClusters;
import edu.duke.donaldLab.share.io.ArgumentType;
import edu.duke.donaldLab.share.io.ArgumentsProcessor;
import edu.duke.donaldLab.share.math.MultiVector;
import edu.duke.donaldLab.share.math.MultiVectorImpl;
import edu.duke.donaldLab.share.perf.Progress;
import edu.duke.donaldLab.share.perf.StreamMessageListener;
import edu.duke.donaldLab.share.perf.Timer;
import edu.duke.donaldLab.share.protein.Subunit;

public class ClusterMain
{
	/**************************
	 *   Static Methods
	 **************************/
	
	public static void main( String[] args )
	throws Exception
	{
		// process the arguments
		ArgumentsProcessor argproc = new ArgumentsProcessor();
		argproc.add( "symmetry", ArgumentType.String, "symmetry type (e.g., Cn, Dn)" );
		argproc.add( "mode", ArgumentType.String, "euclidean, rmsd, matrix, dupe, dupeMat" );
		argproc.add( "inPath", ArgumentType.InFile, "path to input points/cells" );
		argproc.add( "outPointsPath", ArgumentType.OutFile, "p", null, "path to the output clutered points" );
		argproc.add( "inMatrixPath", ArgumentType.InFile, "d", null, "path to distance matrix" );
		argproc.add( "outMatrixPath", ArgumentType.OutFile, "D", null, "path to distance matrix" );
		argproc.add( "numClusters", ArgumentType.Integer, "c", "1", "number of points to \"cluster down to\"" );
		argproc.add( "monomerPath", ArgumentType.InFile, "m", null, "path to the monomer protein" );
		argproc.add( "numSubunits", ArgumentType.Integer, "n", "2", "number of subunits" );
		argproc.add( "minRmsd", ArgumentType.Double, "r", "1.0", "minimum RMSD between clusters" );
		argproc.add( "numThreads", ArgumentType.Double, "t", "1", "number of threads to use for computing distance matrix" );
		argproc.process( args );
		
		// read the arguments for the monomer path and the noes path
		Symmetry symmetry = Symmetry.valueOf( argproc.getString( "symmetry" ) );
		String mode = argproc.getString( "mode" );
		File inFile = argproc.getFile( "inPath" );
		File outPointsFile = argproc.getFile( "outPointsPath" );
		File inMatrix = argproc.getFile( "inMatrixPath" );
		File outMatrix = argproc.getFile( "outMatrixPath" );
		Integer numClusters = argproc.getInteger( "numClusters" );
		File monomerFile = argproc.getFile( "monomerPath" );
		Integer numSubunits = argproc.getInteger( "numSubunits" );
		Double minRmsd = argproc.getDouble( "minRmsd" );
		Integer numThreads = argproc.getInteger( "numThreads" );
		
		// start the clustering
		System.out.println( "Clustering starting..." );
		if( mode.equalsIgnoreCase( "euclidean" ) )
		{
			argproc.modeRequire( mode, "outPointsPath" );
			argproc.modeRequire( mode, "numClusters" );
			clusterPointsEuclidean( symmetry, outPointsFile, inFile, numClusters );
		}
		else if( mode.equalsIgnoreCase( "rmsd" ) )
		{
			argproc.modeRequire( mode, "outPointsPath" );
			argproc.modeRequire( mode, "inMatrixPath" );
			argproc.modeRequire( mode, "minRmsd" );
			clusterPointsRmsd( symmetry, outPointsFile, inFile, inMatrix, minRmsd );
		}
		else if( mode.equalsIgnoreCase( "matrix" ) )
		{
			argproc.modeRequire( mode, "outMatrixPath" );
			argproc.modeRequire( mode, "monomerPath" );
			argproc.modeRequire( mode, "numSubunits" );
			buildDistanceMatrix( symmetry, outMatrix, inFile, getMonomer( symmetry, monomerFile ), numSubunits, numThreads );
		}
		else if( mode.equalsIgnoreCase( "dupe" ) )
		{
			argproc.modeRequire( mode, "outPointsPath" );
			argproc.modeRequire( mode, "monomerPath" );
			argproc.modeRequire( mode, "numSubunits" );
			filterDuplicates( symmetry, outPointsFile, inFile, null, getMonomer( symmetry, monomerFile ), numSubunits ); 
		}
		else if( mode.equalsIgnoreCase( "dupeMat" ) )
		{
			argproc.modeRequire( mode, "outPointsPath" );
			argproc.modeRequire( mode, "inMatrixPath" );
			filterDuplicates( symmetry, outPointsFile, inFile, inMatrix, null, -1 ); 
		}
		else
		{
			throw new Exception( "Invalid mode: " + mode );
		}
		System.out.println( "Clustering complete!" );
	}
	
	private static Subunit getMonomer( Symmetry symmetry, File file )
	throws Exception
	{
		// create a search context
		System.out.println( "Loading search context..." );
		SearchContext searchContext = new SearchContext();
		searchContext.load( symmetry, file, null, null );
		System.out.println( "Loaded search context!" );
		return searchContext.getMonomer();
	}
	
	private static void buildDistanceMatrix( Symmetry symmetry, File outMatrix, File inFile, Subunit monomer, int numSubunits, int numThreads )
	throws Exception
	{
		PointIterator iterPoint = PointIterator.newFromPath( inFile, symmetry );
		int numPoints = (int)iterPoint.getNumPoints();
		
		// ALERT
		System.out.println( "Reading " + numPoints + " points into memory..." );
		
		// read the points into memory
		ArrayList<GridPoint> points = new ArrayList<GridPoint>( numPoints );
		while( iterPoint.hasNext() )
		{
			points.add( iterPoint.next() );
		}
		
		// ALERT
		long numEntries = DistanceMatrix.getSize( numPoints );
		System.out.println( "Allocating distance matrix with " + numEntries + " entries..." );
		float numMib = (float)numEntries * 8.0f / 1024.0f / 1024.0f;
		System.out.println( "Estimated memory needed: " + String.format( "%.2f", numMib ) + " MiB" );
		System.out.println( "building distance matrix..." );
		Timer distanceTimer = new Timer( "distance matrix" );
		distanceTimer.start();
		
		DistanceMatrix distances = DistanceBuilder.build( symmetry, points, monomer, numSubunits, numThreads );
		
		// ALERT
		distanceTimer.stop();
		System.out.println( distanceTimer );
		System.out.println( "Entires per second: " + (int)( numEntries / distanceTimer.getElapsedSeconds() ) );
		
		// write out the distance matrix
		DistanceMatrixWriter.write( outMatrix, distances );
	}
	
	private static void clusterPointsRmsd( Symmetry symmetry, File outPointsFile, File inFile, File inMatrix, double minRmsd )
	throws Exception
	{
		// just in case, check the numbers of points
		int numInPoints = (int)PointIterator.newFromPath( inFile, symmetry ).getNumPoints();
		int numMatrixPoints = DistanceMatrixReader.getNumPoints( inMatrix );
		assert( numInPoints == numMatrixPoints ) : "in(" + numInPoints + ") != matrix(" + numMatrixPoints + ")" ;
		
		// ALERT
		System.out.println( "Reading distance matrix..." );
		Timer distanceTimer = new Timer( "distance" );
		distanceTimer.start();
		
		DistanceMatrixReader.setProgressListener( new StreamMessageListener() );
		DistanceMatrix distances = DistanceMatrixReader.read( inMatrix );
		
		// ALERT
		distanceTimer.stop();
		System.out.println( distanceTimer.toString() );
		System.out.println( "Clustering..." );
		Timer clusteringTimer = new Timer( "clustering" );
		clusteringTimer.start();
		
		// perform the clustering
		DistanceClusterer clusterer = new DistanceClusterer( distances, minRmsd );
		clusterer.setProgressListener( new StreamMessageListener() );
		List<DistanceCluster> clusters = clusterer.cluster();
		
		// ALERT
		clusteringTimer.stop();
		System.out.println( clusteringTimer.toString() );
		System.out.println( "before clustering: " + distances.getNumPoints() + " points" );
		System.out.println( "after clustering: " + clusters.size() + " points" );
		System.out.println( "reading in points..." );
		
		// read the points into memory
		PointIterator iterPoint = PointIterator.newFromPath( inFile, symmetry );
		int numPoints = (int)iterPoint.getNumPoints();
		ArrayList<GridPoint> points = new ArrayList<GridPoint>( numPoints );
		while( iterPoint.hasNext() )
		{
			points.add( iterPoint.next() );
		}
		
		// ALERT
		System.out.println( "writing out points..." );
		Progress writeProgress = new Progress( clusters.size(), 1000 );
		
		// get the representative grid points
		PointWriter pointWriter = new PointWriter( outPointsFile );
		for( DistanceCluster cluster : clusters )
		{
			pointWriter.writePoint( points.get( cluster.getReprepresentativeIndex( distances ) ) );
			writeProgress.incrementProgress();
		}
		pointWriter.close();
		
		// ALERT
		System.out.println( "DONE!!!" );
	}
	
	private static void clusterPointsEuclidean( Symmetry symmetry, File outPointsFile, File inFile, int numClusters )
	throws Exception
	{
		// ALERT
		Timer overallTimer = new Timer( "overall" );
		overallTimer.start();
		Timer prepTimer = new Timer( "prep" );
		prepTimer.start();
		
		PointIterator iterPoint = PointIterator.newFromPath( inFile, symmetry );
		
		// NOTE: let's just hope we don't get too many points (i.e. > 2bil)
		int numPoints = (int)iterPoint.getNumPoints();
		
		// just in case...
		assert( numPoints >= 0 && numPoints <= Integer.MAX_VALUE );
		
		// ALERT
		System.out.println( "Before clustering: " + numPoints + " points" );
		
		// get our points
		List<MultiVector> points = null;
		System.out.println( "Reading in points..." );
		
		// read all the cells into memory and convert them into points
		points = new ArrayList<MultiVector>( numPoints );
		Progress readProgress = new Progress( numPoints, 5000 );
		readProgress.setMessageListener( new StreamMessageListener() );
		while( iterPoint.hasNext() )
		{
			GridPoint gridPoint = iterPoint.next();
			MultiVector point = new MultiVectorImpl( symmetry.getNumDimensions() );
			for( int i=0; i<symmetry.getNumDimensions(); i++ )
			{
				point.set( i, gridPoint.get( i ) );
			}
			points.add( point );
			
			readProgress.incrementProgress();
		}
		
		System.out.println( "points read!" );
		
		// init the clusterer
		Clusterer clusterer = new Clusterer();
		clusterer.setStopCondition( new StopConditionNumClusters( numClusters ) );
		clusterer.setProgressListener( new StreamMessageListener() );
		
		// ALERT
		prepTimer.stop();
		System.out.println( "prep complete!" );
		System.out.println( prepTimer.toString() );
		System.out.println( "Clustering..." );
		Timer clusteringTimer = new Timer( "clustering" );
		clusteringTimer.start();
		
		// perform the clustering
		List<Cluster> clusters = clusterer.cluster( points, new MedianStrategyLongestDimension() );
		
		// ALERT
		clusteringTimer.stop();
		System.out.println( clusteringTimer.toString() );
		System.out.println( "after clustering: " + clusters.size() + " points" );
		System.out.println( "writing out points..." );
		Timer writeTimer = new Timer( "write" );
		writeTimer.start();
		
		// get the representative grid points
		PointWriter pointWriter = new PointWriter( outPointsFile );
		for( Cluster cluster : clusters )
		{
			MultiVector representativePoint = cluster.getRepresentativePoint();
			GridPoint point = symmetry.newGridPoint();
			for( int i=0; i<symmetry.getNumDimensions(); i++ )
			{
				point.set( i, representativePoint.get( i ) );
			}
			pointWriter.writePoint( point );
		}
		pointWriter.close();
		
		// ALERT
		writeTimer.stop();
		System.out.println( "Done writing points!" );
		System.out.println( writeTimer.toString() );
		overallTimer.stop();
		System.out.println( "DONE!!! " );
		System.out.println( overallTimer.toString() );
	}
	
	private static void filterDuplicates( Symmetry symmetry, File outPointsFile, File inFile, File inMatrix, Subunit monomer, int numSubunits )
	throws Exception
	{
		// ALERT
		Timer overallTimer = new Timer( "overall" );
		overallTimer.start();
		
		// read the distance matrix if one is provided
		DistanceMatrix distances = null;
		if( inMatrix != null )
		{
			// ALERT
			System.out.println( "Reading distance matrix..." );
			Timer distanceTimer = new Timer( "distance" );
			distanceTimer.start();
			
			DistanceMatrixReader.setProgressListener( new StreamMessageListener() );
			distances = DistanceMatrixReader.read( inMatrix );
			
			// ALERT
			distanceTimer.stop();
			System.out.println( distanceTimer.toString() );
		}
		
		// ALERT
		System.out.println( "reading points..." );
		
		// read the input points
		PointIterator iterPoint = PointIterator.newFromPath( inFile, symmetry );
		LinkedList<GridPoint> points = new LinkedList<GridPoint>();
		while( iterPoint.hasNext() )
		{
			points.add( iterPoint.next() );
		}
		
		// ALERT
		System.out.println( "filtering..." );
		
		// perform the filtering
		DuplicateFilter duplicateFilter = null;
		if( distances != null )
		{
			// with distance matrix
			duplicateFilter = new DuplicateFilter( distances );
		}
		else
		{
			// without distance matrix
			duplicateFilter = new DuplicateFilter( monomer, numSubunits );
		}
		duplicateFilter.setMessageListener( new StreamMessageListener() );
		points = duplicateFilter.filter( symmetry, points );
		
		// ALERT
		System.out.println( "before filtering: " + iterPoint.getNumPoints() + " points" );
		System.out.println( "after filtering: " + points.size() + " points" );
		System.out.println( "writing out points..." );
		
		// write out the results
		PointWriter pointWriter = new PointWriter( outPointsFile );
		for( GridPoint point : points )
		{
			pointWriter.writePoint( point );
		}
		pointWriter.close();
		
		// ALERT
		overallTimer.stop();
		System.out.println( "DONE!!! " );
		System.out.println( overallTimer.toString() );
	}
}
