Mercurial > repos > timpalpant > java_genomics_toolkit
comparison src/edu/unc/genomics/visualization/KMeans.java @ 2:e16016635b2a
Uploaded
| author | timpalpant |
|---|---|
| date | Mon, 13 Feb 2012 22:12:06 -0500 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| 1:a54db233ee3d | 2:e16016635b2a |
|---|---|
| 1 package edu.unc.genomics.visualization; | |
| 2 | |
| 3 import java.io.BufferedReader; | |
| 4 import java.io.BufferedWriter; | |
| 5 import java.io.IOException; | |
| 6 import java.nio.charset.Charset; | |
| 7 import java.nio.file.Files; | |
| 8 import java.nio.file.Path; | |
| 9 import java.util.ArrayList; | |
| 10 import java.util.Arrays; | |
| 11 import java.util.HashMap; | |
| 12 import java.util.List; | |
| 13 import java.util.Map; | |
| 14 import java.util.Random; | |
| 15 | |
| 16 import org.apache.commons.lang3.StringUtils; | |
| 17 import org.apache.commons.math.stat.clustering.Cluster; | |
| 18 import org.apache.commons.math.stat.clustering.KMeansPlusPlusClusterer; | |
| 19 import org.apache.log4j.Logger; | |
| 20 | |
| 21 import com.beust.jcommander.Parameter; | |
| 22 | |
| 23 import edu.unc.genomics.CommandLineTool; | |
| 24 import edu.unc.genomics.ReadablePathValidator; | |
| 25 import edu.unc.genomics.io.IntervalFileSnifferException; | |
| 26 import edu.unc.genomics.io.WigFileException; | |
| 27 | |
| 28 public class KMeans extends CommandLineTool { | |
| 29 | |
| 30 private static final Logger log = Logger.getLogger(KMeans.class); | |
| 31 | |
| 32 @Parameter(names = {"-i", "--input"}, description = "Input file (matrix2png format)", required = true, validateWith = ReadablePathValidator.class) | |
| 33 public Path inputFile; | |
| 34 @Parameter(names = {"-k", "--clusters"}, description = "Number of clusters") | |
| 35 public int k = 10; | |
| 36 @Parameter(names = {"-1", "--min"}, description = "Minimum column to use for clustering") | |
| 37 public int minCol = 1; | |
| 38 @Parameter(names = {"-2", "--max"}, description = "Maximum column to use for clustering") | |
| 39 public Integer maxCol; | |
| 40 @Parameter(names = {"-o", "--output"}, description = "Output file (clustered matrix2png format)", required = true) | |
| 41 public Path outputFile; | |
| 42 | |
| 43 private Map<String, String> rows = new HashMap<String, String>(); | |
| 44 private List<KMeansRow> data = new ArrayList<KMeansRow>(); | |
| 45 | |
| 46 @Override | |
| 47 public void run() throws IOException { | |
| 48 log.debug("Loading data from the input matrix"); | |
| 49 String headerLine = ""; | |
| 50 try (BufferedReader reader = Files.newBufferedReader(inputFile, Charset.defaultCharset())) { | |
| 51 // Header line | |
| 52 int lineNum = 1; | |
| 53 headerLine = reader.readLine(); | |
| 54 int numColsInMatrix = StringUtils.countMatches(headerLine, "\t"); | |
| 55 | |
| 56 // Validate the range info | |
| 57 if (maxCol != null) { | |
| 58 if (maxCol > numColsInMatrix) { | |
| 59 throw new RuntimeException("Invalid range of data specified for clustering ("+maxCol+" > "+numColsInMatrix+")"); | |
| 60 } | |
| 61 } else { | |
| 62 maxCol = numColsInMatrix; | |
| 63 } | |
| 64 | |
| 65 // Loop over the rows and load the data | |
| 66 String line; | |
| 67 while ((line = reader.readLine()) != null) { | |
| 68 lineNum++; | |
| 69 if (StringUtils.countMatches(line, "\t") != numColsInMatrix) { | |
| 70 throw new RuntimeException("Irregular input matrix does not have same number of columns on line " + lineNum); | |
| 71 } | |
| 72 | |
| 73 int delim = line.indexOf('\t'); | |
| 74 String id = line.substring(0, delim); | |
| 75 String[] row = line.substring(delim+1).split("\t"); | |
| 76 String[] subset = Arrays.copyOfRange(row, minCol, maxCol); | |
| 77 float[] rowData = new float[subset.length]; | |
| 78 for (int i = 0; i < subset.length; i++) { | |
| 79 try { | |
| 80 rowData[i] = Float.parseFloat(subset[i]); | |
| 81 } catch (NumberFormatException e) { | |
| 82 rowData[i] = Float.NaN; | |
| 83 } | |
| 84 } | |
| 85 data.add(new KMeansRow(id, rowData)); | |
| 86 rows.put(id, line); | |
| 87 } | |
| 88 } | |
| 89 | |
| 90 // Perform the clustering | |
| 91 log.debug("Clustering the data"); | |
| 92 Random rng = new Random(); | |
| 93 KMeansPlusPlusClusterer<KMeansRow> clusterer = new KMeansPlusPlusClusterer<KMeansRow>(rng); | |
| 94 List<Cluster<KMeansRow>> clusters = clusterer.cluster(data, k, 50); | |
| 95 | |
| 96 // Write to output | |
| 97 log.debug("Writing clustered data to output file"); | |
| 98 try (BufferedWriter writer = Files.newBufferedWriter(outputFile, Charset.defaultCharset())) { | |
| 99 writer.write(headerLine); | |
| 100 writer.newLine(); | |
| 101 int n = 1; | |
| 102 int count = 1; | |
| 103 for (Cluster<KMeansRow> cluster : clusters) { | |
| 104 int numRowsInCluster = cluster.getPoints().size(); | |
| 105 int stop = count + numRowsInCluster - 1; | |
| 106 log.info("Cluster "+(n++)+": rows "+count+"-"+stop); | |
| 107 count = stop+1; | |
| 108 for (KMeansRow row : cluster.getPoints()) { | |
| 109 writer.write(rows.get(row.getId())); | |
| 110 writer.newLine(); | |
| 111 } | |
| 112 } | |
| 113 } | |
| 114 } | |
| 115 | |
| 116 public static void main(String[] args) throws IOException, WigFileException, IntervalFileSnifferException { | |
| 117 new KMeans().instanceMain(args); | |
| 118 } | |
| 119 | |
| 120 } |
