comparison src/edu/unc/genomics/visualization/KMeans.java @ 2:e16016635b2a

Uploaded
author timpalpant
date Mon, 13 Feb 2012 22:12:06 -0500
parents
children
comparison
equal deleted inserted replaced
1:a54db233ee3d 2:e16016635b2a
1 package edu.unc.genomics.visualization;
2
3 import java.io.BufferedReader;
4 import java.io.BufferedWriter;
5 import java.io.IOException;
6 import java.nio.charset.Charset;
7 import java.nio.file.Files;
8 import java.nio.file.Path;
9 import java.util.ArrayList;
10 import java.util.Arrays;
11 import java.util.HashMap;
12 import java.util.List;
13 import java.util.Map;
14 import java.util.Random;
15
16 import org.apache.commons.lang3.StringUtils;
17 import org.apache.commons.math.stat.clustering.Cluster;
18 import org.apache.commons.math.stat.clustering.KMeansPlusPlusClusterer;
19 import org.apache.log4j.Logger;
20
21 import com.beust.jcommander.Parameter;
22
23 import edu.unc.genomics.CommandLineTool;
24 import edu.unc.genomics.ReadablePathValidator;
25 import edu.unc.genomics.io.IntervalFileSnifferException;
26 import edu.unc.genomics.io.WigFileException;
27
28 public class KMeans extends CommandLineTool {
29
30 private static final Logger log = Logger.getLogger(KMeans.class);
31
32 @Parameter(names = {"-i", "--input"}, description = "Input file (matrix2png format)", required = true, validateWith = ReadablePathValidator.class)
33 public Path inputFile;
34 @Parameter(names = {"-k", "--clusters"}, description = "Number of clusters")
35 public int k = 10;
36 @Parameter(names = {"-1", "--min"}, description = "Minimum column to use for clustering")
37 public int minCol = 1;
38 @Parameter(names = {"-2", "--max"}, description = "Maximum column to use for clustering")
39 public Integer maxCol;
40 @Parameter(names = {"-o", "--output"}, description = "Output file (clustered matrix2png format)", required = true)
41 public Path outputFile;
42
43 private Map<String, String> rows = new HashMap<String, String>();
44 private List<KMeansRow> data = new ArrayList<KMeansRow>();
45
46 @Override
47 public void run() throws IOException {
48 log.debug("Loading data from the input matrix");
49 String headerLine = "";
50 try (BufferedReader reader = Files.newBufferedReader(inputFile, Charset.defaultCharset())) {
51 // Header line
52 int lineNum = 1;
53 headerLine = reader.readLine();
54 int numColsInMatrix = StringUtils.countMatches(headerLine, "\t");
55
56 // Validate the range info
57 if (maxCol != null) {
58 if (maxCol > numColsInMatrix) {
59 throw new RuntimeException("Invalid range of data specified for clustering ("+maxCol+" > "+numColsInMatrix+")");
60 }
61 } else {
62 maxCol = numColsInMatrix;
63 }
64
65 // Loop over the rows and load the data
66 String line;
67 while ((line = reader.readLine()) != null) {
68 lineNum++;
69 if (StringUtils.countMatches(line, "\t") != numColsInMatrix) {
70 throw new RuntimeException("Irregular input matrix does not have same number of columns on line " + lineNum);
71 }
72
73 int delim = line.indexOf('\t');
74 String id = line.substring(0, delim);
75 String[] row = line.substring(delim+1).split("\t");
76 String[] subset = Arrays.copyOfRange(row, minCol, maxCol);
77 float[] rowData = new float[subset.length];
78 for (int i = 0; i < subset.length; i++) {
79 try {
80 rowData[i] = Float.parseFloat(subset[i]);
81 } catch (NumberFormatException e) {
82 rowData[i] = Float.NaN;
83 }
84 }
85 data.add(new KMeansRow(id, rowData));
86 rows.put(id, line);
87 }
88 }
89
90 // Perform the clustering
91 log.debug("Clustering the data");
92 Random rng = new Random();
93 KMeansPlusPlusClusterer<KMeansRow> clusterer = new KMeansPlusPlusClusterer<KMeansRow>(rng);
94 List<Cluster<KMeansRow>> clusters = clusterer.cluster(data, k, 50);
95
96 // Write to output
97 log.debug("Writing clustered data to output file");
98 try (BufferedWriter writer = Files.newBufferedWriter(outputFile, Charset.defaultCharset())) {
99 writer.write(headerLine);
100 writer.newLine();
101 int n = 1;
102 int count = 1;
103 for (Cluster<KMeansRow> cluster : clusters) {
104 int numRowsInCluster = cluster.getPoints().size();
105 int stop = count + numRowsInCluster - 1;
106 log.info("Cluster "+(n++)+": rows "+count+"-"+stop);
107 count = stop+1;
108 for (KMeansRow row : cluster.getPoints()) {
109 writer.write(rows.get(row.getId()));
110 writer.newLine();
111 }
112 }
113 }
114 }
115
116 public static void main(String[] args) throws IOException, WigFileException, IntervalFileSnifferException {
117 new KMeans().instanceMain(args);
118 }
119
120 }