comparison result_heatmap/result_heatmap.py @ 0:e94586e24004 draft default tip

planemo upload for repository https://github.com/jaidevjoshi83/MicroBiomML commit 5ef78d4decc95ac107c468499328e7f086289ff9-dirty
author jay
date Tue, 17 Feb 2026 10:52:24 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e94586e24004
1 import pandas as pd #pandas==2.1.4
2 import plotly.graph_objects as go #plotly==5.20.0
3 import os
4 import argparse
5
6
7 def Analysis(values, thr=0.05):
8 # print(values)
9 better = []
10 comparable = []
11 thr = 0.05
12
13 last_value = values[4]
14
15 for v in values[0:4]:
16 # print(v)
17 better.append(round(last_value -v, 2) > thr)
18 comparable.append(abs(round(last_value -v, 2)) <= thr)
19
20 if all(better):
21 return (True, 'better_all' )
22 elif True in better:
23 return (True, 'better_one' )
24 elif all( comparable):
25 return (True, 'Comp_with_all' )
26 elif True in comparable:
27 return (True, 'Comp_with_one' )
28
29
30 color_scale=[
31 [0, 'green'], # Value -1 will be red
32 [0.5, 'red'], # Value 0 will be yellow
33 [1, 'yellow'] # Value 1 will be blue
34 ]
35
36 # Define the color scale constant
37 COLOR_SCALE = {
38 'Comp_with_all': 'blue',
39 'better_all': 'violet',
40 'Comp_with_one': 'black',
41 'better_one': 'red'
42 }
43 def ResultSummary(file, threshold, column_list=None):
44 print(file)
45 new_DF = pd.read_csv(file, sep='\t')
46 new_DF.set_index('name', inplace=True)
47
48 DF = new_DF.T
49 DF.columns = new_DF.index
50 DF.index = new_DF.columns
51
52 # If no column_list provided, use all columns
53 if column_list is None:
54 df = DF
55 else:
56 df = DF.iloc[column_list]
57
58 column_anno_per = {}
59 comparable = {}
60
61 for n in df.columns.to_list():
62 comparable[n] = Analysis(df[n].values, threshold)
63 return comparable
64
65 def Plot(input_file, width=2460, height=800, color_labels='Greens', font_size=22, tick_font=26, tick_angle=-80, threshold=0.05, column_list=None, outfile='out.html'):
66
67 # Parse column_list if it's a string (from command line)
68 # Convert from 1-indexed (XML) to 0-indexed (Python)
69 if isinstance(column_list, str) and column_list:
70 column_list = [int(i) - 2 for i in column_list.split(',')]
71
72 figure_size = (width, height)
73
74 print(column_list)
75
76 result_1 = ResultSummary(input_file, threshold, column_list)
77
78 true_columns = []
79 true_column_comp = []
80
81 for i, k in enumerate(result_1.keys()):
82 if result_1[k]:
83 true_column_comp.append((i, result_1[k], k))
84
85 plotting_columns = {
86 'Comp_with_all': [],
87 'better_all': [],
88 'Comp_with_one': [],
89 'better_one': [],
90 'None': [],
91 }
92
93 colors = COLOR_SCALE
94 arranged_columns = []
95 counter = 0
96
97 for c in colors.keys():
98 for i, a in enumerate(true_column_comp):
99 if c == a[1][1]:
100 counter += 1
101 plotting_columns[c].append((a[2], counter - 1))
102 arranged_columns.append(a[2])
103
104 # Read and prepare data for plotting - use the same processing as ResultSummary
105 new_DF = pd.read_csv(input_file, sep='\t')
106 new_DF.set_index('name', inplace=True)
107
108 # Transpose to get classifiers as rows and metrics as columns
109 DF = new_DF.T
110 DF.columns = new_DF.index
111 DF.index = new_DF.columns
112
113 column_list
114
115 # Apply column_list filter if provided
116 if column_list is None:
117 df = DF
118 else:
119 df = DF.iloc[column_list]
120
121 print(df)
122
123 # Filter to only keep the arranged_columns (columns that pass the analysis)
124 if arranged_columns:
125 df = df[arranged_columns]
126
127 df.index.name = 'name'
128
129 # print(height, width)
130
131 heatmap = go.Heatmap(
132 z=df.values,
133 x=df.columns,
134 zmin=0,
135 zmax=1,
136 y=df.index,
137 # colorbar=dict(title='Value'),
138 text=df.values, # Display values in each cell
139 texttemplate="%{text}", # Format for text
140 colorscale=color_labels,
141 textfont=dict(size=font_size, color='white')
142 )
143
144 shapes = []
145
146 for i in range(5, len(df), 5):
147 shapes.append(
148 go.layout.Shape(
149 type='line',
150 x0=-0.5,
151 x1=len(df.columns) - 0.5,
152 y0=i - 0.5,
153 y1=i - 0.5,
154 line=dict(color='white', width=1),
155 )
156 )
157
158 ind = 0
159 for t in plotting_columns.keys():
160 if t != 'None' and len(plotting_columns[t]) > 0:
161 col_idx = plotting_columns[t][0][1]
162 row_idx = 4
163 shape1 = go.layout.Shape(
164 type='rect',
165 x0=col_idx - 0.48,
166 x1=plotting_columns[t][-1][1] + 0.48,
167 y0=row_idx - 4.5,
168 y1=row_idx + 0.5,
169 line=dict(color=colors[t], width=2.5), # Use color from the color scale constant
170 fillcolor='rgba(255, 255, 255, 0)', # Transparent fill
171 )
172 shapes.append(shape1)
173
174 fig = go.Figure(data=[heatmap])
175
176 print(input_file.split('/')[len(input_file.split('/'))-1].split('.')[0])
177
178 # Create legend annotations for border colors at the top
179 legend_annotations = []
180 legend_labels = {
181 'better_all': 'Better than all (≥threshold)',
182 'better_one': 'Better than some',
183 'Comp_with_all': 'Comparable with all',
184 'Comp_with_one': 'Comparable with some'
185 }
186
187 x_position = 0.0
188 for color_key, label_text in legend_labels.items():
189 legend_annotations.append(
190 dict(
191 x=x_position,
192 y=1.12,
193 xref='paper',
194 yref='paper',
195 text=f'<b style="color:{colors[color_key]};font-size:14px;">■</b> {label_text}',
196 showarrow=False,
197 xanchor='left',
198 yanchor='bottom',
199 font=dict(size=11)
200 )
201 )
202 x_position += 0.25
203
204 fig.update_layout(
205 width=figure_size[0],
206 height=figure_size[1],
207 shapes=shapes,
208 title='',
209 xaxis=dict(title='Study', tickfont=dict(size=24), tickangle=tick_angle),
210 yaxis=dict(title='Classifier', tickfont=dict(size=24) ),
211 yaxis_autorange='reversed',
212 # colorscale=[[1, 'blue'], [-1, 'red']],
213 autosize=False,
214 annotations=legend_annotations,
215 margin=dict(t=200) # Add top margin for legend
216 )
217
218 # Save the figure as HTML
219 fig.write_html(outfile)
220
221 if __name__ == "__main__":
222 parser = argparse.ArgumentParser(description="Plot heatmap from TSV data with classification results.")
223 parser.add_argument("--input_file", type=str, default="test_data_age_category.tsv", help="Path to input TSV file (default: test_data_age_category.tsv)")
224 parser = argparse.ArgumentParser(description="Plot heatmap from TSV data with classification results.")
225 parser.add_argument("--input_file", type=str, default="test_data_age_category.tsv", help="Path to input TSV file (default: test_data_age_category.tsv)")
226 parser.add_argument("--column_list", type=str, default=None, help="Comma-separated column indices to plot (default: None - plots all data)")
227 parser.add_argument("--width", type=int, default=2460, help="Figure width in pixels (default: 2460)")
228 parser.add_argument("--height", type=int, default=800, help="Figure height in pixels (default: 800)")
229 parser.add_argument("--color_labels", type=str, default="Greens", help="Color scheme for heatmap (default: Greens)")
230 parser.add_argument("--font_size", type=int, default=22, help="Font size for cell text (default: 22)")
231 parser.add_argument("--tick_font", type=int, default=26, help="Font size for tick labels (default: 26)")
232 parser.add_argument("--tick_angle", type=int, default=-80, help="Angle of x-axis tick labels in degrees (default: -80)")
233 parser.add_argument("--threshold", type=float, default=0.05, help="Threshold for comparison analysis (default: 0.05)")
234 parser.add_argument("--output", type=str, default="out.html", help="Output file path (default: out.html)")
235
236 args = parser.parse_args()
237
238 Plot(
239 input_file=args.input_file,
240 width=args.width,
241 height=args.height,
242 color_labels=args.color_labels,
243 font_size=int(args.font_size),
244 tick_font=int(args.tick_font),
245 tick_angle=int(args.tick_angle),
246 threshold=float(args.threshold),
247 column_list=args.column_list,
248 outfile=args.output
249 )