Mercurial > repos > ebi-gxa > decoupler_pathway_inference
comparison decoupler_pseudobulk.py @ 3:c6787c2aee46 draft
planemo upload for repository https://github.com/ebi-gene-expression-group/container-galaxy-sc-tertiary/ commit eea5c13f9e6e070a2359c59400773b01f9cd7567
| author | ebi-gxa | 
|---|---|
| date | Mon, 15 Jul 2024 10:56:37 +0000 | 
| parents | 77d680b36e23 | 
| children | 6c30272fb587 | 
   comparison
  equal
  deleted
  inserted
  replaced
| 2:82b7cd3e1bbd | 3:c6787c2aee46 | 
|---|---|
| 38 if index_value and index_value[0].isdigit(): | 38 if index_value and index_value[0].isdigit(): | 
| 39 return "C" + index_value | 39 return "C" + index_value | 
| 40 return index_value | 40 return index_value | 
| 41 | 41 | 
| 42 | 42 | 
| 43 def genes_to_ignore_per_contrast_field( | |
| 44 count_matrix_df, | |
| 45 samples_metadata, | |
| 46 sample_metadata_col_contrasts, | |
| 47 min_counts_per_sample=5, | |
| 48 use_cpms=False, | |
| 49 ): | |
| 50 """ | |
| 51 # This function calculates the genes to ignore per contrast field | |
| 52 # (e.g., bulk_labels, louvain). | |
| 53 # It does this by first getting the count matrix for each group, | |
| 54 # then identifying genes with a count below a specified threshold. | |
| 55 # The genes to ignore are those that are present in more than a specified | |
| 56 # number of groups. | |
| 57 | |
| 58 >>> import pandas as pd | |
| 59 >>> samples_metadata = pd.DataFrame({'sample': | |
| 60 ... ['S1', 'S2', 'S3', | |
| 61 ... 'S4', 'S5', 'S6'], | |
| 62 ... 'contrast_field': | |
| 63 ... ['A', 'A', 'A', 'B', 'B', 'B']}) | |
| 64 >>> count_matrix_df = pd.DataFrame( | |
| 65 ... {'S1': | |
| 66 ... [30, 1, 40, 50, 30], | |
| 67 ... 'S2': | |
| 68 ... [40, 2, 60, 50, 80], | |
| 69 ... 'S3': | |
| 70 ... [80, 1, 60, 50, 50], | |
| 71 ... 'S4': [1, 50, 50, 50, 2], | |
| 72 ... 'S5': [3, 40, 40, 40, 2], | |
| 73 ... 'S6': [0, 50, 50, 50, 1]}) | |
| 74 >>> count_matrix_df.index = ['Gene1', 'Gene2', 'Gene3', 'Gene4', 'Gene5'] | |
| 75 >>> df = genes_to_ignore_per_contrast_field(count_matrix_df, | |
| 76 ... samples_metadata, min_counts_per_sample=5, | |
| 77 ... sample_metadata_col_contrasts='contrast_field') | |
| 78 >>> df[df['contrast_field'] == 'A'].genes_to_ignore.tolist()[0] | |
| 79 'Gene2' | |
| 80 >>> df[df['contrast_field'] == 'B'].genes_to_ignore.tolist()[0] | |
| 81 'Gene1' | |
| 82 >>> df[df['contrast_field'] == 'B'].genes_to_ignore.tolist()[1] | |
| 83 'Gene5' | |
| 84 """ | |
| 85 | |
| 86 # Initialize a dictionary to store the genes to ignore per contrast field | |
| 87 contrast_fields = [] | |
| 88 genes_to_ignore = [] | |
| 89 | |
| 90 # Iterate over the contrast fields | |
| 91 for contrast_field in samples_metadata[ | |
| 92 sample_metadata_col_contrasts | |
| 93 ].unique(): | |
| 94 # Get the count matrix for the current contrast field | |
| 95 count_matrix_field = count_matrix_df.loc[ | |
| 96 :, | |
| 97 ( | |
| 98 samples_metadata[sample_metadata_col_contrasts] | |
| 99 == contrast_field | |
| 100 ).tolist(), | |
| 101 ] | |
| 102 | |
| 103 # We derive min_counts from the number of samples with that | |
| 104 # contrast_field value | |
| 105 min_counts = count_matrix_field.shape[1] * min_counts_per_sample | |
| 106 | |
| 107 if use_cpms: | |
| 108 # Convert counts to counts per million (CPM) | |
| 109 count_matrix_field = ( | |
| 110 count_matrix_field.div(count_matrix_field.sum(axis=1), axis=0) | |
| 111 * 1e6 | |
| 112 ) | |
| 113 min_counts = 1 # use 1 CPM | |
| 114 | |
| 115 # Calculate the total number of cells in the current contrast field | |
| 116 # (this produces a vector of counts per gene) | |
| 117 total_counts_per_gene = count_matrix_field.sum(axis=1) | |
| 118 | |
| 119 # Identify genes with a count below the specified threshold | |
| 120 genes = total_counts_per_gene[ | |
| 121 total_counts_per_gene < min_counts | |
| 122 ].index.tolist() | |
| 123 if len(genes) > 0: | |
| 124 # genes_to_ignore[contrast_field] = " ".join(genes) | |
| 125 for gene in genes: | |
| 126 genes_to_ignore.append(gene) | |
| 127 contrast_fields.append(contrast_field) | |
| 128 # transform gene_to_ignore to a DataFrame | |
| 129 # genes_to_ignore_df = pd.DataFrame(genes_to_ignore.items(), | |
| 130 # columns=["contrast_field", "genes_to_ignore"]) | |
| 131 genes_to_ignore_df = pd.DataFrame( | |
| 132 {"contrast_field": contrast_fields, "genes_to_ignore": genes_to_ignore} | |
| 133 ) | |
| 134 return genes_to_ignore_df | |
| 135 | |
| 136 | |
| 43 # write results for loading into DESeq2 | 137 # write results for loading into DESeq2 | 
| 44 def write_DESeq2_inputs(pdata, layer=None, output_dir="", factor_fields=None): | 138 def write_DESeq2_inputs( | 
| 139 pdata, | |
| 140 layer=None, | |
| 141 output_dir="", | |
| 142 factor_fields=None, | |
| 143 min_counts_per_sample_marking=20, | |
| 144 ): | |
| 45 """ | 145 """ | 
| 46 >>> import scanpy as sc | 146 >>> import scanpy as sc | 
| 47 >>> adata = sc.datasets.pbmc68k_reduced() | 147 >>> adata = sc.datasets.pbmc68k_reduced() | 
| 48 >>> adata.X = abs(adata.X).astype(int) | 148 >>> adata.X = abs(adata.X).astype(int) | 
| 49 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") | 149 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") | 
| 60 obs_for_deseq.index = obs_for_deseq.index.str.replace(" ", "_") | 160 obs_for_deseq.index = obs_for_deseq.index.str.replace(" ", "_") | 
| 61 col_metadata_file = f"{output_dir}col_metadata.tsv" | 161 col_metadata_file = f"{output_dir}col_metadata.tsv" | 
| 62 # write obs to a col_metadata file | 162 # write obs to a col_metadata file | 
| 63 if factor_fields: | 163 if factor_fields: | 
| 64 # only output the index plus the columns in factor_fields in that order | 164 # only output the index plus the columns in factor_fields in that order | 
| 65 obs_for_deseq[factor_fields].to_csv(col_metadata_file, sep="\t", index=True) | 165 obs_for_deseq[factor_fields].to_csv( | 
| 166 col_metadata_file, sep="\t", index=True | |
| 167 ) | |
| 66 else: | 168 else: | 
| 67 obs_for_deseq.to_csv(col_metadata_file, sep="\t", index=True) | 169 obs_for_deseq.to_csv(col_metadata_file, sep="\t", index=True) | 
| 68 # write var to a gene_metadata file | 170 # write var to a gene_metadata file | 
| 69 pdata.var.to_csv(f"{output_dir}gene_metadata.tsv", sep="\t", index=True) | 171 pdata.var.to_csv(f"{output_dir}gene_metadata.tsv", sep="\t", index=True) | 
| 70 # write the counts matrix of a specified layer to file | 172 # write the counts matrix of a specified layer to file | 
| 71 if layer is None: | 173 if layer is None: | 
| 72 # write the X numpy matrix transposed to file | 174 # write the X numpy matrix transposed to file | 
| 73 df = pd.DataFrame(pdata.X.T, index=pdata.var.index, columns=obs_for_deseq.index) | 175 df = pd.DataFrame( | 
| 176 pdata.X.T, index=pdata.var.index, columns=obs_for_deseq.index | |
| 177 ) | |
| 74 else: | 178 else: | 
| 75 df = pd.DataFrame( | 179 df = pd.DataFrame( | 
| 76 pdata.layers[layer].T, index=pdata.var.index, columns=obs_for_deseq.index | 180 pdata.layers[layer].T, | 
| 181 index=pdata.var.index, | |
| 182 columns=obs_for_deseq.index, | |
| 77 ) | 183 ) | 
| 78 df.to_csv(f"{output_dir}counts_matrix.tsv", sep="\t", index_label="") | 184 df.to_csv(f"{output_dir}counts_matrix.tsv", sep="\t", index_label="") | 
| 185 | |
| 186 if factor_fields: | |
| 187 df_genes_ignore = genes_to_ignore_per_contrast_field( | |
| 188 count_matrix_df=df, | |
| 189 samples_metadata=obs_for_deseq, | |
| 190 sample_metadata_col_contrasts=factor_fields[0], | |
| 191 min_counts_per_sample=min_counts_per_sample_marking, | |
| 192 ) | |
| 193 df_genes_ignore.to_csv( | |
| 194 f"{output_dir}genes_to_ignore_per_contrast_field.tsv", sep="\t" | |
| 195 ) | |
| 79 | 196 | 
| 80 | 197 | 
| 81 def plot_pseudobulk_samples( | 198 def plot_pseudobulk_samples( | 
| 82 pseudobulk_data, | 199 pseudobulk_data, | 
| 83 groupby, | 200 groupby, | 
| 87 """ | 204 """ | 
| 88 >>> import scanpy as sc | 205 >>> import scanpy as sc | 
| 89 >>> adata = sc.datasets.pbmc68k_reduced() | 206 >>> adata = sc.datasets.pbmc68k_reduced() | 
| 90 >>> adata.X = abs(adata.X).astype(int) | 207 >>> adata.X = abs(adata.X).astype(int) | 
| 91 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") | 208 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") | 
| 92 >>> plot_pseudobulk_samples(pseudobulk, groupby=["bulk_labels", "louvain"], figsize=(10, 10)) | 209 >>> plot_pseudobulk_samples(pseudobulk, | 
| 210 ... groupby=["bulk_labels", "louvain"], | |
| 211 ... figsize=(10, 10)) | |
| 93 """ | 212 """ | 
| 94 fig = decoupler.plot_psbulk_samples( | 213 fig = decoupler.plot_psbulk_samples( | 
| 95 pseudobulk_data, groupby=groupby, figsize=figsize, return_fig=True | 214 pseudobulk_data, groupby=groupby, figsize=figsize, return_fig=True | 
| 96 ) | 215 ) | 
| 97 if save_path: | 216 if save_path: | 
| 99 else: | 218 else: | 
| 100 fig.show() | 219 fig.show() | 
| 101 | 220 | 
| 102 | 221 | 
| 103 def plot_filter_by_expr( | 222 def plot_filter_by_expr( | 
| 104 pseudobulk_data, group, min_count=None, min_total_count=None, save_path=None | 223 pseudobulk_data, | 
| 224 group, | |
| 225 min_count=None, | |
| 226 min_total_count=None, | |
| 227 save_path=None, | |
| 105 ): | 228 ): | 
| 106 """ | 229 """ | 
| 107 >>> import scanpy as sc | 230 >>> import scanpy as sc | 
| 108 >>> adata = sc.datasets.pbmc68k_reduced() | 231 >>> adata = sc.datasets.pbmc68k_reduced() | 
| 109 >>> adata.X = abs(adata.X).astype(int) | 232 >>> adata.X = abs(adata.X).astype(int) | 
| 110 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") | 233 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") | 
| 111 >>> plot_filter_by_expr(pseudobulk, group="bulk_labels", min_count=10, min_total_count=200) | 234 >>> plot_filter_by_expr(pseudobulk, group="bulk_labels", | 
| 235 ... min_count=10, min_total_count=200) | |
| 112 """ | 236 """ | 
| 113 fig = decoupler.plot_filter_by_expr( | 237 fig = decoupler.plot_filter_by_expr( | 
| 114 pseudobulk_data, | 238 pseudobulk_data, | 
| 115 group=group, | 239 group=group, | 
| 116 min_count=min_count, | 240 min_count=min_count, | 
| 127 """ | 251 """ | 
| 128 >>> import scanpy as sc | 252 >>> import scanpy as sc | 
| 129 >>> adata = sc.datasets.pbmc68k_reduced() | 253 >>> adata = sc.datasets.pbmc68k_reduced() | 
| 130 >>> adata.X = abs(adata.X).astype(int) | 254 >>> adata.X = abs(adata.X).astype(int) | 
| 131 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") | 255 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") | 
| 132 >>> pdata_filt = filter_by_expr(pseudobulk, min_count=10, min_total_count=200) | 256 >>> pdata_filt = filter_by_expr(pseudobulk, | 
| 257 ... min_count=10, min_total_count=200) | |
| 133 """ | 258 """ | 
| 134 genes = decoupler.filter_by_expr( | 259 genes = decoupler.filter_by_expr( | 
| 135 pdata, min_count=min_count, min_total_count=min_total_count | 260 pdata, min_count=min_count, min_total_count=min_total_count | 
| 136 ) | 261 ) | 
| 137 return pdata[:, genes].copy() | 262 return pdata[:, genes].copy() | 
| 148 if context: | 273 if context: | 
| 149 legend = f", passed in {context}," | 274 legend = f", passed in {context}," | 
| 150 if obs: | 275 if obs: | 
| 151 if not set(fields).issubset(set(adata.obs.columns)): | 276 if not set(fields).issubset(set(adata.obs.columns)): | 
| 152 raise ValueError( | 277 raise ValueError( | 
| 153 f"Some of the following fields {legend} are not present in adata.obs: {fields}. Possible fields are: {list(set(adata.obs.columns))}" | 278 f"Some of the following fields {legend} are not present \ | 
| 279 in adata.obs: {fields}. \ | |
| 280 Possible fields are: {list(set(adata.obs.columns))}" | |
| 154 ) | 281 ) | 
| 155 else: | 282 else: | 
| 156 if not set(fields).issubset(set(adata.var.columns)): | 283 if not set(fields).issubset(set(adata.var.columns)): | 
| 157 raise ValueError( | 284 raise ValueError( | 
| 158 f"Some of the following fields {legend} are not present in adata.var: {fields}. Possible fields are: {list(set(adata.var.columns))}" | 285 f"Some of the following fields {legend} are not present \ | 
| 286 in adata.var: {fields}. \ | |
| 287 Possible fields are: {list(set(adata.var.columns))}" | |
| 159 ) | 288 ) | 
| 160 | 289 | 
| 161 | 290 | 
| 162 def main(args): | 291 def main(args): | 
| 163 # Load AnnData object from file | 292 # Load AnnData object from file | 
| 217 | 346 | 
| 218 pseudobulk_data = filtered_adata | 347 pseudobulk_data = filtered_adata | 
| 219 | 348 | 
| 220 # Save the pseudobulk data | 349 # Save the pseudobulk data | 
| 221 if args.anndata_output_path: | 350 if args.anndata_output_path: | 
| 222 pseudobulk_data.write_h5ad(args.anndata_output_path, compression="gzip") | 351 pseudobulk_data.write_h5ad( | 
| 352 args.anndata_output_path, compression="gzip" | |
| 353 ) | |
| 223 | 354 | 
| 224 write_DESeq2_inputs( | 355 write_DESeq2_inputs( | 
| 225 pseudobulk_data, output_dir=args.deseq2_output_path, factor_fields=factor_fields | 356 pseudobulk_data, | 
| 357 output_dir=args.deseq2_output_path, | |
| 358 factor_fields=factor_fields, | |
| 359 min_counts_per_sample_marking=args.min_counts_per_sample_marking, | |
| 226 ) | 360 ) | 
| 227 | 361 | 
| 228 | 362 | 
| 229 def merge_adata_obs_fields(obs_fields_to_merge, adata): | 363 def merge_adata_obs_fields(obs_fields_to_merge, adata): | 
| 230 """ | 364 """ | 
| 252 dtype='object') | 386 dtype='object') | 
| 253 """ | 387 """ | 
| 254 field_name = "_".join(obs_fields_to_merge) | 388 field_name = "_".join(obs_fields_to_merge) | 
| 255 for field in obs_fields_to_merge: | 389 for field in obs_fields_to_merge: | 
| 256 if field not in adata.obs.columns: | 390 if field not in adata.obs.columns: | 
| 257 raise ValueError(f"The '{field}' column is not present in adata.obs.") | 391 raise ValueError( | 
| 392 f"The '{field}' column is not present in adata.obs." | |
| 393 ) | |
| 258 if field_name not in adata.obs.columns: | 394 if field_name not in adata.obs.columns: | 
| 259 adata.obs[field_name] = adata.obs[field].astype(str) | 395 adata.obs[field_name] = adata.obs[field].astype(str) | 
| 260 else: | 396 else: | 
| 261 adata.obs[field_name] = ( | 397 adata.obs[field_name] = ( | 
| 262 adata.obs[field_name] + "_" + adata.obs[field].astype(str) | 398 adata.obs[field_name] + "_" + adata.obs[field].astype(str) | 
| 269 parser = argparse.ArgumentParser( | 405 parser = argparse.ArgumentParser( | 
| 270 description="Perform pseudobulk analysis on an AnnData object" | 406 description="Perform pseudobulk analysis on an AnnData object" | 
| 271 ) | 407 ) | 
| 272 | 408 | 
| 273 # Add arguments | 409 # Add arguments | 
| 274 parser.add_argument("adata_file", type=str, help="Path to the AnnData file") | 410 parser.add_argument( | 
| 411 "adata_file", type=str, help="Path to the AnnData file" | |
| 412 ) | |
| 275 parser.add_argument( | 413 parser.add_argument( | 
| 276 "-m", | 414 "-m", | 
| 277 "--adata_obs_fields_to_merge", | 415 "--adata_obs_fields_to_merge", | 
| 278 type=str, | 416 type=str, | 
| 279 help="Fields in adata.obs to merge, comma separated. You can have more than one set of fields, separated by semi-colon ;", | 417 help="Fields in adata.obs to merge, comma separated. \ | 
| 418 You can have more than one set of fields, \ | |
| 419 separated by semi-colon ;", | |
| 280 ) | 420 ) | 
| 281 parser.add_argument( | 421 parser.add_argument( | 
| 282 "--groupby", | 422 "--groupby", | 
| 283 type=str, | 423 type=str, | 
| 284 required=True, | 424 required=True, | 
| 326 "--min_counts", | 466 "--min_counts", | 
| 327 type=int, | 467 type=int, | 
| 328 help="Minimum count threshold for filtering by expression", | 468 help="Minimum count threshold for filtering by expression", | 
| 329 ) | 469 ) | 
| 330 parser.add_argument( | 470 parser.add_argument( | 
| 471 "--min_counts_per_sample_marking", | |
| 472 type=int, | |
| 473 default=20, | |
| 474 help="Minimum count threshold per sample for \ | |
| 475 marking genes to be ignored after DE", | |
| 476 ) | |
| 477 parser.add_argument( | |
| 331 "--min_total_counts", | 478 "--min_total_counts", | 
| 332 type=int, | 479 type=int, | 
| 333 help="Minimum total count threshold for filtering by expression", | 480 help="Minimum total count threshold for filtering by expression", | 
| 334 ) | 481 ) | 
| 335 parser.add_argument( | 482 parser.add_argument( | 
| 336 "--anndata_output_path", | 483 "--anndata_output_path", | 
| 337 type=str, | 484 type=str, | 
| 338 help="Path to save the filtered AnnData object or pseudobulk data", | 485 help="Path to save the filtered AnnData object or pseudobulk data", | 
| 339 ) | 486 ) | 
| 340 parser.add_argument( | 487 parser.add_argument( | 
| 341 "--filter_expr", action="store_true", help="Enable filtering by expression" | 488 "--filter_expr", | 
| 489 action="store_true", | |
| 490 help="Enable filtering by expression", | |
| 342 ) | 491 ) | 
| 343 parser.add_argument( | 492 parser.add_argument( | 
| 344 "--factor_fields", | 493 "--factor_fields", | 
| 345 type=str, | 494 type=str, | 
| 346 help="Comma separated list of fields for the factors", | 495 help="Comma separated list of fields for the factors", | 
| 356 type=int, | 505 type=int, | 
| 357 default=[10, 10], | 506 default=[10, 10], | 
| 358 nargs=2, | 507 nargs=2, | 
| 359 help="Size of the samples plot as a tuple (two arguments)", | 508 help="Size of the samples plot as a tuple (two arguments)", | 
| 360 ) | 509 ) | 
| 361 parser.add_argument("--plot_filtering_figsize", type=int, default=[10, 10], nargs=2) | 510 parser.add_argument( | 
| 511 "--plot_filtering_figsize", type=int, default=[10, 10], nargs=2 | |
| 512 ) | |
| 362 | 513 | 
| 363 # Parse the command line arguments | 514 # Parse the command line arguments | 
| 364 args = parser.parse_args() | 515 args = parser.parse_args() | 
| 365 | 516 | 
| 366 # Call the main function | 517 # Call the main function | 
