2
|
1 package edu.unc.genomics.visualization;
|
|
2
|
|
3 import java.io.BufferedReader;
|
|
4 import java.io.BufferedWriter;
|
|
5 import java.io.IOException;
|
|
6 import java.nio.charset.Charset;
|
|
7 import java.nio.file.Files;
|
|
8 import java.nio.file.Path;
|
|
9 import java.util.ArrayList;
|
|
10 import java.util.Arrays;
|
|
11 import java.util.HashMap;
|
|
12 import java.util.List;
|
|
13 import java.util.Map;
|
|
14 import java.util.Random;
|
|
15
|
|
16 import org.apache.commons.lang3.StringUtils;
|
|
17 import org.apache.commons.math.stat.clustering.Cluster;
|
|
18 import org.apache.commons.math.stat.clustering.KMeansPlusPlusClusterer;
|
|
19 import org.apache.log4j.Logger;
|
|
20
|
|
21 import com.beust.jcommander.Parameter;
|
|
22
|
|
23 import edu.unc.genomics.CommandLineTool;
|
|
24 import edu.unc.genomics.ReadablePathValidator;
|
|
25 import edu.unc.genomics.io.IntervalFileSnifferException;
|
|
26 import edu.unc.genomics.io.WigFileException;
|
|
27
|
|
28 public class KMeans extends CommandLineTool {
|
|
29
|
|
30 private static final Logger log = Logger.getLogger(KMeans.class);
|
|
31
|
|
32 @Parameter(names = {"-i", "--input"}, description = "Input file (matrix2png format)", required = true, validateWith = ReadablePathValidator.class)
|
|
33 public Path inputFile;
|
|
34 @Parameter(names = {"-k", "--clusters"}, description = "Number of clusters")
|
|
35 public int k = 10;
|
|
36 @Parameter(names = {"-1", "--min"}, description = "Minimum column to use for clustering")
|
|
37 public int minCol = 1;
|
|
38 @Parameter(names = {"-2", "--max"}, description = "Maximum column to use for clustering")
|
|
39 public Integer maxCol;
|
|
40 @Parameter(names = {"-o", "--output"}, description = "Output file (clustered matrix2png format)", required = true)
|
|
41 public Path outputFile;
|
|
42
|
|
43 private Map<String, String> rows = new HashMap<String, String>();
|
|
44 private List<KMeansRow> data = new ArrayList<KMeansRow>();
|
|
45
|
|
46 @Override
|
|
47 public void run() throws IOException {
|
|
48 log.debug("Loading data from the input matrix");
|
|
49 String headerLine = "";
|
|
50 try (BufferedReader reader = Files.newBufferedReader(inputFile, Charset.defaultCharset())) {
|
|
51 // Header line
|
|
52 int lineNum = 1;
|
|
53 headerLine = reader.readLine();
|
|
54 int numColsInMatrix = StringUtils.countMatches(headerLine, "\t");
|
|
55
|
|
56 // Validate the range info
|
|
57 if (maxCol != null) {
|
|
58 if (maxCol > numColsInMatrix) {
|
|
59 throw new RuntimeException("Invalid range of data specified for clustering ("+maxCol+" > "+numColsInMatrix+")");
|
|
60 }
|
|
61 } else {
|
|
62 maxCol = numColsInMatrix;
|
|
63 }
|
|
64
|
|
65 // Loop over the rows and load the data
|
|
66 String line;
|
|
67 while ((line = reader.readLine()) != null) {
|
|
68 lineNum++;
|
|
69 if (StringUtils.countMatches(line, "\t") != numColsInMatrix) {
|
|
70 throw new RuntimeException("Irregular input matrix does not have same number of columns on line " + lineNum);
|
|
71 }
|
|
72
|
|
73 int delim = line.indexOf('\t');
|
|
74 String id = line.substring(0, delim);
|
|
75 String[] row = line.substring(delim+1).split("\t");
|
|
76 String[] subset = Arrays.copyOfRange(row, minCol, maxCol);
|
|
77 float[] rowData = new float[subset.length];
|
|
78 for (int i = 0; i < subset.length; i++) {
|
|
79 try {
|
|
80 rowData[i] = Float.parseFloat(subset[i]);
|
|
81 } catch (NumberFormatException e) {
|
|
82 rowData[i] = Float.NaN;
|
|
83 }
|
|
84 }
|
|
85 data.add(new KMeansRow(id, rowData));
|
|
86 rows.put(id, line);
|
|
87 }
|
|
88 }
|
|
89
|
|
90 // Perform the clustering
|
|
91 log.debug("Clustering the data");
|
|
92 Random rng = new Random();
|
|
93 KMeansPlusPlusClusterer<KMeansRow> clusterer = new KMeansPlusPlusClusterer<KMeansRow>(rng);
|
|
94 List<Cluster<KMeansRow>> clusters = clusterer.cluster(data, k, 50);
|
|
95
|
|
96 // Write to output
|
|
97 log.debug("Writing clustered data to output file");
|
|
98 try (BufferedWriter writer = Files.newBufferedWriter(outputFile, Charset.defaultCharset())) {
|
|
99 writer.write(headerLine);
|
|
100 writer.newLine();
|
|
101 int n = 1;
|
|
102 int count = 1;
|
|
103 for (Cluster<KMeansRow> cluster : clusters) {
|
|
104 int numRowsInCluster = cluster.getPoints().size();
|
|
105 int stop = count + numRowsInCluster - 1;
|
|
106 log.info("Cluster "+(n++)+": rows "+count+"-"+stop);
|
|
107 count = stop+1;
|
|
108 for (KMeansRow row : cluster.getPoints()) {
|
|
109 writer.write(rows.get(row.getId()));
|
|
110 writer.newLine();
|
|
111 }
|
|
112 }
|
|
113 }
|
|
114 }
|
|
115
|
|
116 public static void main(String[] args) throws IOException, WigFileException, IntervalFileSnifferException {
|
|
117 new KMeans().instanceMain(args);
|
|
118 }
|
|
119
|
|
120 }
|