Mercurial > repos > fubar > lifelines_km_cph_tool
comparison lifelines_tool/plotlykm.py @ 0:dd49a7040643 draft
Initial commit
| author | fubar |
|---|---|
| date | Wed, 09 Aug 2023 11:12:16 +0000 |
| parents | |
| children | 232b874046a7 |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:dd49a7040643 |
|---|---|
| 1 # script for a lifelines ToolFactory KM/CPH tool for Galaxy | |
| 2 # km models for https://github.com/galaxyproject/tools-iuc/issues/5393 | |
| 3 # test as | |
| 4 # python plotlykm.py --input_tab rossi.tab --htmlout "testfoo" --time "week" --status "arrest" --title "test" --image_dir images --cphcol="prio,age,race,paro,mar,fin" | |
| 5 | |
| 6 import argparse | |
| 7 import os | |
| 8 import sys | |
| 9 | |
| 10 import lifelines | |
| 11 | |
| 12 from matplotlib import pyplot as plt | |
| 13 | |
| 14 import pandas as pd | |
| 15 | |
| 16 # Ross Lazarus July 2023 | |
| 17 | |
| 18 | |
| 19 kmf = lifelines.KaplanMeierFitter() | |
| 20 cph = lifelines.CoxPHFitter() | |
| 21 | |
| 22 parser = argparse.ArgumentParser() | |
| 23 a = parser.add_argument | |
| 24 a('--input_tab', default='', required=True) | |
| 25 a('--header', default='') | |
| 26 a('--htmlout', default="test_run.html") | |
| 27 a('--group', default='') | |
| 28 a('--time', default='', required=True) | |
| 29 a('--status',default='', required=True) | |
| 30 a('--cphcols',default='') | |
| 31 a('--title', default='Default plot title') | |
| 32 a('--image_type', default='png') | |
| 33 a('--image_dir', default='images') | |
| 34 a('--readme', default='run_log.txt') | |
| 35 args = parser.parse_args() | |
| 36 sys.stdout = open(args.readme, 'w') | |
| 37 df = pd.read_csv(args.input_tab, sep='\t') | |
| 38 NCOLS = df.columns.size | |
| 39 NROWS = len(df.index) | |
| 40 defaultcols = ['col%d' % (x+1) for x in range(NCOLS)] | |
| 41 testcols = df.columns | |
| 42 if len(args.header.strip()) > 0: | |
| 43 newcols = args.header.split(',') | |
| 44 if len(newcols) == NCOLS: | |
| 45 if (args.time in newcols) and (args.status in newcols): | |
| 46 df.columns = newcols | |
| 47 else: | |
| 48 sys.stderr.write('## CRITICAL USAGE ERROR (not a bug!): time %s and/or status %s not found in supplied header parameter %s' % (args.time, args.status, args.header)) | |
| 49 sys.exit(4) | |
| 50 else: | |
| 51 sys.stderr.write('## CRITICAL USAGE ERROR (not a bug!): Supplied header %s has %d comma delimited header names - does not match the input tabular file %d columns' % (args.header, len(newcols), NCOLS)) | |
| 52 sys.exit(5) | |
| 53 else: # no header supplied - check for a real one that matches the x and y axis column names | |
| 54 colsok = (args.time in testcols) and (args.status in testcols) # if they match, probably ok...should use more code and logic.. | |
| 55 if colsok: | |
| 56 df.columns = testcols # use actual header | |
| 57 else: | |
| 58 colsok = (args.time in defaultcols) and (args.status in defaultcols) | |
| 59 if colsok: | |
| 60 sys.stderr.write('replacing first row of data derived header %s with %s' % (testcols, defaultcols)) | |
| 61 df.columns = defaultcols | |
| 62 else: | |
| 63 sys.stderr.write('## CRITICAL USAGE ERROR (not a bug!): time %s and status %s do not match anything in the file header, supplied header or automatic default column names %s' % (args.time, args.status, defaultcols)) | |
| 64 print('## Lifelines tool starting.\nUsing data header =', df.columns, 'time column =', args.time, 'status column =', args.status) | |
| 65 os.makedirs(args.image_dir, exist_ok=True) | |
| 66 fig, ax = plt.subplots() | |
| 67 if args.group > '': | |
| 68 names = [] | |
| 69 times = [] | |
| 70 events = [] | |
| 71 rmst = [] | |
| 72 for name, grouped_df in df.groupby(args.group): | |
| 73 T = grouped_df[args.time] | |
| 74 E = grouped_df[args.status] | |
| 75 gfit = kmf.fit(T, E, label=name) | |
| 76 kmf.plot_survival_function(ax=ax) | |
| 77 rst = lifelines.utils.restricted_mean_survival_time(gfit) | |
| 78 rmst.append(rst) | |
| 79 names.append(str(name)) | |
| 80 times.append(T) | |
| 81 events.append(E) | |
| 82 ngroup = len(names) | |
| 83 if ngroup == 2: # run logrank test if 2 groups | |
| 84 results = lifelines.statistics.logrank_test(times[0], times[1], events[0], events[1], alpha=.99) | |
| 85 print(' vs '.join(names), results) | |
| 86 results.print_summary() | |
| 87 elif ngroup > 1: | |
| 88 fig, ax = plt.subplots(nrows=ngroup, ncols=1, sharex=True) | |
| 89 for i, rst in rmst: | |
| 90 lifelines.plotting.rmst_plot(rst, ax=ax) | |
| 91 fig.savefig(os.path.join(args.image_dir,'RMST_%s.png' % args.title)) | |
| 92 else: | |
| 93 kmf.fit(df[args.time], df[args.status]) | |
| 94 kmf.plot_survival_function(ax=ax) | |
| 95 fig.savefig(os.path.join(args.image_dir,'KM_%s.png' % args.title)) | |
| 96 if len(args.cphcols) > 0: | |
| 97 fig, ax = plt.subplots() | |
| 98 cphcols = args.cphcols.strip().split(',') | |
| 99 cphcols = [x.strip() for x in cphcols] | |
| 100 notfound = sum([(x not in df.columns) for x in cphcols]) | |
| 101 if notfound > 0: | |
| 102 sys.stderr.write('## CRITICAL USAGE ERROR (not a bug!): One or more requested Cox PH columns %s not found in supplied column header %s' % (args.cphcols, df.columns)) | |
| 103 sys.exit(6) | |
| 104 print('### Lifelines test of Proportional Hazards results with %s as covariates on %s' % (', '.join(cphcols), args.title)) | |
| 105 cphcols += [args.time, args.status] | |
| 106 cphdf = df[cphcols] | |
| 107 cph.fit(cphdf, duration_col=args.time, event_col=args.status) | |
| 108 cph.print_summary() | |
| 109 cphaxes = cph.check_assumptions(cphdf, p_value_threshold=0.01, show_plots=True) | |
| 110 for i, ax in enumerate(cphaxes): | |
| 111 figr = ax[0].get_figure() | |
| 112 titl = figr._suptitle.get_text().replace(' ','_').replace("'","") | |
| 113 oname = os.path.join(args.image_dir,'CPH%s.%s' % (titl, args.image_type)) | |
| 114 figr.savefig(oname) | |
| 115 | |
| 116 | |
| 117 | |
| 118 |
