comparison decoupler_pseudobulk.py @ 0:77d680b36e23 draft

planemo upload for repository https://github.com/ebi-gene-expression-group/container-galaxy-sc-tertiary/ commit 1034a450c97dcbb77871050cf0c6d3da90dac823
author ebi-gxa
date Fri, 15 Mar 2024 12:17:49 +0000
parents
children c6787c2aee46
comparison
equal deleted inserted replaced
-1:000000000000 0:77d680b36e23
1 import argparse
2
3 import anndata
4 import decoupler
5 import pandas as pd
6
7
8 def get_pseudobulk(
9 adata,
10 sample_col,
11 groups_col,
12 layer=None,
13 mode="sum",
14 min_cells=10,
15 min_counts=1000,
16 use_raw=False,
17 ):
18 """
19 >>> import scanpy as sc
20 >>> adata = sc.datasets.pbmc68k_reduced()
21 >>> adata.X = abs(adata.X).astype(int)
22 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain")
23 """
24
25 return decoupler.get_pseudobulk(
26 adata,
27 sample_col=sample_col,
28 groups_col=groups_col,
29 layer=layer,
30 mode=mode,
31 use_raw=use_raw,
32 min_cells=min_cells,
33 min_counts=min_counts,
34 )
35
36
37 def prepend_c_to_index(index_value):
38 if index_value and index_value[0].isdigit():
39 return "C" + index_value
40 return index_value
41
42
43 # write results for loading into DESeq2
44 def write_DESeq2_inputs(pdata, layer=None, output_dir="", factor_fields=None):
45 """
46 >>> import scanpy as sc
47 >>> adata = sc.datasets.pbmc68k_reduced()
48 >>> adata.X = abs(adata.X).astype(int)
49 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain")
50 >>> write_DESeq2_inputs(pseudobulk)
51 """
52 # add / to output_dir if is not empty or if it doesn't end with /
53 if output_dir != "" and not output_dir.endswith("/"):
54 output_dir = output_dir + "/"
55 obs_for_deseq = pdata.obs.copy()
56 # replace any index starting with digits to start with C instead.
57 obs_for_deseq.rename(index=prepend_c_to_index, inplace=True)
58 # avoid dash that is read as point on R colnames.
59 obs_for_deseq.index = obs_for_deseq.index.str.replace("-", "_")
60 obs_for_deseq.index = obs_for_deseq.index.str.replace(" ", "_")
61 col_metadata_file = f"{output_dir}col_metadata.tsv"
62 # write obs to a col_metadata file
63 if factor_fields:
64 # only output the index plus the columns in factor_fields in that order
65 obs_for_deseq[factor_fields].to_csv(col_metadata_file, sep="\t", index=True)
66 else:
67 obs_for_deseq.to_csv(col_metadata_file, sep="\t", index=True)
68 # write var to a gene_metadata file
69 pdata.var.to_csv(f"{output_dir}gene_metadata.tsv", sep="\t", index=True)
70 # write the counts matrix of a specified layer to file
71 if layer is None:
72 # write the X numpy matrix transposed to file
73 df = pd.DataFrame(pdata.X.T, index=pdata.var.index, columns=obs_for_deseq.index)
74 else:
75 df = pd.DataFrame(
76 pdata.layers[layer].T, index=pdata.var.index, columns=obs_for_deseq.index
77 )
78 df.to_csv(f"{output_dir}counts_matrix.tsv", sep="\t", index_label="")
79
80
81 def plot_pseudobulk_samples(
82 pseudobulk_data,
83 groupby,
84 figsize=(10, 10),
85 save_path=None,
86 ):
87 """
88 >>> import scanpy as sc
89 >>> adata = sc.datasets.pbmc68k_reduced()
90 >>> adata.X = abs(adata.X).astype(int)
91 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain")
92 >>> plot_pseudobulk_samples(pseudobulk, groupby=["bulk_labels", "louvain"], figsize=(10, 10))
93 """
94 fig = decoupler.plot_psbulk_samples(
95 pseudobulk_data, groupby=groupby, figsize=figsize, return_fig=True
96 )
97 if save_path:
98 fig.savefig(f"{save_path}/pseudobulk_samples.png")
99 else:
100 fig.show()
101
102
103 def plot_filter_by_expr(
104 pseudobulk_data, group, min_count=None, min_total_count=None, save_path=None
105 ):
106 """
107 >>> import scanpy as sc
108 >>> adata = sc.datasets.pbmc68k_reduced()
109 >>> adata.X = abs(adata.X).astype(int)
110 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain")
111 >>> plot_filter_by_expr(pseudobulk, group="bulk_labels", min_count=10, min_total_count=200)
112 """
113 fig = decoupler.plot_filter_by_expr(
114 pseudobulk_data,
115 group=group,
116 min_count=min_count,
117 min_total_count=min_total_count,
118 return_fig=True,
119 )
120 if save_path:
121 fig.savefig(f"{save_path}/filter_by_expr.png")
122 else:
123 fig.show()
124
125
126 def filter_by_expr(pdata, min_count=None, min_total_count=None):
127 """
128 >>> import scanpy as sc
129 >>> adata = sc.datasets.pbmc68k_reduced()
130 >>> adata.X = abs(adata.X).astype(int)
131 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain")
132 >>> pdata_filt = filter_by_expr(pseudobulk, min_count=10, min_total_count=200)
133 """
134 genes = decoupler.filter_by_expr(
135 pdata, min_count=min_count, min_total_count=min_total_count
136 )
137 return pdata[:, genes].copy()
138
139
140 def check_fields(fields, adata, obs=True, context=None):
141 """
142 >>> import scanpy as sc
143 >>> adata = sc.datasets.pbmc68k_reduced()
144 >>> check_fields(["bulk_labels", "louvain"], adata, obs=True)
145 """
146
147 legend = ""
148 if context:
149 legend = f", passed in {context},"
150 if obs:
151 if not set(fields).issubset(set(adata.obs.columns)):
152 raise ValueError(
153 f"Some of the following fields {legend} are not present in adata.obs: {fields}. Possible fields are: {list(set(adata.obs.columns))}"
154 )
155 else:
156 if not set(fields).issubset(set(adata.var.columns)):
157 raise ValueError(
158 f"Some of the following fields {legend} are not present in adata.var: {fields}. Possible fields are: {list(set(adata.var.columns))}"
159 )
160
161
162 def main(args):
163 # Load AnnData object from file
164 adata = anndata.read_h5ad(args.adata_file)
165
166 # Merge adata.obs fields specified in args.adata_obs_fields_to_merge
167 if args.adata_obs_fields_to_merge:
168 # first split potential groups by ":" and iterate over them
169 for group in args.adata_obs_fields_to_merge.split(":"):
170 fields = group.split(",")
171 check_fields(fields, adata)
172 adata = merge_adata_obs_fields(fields, adata)
173
174 check_fields([args.groupby, args.sample_key], adata)
175
176 factor_fields = None
177 if args.factor_fields:
178 factor_fields = args.factor_fields.split(",")
179 check_fields(factor_fields, adata)
180
181 print(f"Using mode: {args.mode}")
182 # Perform pseudobulk analysis
183 pseudobulk_data = get_pseudobulk(
184 adata,
185 sample_col=args.sample_key,
186 groups_col=args.groupby,
187 layer=args.layer,
188 mode=args.mode,
189 use_raw=args.use_raw,
190 min_cells=args.min_cells,
191 min_counts=args.min_counts,
192 )
193
194 # Plot pseudobulk samples
195 plot_pseudobulk_samples(
196 pseudobulk_data,
197 args.groupby,
198 save_path=args.save_path,
199 figsize=args.plot_samples_figsize,
200 )
201
202 plot_filter_by_expr(
203 pseudobulk_data,
204 group=args.groupby,
205 min_count=args.min_counts,
206 min_total_count=args.min_total_counts,
207 save_path=args.save_path,
208 )
209
210 # Filter by expression if enabled
211 if args.filter_expr:
212 filtered_adata = filter_by_expr(
213 pseudobulk_data,
214 min_count=args.min_counts,
215 min_total_count=args.min_total_counts,
216 )
217
218 pseudobulk_data = filtered_adata
219
220 # Save the pseudobulk data
221 if args.anndata_output_path:
222 pseudobulk_data.write_h5ad(args.anndata_output_path, compression="gzip")
223
224 write_DESeq2_inputs(
225 pseudobulk_data, output_dir=args.deseq2_output_path, factor_fields=factor_fields
226 )
227
228
229 def merge_adata_obs_fields(obs_fields_to_merge, adata):
230 """
231 Merge adata.obs fields specified in args.adata_obs_fields_to_merge
232
233 Parameters
234 ----------
235 obs_fields_to_merge : str
236 Fields in adata.obs to merge, comma separated
237 adata : anndata.AnnData
238 The AnnData object
239
240 Returns
241 -------
242 anndata.AnnData
243 The merged AnnData object
244
245 docstring tests:
246 >>> import scanpy as sc
247 >>> ad = sc.datasets.pbmc68k_reduced()
248 >>> ad = merge_adata_obs_fields(["bulk_labels","louvain"], ad)
249 >>> ad.obs.columns
250 Index(['bulk_labels', 'n_genes', 'percent_mito', 'n_counts', 'S_score',
251 'G2M_score', 'phase', 'louvain', 'bulk_labels_louvain'],
252 dtype='object')
253 """
254 field_name = "_".join(obs_fields_to_merge)
255 for field in obs_fields_to_merge:
256 if field not in adata.obs.columns:
257 raise ValueError(f"The '{field}' column is not present in adata.obs.")
258 if field_name not in adata.obs.columns:
259 adata.obs[field_name] = adata.obs[field].astype(str)
260 else:
261 adata.obs[field_name] = (
262 adata.obs[field_name] + "_" + adata.obs[field].astype(str)
263 )
264 return adata
265
266
267 if __name__ == "__main__":
268 # Create argument parser
269 parser = argparse.ArgumentParser(
270 description="Perform pseudobulk analysis on an AnnData object"
271 )
272
273 # Add arguments
274 parser.add_argument("adata_file", type=str, help="Path to the AnnData file")
275 parser.add_argument(
276 "-m",
277 "--adata_obs_fields_to_merge",
278 type=str,
279 help="Fields in adata.obs to merge, comma separated. You can have more than one set of fields, separated by semi-colon ;",
280 )
281 parser.add_argument(
282 "--groupby",
283 type=str,
284 required=True,
285 help="The column in adata.obs that defines the groups",
286 )
287 parser.add_argument(
288 "--sample_key",
289 required=True,
290 type=str,
291 help="The column in adata.obs that defines the samples",
292 )
293 # add argument for layer
294 parser.add_argument(
295 "--layer",
296 type=str,
297 default=None,
298 help="The name of the layer of the AnnData object to use",
299 )
300 # add argument for mode
301 parser.add_argument(
302 "--mode",
303 type=str,
304 default="sum",
305 help="The mode for Decoupler pseudobulk analysis",
306 choices=["sum", "mean", "median"],
307 )
308 # add boolean argument for use_raw
309 parser.add_argument(
310 "--use_raw",
311 action="store_true",
312 default=False,
313 help="Whether to use the raw part of the AnnData object",
314 )
315 # add argument for min_cells
316 parser.add_argument(
317 "--min_cells",
318 type=int,
319 default=10,
320 help="Minimum number of cells for pseudobulk analysis",
321 )
322 parser.add_argument(
323 "--save_path", type=str, help="Path to save the plot (optional)"
324 )
325 parser.add_argument(
326 "--min_counts",
327 type=int,
328 help="Minimum count threshold for filtering by expression",
329 )
330 parser.add_argument(
331 "--min_total_counts",
332 type=int,
333 help="Minimum total count threshold for filtering by expression",
334 )
335 parser.add_argument(
336 "--anndata_output_path",
337 type=str,
338 help="Path to save the filtered AnnData object or pseudobulk data",
339 )
340 parser.add_argument(
341 "--filter_expr", action="store_true", help="Enable filtering by expression"
342 )
343 parser.add_argument(
344 "--factor_fields",
345 type=str,
346 help="Comma separated list of fields for the factors",
347 )
348 parser.add_argument(
349 "--deseq2_output_path",
350 type=str,
351 help="Path to save the DESeq2 inputs",
352 required=True,
353 )
354 parser.add_argument(
355 "--plot_samples_figsize",
356 type=int,
357 default=[10, 10],
358 nargs=2,
359 help="Size of the samples plot as a tuple (two arguments)",
360 )
361 parser.add_argument("--plot_filtering_figsize", type=int, default=[10, 10], nargs=2)
362
363 # Parse the command line arguments
364 args = parser.parse_args()
365
366 # Call the main function
367 main(args)