Mercurial > repos > iuc > maplot
comparison maplot.py @ 0:e9212adafd7a draft default tip
planemo upload for repository https://github.com/galaxyproject/tools-iuc commit d5065f0bdf2d38c2344d96d68537223c1096daab
| author | iuc |
|---|---|
| date | Thu, 15 May 2025 12:55:13 +0000 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:e9212adafd7a |
|---|---|
| 1 import argparse | |
| 2 from typing import Dict, List, Tuple | |
| 3 | |
| 4 import matplotlib.pyplot as plt | |
| 5 import numpy as np | |
| 6 import pandas as pd | |
| 7 import plotly.graph_objects as go | |
| 8 import plotly.io as pio | |
| 9 import plotly.subplots as sp | |
| 10 import statsmodels.api as sm # to build a LOWESS model | |
| 11 from scipy.stats import gaussian_kde | |
| 12 | |
| 13 | |
| 14 # subplot titles | |
| 15 def make_subplot_titles(sample_names: List[str]) -> List[str]: | |
| 16 """Generates subplot titles for the MA plot. | |
| 17 | |
| 18 Args: | |
| 19 sample_names (list): List of sample names. | |
| 20 | |
| 21 Returns: | |
| 22 list: List of subplot titles. | |
| 23 """ | |
| 24 subplot_titles = [] | |
| 25 num_samples = len(sample_names) | |
| 26 for i in range(num_samples): | |
| 27 for j in range(num_samples): | |
| 28 if i == j: | |
| 29 subplot_titles.append(f"{sample_names[i]}") | |
| 30 else: | |
| 31 subplot_titles.append(f"{sample_names[i]} vs. {sample_names[j]}") | |
| 32 return subplot_titles | |
| 33 | |
| 34 | |
| 35 def densities(x: np.ndarray, y: np.ndarray) -> np.ndarray: | |
| 36 """Calculates the density of points for a scatter plot. | |
| 37 | |
| 38 Args: | |
| 39 x (array-like): X-axis values. | |
| 40 y (array-like): Y-axis values. | |
| 41 | |
| 42 Returns: | |
| 43 array: Density values for the points. | |
| 44 """ | |
| 45 values = np.vstack([x, y]) | |
| 46 return gaussian_kde(values)(values) | |
| 47 | |
| 48 | |
| 49 def movingaverage(data: np.ndarray, window_width: int) -> np.ndarray: | |
| 50 """Calculates the moving average of the data. | |
| 51 | |
| 52 Args: | |
| 53 data (array-like): Input data. | |
| 54 window_width (int): Width of the moving window. | |
| 55 | |
| 56 Returns: | |
| 57 array: Moving average values. | |
| 58 """ | |
| 59 cumsum_vec = np.cumsum(np.insert(data, 0, 0)) | |
| 60 ma_vec = (cumsum_vec[window_width:] - cumsum_vec[:-window_width]) / window_width | |
| 61 return ma_vec | |
| 62 | |
| 63 | |
| 64 def update_max(current: float, values: np.ndarray) -> float: | |
| 65 """Updates the maximum value. | |
| 66 | |
| 67 Args: | |
| 68 current (float): Current maximum value. | |
| 69 values (array-like): Array of values to compare. | |
| 70 | |
| 71 Returns: | |
| 72 float: Updated maximum value. | |
| 73 """ | |
| 74 return max(current, np.max(values)) | |
| 75 | |
| 76 | |
| 77 def get_indices( | |
| 78 num_samples: int, num_cols: int, plot_num: int | |
| 79 ) -> Tuple[int, int, int, int]: | |
| 80 """Calculates the indices for subplot placement. | |
| 81 | |
| 82 Args: | |
| 83 num_samples (int): Number of samples. | |
| 84 num_cols (int): Number of columns in the subplot grid. | |
| 85 plot_num (int): Plot number. | |
| 86 | |
| 87 Returns: | |
| 88 tuple: Indices for subplot placement (i, j, col, row). | |
| 89 """ | |
| 90 i = plot_num // num_samples | |
| 91 j = plot_num % num_samples | |
| 92 col = plot_num % num_cols + 1 | |
| 93 row = plot_num // num_cols + 1 | |
| 94 return i, j, col, row | |
| 95 | |
| 96 | |
| 97 def create_subplot_data( | |
| 98 frac: float, | |
| 99 it: int, | |
| 100 num_bins: int, | |
| 101 window_width: int, | |
| 102 samples: pd.DataFrame, | |
| 103 i: int, | |
| 104 j: int, | |
| 105 ) -> Dict: | |
| 106 """Creates data for a single subplot. | |
| 107 | |
| 108 Args: | |
| 109 frac (float): LOESS smoothing parameter. | |
| 110 it (int): Number of iterations for LOESS smoothing. | |
| 111 num_bins (int): Number of bins for histogram. | |
| 112 window_width (int): Window width for moving average. | |
| 113 samples (DataFrame): DataFrame containing sample data. | |
| 114 i (int): Index of the first sample. | |
| 115 j (int): Index of the second sample. | |
| 116 | |
| 117 Returns: | |
| 118 dict: Data for the subplot. | |
| 119 """ | |
| 120 subplot_data = {} | |
| 121 subplot_data["mean"] = np.log(samples.iloc[:, [i, j]].mean(axis=1)) | |
| 122 if i == j: | |
| 123 counts, bins = np.histogram(subplot_data["mean"], bins=num_bins) | |
| 124 subplot_data["bins"] = bins | |
| 125 subplot_data["counts"] = counts | |
| 126 subplot_data["counts_smoothed"] = movingaverage(counts, window_width) | |
| 127 subplot_data["max_counts"] = np.max(counts) | |
| 128 else: | |
| 129 subplot_data["log_fold_change"] = np.log2( | |
| 130 samples.iloc[:, i] / samples.iloc[:, j] | |
| 131 ) | |
| 132 subplot_data["max_log_fold_change"] = np.max(subplot_data["log_fold_change"]) | |
| 133 subplot_data["densities"] = densities( | |
| 134 subplot_data["mean"], subplot_data["log_fold_change"] | |
| 135 ) | |
| 136 subplot_data["regression"] = sm.nonparametric.lowess( | |
| 137 subplot_data["log_fold_change"], subplot_data["mean"], frac=frac, it=it | |
| 138 ) | |
| 139 return subplot_data | |
| 140 | |
| 141 | |
| 142 def create_plot_data( | |
| 143 frac: float, | |
| 144 it: int, | |
| 145 num_bins: int, | |
| 146 window_width: int, | |
| 147 samples: pd.DataFrame, | |
| 148 num_samples: int, | |
| 149 num_plots: int, | |
| 150 num_cols: int, | |
| 151 ) -> List[Dict]: | |
| 152 """Creates data for all subplots. | |
| 153 | |
| 154 Args: | |
| 155 frac (float): LOESS smoothing parameter. | |
| 156 it (int): Number of iterations for LOESS smoothing. | |
| 157 num_bins (int): Number of bins for histogram. | |
| 158 window_width (int): Window width for moving average. | |
| 159 samples (DataFrame): DataFrame containing sample data. | |
| 160 num_samples (int): Number of samples. | |
| 161 num_plots (int): Number of plots. | |
| 162 num_cols (int): Number of columns in the subplot grid. | |
| 163 | |
| 164 Returns: | |
| 165 list: List of data for each subplot. | |
| 166 """ | |
| 167 plots_data = [] | |
| 168 for plot_num in range(num_plots): | |
| 169 i, j, _, _ = get_indices(num_samples, num_cols, plot_num) | |
| 170 subplot_data = create_subplot_data( | |
| 171 frac, it, num_bins, window_width, samples, i, j | |
| 172 ) | |
| 173 plots_data.append(subplot_data) | |
| 174 return plots_data | |
| 175 | |
| 176 | |
| 177 def ma_plots_plotly( | |
| 178 num_rows: int, | |
| 179 num_cols: int, | |
| 180 num_plots: int, | |
| 181 plots_data: List[Dict], | |
| 182 sample_names: List[str], | |
| 183 size: int, | |
| 184 ylim_hist: float, | |
| 185 ylim_ma: float, | |
| 186 features: np.ndarray, | |
| 187 ) -> go.Figure: | |
| 188 """Generates MA plots using Plotly. | |
| 189 | |
| 190 Args: | |
| 191 num_rows (int): Number of rows in the subplot grid. | |
| 192 num_cols (int): Number of columns in the subplot grid. | |
| 193 num_plots (int): Number of plots. | |
| 194 plots_data (list): List of data for each subplot. | |
| 195 sample_names (list): List of sample names. | |
| 196 size (int): Size of the plot. | |
| 197 ylim_hist (float): Y-axis limit for histograms. | |
| 198 ylim_ma (float): Y-axis limit for MA plots. | |
| 199 features (array-like): Feature names. | |
| 200 | |
| 201 Returns: | |
| 202 Figure: Plotly figure object. | |
| 203 """ | |
| 204 fig = sp.make_subplots( | |
| 205 rows=num_rows, | |
| 206 cols=num_cols, | |
| 207 shared_xaxes="all", | |
| 208 subplot_titles=make_subplot_titles(sample_names), | |
| 209 ) | |
| 210 | |
| 211 for plot_num in range(num_plots): | |
| 212 i, j, col, row = get_indices(len(sample_names), num_cols, plot_num) | |
| 213 subplot_data = plots_data[plot_num] | |
| 214 | |
| 215 mean = subplot_data["mean"] | |
| 216 | |
| 217 if i == j: | |
| 218 # Plot histogram on the diagonal | |
| 219 hist_bar = go.Bar( | |
| 220 x=subplot_data["bins"], | |
| 221 y=subplot_data["counts"], | |
| 222 ) | |
| 223 fig.add_trace(hist_bar, row=row, col=col) | |
| 224 | |
| 225 hist_line = go.Scatter( | |
| 226 x=subplot_data["bins"], | |
| 227 y=subplot_data["counts_smoothed"], | |
| 228 marker=dict( | |
| 229 color="red", | |
| 230 ), | |
| 231 ) | |
| 232 fig.add_trace(hist_line, row=row, col=col) | |
| 233 fig.update_yaxes( | |
| 234 title_text="Counts", | |
| 235 range=[0, ylim_hist], | |
| 236 matches="y1", | |
| 237 showticklabels=True, | |
| 238 row=row, | |
| 239 col=col, | |
| 240 ) | |
| 241 else: | |
| 242 log_fold_change = subplot_data["log_fold_change"] | |
| 243 scatter = go.Scatter( | |
| 244 x=mean, | |
| 245 y=log_fold_change, | |
| 246 mode="markers", | |
| 247 marker=dict( | |
| 248 color=subplot_data["densities"], symbol="circle", colorscale="jet" | |
| 249 ), | |
| 250 name=f"{sample_names[i]} vs {sample_names[j]}", | |
| 251 text=features, | |
| 252 hovertemplate="<b>%{text}</b><br>Log Mean: %{x}<br>Log2 Fold Change: %{y}<extra></extra>", | |
| 253 ) | |
| 254 fig.add_trace(scatter, row=row, col=col) | |
| 255 | |
| 256 regression = subplot_data["regression"] | |
| 257 line = go.Scatter( | |
| 258 x=regression[:, 0], | |
| 259 y=regression[:, 1], | |
| 260 mode="lines", | |
| 261 line=dict(color="red"), | |
| 262 name=f"LOWESS {sample_names[i]} vs. {sample_names[j]}", | |
| 263 ) | |
| 264 fig.add_trace(line, row=row, col=col) | |
| 265 | |
| 266 fig.update_yaxes( | |
| 267 title_text="Log2 Fold Change", | |
| 268 range=[-ylim_ma, ylim_ma], | |
| 269 matches="y2", | |
| 270 showticklabels=True, | |
| 271 row=row, | |
| 272 col=col, | |
| 273 ) | |
| 274 fig.update_xaxes( | |
| 275 title_text="Log Mean Intensity", showticklabels=True, row=row, col=col | |
| 276 ) | |
| 277 | |
| 278 # Update layout for the entire figure | |
| 279 fig.update_layout( | |
| 280 height=size * num_rows, | |
| 281 width=size * num_cols, | |
| 282 showlegend=False, | |
| 283 template="simple_white", # Apply the 'plotly_white' template | |
| 284 ) | |
| 285 return fig | |
| 286 | |
| 287 | |
| 288 def ma_plots_matplotlib( | |
| 289 num_rows: int, | |
| 290 num_cols: int, | |
| 291 num_plots: int, | |
| 292 pots_data: List[Dict], | |
| 293 sample_names: List[str], | |
| 294 size: int, | |
| 295 ylim_hist: float, | |
| 296 ylim_ma: float, | |
| 297 window_width: int, | |
| 298 ) -> plt.Figure: | |
| 299 """Generates MA plots using Matplotlib. | |
| 300 | |
| 301 Args: | |
| 302 num_rows (int): Number of rows in the subplot grid. | |
| 303 num_cols (int): Number of columns in the subplot grid. | |
| 304 num_plots (int): Number of plots. | |
| 305 pots_data (list): List of data for each subplot. | |
| 306 sample_names (list): List of sample names. | |
| 307 size (int): Size of the plot. | |
| 308 ylim_hist (float): Y-axis limit for histograms. | |
| 309 ylim_ma (float): Y-axis limit for MA plots. | |
| 310 window_width (int): Window width for moving average. | |
| 311 | |
| 312 Returns: | |
| 313 Figure: Matplotlib figure object. | |
| 314 """ | |
| 315 subplot_titles = make_subplot_titles(sample_names) | |
| 316 fig, axes = plt.subplots( | |
| 317 num_rows, | |
| 318 num_cols, | |
| 319 figsize=(size * num_cols / 100, size * num_rows / 100), | |
| 320 dpi=300, | |
| 321 sharex="all", | |
| 322 ) | |
| 323 axes = axes.flatten() | |
| 324 | |
| 325 for plot_num in range(num_plots): | |
| 326 i, j, _, _ = get_indices(len(sample_names), num_cols, plot_num) | |
| 327 subplot_data = pots_data[plot_num] | |
| 328 | |
| 329 mean = subplot_data["mean"] | |
| 330 | |
| 331 ax = axes[plot_num] | |
| 332 | |
| 333 if i == j: | |
| 334 # Plot histogram on the diagonal | |
| 335 ax.bar( | |
| 336 subplot_data["bins"][:-1], | |
| 337 subplot_data["counts"], | |
| 338 width=np.diff(subplot_data["bins"]), | |
| 339 edgecolor="black", | |
| 340 align="edge", | |
| 341 ) | |
| 342 | |
| 343 # Plot moving average line | |
| 344 ax.plot( | |
| 345 subplot_data["bins"][window_width // 2: -window_width // 2], | |
| 346 subplot_data["counts_smoothed"], | |
| 347 color="red", | |
| 348 ) | |
| 349 | |
| 350 ax.set_ylabel("Counts") | |
| 351 ax.set_ylim(0, ylim_hist) | |
| 352 else: | |
| 353 # Scatter plot | |
| 354 ax.scatter( | |
| 355 mean, | |
| 356 subplot_data["log_fold_change"], | |
| 357 c=subplot_data["densities"], | |
| 358 cmap="jet", | |
| 359 edgecolor="black", | |
| 360 label=f"{sample_names[i]} vs {sample_names[j]}", | |
| 361 ) | |
| 362 | |
| 363 # Regression line | |
| 364 regression = subplot_data["regression"] | |
| 365 ax.plot( | |
| 366 regression[:, 0], | |
| 367 regression[:, 1], | |
| 368 color="red", | |
| 369 label=f"LOWESS {sample_names[i]} vs. {sample_names[j]}", | |
| 370 ) | |
| 371 | |
| 372 ax.set_ylabel("Log2 Fold Change") | |
| 373 ax.set_ylim(-ylim_ma, ylim_ma) | |
| 374 | |
| 375 ax.set_xlabel("Log Mean Intensity") | |
| 376 ax.tick_params(labelbottom=True) # Force showing x-tick labels | |
| 377 ax.set_title(subplot_titles[plot_num]) # Add subplot title | |
| 378 | |
| 379 # Adjust layout | |
| 380 plt.tight_layout() | |
| 381 return fig | |
| 382 | |
| 383 | |
| 384 def main(): | |
| 385 """Main function to generate MA plots.""" | |
| 386 parser = argparse.ArgumentParser(description="Generate MA plots.") | |
| 387 parser.add_argument("--file_path", type=str, help="Path to the input CSV file") | |
| 388 parser.add_argument("--file_extension", type=str, help="File extension") | |
| 389 parser.add_argument( | |
| 390 "--frac", type=float, default=4 / 5, help="LOESS smoothing parameter" | |
| 391 ) | |
| 392 parser.add_argument( | |
| 393 "--it", type=int, default=5, help="Number of iterations for LOESS smoothing" | |
| 394 ) | |
| 395 parser.add_argument( | |
| 396 "--num_bins", type=int, default=100, help="Number of bins for histogram" | |
| 397 ) | |
| 398 parser.add_argument( | |
| 399 "--window_width", type=int, default=5, help="Window width for moving average" | |
| 400 ) | |
| 401 parser.add_argument("--size", type=int, default=500, help="Size of the plot") | |
| 402 parser.add_argument( | |
| 403 "--scale", type=int, default=3, help="Scale factor for the plot" | |
| 404 ) | |
| 405 parser.add_argument( | |
| 406 "--y_scale_factor", type=float, default=1.1, help="Y-axis scale factor" | |
| 407 ) | |
| 408 parser.add_argument( | |
| 409 "--max_num_cols", | |
| 410 type=int, | |
| 411 default=100, | |
| 412 help="Maximum number of columns in the plot", | |
| 413 ) | |
| 414 parser.add_argument( | |
| 415 "--interactive", | |
| 416 action="store_true", | |
| 417 help="Generate interactive plot using Plotly", | |
| 418 ) | |
| 419 parser.add_argument( | |
| 420 "--output_format", | |
| 421 type=str, | |
| 422 default="pdf", | |
| 423 choices=["pdf", "png", "html"], | |
| 424 help="Output format for the plot", | |
| 425 ) | |
| 426 parser.add_argument( | |
| 427 "--output_file", | |
| 428 type=str, | |
| 429 default="ma_plot", | |
| 430 help="Output file name without extension", | |
| 431 ) | |
| 432 | |
| 433 args = parser.parse_args() | |
| 434 | |
| 435 # Load the data | |
| 436 file_extension = args.file_extension.lower() | |
| 437 if file_extension == "csv": | |
| 438 data = pd.read_csv(args.file_path) | |
| 439 elif file_extension in ["txt", "tsv", "tabular"]: | |
| 440 data = pd.read_csv(args.file_path, sep="\t") | |
| 441 elif file_extension == "parquet": | |
| 442 data = pd.read_parquet(args.file_path) | |
| 443 else: | |
| 444 raise ValueError(f"Unsupported file format: {file_extension}") | |
| 445 | |
| 446 features = data.iloc[:, 0] # Assuming the first column is the feature names | |
| 447 samples = data.iloc[:, 1:] # and the rest are samples | |
| 448 | |
| 449 # Create a subplot figure | |
| 450 num_samples = samples.shape[1] | |
| 451 sample_names = samples.columns | |
| 452 num_plots = num_samples**2 | |
| 453 num_cols = min(num_samples, args.max_num_cols) | |
| 454 num_rows = int(np.ceil(num_plots / num_cols)) | |
| 455 | |
| 456 plots_data = create_plot_data( | |
| 457 args.frac, | |
| 458 args.it, | |
| 459 args.num_bins, | |
| 460 args.window_width, | |
| 461 samples, | |
| 462 num_samples, | |
| 463 num_plots, | |
| 464 num_cols, | |
| 465 ) | |
| 466 | |
| 467 count_max = np.max([x.get("max_counts", 0) for x in plots_data]) | |
| 468 log_fold_change_max = np.max([x.get("max_log_fold_change", 0) for x in plots_data]) | |
| 469 | |
| 470 ylim_hist = count_max * args.y_scale_factor | |
| 471 ylim_ma = log_fold_change_max * args.y_scale_factor | |
| 472 | |
| 473 if args.interactive: | |
| 474 fig = ma_plots_plotly( | |
| 475 num_rows, | |
| 476 num_cols, | |
| 477 num_plots, | |
| 478 plots_data, | |
| 479 sample_names, | |
| 480 args.size, | |
| 481 ylim_hist, | |
| 482 ylim_ma, | |
| 483 features, | |
| 484 ) | |
| 485 fig.show() | |
| 486 if args.output_format == "html": | |
| 487 fig.write_html(f"{args.output_file}") | |
| 488 else: | |
| 489 pio.write_image( | |
| 490 fig, | |
| 491 f"{args.output_file}", | |
| 492 format=args.output_format, | |
| 493 width=args.size * num_cols, | |
| 494 height=args.size * num_rows, | |
| 495 scale=args.scale, | |
| 496 ) | |
| 497 else: | |
| 498 fig = ma_plots_matplotlib( | |
| 499 num_rows, | |
| 500 num_cols, | |
| 501 num_plots, | |
| 502 plots_data, | |
| 503 sample_names, | |
| 504 args.size, | |
| 505 ylim_hist, | |
| 506 ylim_ma, | |
| 507 args.window_width, | |
| 508 ) | |
| 509 plt.show() | |
| 510 fig.savefig(f"{args.output_file}", format=args.output_format, dpi=300) | |
| 511 return 0 | |
| 512 | |
| 513 | |
| 514 if __name__ == "__main__": | |
| 515 main() |
