comparison decoupler_pseudobulk.py @ 3:c6787c2aee46 draft

planemo upload for repository https://github.com/ebi-gene-expression-group/container-galaxy-sc-tertiary/ commit eea5c13f9e6e070a2359c59400773b01f9cd7567
author ebi-gxa
date Mon, 15 Jul 2024 10:56:37 +0000
parents 77d680b36e23
children 6c30272fb587
comparison
equal deleted inserted replaced
2:82b7cd3e1bbd 3:c6787c2aee46
38 if index_value and index_value[0].isdigit(): 38 if index_value and index_value[0].isdigit():
39 return "C" + index_value 39 return "C" + index_value
40 return index_value 40 return index_value
41 41
42 42
43 def genes_to_ignore_per_contrast_field(
44 count_matrix_df,
45 samples_metadata,
46 sample_metadata_col_contrasts,
47 min_counts_per_sample=5,
48 use_cpms=False,
49 ):
50 """
51 # This function calculates the genes to ignore per contrast field
52 # (e.g., bulk_labels, louvain).
53 # It does this by first getting the count matrix for each group,
54 # then identifying genes with a count below a specified threshold.
55 # The genes to ignore are those that are present in more than a specified
56 # number of groups.
57
58 >>> import pandas as pd
59 >>> samples_metadata = pd.DataFrame({'sample':
60 ... ['S1', 'S2', 'S3',
61 ... 'S4', 'S5', 'S6'],
62 ... 'contrast_field':
63 ... ['A', 'A', 'A', 'B', 'B', 'B']})
64 >>> count_matrix_df = pd.DataFrame(
65 ... {'S1':
66 ... [30, 1, 40, 50, 30],
67 ... 'S2':
68 ... [40, 2, 60, 50, 80],
69 ... 'S3':
70 ... [80, 1, 60, 50, 50],
71 ... 'S4': [1, 50, 50, 50, 2],
72 ... 'S5': [3, 40, 40, 40, 2],
73 ... 'S6': [0, 50, 50, 50, 1]})
74 >>> count_matrix_df.index = ['Gene1', 'Gene2', 'Gene3', 'Gene4', 'Gene5']
75 >>> df = genes_to_ignore_per_contrast_field(count_matrix_df,
76 ... samples_metadata, min_counts_per_sample=5,
77 ... sample_metadata_col_contrasts='contrast_field')
78 >>> df[df['contrast_field'] == 'A'].genes_to_ignore.tolist()[0]
79 'Gene2'
80 >>> df[df['contrast_field'] == 'B'].genes_to_ignore.tolist()[0]
81 'Gene1'
82 >>> df[df['contrast_field'] == 'B'].genes_to_ignore.tolist()[1]
83 'Gene5'
84 """
85
86 # Initialize a dictionary to store the genes to ignore per contrast field
87 contrast_fields = []
88 genes_to_ignore = []
89
90 # Iterate over the contrast fields
91 for contrast_field in samples_metadata[
92 sample_metadata_col_contrasts
93 ].unique():
94 # Get the count matrix for the current contrast field
95 count_matrix_field = count_matrix_df.loc[
96 :,
97 (
98 samples_metadata[sample_metadata_col_contrasts]
99 == contrast_field
100 ).tolist(),
101 ]
102
103 # We derive min_counts from the number of samples with that
104 # contrast_field value
105 min_counts = count_matrix_field.shape[1] * min_counts_per_sample
106
107 if use_cpms:
108 # Convert counts to counts per million (CPM)
109 count_matrix_field = (
110 count_matrix_field.div(count_matrix_field.sum(axis=1), axis=0)
111 * 1e6
112 )
113 min_counts = 1 # use 1 CPM
114
115 # Calculate the total number of cells in the current contrast field
116 # (this produces a vector of counts per gene)
117 total_counts_per_gene = count_matrix_field.sum(axis=1)
118
119 # Identify genes with a count below the specified threshold
120 genes = total_counts_per_gene[
121 total_counts_per_gene < min_counts
122 ].index.tolist()
123 if len(genes) > 0:
124 # genes_to_ignore[contrast_field] = " ".join(genes)
125 for gene in genes:
126 genes_to_ignore.append(gene)
127 contrast_fields.append(contrast_field)
128 # transform gene_to_ignore to a DataFrame
129 # genes_to_ignore_df = pd.DataFrame(genes_to_ignore.items(),
130 # columns=["contrast_field", "genes_to_ignore"])
131 genes_to_ignore_df = pd.DataFrame(
132 {"contrast_field": contrast_fields, "genes_to_ignore": genes_to_ignore}
133 )
134 return genes_to_ignore_df
135
136
43 # write results for loading into DESeq2 137 # write results for loading into DESeq2
44 def write_DESeq2_inputs(pdata, layer=None, output_dir="", factor_fields=None): 138 def write_DESeq2_inputs(
139 pdata,
140 layer=None,
141 output_dir="",
142 factor_fields=None,
143 min_counts_per_sample_marking=20,
144 ):
45 """ 145 """
46 >>> import scanpy as sc 146 >>> import scanpy as sc
47 >>> adata = sc.datasets.pbmc68k_reduced() 147 >>> adata = sc.datasets.pbmc68k_reduced()
48 >>> adata.X = abs(adata.X).astype(int) 148 >>> adata.X = abs(adata.X).astype(int)
49 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") 149 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain")
60 obs_for_deseq.index = obs_for_deseq.index.str.replace(" ", "_") 160 obs_for_deseq.index = obs_for_deseq.index.str.replace(" ", "_")
61 col_metadata_file = f"{output_dir}col_metadata.tsv" 161 col_metadata_file = f"{output_dir}col_metadata.tsv"
62 # write obs to a col_metadata file 162 # write obs to a col_metadata file
63 if factor_fields: 163 if factor_fields:
64 # only output the index plus the columns in factor_fields in that order 164 # only output the index plus the columns in factor_fields in that order
65 obs_for_deseq[factor_fields].to_csv(col_metadata_file, sep="\t", index=True) 165 obs_for_deseq[factor_fields].to_csv(
166 col_metadata_file, sep="\t", index=True
167 )
66 else: 168 else:
67 obs_for_deseq.to_csv(col_metadata_file, sep="\t", index=True) 169 obs_for_deseq.to_csv(col_metadata_file, sep="\t", index=True)
68 # write var to a gene_metadata file 170 # write var to a gene_metadata file
69 pdata.var.to_csv(f"{output_dir}gene_metadata.tsv", sep="\t", index=True) 171 pdata.var.to_csv(f"{output_dir}gene_metadata.tsv", sep="\t", index=True)
70 # write the counts matrix of a specified layer to file 172 # write the counts matrix of a specified layer to file
71 if layer is None: 173 if layer is None:
72 # write the X numpy matrix transposed to file 174 # write the X numpy matrix transposed to file
73 df = pd.DataFrame(pdata.X.T, index=pdata.var.index, columns=obs_for_deseq.index) 175 df = pd.DataFrame(
176 pdata.X.T, index=pdata.var.index, columns=obs_for_deseq.index
177 )
74 else: 178 else:
75 df = pd.DataFrame( 179 df = pd.DataFrame(
76 pdata.layers[layer].T, index=pdata.var.index, columns=obs_for_deseq.index 180 pdata.layers[layer].T,
181 index=pdata.var.index,
182 columns=obs_for_deseq.index,
77 ) 183 )
78 df.to_csv(f"{output_dir}counts_matrix.tsv", sep="\t", index_label="") 184 df.to_csv(f"{output_dir}counts_matrix.tsv", sep="\t", index_label="")
185
186 if factor_fields:
187 df_genes_ignore = genes_to_ignore_per_contrast_field(
188 count_matrix_df=df,
189 samples_metadata=obs_for_deseq,
190 sample_metadata_col_contrasts=factor_fields[0],
191 min_counts_per_sample=min_counts_per_sample_marking,
192 )
193 df_genes_ignore.to_csv(
194 f"{output_dir}genes_to_ignore_per_contrast_field.tsv", sep="\t"
195 )
79 196
80 197
81 def plot_pseudobulk_samples( 198 def plot_pseudobulk_samples(
82 pseudobulk_data, 199 pseudobulk_data,
83 groupby, 200 groupby,
87 """ 204 """
88 >>> import scanpy as sc 205 >>> import scanpy as sc
89 >>> adata = sc.datasets.pbmc68k_reduced() 206 >>> adata = sc.datasets.pbmc68k_reduced()
90 >>> adata.X = abs(adata.X).astype(int) 207 >>> adata.X = abs(adata.X).astype(int)
91 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") 208 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain")
92 >>> plot_pseudobulk_samples(pseudobulk, groupby=["bulk_labels", "louvain"], figsize=(10, 10)) 209 >>> plot_pseudobulk_samples(pseudobulk,
210 ... groupby=["bulk_labels", "louvain"],
211 ... figsize=(10, 10))
93 """ 212 """
94 fig = decoupler.plot_psbulk_samples( 213 fig = decoupler.plot_psbulk_samples(
95 pseudobulk_data, groupby=groupby, figsize=figsize, return_fig=True 214 pseudobulk_data, groupby=groupby, figsize=figsize, return_fig=True
96 ) 215 )
97 if save_path: 216 if save_path:
99 else: 218 else:
100 fig.show() 219 fig.show()
101 220
102 221
103 def plot_filter_by_expr( 222 def plot_filter_by_expr(
104 pseudobulk_data, group, min_count=None, min_total_count=None, save_path=None 223 pseudobulk_data,
224 group,
225 min_count=None,
226 min_total_count=None,
227 save_path=None,
105 ): 228 ):
106 """ 229 """
107 >>> import scanpy as sc 230 >>> import scanpy as sc
108 >>> adata = sc.datasets.pbmc68k_reduced() 231 >>> adata = sc.datasets.pbmc68k_reduced()
109 >>> adata.X = abs(adata.X).astype(int) 232 >>> adata.X = abs(adata.X).astype(int)
110 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") 233 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain")
111 >>> plot_filter_by_expr(pseudobulk, group="bulk_labels", min_count=10, min_total_count=200) 234 >>> plot_filter_by_expr(pseudobulk, group="bulk_labels",
235 ... min_count=10, min_total_count=200)
112 """ 236 """
113 fig = decoupler.plot_filter_by_expr( 237 fig = decoupler.plot_filter_by_expr(
114 pseudobulk_data, 238 pseudobulk_data,
115 group=group, 239 group=group,
116 min_count=min_count, 240 min_count=min_count,
127 """ 251 """
128 >>> import scanpy as sc 252 >>> import scanpy as sc
129 >>> adata = sc.datasets.pbmc68k_reduced() 253 >>> adata = sc.datasets.pbmc68k_reduced()
130 >>> adata.X = abs(adata.X).astype(int) 254 >>> adata.X = abs(adata.X).astype(int)
131 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") 255 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain")
132 >>> pdata_filt = filter_by_expr(pseudobulk, min_count=10, min_total_count=200) 256 >>> pdata_filt = filter_by_expr(pseudobulk,
257 ... min_count=10, min_total_count=200)
133 """ 258 """
134 genes = decoupler.filter_by_expr( 259 genes = decoupler.filter_by_expr(
135 pdata, min_count=min_count, min_total_count=min_total_count 260 pdata, min_count=min_count, min_total_count=min_total_count
136 ) 261 )
137 return pdata[:, genes].copy() 262 return pdata[:, genes].copy()
148 if context: 273 if context:
149 legend = f", passed in {context}," 274 legend = f", passed in {context},"
150 if obs: 275 if obs:
151 if not set(fields).issubset(set(adata.obs.columns)): 276 if not set(fields).issubset(set(adata.obs.columns)):
152 raise ValueError( 277 raise ValueError(
153 f"Some of the following fields {legend} are not present in adata.obs: {fields}. Possible fields are: {list(set(adata.obs.columns))}" 278 f"Some of the following fields {legend} are not present \
279 in adata.obs: {fields}. \
280 Possible fields are: {list(set(adata.obs.columns))}"
154 ) 281 )
155 else: 282 else:
156 if not set(fields).issubset(set(adata.var.columns)): 283 if not set(fields).issubset(set(adata.var.columns)):
157 raise ValueError( 284 raise ValueError(
158 f"Some of the following fields {legend} are not present in adata.var: {fields}. Possible fields are: {list(set(adata.var.columns))}" 285 f"Some of the following fields {legend} are not present \
286 in adata.var: {fields}. \
287 Possible fields are: {list(set(adata.var.columns))}"
159 ) 288 )
160 289
161 290
162 def main(args): 291 def main(args):
163 # Load AnnData object from file 292 # Load AnnData object from file
217 346
218 pseudobulk_data = filtered_adata 347 pseudobulk_data = filtered_adata
219 348
220 # Save the pseudobulk data 349 # Save the pseudobulk data
221 if args.anndata_output_path: 350 if args.anndata_output_path:
222 pseudobulk_data.write_h5ad(args.anndata_output_path, compression="gzip") 351 pseudobulk_data.write_h5ad(
352 args.anndata_output_path, compression="gzip"
353 )
223 354
224 write_DESeq2_inputs( 355 write_DESeq2_inputs(
225 pseudobulk_data, output_dir=args.deseq2_output_path, factor_fields=factor_fields 356 pseudobulk_data,
357 output_dir=args.deseq2_output_path,
358 factor_fields=factor_fields,
359 min_counts_per_sample_marking=args.min_counts_per_sample_marking,
226 ) 360 )
227 361
228 362
229 def merge_adata_obs_fields(obs_fields_to_merge, adata): 363 def merge_adata_obs_fields(obs_fields_to_merge, adata):
230 """ 364 """
252 dtype='object') 386 dtype='object')
253 """ 387 """
254 field_name = "_".join(obs_fields_to_merge) 388 field_name = "_".join(obs_fields_to_merge)
255 for field in obs_fields_to_merge: 389 for field in obs_fields_to_merge:
256 if field not in adata.obs.columns: 390 if field not in adata.obs.columns:
257 raise ValueError(f"The '{field}' column is not present in adata.obs.") 391 raise ValueError(
392 f"The '{field}' column is not present in adata.obs."
393 )
258 if field_name not in adata.obs.columns: 394 if field_name not in adata.obs.columns:
259 adata.obs[field_name] = adata.obs[field].astype(str) 395 adata.obs[field_name] = adata.obs[field].astype(str)
260 else: 396 else:
261 adata.obs[field_name] = ( 397 adata.obs[field_name] = (
262 adata.obs[field_name] + "_" + adata.obs[field].astype(str) 398 adata.obs[field_name] + "_" + adata.obs[field].astype(str)
269 parser = argparse.ArgumentParser( 405 parser = argparse.ArgumentParser(
270 description="Perform pseudobulk analysis on an AnnData object" 406 description="Perform pseudobulk analysis on an AnnData object"
271 ) 407 )
272 408
273 # Add arguments 409 # Add arguments
274 parser.add_argument("adata_file", type=str, help="Path to the AnnData file") 410 parser.add_argument(
411 "adata_file", type=str, help="Path to the AnnData file"
412 )
275 parser.add_argument( 413 parser.add_argument(
276 "-m", 414 "-m",
277 "--adata_obs_fields_to_merge", 415 "--adata_obs_fields_to_merge",
278 type=str, 416 type=str,
279 help="Fields in adata.obs to merge, comma separated. You can have more than one set of fields, separated by semi-colon ;", 417 help="Fields in adata.obs to merge, comma separated. \
418 You can have more than one set of fields, \
419 separated by semi-colon ;",
280 ) 420 )
281 parser.add_argument( 421 parser.add_argument(
282 "--groupby", 422 "--groupby",
283 type=str, 423 type=str,
284 required=True, 424 required=True,
326 "--min_counts", 466 "--min_counts",
327 type=int, 467 type=int,
328 help="Minimum count threshold for filtering by expression", 468 help="Minimum count threshold for filtering by expression",
329 ) 469 )
330 parser.add_argument( 470 parser.add_argument(
471 "--min_counts_per_sample_marking",
472 type=int,
473 default=20,
474 help="Minimum count threshold per sample for \
475 marking genes to be ignored after DE",
476 )
477 parser.add_argument(
331 "--min_total_counts", 478 "--min_total_counts",
332 type=int, 479 type=int,
333 help="Minimum total count threshold for filtering by expression", 480 help="Minimum total count threshold for filtering by expression",
334 ) 481 )
335 parser.add_argument( 482 parser.add_argument(
336 "--anndata_output_path", 483 "--anndata_output_path",
337 type=str, 484 type=str,
338 help="Path to save the filtered AnnData object or pseudobulk data", 485 help="Path to save the filtered AnnData object or pseudobulk data",
339 ) 486 )
340 parser.add_argument( 487 parser.add_argument(
341 "--filter_expr", action="store_true", help="Enable filtering by expression" 488 "--filter_expr",
489 action="store_true",
490 help="Enable filtering by expression",
342 ) 491 )
343 parser.add_argument( 492 parser.add_argument(
344 "--factor_fields", 493 "--factor_fields",
345 type=str, 494 type=str,
346 help="Comma separated list of fields for the factors", 495 help="Comma separated list of fields for the factors",
356 type=int, 505 type=int,
357 default=[10, 10], 506 default=[10, 10],
358 nargs=2, 507 nargs=2,
359 help="Size of the samples plot as a tuple (two arguments)", 508 help="Size of the samples plot as a tuple (two arguments)",
360 ) 509 )
361 parser.add_argument("--plot_filtering_figsize", type=int, default=[10, 10], nargs=2) 510 parser.add_argument(
511 "--plot_filtering_figsize", type=int, default=[10, 10], nargs=2
512 )
362 513
363 # Parse the command line arguments 514 # Parse the command line arguments
364 args = parser.parse_args() 515 args = parser.parse_args()
365 516
366 # Call the main function 517 # Call the main function