Mercurial > repos > ebi-gxa > decoupler_pathway_inference
comparison decoupler_pseudobulk.py @ 3:c6787c2aee46 draft
planemo upload for repository https://github.com/ebi-gene-expression-group/container-galaxy-sc-tertiary/ commit eea5c13f9e6e070a2359c59400773b01f9cd7567
author | ebi-gxa |
---|---|
date | Mon, 15 Jul 2024 10:56:37 +0000 |
parents | 77d680b36e23 |
children | 6c30272fb587 |
comparison
equal
deleted
inserted
replaced
2:82b7cd3e1bbd | 3:c6787c2aee46 |
---|---|
38 if index_value and index_value[0].isdigit(): | 38 if index_value and index_value[0].isdigit(): |
39 return "C" + index_value | 39 return "C" + index_value |
40 return index_value | 40 return index_value |
41 | 41 |
42 | 42 |
43 def genes_to_ignore_per_contrast_field( | |
44 count_matrix_df, | |
45 samples_metadata, | |
46 sample_metadata_col_contrasts, | |
47 min_counts_per_sample=5, | |
48 use_cpms=False, | |
49 ): | |
50 """ | |
51 # This function calculates the genes to ignore per contrast field | |
52 # (e.g., bulk_labels, louvain). | |
53 # It does this by first getting the count matrix for each group, | |
54 # then identifying genes with a count below a specified threshold. | |
55 # The genes to ignore are those that are present in more than a specified | |
56 # number of groups. | |
57 | |
58 >>> import pandas as pd | |
59 >>> samples_metadata = pd.DataFrame({'sample': | |
60 ... ['S1', 'S2', 'S3', | |
61 ... 'S4', 'S5', 'S6'], | |
62 ... 'contrast_field': | |
63 ... ['A', 'A', 'A', 'B', 'B', 'B']}) | |
64 >>> count_matrix_df = pd.DataFrame( | |
65 ... {'S1': | |
66 ... [30, 1, 40, 50, 30], | |
67 ... 'S2': | |
68 ... [40, 2, 60, 50, 80], | |
69 ... 'S3': | |
70 ... [80, 1, 60, 50, 50], | |
71 ... 'S4': [1, 50, 50, 50, 2], | |
72 ... 'S5': [3, 40, 40, 40, 2], | |
73 ... 'S6': [0, 50, 50, 50, 1]}) | |
74 >>> count_matrix_df.index = ['Gene1', 'Gene2', 'Gene3', 'Gene4', 'Gene5'] | |
75 >>> df = genes_to_ignore_per_contrast_field(count_matrix_df, | |
76 ... samples_metadata, min_counts_per_sample=5, | |
77 ... sample_metadata_col_contrasts='contrast_field') | |
78 >>> df[df['contrast_field'] == 'A'].genes_to_ignore.tolist()[0] | |
79 'Gene2' | |
80 >>> df[df['contrast_field'] == 'B'].genes_to_ignore.tolist()[0] | |
81 'Gene1' | |
82 >>> df[df['contrast_field'] == 'B'].genes_to_ignore.tolist()[1] | |
83 'Gene5' | |
84 """ | |
85 | |
86 # Initialize a dictionary to store the genes to ignore per contrast field | |
87 contrast_fields = [] | |
88 genes_to_ignore = [] | |
89 | |
90 # Iterate over the contrast fields | |
91 for contrast_field in samples_metadata[ | |
92 sample_metadata_col_contrasts | |
93 ].unique(): | |
94 # Get the count matrix for the current contrast field | |
95 count_matrix_field = count_matrix_df.loc[ | |
96 :, | |
97 ( | |
98 samples_metadata[sample_metadata_col_contrasts] | |
99 == contrast_field | |
100 ).tolist(), | |
101 ] | |
102 | |
103 # We derive min_counts from the number of samples with that | |
104 # contrast_field value | |
105 min_counts = count_matrix_field.shape[1] * min_counts_per_sample | |
106 | |
107 if use_cpms: | |
108 # Convert counts to counts per million (CPM) | |
109 count_matrix_field = ( | |
110 count_matrix_field.div(count_matrix_field.sum(axis=1), axis=0) | |
111 * 1e6 | |
112 ) | |
113 min_counts = 1 # use 1 CPM | |
114 | |
115 # Calculate the total number of cells in the current contrast field | |
116 # (this produces a vector of counts per gene) | |
117 total_counts_per_gene = count_matrix_field.sum(axis=1) | |
118 | |
119 # Identify genes with a count below the specified threshold | |
120 genes = total_counts_per_gene[ | |
121 total_counts_per_gene < min_counts | |
122 ].index.tolist() | |
123 if len(genes) > 0: | |
124 # genes_to_ignore[contrast_field] = " ".join(genes) | |
125 for gene in genes: | |
126 genes_to_ignore.append(gene) | |
127 contrast_fields.append(contrast_field) | |
128 # transform gene_to_ignore to a DataFrame | |
129 # genes_to_ignore_df = pd.DataFrame(genes_to_ignore.items(), | |
130 # columns=["contrast_field", "genes_to_ignore"]) | |
131 genes_to_ignore_df = pd.DataFrame( | |
132 {"contrast_field": contrast_fields, "genes_to_ignore": genes_to_ignore} | |
133 ) | |
134 return genes_to_ignore_df | |
135 | |
136 | |
43 # write results for loading into DESeq2 | 137 # write results for loading into DESeq2 |
44 def write_DESeq2_inputs(pdata, layer=None, output_dir="", factor_fields=None): | 138 def write_DESeq2_inputs( |
139 pdata, | |
140 layer=None, | |
141 output_dir="", | |
142 factor_fields=None, | |
143 min_counts_per_sample_marking=20, | |
144 ): | |
45 """ | 145 """ |
46 >>> import scanpy as sc | 146 >>> import scanpy as sc |
47 >>> adata = sc.datasets.pbmc68k_reduced() | 147 >>> adata = sc.datasets.pbmc68k_reduced() |
48 >>> adata.X = abs(adata.X).astype(int) | 148 >>> adata.X = abs(adata.X).astype(int) |
49 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") | 149 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") |
60 obs_for_deseq.index = obs_for_deseq.index.str.replace(" ", "_") | 160 obs_for_deseq.index = obs_for_deseq.index.str.replace(" ", "_") |
61 col_metadata_file = f"{output_dir}col_metadata.tsv" | 161 col_metadata_file = f"{output_dir}col_metadata.tsv" |
62 # write obs to a col_metadata file | 162 # write obs to a col_metadata file |
63 if factor_fields: | 163 if factor_fields: |
64 # only output the index plus the columns in factor_fields in that order | 164 # only output the index plus the columns in factor_fields in that order |
65 obs_for_deseq[factor_fields].to_csv(col_metadata_file, sep="\t", index=True) | 165 obs_for_deseq[factor_fields].to_csv( |
166 col_metadata_file, sep="\t", index=True | |
167 ) | |
66 else: | 168 else: |
67 obs_for_deseq.to_csv(col_metadata_file, sep="\t", index=True) | 169 obs_for_deseq.to_csv(col_metadata_file, sep="\t", index=True) |
68 # write var to a gene_metadata file | 170 # write var to a gene_metadata file |
69 pdata.var.to_csv(f"{output_dir}gene_metadata.tsv", sep="\t", index=True) | 171 pdata.var.to_csv(f"{output_dir}gene_metadata.tsv", sep="\t", index=True) |
70 # write the counts matrix of a specified layer to file | 172 # write the counts matrix of a specified layer to file |
71 if layer is None: | 173 if layer is None: |
72 # write the X numpy matrix transposed to file | 174 # write the X numpy matrix transposed to file |
73 df = pd.DataFrame(pdata.X.T, index=pdata.var.index, columns=obs_for_deseq.index) | 175 df = pd.DataFrame( |
176 pdata.X.T, index=pdata.var.index, columns=obs_for_deseq.index | |
177 ) | |
74 else: | 178 else: |
75 df = pd.DataFrame( | 179 df = pd.DataFrame( |
76 pdata.layers[layer].T, index=pdata.var.index, columns=obs_for_deseq.index | 180 pdata.layers[layer].T, |
181 index=pdata.var.index, | |
182 columns=obs_for_deseq.index, | |
77 ) | 183 ) |
78 df.to_csv(f"{output_dir}counts_matrix.tsv", sep="\t", index_label="") | 184 df.to_csv(f"{output_dir}counts_matrix.tsv", sep="\t", index_label="") |
185 | |
186 if factor_fields: | |
187 df_genes_ignore = genes_to_ignore_per_contrast_field( | |
188 count_matrix_df=df, | |
189 samples_metadata=obs_for_deseq, | |
190 sample_metadata_col_contrasts=factor_fields[0], | |
191 min_counts_per_sample=min_counts_per_sample_marking, | |
192 ) | |
193 df_genes_ignore.to_csv( | |
194 f"{output_dir}genes_to_ignore_per_contrast_field.tsv", sep="\t" | |
195 ) | |
79 | 196 |
80 | 197 |
81 def plot_pseudobulk_samples( | 198 def plot_pseudobulk_samples( |
82 pseudobulk_data, | 199 pseudobulk_data, |
83 groupby, | 200 groupby, |
87 """ | 204 """ |
88 >>> import scanpy as sc | 205 >>> import scanpy as sc |
89 >>> adata = sc.datasets.pbmc68k_reduced() | 206 >>> adata = sc.datasets.pbmc68k_reduced() |
90 >>> adata.X = abs(adata.X).astype(int) | 207 >>> adata.X = abs(adata.X).astype(int) |
91 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") | 208 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") |
92 >>> plot_pseudobulk_samples(pseudobulk, groupby=["bulk_labels", "louvain"], figsize=(10, 10)) | 209 >>> plot_pseudobulk_samples(pseudobulk, |
210 ... groupby=["bulk_labels", "louvain"], | |
211 ... figsize=(10, 10)) | |
93 """ | 212 """ |
94 fig = decoupler.plot_psbulk_samples( | 213 fig = decoupler.plot_psbulk_samples( |
95 pseudobulk_data, groupby=groupby, figsize=figsize, return_fig=True | 214 pseudobulk_data, groupby=groupby, figsize=figsize, return_fig=True |
96 ) | 215 ) |
97 if save_path: | 216 if save_path: |
99 else: | 218 else: |
100 fig.show() | 219 fig.show() |
101 | 220 |
102 | 221 |
103 def plot_filter_by_expr( | 222 def plot_filter_by_expr( |
104 pseudobulk_data, group, min_count=None, min_total_count=None, save_path=None | 223 pseudobulk_data, |
224 group, | |
225 min_count=None, | |
226 min_total_count=None, | |
227 save_path=None, | |
105 ): | 228 ): |
106 """ | 229 """ |
107 >>> import scanpy as sc | 230 >>> import scanpy as sc |
108 >>> adata = sc.datasets.pbmc68k_reduced() | 231 >>> adata = sc.datasets.pbmc68k_reduced() |
109 >>> adata.X = abs(adata.X).astype(int) | 232 >>> adata.X = abs(adata.X).astype(int) |
110 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") | 233 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") |
111 >>> plot_filter_by_expr(pseudobulk, group="bulk_labels", min_count=10, min_total_count=200) | 234 >>> plot_filter_by_expr(pseudobulk, group="bulk_labels", |
235 ... min_count=10, min_total_count=200) | |
112 """ | 236 """ |
113 fig = decoupler.plot_filter_by_expr( | 237 fig = decoupler.plot_filter_by_expr( |
114 pseudobulk_data, | 238 pseudobulk_data, |
115 group=group, | 239 group=group, |
116 min_count=min_count, | 240 min_count=min_count, |
127 """ | 251 """ |
128 >>> import scanpy as sc | 252 >>> import scanpy as sc |
129 >>> adata = sc.datasets.pbmc68k_reduced() | 253 >>> adata = sc.datasets.pbmc68k_reduced() |
130 >>> adata.X = abs(adata.X).astype(int) | 254 >>> adata.X = abs(adata.X).astype(int) |
131 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") | 255 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") |
132 >>> pdata_filt = filter_by_expr(pseudobulk, min_count=10, min_total_count=200) | 256 >>> pdata_filt = filter_by_expr(pseudobulk, |
257 ... min_count=10, min_total_count=200) | |
133 """ | 258 """ |
134 genes = decoupler.filter_by_expr( | 259 genes = decoupler.filter_by_expr( |
135 pdata, min_count=min_count, min_total_count=min_total_count | 260 pdata, min_count=min_count, min_total_count=min_total_count |
136 ) | 261 ) |
137 return pdata[:, genes].copy() | 262 return pdata[:, genes].copy() |
148 if context: | 273 if context: |
149 legend = f", passed in {context}," | 274 legend = f", passed in {context}," |
150 if obs: | 275 if obs: |
151 if not set(fields).issubset(set(adata.obs.columns)): | 276 if not set(fields).issubset(set(adata.obs.columns)): |
152 raise ValueError( | 277 raise ValueError( |
153 f"Some of the following fields {legend} are not present in adata.obs: {fields}. Possible fields are: {list(set(adata.obs.columns))}" | 278 f"Some of the following fields {legend} are not present \ |
279 in adata.obs: {fields}. \ | |
280 Possible fields are: {list(set(adata.obs.columns))}" | |
154 ) | 281 ) |
155 else: | 282 else: |
156 if not set(fields).issubset(set(adata.var.columns)): | 283 if not set(fields).issubset(set(adata.var.columns)): |
157 raise ValueError( | 284 raise ValueError( |
158 f"Some of the following fields {legend} are not present in adata.var: {fields}. Possible fields are: {list(set(adata.var.columns))}" | 285 f"Some of the following fields {legend} are not present \ |
286 in adata.var: {fields}. \ | |
287 Possible fields are: {list(set(adata.var.columns))}" | |
159 ) | 288 ) |
160 | 289 |
161 | 290 |
162 def main(args): | 291 def main(args): |
163 # Load AnnData object from file | 292 # Load AnnData object from file |
217 | 346 |
218 pseudobulk_data = filtered_adata | 347 pseudobulk_data = filtered_adata |
219 | 348 |
220 # Save the pseudobulk data | 349 # Save the pseudobulk data |
221 if args.anndata_output_path: | 350 if args.anndata_output_path: |
222 pseudobulk_data.write_h5ad(args.anndata_output_path, compression="gzip") | 351 pseudobulk_data.write_h5ad( |
352 args.anndata_output_path, compression="gzip" | |
353 ) | |
223 | 354 |
224 write_DESeq2_inputs( | 355 write_DESeq2_inputs( |
225 pseudobulk_data, output_dir=args.deseq2_output_path, factor_fields=factor_fields | 356 pseudobulk_data, |
357 output_dir=args.deseq2_output_path, | |
358 factor_fields=factor_fields, | |
359 min_counts_per_sample_marking=args.min_counts_per_sample_marking, | |
226 ) | 360 ) |
227 | 361 |
228 | 362 |
229 def merge_adata_obs_fields(obs_fields_to_merge, adata): | 363 def merge_adata_obs_fields(obs_fields_to_merge, adata): |
230 """ | 364 """ |
252 dtype='object') | 386 dtype='object') |
253 """ | 387 """ |
254 field_name = "_".join(obs_fields_to_merge) | 388 field_name = "_".join(obs_fields_to_merge) |
255 for field in obs_fields_to_merge: | 389 for field in obs_fields_to_merge: |
256 if field not in adata.obs.columns: | 390 if field not in adata.obs.columns: |
257 raise ValueError(f"The '{field}' column is not present in adata.obs.") | 391 raise ValueError( |
392 f"The '{field}' column is not present in adata.obs." | |
393 ) | |
258 if field_name not in adata.obs.columns: | 394 if field_name not in adata.obs.columns: |
259 adata.obs[field_name] = adata.obs[field].astype(str) | 395 adata.obs[field_name] = adata.obs[field].astype(str) |
260 else: | 396 else: |
261 adata.obs[field_name] = ( | 397 adata.obs[field_name] = ( |
262 adata.obs[field_name] + "_" + adata.obs[field].astype(str) | 398 adata.obs[field_name] + "_" + adata.obs[field].astype(str) |
269 parser = argparse.ArgumentParser( | 405 parser = argparse.ArgumentParser( |
270 description="Perform pseudobulk analysis on an AnnData object" | 406 description="Perform pseudobulk analysis on an AnnData object" |
271 ) | 407 ) |
272 | 408 |
273 # Add arguments | 409 # Add arguments |
274 parser.add_argument("adata_file", type=str, help="Path to the AnnData file") | 410 parser.add_argument( |
411 "adata_file", type=str, help="Path to the AnnData file" | |
412 ) | |
275 parser.add_argument( | 413 parser.add_argument( |
276 "-m", | 414 "-m", |
277 "--adata_obs_fields_to_merge", | 415 "--adata_obs_fields_to_merge", |
278 type=str, | 416 type=str, |
279 help="Fields in adata.obs to merge, comma separated. You can have more than one set of fields, separated by semi-colon ;", | 417 help="Fields in adata.obs to merge, comma separated. \ |
418 You can have more than one set of fields, \ | |
419 separated by semi-colon ;", | |
280 ) | 420 ) |
281 parser.add_argument( | 421 parser.add_argument( |
282 "--groupby", | 422 "--groupby", |
283 type=str, | 423 type=str, |
284 required=True, | 424 required=True, |
326 "--min_counts", | 466 "--min_counts", |
327 type=int, | 467 type=int, |
328 help="Minimum count threshold for filtering by expression", | 468 help="Minimum count threshold for filtering by expression", |
329 ) | 469 ) |
330 parser.add_argument( | 470 parser.add_argument( |
471 "--min_counts_per_sample_marking", | |
472 type=int, | |
473 default=20, | |
474 help="Minimum count threshold per sample for \ | |
475 marking genes to be ignored after DE", | |
476 ) | |
477 parser.add_argument( | |
331 "--min_total_counts", | 478 "--min_total_counts", |
332 type=int, | 479 type=int, |
333 help="Minimum total count threshold for filtering by expression", | 480 help="Minimum total count threshold for filtering by expression", |
334 ) | 481 ) |
335 parser.add_argument( | 482 parser.add_argument( |
336 "--anndata_output_path", | 483 "--anndata_output_path", |
337 type=str, | 484 type=str, |
338 help="Path to save the filtered AnnData object or pseudobulk data", | 485 help="Path to save the filtered AnnData object or pseudobulk data", |
339 ) | 486 ) |
340 parser.add_argument( | 487 parser.add_argument( |
341 "--filter_expr", action="store_true", help="Enable filtering by expression" | 488 "--filter_expr", |
489 action="store_true", | |
490 help="Enable filtering by expression", | |
342 ) | 491 ) |
343 parser.add_argument( | 492 parser.add_argument( |
344 "--factor_fields", | 493 "--factor_fields", |
345 type=str, | 494 type=str, |
346 help="Comma separated list of fields for the factors", | 495 help="Comma separated list of fields for the factors", |
356 type=int, | 505 type=int, |
357 default=[10, 10], | 506 default=[10, 10], |
358 nargs=2, | 507 nargs=2, |
359 help="Size of the samples plot as a tuple (two arguments)", | 508 help="Size of the samples plot as a tuple (two arguments)", |
360 ) | 509 ) |
361 parser.add_argument("--plot_filtering_figsize", type=int, default=[10, 10], nargs=2) | 510 parser.add_argument( |
511 "--plot_filtering_figsize", type=int, default=[10, 10], nargs=2 | |
512 ) | |
362 | 513 |
363 # Parse the command line arguments | 514 # Parse the command line arguments |
364 args = parser.parse_args() | 515 args = parser.parse_args() |
365 | 516 |
366 # Call the main function | 517 # Call the main function |