comparison galaxy_micropita/src/breadcrumbs/src/SVM.py @ 3:8fb4630ab314 draft default tip

Uploaded
author sagun98
date Thu, 03 Jun 2021 17:07:36 +0000
parents
children
comparison
equal deleted inserted replaced
2:1c5736dc85ab 3:8fb4630ab314
1 """
2 Author: Timothy Tickle
3 Description: Class to Allow Support Vector Machine analysis and to contain associated scripts
4 """
5
6 #####################################################################################
7 #Copyright (C) <2012>
8 #
9 #Permission is hereby granted, free of charge, to any person obtaining a copy of
10 #this software and associated documentation files (the "Software"), to deal in the
11 #Software without restriction, including without limitation the rights to use, copy,
12 #modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
13 #and to permit persons to whom the Software is furnished to do so, subject to
14 #the following conditions:
15 #
16 #The above copyright notice and this permission notice shall be included in all copies
17 #or substantial portions of the Software.
18 #
19 #THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
20 #INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
21 #PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
22 #HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
23 #OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
24 #SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
25 #####################################################################################
26
27 __author__ = "Timothy Tickle"
28 __copyright__ = "Copyright 2012"
29 __credits__ = ["Timothy Tickle"]
30 __license__ = "MIT"
31 __maintainer__ = "Timothy Tickle"
32 __email__ = "ttickle@sph.harvard.edu"
33 __status__ = "Development"
34
35 #Libraries
36 from AbundanceTable import AbundanceTable
37 from ConstantsBreadCrumbs import ConstantsBreadCrumbs
38 import csv
39 import os
40 from random import shuffle
41 from ValidateData import ValidateData
42
43 class SVM:
44 """
45 Class which holds generic methods for SVM use.
46 """
47
48 #1 Happy Path tested
49 @staticmethod
50 def funcConvertAbundanceTableToSVMFile(abndAbundanceTable, xOutputSVMFile, sMetadataLabel, lsOriginalLabels = None, lsSampleOrdering = None):
51 """
52 Converts abundance files to input SVM files.
53
54 :param abndAbundanceTable: AbudanceTable object to turn to input SVM file.
55 :type: AbundanceTable
56 :param xOutputSVMFile: File to save SVM data to when converted from the abundance table.
57 :type: FileStream or string file path
58 :param sMetadataLabel: The name of the last row in the abundance table representing metadata.
59 :type: String
60 :param: lsOriginalLabels The original labels.
61 :type: List of strings
62 :param lsSampleOrdering: Order of samples to output to output file. If none, the order in the abundance table is used.
63 :type: List of strings
64 :return lsUniqueLabels: List of unique labels.
65 """
66
67 #Create data matrix
68 dataMatrix = zip(*abndAbundanceTable.funcGetAbundanceCopy())
69
70 #Add labels
71 llData = []
72 lsLabels = lsOriginalLabels if lsOriginalLabels else SVM.funcMakeLabels(abndAbundanceTable.funcGetMetadata(sMetadataLabel))
73 if not isinstance(xOutputSVMFile,str):
74 if xOutputSVMFile.closed:
75 xOutputSVMFile = open(xOutputSVMFile.name,"w")
76 ostm = open(xOutputSVMFile,"w") if isinstance(xOutputSVMFile, str) else xOutputSVMFile
77 f = csv.writer(ostm, csv.excel_tab, delimiter = ConstantsBreadCrumbs.c_strBreadCrumbsSVMSpace)
78
79 #This allows the creation of partially known files for stratification purposes
80 lsCurrentSamples = abndAbundanceTable.funcGetSampleNames()
81 lsOrderingSamples = lsSampleOrdering if lsSampleOrdering else lsCurrentSamples[:]
82
83 iLabelIndex = 0
84 iSize = len(dataMatrix[0])
85 iIndexSample = 1
86 for sSample in lsOrderingSamples:
87 if sSample in lsCurrentSamples:
88 f.writerow([lsLabels[iLabelIndex]]+
89 [ConstantsBreadCrumbs.c_strColon.join([str(tpleFeature[0]+1),str(tpleFeature[1])]) for tpleFeature in enumerate(dataMatrix[iIndexSample])])
90 iLabelIndex += 1
91 iIndexSample += 1
92 #Make blank entry
93 else:
94 f.writerow([ConstantsBreadCrumbs.c_strSVMNoSample]+[ConstantsBreadCrumbs.c_strColon.join([str(tpleNas[0]+1),str(tpleNas[1])])
95 for tpleNas in enumerate([ConstantsBreadCrumbs.c_strSVMNoSample]*iSize)])
96 if lsOriginalLabels:
97 iLabelIndex += 1
98 ostm.close()
99 return set(lsLabels)
100
101 @staticmethod
102 def funcUpdateSVMFileWithAbundanceTable(abndAbundanceTable, xOutputSVMFile, lsOriginalLabels, lsSampleOrdering):
103 """
104 Takes a SVM input file and updates it with an abundance table.
105 lsOriginalLabels and lsSampleOrdering should be consistent to the input file.
106 Samples in the abundance table will be used to update the file if the sample name in the abundace table is also in the lsSampleOrdering.
107 lsOriginalLabels and lsSampleOrdering should be in the same order.
108
109 :param abndAbundanceTable: AbudanceTable object to turn to input SVM file.
110 :type: AbundanceTable
111 :param xOutputSVMFile: File to save SVM data to when converted from the abundance table.
112 :type: FileStream or string file path
113 :param lsOriginalLabels: The list of the original labels (as numerics 0,1,2,3,4...as should be in the file).
114 :type: List of strings
115 :param lsSampleOrdering: Order of samples in the output file.
116 :type: List of strings
117 :return lsUniqueLabels: List of unique labels.
118 """
119
120 #Read in old file
121 if not isinstance(xOutputSVMFile,str):
122 if xOutputSVMFile.closed:
123 xOutputSVMFile = open(xOutputSVMFile.name,"r")
124 ostm = open(xOutputSVMFile,"r") if isinstance(xOutputSVMFile, str) else xOutputSVMFile
125 fin = csv.reader(ostm, csv.excel_tab, delimiter = ConstantsBreadCrumbs.c_strBreadCrumbsSVMSpace)
126 #Read in contents of file
127 llsOldContents = [lsRow for lsRow in fin]
128 ostm.close()
129
130 #Check to make sure this ordering covers all positions in the old file
131 if not len(llsOldContents) == len(lsSampleOrdering):
132 print "The length of the original file ("+str(len(llsOldContents))+") does not match the length of the ordering given ("+str(len(lsSampleOrdering))+")."
133 return False
134
135 #Create data matrix from new data
136 dataMatrix = zip(*abndAbundanceTable.funcGetAbundanceCopy())
137
138 #Add labels
139 llData = []
140
141 #Write to file
142 if not isinstance(xOutputSVMFile,str):
143 if xOutputSVMFile.closed:
144 xOutputSVMFile = open(xOutputSVMFile.name,"w")
145 ostm = open(xOutputSVMFile,"w") if isinstance(xOutputSVMFile, str) else xOutputSVMFile
146 f = csv.writer(ostm, csv.excel_tab, delimiter = ConstantsBreadCrumbs.c_strBreadCrumbsSVMSpace)
147
148 #This allows to know what position to place the new lines
149 lsCurrentSamples = abndAbundanceTable.funcGetSampleNames()
150
151 iSize = len(dataMatrix[0])
152 iIndexSample = 1
153 iIndexOriginalOrder = 0
154 for sSample in lsSampleOrdering:
155 if sSample in lsCurrentSamples:
156 f.writerow([lsOriginalLabels[iIndexOriginalOrder]]+
157 [ConstantsBreadCrumbs.c_strColon.join([str(tpleFeature[0]+1),str(tpleFeature[1])]) for tpleFeature in enumerate(dataMatrix[iIndexSample])])
158 iIndexSample += 1
159 #Make blank entry
160 else:
161 f.writerow(llsOldContents[iIndexOriginalOrder])
162 iIndexOriginalOrder += 1
163 ostm.close()
164 return True
165
166 #Tested 5
167 @staticmethod
168 def funcMakeLabels(lsMetadata):
169 """
170 Given a list of metadata, labels are assigned. This is function represents a central location to make labels so all are consistent.
171
172 :param lsMetafdata: List of metadata to turn into labels based on the metadata's values.
173 :type: List of integer labels
174 """
175 #Do not use a set to make elements unique. Need to preserve order.
176 #First label should be 0
177 lsUniqueLabels = []
178 [lsUniqueLabels.append(sElement) for sElement in lsMetadata if not (sElement in lsUniqueLabels)]
179
180 dictLabels = dict([[str(lenuLabels[1]),str(lenuLabels[0])] for lenuLabels in enumerate(lsUniqueLabels)])
181 return [dictLabels[sLabel] for sLabel in lsMetadata]
182
183 #Tested
184 @staticmethod
185 def funcReadLabelsFromFile(xSVMFile, lsAllSampleNames, isPredictFile):
186 """
187 Reads in the labels from the input file or prediction output file of a LibSVM formatted file
188 and associates them in order with the given sample names.
189
190 Prediction file expected format: Labels declared in first line with labels keyword.
191 Each following row a sample with the first entry the predicted label
192 Prediction file example:
193 labels 0 1
194 0 0.3 0.4 0.6
195 1 0.1 0.2 0.3
196 1 0.2 0.2 0.2
197 0 0.2 0.4 0.3
198
199 Input file expected format:
200 Each row a sample with the first entry the predicted label
201 Input file example:
202 0 0.3 0.4 0.6
203 1 0.1 0.2 0.3
204 1 0.2 0.2 0.2
205 0 0.2 0.4 0.3
206
207 :param xSVMFile: File path to read in prediction labels.
208 :type String
209 :param lsAllSampleNames List of sample ids in the order of the labels.
210 :type List of Strings
211 :param isPredictFile: Indicates if the file is the input (False) or prediction (True) file
212 :type boolean
213 :return: Dictionary {label:["sampleName1", "sampleName2"...],...} or False on error
214 """
215 #Open prediction file and input file and get labels to compare to the predictions
216 g = csv.reader( open(xSVMFile, 'r') if isinstance(xSVMFile, str) else xSVMFile, csv.excel_tab, delimiter = ConstantsBreadCrumbs.c_strBreadCrumbsSVMSpace )
217 lsOriginalLabels = [lsLineElements[0] for lsLineElements in g if not lsLineElements[0] == ConstantsBreadCrumbs.c_strSVMNoSample]
218
219 if isPredictFile:
220 lsOriginalLabels = lsOriginalLabels[1:]
221
222 #Check sample name length
223 if not len(lsAllSampleNames) == len(lsOriginalLabels):
224 print "SVM::funcReadLabelsFromFile. Error, the length of sample names did not match the original labels length. Samples ("+str(len(lsAllSampleNames))+"):"+str(lsAllSampleNames)+" Labels ("+str(len(lsOriginalLabels))+"):"+str(lsOriginalLabels)
225 return False
226
227 #Change to {label:["sampleName1", "sampleName2"...],...}
228 dictSampleLabelsRet = dict()
229 for sValue in set(lsOriginalLabels):
230 dictSampleLabelsRet[sValue] = set([lsAllSampleNames[iindex] for iindex, sLabel in enumerate(lsOriginalLabels) if sLabel == sValue])
231 return dictSampleLabelsRet
232
233 #Tested
234 @staticmethod
235 def funcScaleFeature(npdData):
236 """
237 Scale a feature between 0 and 1. Using 01 and not 01,1 because it keeps the sparsity of the data and may save time.
238
239 :param npdData: Feature data to scale.
240 :type Numpy Array Scaled feature data.
241 :return npaFloat: A numpy array of floats.
242 """
243 if sum(npdData) == 0 or len(set(npdData))==1:
244 return npdData
245 dMin = min(npdData)
246 return (npdData-dMin)/float(max(npdData-dMin))
247
248 #Tested
249 @staticmethod
250 def funcWeightLabels(lLabels):
251 """
252 Returns weights for labels based on how balanced the labels are. Weights try to balance unbalanced results.
253
254 :params lLabels: List of labels to use for measure how balanced the comparison is.
255 :type List
256 :return List: [dictWeights ({"label":weight}),lUniqueLabels (unique occurences of original labels)]
257 """
258 #Convert to dict
259 #Do not use set to make elements unique. Need to preserve order.
260 #First label should be 0
261 lUniqueLabels = []
262 for sElement in lLabels:
263 if sElement not in lUniqueLabels:
264 lUniqueLabels.append(sElement)
265 dictLabels = dict(zip(lUniqueLabels, range(len(lUniqueLabels))))
266
267 #Build a dict of weights per label {label:weight, label:weight}
268 #Get the occurrence of each label
269 dictWeights = dict()
270 for sLabelKey in dictLabels:
271 sCurLabel = dictLabels[sLabelKey]
272 dictWeights[sCurLabel] = lLabels.count(sLabelKey)
273
274 #Divide the highest occurrence each occurrence
275 iMaxOccurence = max(dictWeights.values())
276 for sWeightKey in dictWeights:
277 dictWeights[sWeightKey]=iMaxOccurence/float(dictWeights[sWeightKey])
278
279 return [dictWeights,lUniqueLabels]
280
281 #Tested 3/4 cases could add in test 12 with randomize True
282 def func10FoldCrossvalidation(self, iTotalSampleCount, fRandomise = False):
283 """
284 Generator.
285 Generates the indexes for a 10 fold cross validation given a sample count.
286 If there are less than 10 samples, it uses the sample count as the K-fold cross validation
287 as a leave one out method.
288
289 :param iTotalSampleCount: Total Sample Count
290 :type Integer Sample Count
291 :param fRandomise: Random sample indices
292 :type Boolean True indicates randomise (Default False)
293 """
294 #Make indices and shuffle if needed
295 liindices = range(iTotalSampleCount)
296 if fRandomise:
297 shuffle(liindices)
298
299 #For 10 times
300 iKFold = 10
301 if iTotalSampleCount < iKFold:
302 iKFold = iTotalSampleCount
303 for iiteration in xrange(iKFold):
304 lfTraining = [iindex % iKFold != iiteration for iindex in liindices]
305 lfValidation = [not iindex for iindex in lfTraining]
306 yield lfTraining, lfValidation