Mercurial > repos > guerler > springsuite
diff spring_mcc.py @ 39:172398348efd draft
"planemo upload commit 26b4018c88041ee0ca7c2976e0a012015173d7b6-dirty"
author | guerler |
---|---|
date | Fri, 22 Jan 2021 15:50:27 +0000 |
parents | |
children | f316caf098a6 |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/spring_mcc.py Fri Jan 22 15:50:27 2021 +0000 @@ -0,0 +1,320 @@ +#! /usr/bin/env python +import argparse +import math +from os.path import isfile +import re +from matplotlib import pyplot as plt + + +def getIds(rawIds): + return rawIds.split("|") + + +def getCenterId(rawId): + elements = rawId.split("|") + if len(elements) > 1: + return elements[1] + return rawId + + +def getOrganism(rawId): + elements = rawId.split("_") + return elements[-1] + + +def getKey(a, b): + if a > b: + name = "%s_%s" % (a, b) + else: + name = "%s_%s" % (b, a) + return name + + +def getPercentage(rate, denominator): + if denominator > 0: + return 100.0 * rate / denominator + return 0.0 + + +def getFilter(filterName): + print("Loading target organism(s)...") + filterSets = dict() + with open(filterName) as filterFile: + for line in filterFile: + columns = line.split() + for colIndex in [0, 1]: + if colIndex >= len(columns): + break + colEntry = columns[colIndex] + id = getCenterId(colEntry) + organism = getOrganism(colEntry) + if organism not in filterSets: + filterSets[organism] = set() + filterSets[organism].add(id) + print("Organism(s) in set: %s." % filterSets.keys()) + return filterSets + + +def getReference(fileName, filterA=None, filterB=None, minScore=None, aCol=0, + bCol=1, scoreCol=-1, separator=None, + skipFirstLine=False, filterValues=list()): + + index = dict() + count = 0 + with open(fileName) as fp: + line = fp.readline() + if skipFirstLine: + line = fp.readline() + while line: + ls = line.split(separator) + skipEntry = False + if separator is not None: + aList = getIds(ls[aCol]) + bList = getIds(ls[bCol]) + else: + aId = getCenterId(ls[aCol]) + bId = getCenterId(ls[bCol]) + aList = [aId] + bList = [bId] + if not skipEntry: + validEntry = False + for a in aList: + for b in bList: + skip = False + if a == "-" or b == "-": + skip = True + if filterA is not None and filterB is not None: + skip = not ((a in filterA and b in filterB) or (a in filterB and b in filterA)) + for f in filterValues: + if len(ls) > f[0]: + columnEntry = ls[f[0]].lower() + searchEntry = f[1].lower() + if columnEntry.find(searchEntry) == -1: + skip = True + if not skip: + name = getKey(a, b) + if name not in index: + validEntry = True + if scoreCol >= 0 and len(ls) > scoreCol: + score = float(ls[scoreCol]) + skip = False + if minScore is not None: + if minScore > score: + return index, count + if not skip: + index[name] = score + else: + index[name] = 1.0 + if validEntry: + count = count + 1 + line = fp.readline() + return index, count + + +def getMCC(prediction, positive, positiveCount, negative): + sortedPrediction = sorted(prediction.items(), key=lambda x: x[1], + reverse=True) + positiveTotal = positiveCount + negativeTotal = len(negative) + x = list([0]) + y = list([0]) + xMax = 0 + topCount = 0 + topMCC = 0.0 + topFP = 0.0 + topTP = 0.0 + topScore = 0.0 + tp = 0 + fp = 0 + count = 0 + for (name, score) in sortedPrediction: + found = False + if name in positive: + found = True + tp = tp + 1 + if name in negative: + found = True + fp = fp + 1 + fn = positiveTotal - tp + tn = negativeTotal - fp + denom = (tp+fp)*(tp+fn)*(tn+fp)*(tn+fn) + yValue = getPercentage(tp, tp + fn) + xValue = getPercentage(fp, fp + tn) + if denom > 0.0: + mcc = (tp*tn-fp*fn)/math.sqrt(denom) + if mcc >= topMCC: + topMCC = mcc + topScore = score + topCount = count + topFP = xValue + topTP = yValue + if found: + y.append(yValue) + x.append(xValue) + xMax = max(xValue, xMax) + count = count + 1 + if len(sortedPrediction) > 0: + print("Top ranking prediction %s." % str(sortedPrediction[0])) + print("Total count of prediction set: %s (tp=%1.2f, fp=%1.2f)." % (topCount, topTP, topFP)) + print("Total count of positive set: %s." % len(positive)) + print("Total count of negative set: %s." % len(negative)) + print("Matthews-Correlation-Coefficient: %s at Score >= %s." % (round(topMCC, 2), topScore)) + return topMCC + + +def getNegativeSet(args, filterA, filterB, negativeRequired, jSize=5): + # determine negative set + print("Identifying non-interacting pairs...") + negative = set() + if args.negative and isfile(args.negative): + # load from explicit file + with open(args.negative) as file: + for line in file: + cols = line.split() + nameA = cols[0] + nameB = cols[1] + key = getKey(nameA, nameB) + if key not in negative: + negative.add(key) + else: + if args.region_a and args.region_b: + locations = dict() + regionA = args.region_a.lower() + regionB = args.region_b.lower() + locations[regionA] = list() + locations[regionB] = list() + regions = [regionA, regionB] + print("Filtering regions %s" % str(regions)) + with open(args.locations) as locFile: + for line in locFile: + searchKey = "SUBCELLULAR LOCATION" + searchPos = line.find(searchKey) + if searchPos != -1: + uniId = line.split()[0] + if uniId not in filterA and uniId not in filterB: + continue + locStart = searchPos + len(searchKey) + 1 + locId = line[locStart:] + locId = re.sub(r"\s*{.*}\s*", "", locId) + locId = locId.replace(".", ",") + locId = locId.strip().lower() + filter_pos = locId.find("note=") + if filter_pos >= 0: + locId = locId[:filter_pos] + filter_pos = locId.find(";") + if filter_pos >= 0: + locId = locId[:filter_pos] + if locId: + locId = list(map(lambda x: x.strip(), locId.split(","))) + if (regionA in locId and regionB not in locId): + locations[regionA].append(uniId) + elif (regionA not in locId and regionB in locId): + locations[regionB].append(uniId) + filterAList = sorted(locations[regionA]) + filterBList = sorted(locations[regionB]) + else: + filterAList = list(filterA) + filterBList = list(filterB) + for i, j in randomPairs(len(filterAList), len(filterBList), jSize): + nameA = filterAList[i] + nameB = filterBList[j] + key = getKey(nameA, nameB) + if key not in negative: + negative.add(key) + negativeRequired = negativeRequired - 1 + if negativeRequired == 0: + break + return negative + + +def randomPairs(iLen, jLen, jSize): + i = 0 + jStart = 0 + while i < iLen: + jMax = min(jStart + jSize, jLen) + for j in range(jStart, jMax): + yield i, j + i = i + 1 + if i == iLen and jMax < jLen: + i = 0 + jStart = jStart + jSize + 1 + + +def main(args): + # load source files + filterSets = getFilter(args.input) + filterKeys = list(filterSets.keys()) + filterA = filterSets[filterKeys[0]] + if len(filterKeys) > 1: + filterB = filterSets[filterKeys[1]] + else: + filterB = filterA + + # identify biogrid filter options + filterValues = list() + filterValues.append([11, args.method]) + + # process biogrid database + print("Loading positive set from BioGRID file...") + positive, positiveCount = getReference(args.biogrid, aCol=23, bCol=26, + separator="\t", filterA=filterA, + filterB=filterB, skipFirstLine=True, + filterValues=filterValues) + + # estimate negative set + negative = getNegativeSet(args, filterA, filterB, positiveCount) + + # get prediction results + print("Loading prediction file...") + prediction, _ = getReference(args.input, scoreCol=2, minScore=0.8) + mcc = getMCC(prediction, positive, positiveCount, negative) + yValues = [mcc] + yTicks = ["SPRING"] + + # identify biogrid filter options + for method in ["Affinity Capture-MS", + "Biochemical Activity", + "Co-crystal Structure", + "Co-fractionation", + "Co-localization", + "Co-purification", + "Far Western", + "FRET", + "PCA", + "Reconstituted Complex", + "Two-hybrid"]: + if args.method != method: + print("Method: %s" % method) + filterValues = [[11, method]] + prediction, _ = getReference(args.biogrid, aCol=23, bCol=26, + separator="\t", filterA=filterA, + filterB=filterB, skipFirstLine=True, + filterValues=filterValues) + mcc = getMCC(prediction, positive, positiveCount, negative) + yValues.append(mcc) + yTicks.append(method) + + # create plot + print("Producing plot data...") + print("Total count in prediction file: %d." % len(prediction)) + print("Total count in positive file: %d." % len(positive)) + plt.xlabel("Matthews-Correlation Coefficient (MCC)") + plt.title("Positive set: %s" % args.method) + plt.barh(yTicks, yValues) + plt.tight_layout() + plt.savefig(args.output, format="png") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Create ROC plot.') + parser.add_argument('-i', '--input', help='Input prediction file (2-columns).', required=True) + parser.add_argument('-b', '--biogrid', help='BioGRID interaction database file', required=True) + parser.add_argument('-l', '--locations', help='UniProt export table with subcellular locations', required=False) + parser.add_argument('-ra', '--region_a', help='First subcellular location', required=False) + parser.add_argument('-rb', '--region_b', help='Second subcellular location', required=False) + parser.add_argument('-n', '--negative', help='Negative set (2-columns)', required=False) + parser.add_argument('-t', '--throughput', help='Throughput (low/high)', required=False) + parser.add_argument('-m', '--method', help='Method e.g. Two-hybrid', required=False) + parser.add_argument('-o', '--output', help='Output (png)', required=True) + args = parser.parse_args() + main(args)