diff src/edu/unc/genomics/visualization/KMeans.java @ 2:e16016635b2a

Uploaded
author timpalpant
date Mon, 13 Feb 2012 22:12:06 -0500
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/edu/unc/genomics/visualization/KMeans.java	Mon Feb 13 22:12:06 2012 -0500
@@ -0,0 +1,120 @@
+package edu.unc.genomics.visualization;
+
+import java.io.BufferedReader;
+import java.io.BufferedWriter;
+import java.io.IOException;
+import java.nio.charset.Charset;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import org.apache.commons.lang3.StringUtils;
+import org.apache.commons.math.stat.clustering.Cluster;
+import org.apache.commons.math.stat.clustering.KMeansPlusPlusClusterer;
+import org.apache.log4j.Logger;
+
+import com.beust.jcommander.Parameter;
+
+import edu.unc.genomics.CommandLineTool;
+import edu.unc.genomics.ReadablePathValidator;
+import edu.unc.genomics.io.IntervalFileSnifferException;
+import edu.unc.genomics.io.WigFileException;
+
+public class KMeans extends CommandLineTool {
+
+	private static final Logger log = Logger.getLogger(KMeans.class);
+
+	@Parameter(names = {"-i", "--input"}, description = "Input file (matrix2png format)", required = true, validateWith = ReadablePathValidator.class)
+	public Path inputFile;
+	@Parameter(names = {"-k", "--clusters"}, description = "Number of clusters")
+	public int k = 10;
+	@Parameter(names = {"-1", "--min"}, description = "Minimum column to use for clustering")
+	public int minCol = 1;
+	@Parameter(names = {"-2", "--max"}, description = "Maximum column to use for clustering")
+	public Integer maxCol;
+	@Parameter(names = {"-o", "--output"}, description = "Output file (clustered matrix2png format)", required = true)
+	public Path outputFile;
+	
+	private Map<String, String> rows = new HashMap<String, String>();
+	private List<KMeansRow> data = new ArrayList<KMeansRow>();
+	
+	@Override
+	public void run() throws IOException {
+		log.debug("Loading data from the input matrix");
+		String headerLine = "";
+		try (BufferedReader reader = Files.newBufferedReader(inputFile, Charset.defaultCharset())) {
+			// Header line
+			int lineNum = 1;
+			headerLine = reader.readLine();
+			int numColsInMatrix = StringUtils.countMatches(headerLine, "\t");
+			
+			// Validate the range info
+			if (maxCol != null) {
+				if (maxCol > numColsInMatrix) {
+					throw new RuntimeException("Invalid range of data specified for clustering ("+maxCol+" > "+numColsInMatrix+")");
+				}
+			} else {
+				maxCol = numColsInMatrix;
+			}
+			
+			// Loop over the rows and load the data
+			String line;
+			while ((line = reader.readLine()) != null) {
+				lineNum++;
+				if (StringUtils.countMatches(line, "\t") != numColsInMatrix) {
+					throw new RuntimeException("Irregular input matrix does not have same number of columns on line " + lineNum);
+				}
+				
+				int delim = line.indexOf('\t');
+				String id = line.substring(0, delim);
+				String[] row = line.substring(delim+1).split("\t");
+				String[] subset = Arrays.copyOfRange(row, minCol, maxCol);
+				float[] rowData = new float[subset.length];
+				for (int i = 0; i < subset.length; i++) {
+					try {
+						rowData[i] = Float.parseFloat(subset[i]);
+					} catch (NumberFormatException e) {
+						rowData[i] = Float.NaN;
+					}
+				}
+				data.add(new KMeansRow(id, rowData));
+				rows.put(id, line);
+			}
+		}
+		
+		// Perform the clustering
+		log.debug("Clustering the data");
+		Random rng = new Random();
+		KMeansPlusPlusClusterer<KMeansRow> clusterer = new KMeansPlusPlusClusterer<KMeansRow>(rng);
+		List<Cluster<KMeansRow>> clusters = clusterer.cluster(data, k, 50);
+		
+		// Write to output
+		log.debug("Writing clustered data to output file");
+		try (BufferedWriter writer = Files.newBufferedWriter(outputFile, Charset.defaultCharset())) {
+			writer.write(headerLine);
+			writer.newLine();
+			int n = 1;
+			int count = 1;
+			for (Cluster<KMeansRow> cluster : clusters) {
+				int numRowsInCluster = cluster.getPoints().size();
+				int stop = count + numRowsInCluster - 1;
+				log.info("Cluster "+(n++)+": rows "+count+"-"+stop);
+				count = stop+1;
+				for (KMeansRow row : cluster.getPoints()) {
+					writer.write(rows.get(row.getId()));
+					writer.newLine();
+				}
+			}
+		}
+	}
+	
+	public static void main(String[] args) throws IOException, WigFileException, IntervalFileSnifferException {
+		new KMeans().instanceMain(args);
+	}
+
+}