Mercurial > repos > jay > ml_tool
comparison result_heatmap/result_heatmap.py @ 0:e94586e24004 draft default tip
planemo upload for repository https://github.com/jaidevjoshi83/MicroBiomML commit 5ef78d4decc95ac107c468499328e7f086289ff9-dirty
| author | jay |
|---|---|
| date | Tue, 17 Feb 2026 10:52:24 +0000 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:e94586e24004 |
|---|---|
| 1 import pandas as pd #pandas==2.1.4 | |
| 2 import plotly.graph_objects as go #plotly==5.20.0 | |
| 3 import os | |
| 4 import argparse | |
| 5 | |
| 6 | |
| 7 def Analysis(values, thr=0.05): | |
| 8 # print(values) | |
| 9 better = [] | |
| 10 comparable = [] | |
| 11 thr = 0.05 | |
| 12 | |
| 13 last_value = values[4] | |
| 14 | |
| 15 for v in values[0:4]: | |
| 16 # print(v) | |
| 17 better.append(round(last_value -v, 2) > thr) | |
| 18 comparable.append(abs(round(last_value -v, 2)) <= thr) | |
| 19 | |
| 20 if all(better): | |
| 21 return (True, 'better_all' ) | |
| 22 elif True in better: | |
| 23 return (True, 'better_one' ) | |
| 24 elif all( comparable): | |
| 25 return (True, 'Comp_with_all' ) | |
| 26 elif True in comparable: | |
| 27 return (True, 'Comp_with_one' ) | |
| 28 | |
| 29 | |
| 30 color_scale=[ | |
| 31 [0, 'green'], # Value -1 will be red | |
| 32 [0.5, 'red'], # Value 0 will be yellow | |
| 33 [1, 'yellow'] # Value 1 will be blue | |
| 34 ] | |
| 35 | |
| 36 # Define the color scale constant | |
| 37 COLOR_SCALE = { | |
| 38 'Comp_with_all': 'blue', | |
| 39 'better_all': 'violet', | |
| 40 'Comp_with_one': 'black', | |
| 41 'better_one': 'red' | |
| 42 } | |
| 43 def ResultSummary(file, threshold, column_list=None): | |
| 44 print(file) | |
| 45 new_DF = pd.read_csv(file, sep='\t') | |
| 46 new_DF.set_index('name', inplace=True) | |
| 47 | |
| 48 DF = new_DF.T | |
| 49 DF.columns = new_DF.index | |
| 50 DF.index = new_DF.columns | |
| 51 | |
| 52 # If no column_list provided, use all columns | |
| 53 if column_list is None: | |
| 54 df = DF | |
| 55 else: | |
| 56 df = DF.iloc[column_list] | |
| 57 | |
| 58 column_anno_per = {} | |
| 59 comparable = {} | |
| 60 | |
| 61 for n in df.columns.to_list(): | |
| 62 comparable[n] = Analysis(df[n].values, threshold) | |
| 63 return comparable | |
| 64 | |
| 65 def Plot(input_file, width=2460, height=800, color_labels='Greens', font_size=22, tick_font=26, tick_angle=-80, threshold=0.05, column_list=None, outfile='out.html'): | |
| 66 | |
| 67 # Parse column_list if it's a string (from command line) | |
| 68 # Convert from 1-indexed (XML) to 0-indexed (Python) | |
| 69 if isinstance(column_list, str) and column_list: | |
| 70 column_list = [int(i) - 2 for i in column_list.split(',')] | |
| 71 | |
| 72 figure_size = (width, height) | |
| 73 | |
| 74 print(column_list) | |
| 75 | |
| 76 result_1 = ResultSummary(input_file, threshold, column_list) | |
| 77 | |
| 78 true_columns = [] | |
| 79 true_column_comp = [] | |
| 80 | |
| 81 for i, k in enumerate(result_1.keys()): | |
| 82 if result_1[k]: | |
| 83 true_column_comp.append((i, result_1[k], k)) | |
| 84 | |
| 85 plotting_columns = { | |
| 86 'Comp_with_all': [], | |
| 87 'better_all': [], | |
| 88 'Comp_with_one': [], | |
| 89 'better_one': [], | |
| 90 'None': [], | |
| 91 } | |
| 92 | |
| 93 colors = COLOR_SCALE | |
| 94 arranged_columns = [] | |
| 95 counter = 0 | |
| 96 | |
| 97 for c in colors.keys(): | |
| 98 for i, a in enumerate(true_column_comp): | |
| 99 if c == a[1][1]: | |
| 100 counter += 1 | |
| 101 plotting_columns[c].append((a[2], counter - 1)) | |
| 102 arranged_columns.append(a[2]) | |
| 103 | |
| 104 # Read and prepare data for plotting - use the same processing as ResultSummary | |
| 105 new_DF = pd.read_csv(input_file, sep='\t') | |
| 106 new_DF.set_index('name', inplace=True) | |
| 107 | |
| 108 # Transpose to get classifiers as rows and metrics as columns | |
| 109 DF = new_DF.T | |
| 110 DF.columns = new_DF.index | |
| 111 DF.index = new_DF.columns | |
| 112 | |
| 113 column_list | |
| 114 | |
| 115 # Apply column_list filter if provided | |
| 116 if column_list is None: | |
| 117 df = DF | |
| 118 else: | |
| 119 df = DF.iloc[column_list] | |
| 120 | |
| 121 print(df) | |
| 122 | |
| 123 # Filter to only keep the arranged_columns (columns that pass the analysis) | |
| 124 if arranged_columns: | |
| 125 df = df[arranged_columns] | |
| 126 | |
| 127 df.index.name = 'name' | |
| 128 | |
| 129 # print(height, width) | |
| 130 | |
| 131 heatmap = go.Heatmap( | |
| 132 z=df.values, | |
| 133 x=df.columns, | |
| 134 zmin=0, | |
| 135 zmax=1, | |
| 136 y=df.index, | |
| 137 # colorbar=dict(title='Value'), | |
| 138 text=df.values, # Display values in each cell | |
| 139 texttemplate="%{text}", # Format for text | |
| 140 colorscale=color_labels, | |
| 141 textfont=dict(size=font_size, color='white') | |
| 142 ) | |
| 143 | |
| 144 shapes = [] | |
| 145 | |
| 146 for i in range(5, len(df), 5): | |
| 147 shapes.append( | |
| 148 go.layout.Shape( | |
| 149 type='line', | |
| 150 x0=-0.5, | |
| 151 x1=len(df.columns) - 0.5, | |
| 152 y0=i - 0.5, | |
| 153 y1=i - 0.5, | |
| 154 line=dict(color='white', width=1), | |
| 155 ) | |
| 156 ) | |
| 157 | |
| 158 ind = 0 | |
| 159 for t in plotting_columns.keys(): | |
| 160 if t != 'None' and len(plotting_columns[t]) > 0: | |
| 161 col_idx = plotting_columns[t][0][1] | |
| 162 row_idx = 4 | |
| 163 shape1 = go.layout.Shape( | |
| 164 type='rect', | |
| 165 x0=col_idx - 0.48, | |
| 166 x1=plotting_columns[t][-1][1] + 0.48, | |
| 167 y0=row_idx - 4.5, | |
| 168 y1=row_idx + 0.5, | |
| 169 line=dict(color=colors[t], width=2.5), # Use color from the color scale constant | |
| 170 fillcolor='rgba(255, 255, 255, 0)', # Transparent fill | |
| 171 ) | |
| 172 shapes.append(shape1) | |
| 173 | |
| 174 fig = go.Figure(data=[heatmap]) | |
| 175 | |
| 176 print(input_file.split('/')[len(input_file.split('/'))-1].split('.')[0]) | |
| 177 | |
| 178 # Create legend annotations for border colors at the top | |
| 179 legend_annotations = [] | |
| 180 legend_labels = { | |
| 181 'better_all': 'Better than all (≥threshold)', | |
| 182 'better_one': 'Better than some', | |
| 183 'Comp_with_all': 'Comparable with all', | |
| 184 'Comp_with_one': 'Comparable with some' | |
| 185 } | |
| 186 | |
| 187 x_position = 0.0 | |
| 188 for color_key, label_text in legend_labels.items(): | |
| 189 legend_annotations.append( | |
| 190 dict( | |
| 191 x=x_position, | |
| 192 y=1.12, | |
| 193 xref='paper', | |
| 194 yref='paper', | |
| 195 text=f'<b style="color:{colors[color_key]};font-size:14px;">■</b> {label_text}', | |
| 196 showarrow=False, | |
| 197 xanchor='left', | |
| 198 yanchor='bottom', | |
| 199 font=dict(size=11) | |
| 200 ) | |
| 201 ) | |
| 202 x_position += 0.25 | |
| 203 | |
| 204 fig.update_layout( | |
| 205 width=figure_size[0], | |
| 206 height=figure_size[1], | |
| 207 shapes=shapes, | |
| 208 title='', | |
| 209 xaxis=dict(title='Study', tickfont=dict(size=24), tickangle=tick_angle), | |
| 210 yaxis=dict(title='Classifier', tickfont=dict(size=24) ), | |
| 211 yaxis_autorange='reversed', | |
| 212 # colorscale=[[1, 'blue'], [-1, 'red']], | |
| 213 autosize=False, | |
| 214 annotations=legend_annotations, | |
| 215 margin=dict(t=200) # Add top margin for legend | |
| 216 ) | |
| 217 | |
| 218 # Save the figure as HTML | |
| 219 fig.write_html(outfile) | |
| 220 | |
| 221 if __name__ == "__main__": | |
| 222 parser = argparse.ArgumentParser(description="Plot heatmap from TSV data with classification results.") | |
| 223 parser.add_argument("--input_file", type=str, default="test_data_age_category.tsv", help="Path to input TSV file (default: test_data_age_category.tsv)") | |
| 224 parser = argparse.ArgumentParser(description="Plot heatmap from TSV data with classification results.") | |
| 225 parser.add_argument("--input_file", type=str, default="test_data_age_category.tsv", help="Path to input TSV file (default: test_data_age_category.tsv)") | |
| 226 parser.add_argument("--column_list", type=str, default=None, help="Comma-separated column indices to plot (default: None - plots all data)") | |
| 227 parser.add_argument("--width", type=int, default=2460, help="Figure width in pixels (default: 2460)") | |
| 228 parser.add_argument("--height", type=int, default=800, help="Figure height in pixels (default: 800)") | |
| 229 parser.add_argument("--color_labels", type=str, default="Greens", help="Color scheme for heatmap (default: Greens)") | |
| 230 parser.add_argument("--font_size", type=int, default=22, help="Font size for cell text (default: 22)") | |
| 231 parser.add_argument("--tick_font", type=int, default=26, help="Font size for tick labels (default: 26)") | |
| 232 parser.add_argument("--tick_angle", type=int, default=-80, help="Angle of x-axis tick labels in degrees (default: -80)") | |
| 233 parser.add_argument("--threshold", type=float, default=0.05, help="Threshold for comparison analysis (default: 0.05)") | |
| 234 parser.add_argument("--output", type=str, default="out.html", help="Output file path (default: out.html)") | |
| 235 | |
| 236 args = parser.parse_args() | |
| 237 | |
| 238 Plot( | |
| 239 input_file=args.input_file, | |
| 240 width=args.width, | |
| 241 height=args.height, | |
| 242 color_labels=args.color_labels, | |
| 243 font_size=int(args.font_size), | |
| 244 tick_font=int(args.tick_font), | |
| 245 tick_angle=int(args.tick_angle), | |
| 246 threshold=float(args.threshold), | |
| 247 column_list=args.column_list, | |
| 248 outfile=args.output | |
| 249 ) |
