Mercurial > repos > timpalpant > java_genomics_toolkit
comparison src/edu/unc/genomics/visualization/KMeans.java @ 2:e16016635b2a
Uploaded
author | timpalpant |
---|---|
date | Mon, 13 Feb 2012 22:12:06 -0500 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
1:a54db233ee3d | 2:e16016635b2a |
---|---|
1 package edu.unc.genomics.visualization; | |
2 | |
3 import java.io.BufferedReader; | |
4 import java.io.BufferedWriter; | |
5 import java.io.IOException; | |
6 import java.nio.charset.Charset; | |
7 import java.nio.file.Files; | |
8 import java.nio.file.Path; | |
9 import java.util.ArrayList; | |
10 import java.util.Arrays; | |
11 import java.util.HashMap; | |
12 import java.util.List; | |
13 import java.util.Map; | |
14 import java.util.Random; | |
15 | |
16 import org.apache.commons.lang3.StringUtils; | |
17 import org.apache.commons.math.stat.clustering.Cluster; | |
18 import org.apache.commons.math.stat.clustering.KMeansPlusPlusClusterer; | |
19 import org.apache.log4j.Logger; | |
20 | |
21 import com.beust.jcommander.Parameter; | |
22 | |
23 import edu.unc.genomics.CommandLineTool; | |
24 import edu.unc.genomics.ReadablePathValidator; | |
25 import edu.unc.genomics.io.IntervalFileSnifferException; | |
26 import edu.unc.genomics.io.WigFileException; | |
27 | |
28 public class KMeans extends CommandLineTool { | |
29 | |
30 private static final Logger log = Logger.getLogger(KMeans.class); | |
31 | |
32 @Parameter(names = {"-i", "--input"}, description = "Input file (matrix2png format)", required = true, validateWith = ReadablePathValidator.class) | |
33 public Path inputFile; | |
34 @Parameter(names = {"-k", "--clusters"}, description = "Number of clusters") | |
35 public int k = 10; | |
36 @Parameter(names = {"-1", "--min"}, description = "Minimum column to use for clustering") | |
37 public int minCol = 1; | |
38 @Parameter(names = {"-2", "--max"}, description = "Maximum column to use for clustering") | |
39 public Integer maxCol; | |
40 @Parameter(names = {"-o", "--output"}, description = "Output file (clustered matrix2png format)", required = true) | |
41 public Path outputFile; | |
42 | |
43 private Map<String, String> rows = new HashMap<String, String>(); | |
44 private List<KMeansRow> data = new ArrayList<KMeansRow>(); | |
45 | |
46 @Override | |
47 public void run() throws IOException { | |
48 log.debug("Loading data from the input matrix"); | |
49 String headerLine = ""; | |
50 try (BufferedReader reader = Files.newBufferedReader(inputFile, Charset.defaultCharset())) { | |
51 // Header line | |
52 int lineNum = 1; | |
53 headerLine = reader.readLine(); | |
54 int numColsInMatrix = StringUtils.countMatches(headerLine, "\t"); | |
55 | |
56 // Validate the range info | |
57 if (maxCol != null) { | |
58 if (maxCol > numColsInMatrix) { | |
59 throw new RuntimeException("Invalid range of data specified for clustering ("+maxCol+" > "+numColsInMatrix+")"); | |
60 } | |
61 } else { | |
62 maxCol = numColsInMatrix; | |
63 } | |
64 | |
65 // Loop over the rows and load the data | |
66 String line; | |
67 while ((line = reader.readLine()) != null) { | |
68 lineNum++; | |
69 if (StringUtils.countMatches(line, "\t") != numColsInMatrix) { | |
70 throw new RuntimeException("Irregular input matrix does not have same number of columns on line " + lineNum); | |
71 } | |
72 | |
73 int delim = line.indexOf('\t'); | |
74 String id = line.substring(0, delim); | |
75 String[] row = line.substring(delim+1).split("\t"); | |
76 String[] subset = Arrays.copyOfRange(row, minCol, maxCol); | |
77 float[] rowData = new float[subset.length]; | |
78 for (int i = 0; i < subset.length; i++) { | |
79 try { | |
80 rowData[i] = Float.parseFloat(subset[i]); | |
81 } catch (NumberFormatException e) { | |
82 rowData[i] = Float.NaN; | |
83 } | |
84 } | |
85 data.add(new KMeansRow(id, rowData)); | |
86 rows.put(id, line); | |
87 } | |
88 } | |
89 | |
90 // Perform the clustering | |
91 log.debug("Clustering the data"); | |
92 Random rng = new Random(); | |
93 KMeansPlusPlusClusterer<KMeansRow> clusterer = new KMeansPlusPlusClusterer<KMeansRow>(rng); | |
94 List<Cluster<KMeansRow>> clusters = clusterer.cluster(data, k, 50); | |
95 | |
96 // Write to output | |
97 log.debug("Writing clustered data to output file"); | |
98 try (BufferedWriter writer = Files.newBufferedWriter(outputFile, Charset.defaultCharset())) { | |
99 writer.write(headerLine); | |
100 writer.newLine(); | |
101 int n = 1; | |
102 int count = 1; | |
103 for (Cluster<KMeansRow> cluster : clusters) { | |
104 int numRowsInCluster = cluster.getPoints().size(); | |
105 int stop = count + numRowsInCluster - 1; | |
106 log.info("Cluster "+(n++)+": rows "+count+"-"+stop); | |
107 count = stop+1; | |
108 for (KMeansRow row : cluster.getPoints()) { | |
109 writer.write(rows.get(row.getId())); | |
110 writer.newLine(); | |
111 } | |
112 } | |
113 } | |
114 } | |
115 | |
116 public static void main(String[] args) throws IOException, WigFileException, IntervalFileSnifferException { | |
117 new KMeans().instanceMain(args); | |
118 } | |
119 | |
120 } |