Mercurial > repos > ebi-gxa > decoupler_pathway_inference
comparison decoupler_pseudobulk.py @ 0:77d680b36e23 draft
planemo upload for repository https://github.com/ebi-gene-expression-group/container-galaxy-sc-tertiary/ commit 1034a450c97dcbb77871050cf0c6d3da90dac823
| author | ebi-gxa |
|---|---|
| date | Fri, 15 Mar 2024 12:17:49 +0000 |
| parents | |
| children | c6787c2aee46 |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:77d680b36e23 |
|---|---|
| 1 import argparse | |
| 2 | |
| 3 import anndata | |
| 4 import decoupler | |
| 5 import pandas as pd | |
| 6 | |
| 7 | |
| 8 def get_pseudobulk( | |
| 9 adata, | |
| 10 sample_col, | |
| 11 groups_col, | |
| 12 layer=None, | |
| 13 mode="sum", | |
| 14 min_cells=10, | |
| 15 min_counts=1000, | |
| 16 use_raw=False, | |
| 17 ): | |
| 18 """ | |
| 19 >>> import scanpy as sc | |
| 20 >>> adata = sc.datasets.pbmc68k_reduced() | |
| 21 >>> adata.X = abs(adata.X).astype(int) | |
| 22 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") | |
| 23 """ | |
| 24 | |
| 25 return decoupler.get_pseudobulk( | |
| 26 adata, | |
| 27 sample_col=sample_col, | |
| 28 groups_col=groups_col, | |
| 29 layer=layer, | |
| 30 mode=mode, | |
| 31 use_raw=use_raw, | |
| 32 min_cells=min_cells, | |
| 33 min_counts=min_counts, | |
| 34 ) | |
| 35 | |
| 36 | |
| 37 def prepend_c_to_index(index_value): | |
| 38 if index_value and index_value[0].isdigit(): | |
| 39 return "C" + index_value | |
| 40 return index_value | |
| 41 | |
| 42 | |
| 43 # write results for loading into DESeq2 | |
| 44 def write_DESeq2_inputs(pdata, layer=None, output_dir="", factor_fields=None): | |
| 45 """ | |
| 46 >>> import scanpy as sc | |
| 47 >>> adata = sc.datasets.pbmc68k_reduced() | |
| 48 >>> adata.X = abs(adata.X).astype(int) | |
| 49 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") | |
| 50 >>> write_DESeq2_inputs(pseudobulk) | |
| 51 """ | |
| 52 # add / to output_dir if is not empty or if it doesn't end with / | |
| 53 if output_dir != "" and not output_dir.endswith("/"): | |
| 54 output_dir = output_dir + "/" | |
| 55 obs_for_deseq = pdata.obs.copy() | |
| 56 # replace any index starting with digits to start with C instead. | |
| 57 obs_for_deseq.rename(index=prepend_c_to_index, inplace=True) | |
| 58 # avoid dash that is read as point on R colnames. | |
| 59 obs_for_deseq.index = obs_for_deseq.index.str.replace("-", "_") | |
| 60 obs_for_deseq.index = obs_for_deseq.index.str.replace(" ", "_") | |
| 61 col_metadata_file = f"{output_dir}col_metadata.tsv" | |
| 62 # write obs to a col_metadata file | |
| 63 if factor_fields: | |
| 64 # only output the index plus the columns in factor_fields in that order | |
| 65 obs_for_deseq[factor_fields].to_csv(col_metadata_file, sep="\t", index=True) | |
| 66 else: | |
| 67 obs_for_deseq.to_csv(col_metadata_file, sep="\t", index=True) | |
| 68 # write var to a gene_metadata file | |
| 69 pdata.var.to_csv(f"{output_dir}gene_metadata.tsv", sep="\t", index=True) | |
| 70 # write the counts matrix of a specified layer to file | |
| 71 if layer is None: | |
| 72 # write the X numpy matrix transposed to file | |
| 73 df = pd.DataFrame(pdata.X.T, index=pdata.var.index, columns=obs_for_deseq.index) | |
| 74 else: | |
| 75 df = pd.DataFrame( | |
| 76 pdata.layers[layer].T, index=pdata.var.index, columns=obs_for_deseq.index | |
| 77 ) | |
| 78 df.to_csv(f"{output_dir}counts_matrix.tsv", sep="\t", index_label="") | |
| 79 | |
| 80 | |
| 81 def plot_pseudobulk_samples( | |
| 82 pseudobulk_data, | |
| 83 groupby, | |
| 84 figsize=(10, 10), | |
| 85 save_path=None, | |
| 86 ): | |
| 87 """ | |
| 88 >>> import scanpy as sc | |
| 89 >>> adata = sc.datasets.pbmc68k_reduced() | |
| 90 >>> adata.X = abs(adata.X).astype(int) | |
| 91 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") | |
| 92 >>> plot_pseudobulk_samples(pseudobulk, groupby=["bulk_labels", "louvain"], figsize=(10, 10)) | |
| 93 """ | |
| 94 fig = decoupler.plot_psbulk_samples( | |
| 95 pseudobulk_data, groupby=groupby, figsize=figsize, return_fig=True | |
| 96 ) | |
| 97 if save_path: | |
| 98 fig.savefig(f"{save_path}/pseudobulk_samples.png") | |
| 99 else: | |
| 100 fig.show() | |
| 101 | |
| 102 | |
| 103 def plot_filter_by_expr( | |
| 104 pseudobulk_data, group, min_count=None, min_total_count=None, save_path=None | |
| 105 ): | |
| 106 """ | |
| 107 >>> import scanpy as sc | |
| 108 >>> adata = sc.datasets.pbmc68k_reduced() | |
| 109 >>> adata.X = abs(adata.X).astype(int) | |
| 110 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") | |
| 111 >>> plot_filter_by_expr(pseudobulk, group="bulk_labels", min_count=10, min_total_count=200) | |
| 112 """ | |
| 113 fig = decoupler.plot_filter_by_expr( | |
| 114 pseudobulk_data, | |
| 115 group=group, | |
| 116 min_count=min_count, | |
| 117 min_total_count=min_total_count, | |
| 118 return_fig=True, | |
| 119 ) | |
| 120 if save_path: | |
| 121 fig.savefig(f"{save_path}/filter_by_expr.png") | |
| 122 else: | |
| 123 fig.show() | |
| 124 | |
| 125 | |
| 126 def filter_by_expr(pdata, min_count=None, min_total_count=None): | |
| 127 """ | |
| 128 >>> import scanpy as sc | |
| 129 >>> adata = sc.datasets.pbmc68k_reduced() | |
| 130 >>> adata.X = abs(adata.X).astype(int) | |
| 131 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") | |
| 132 >>> pdata_filt = filter_by_expr(pseudobulk, min_count=10, min_total_count=200) | |
| 133 """ | |
| 134 genes = decoupler.filter_by_expr( | |
| 135 pdata, min_count=min_count, min_total_count=min_total_count | |
| 136 ) | |
| 137 return pdata[:, genes].copy() | |
| 138 | |
| 139 | |
| 140 def check_fields(fields, adata, obs=True, context=None): | |
| 141 """ | |
| 142 >>> import scanpy as sc | |
| 143 >>> adata = sc.datasets.pbmc68k_reduced() | |
| 144 >>> check_fields(["bulk_labels", "louvain"], adata, obs=True) | |
| 145 """ | |
| 146 | |
| 147 legend = "" | |
| 148 if context: | |
| 149 legend = f", passed in {context}," | |
| 150 if obs: | |
| 151 if not set(fields).issubset(set(adata.obs.columns)): | |
| 152 raise ValueError( | |
| 153 f"Some of the following fields {legend} are not present in adata.obs: {fields}. Possible fields are: {list(set(adata.obs.columns))}" | |
| 154 ) | |
| 155 else: | |
| 156 if not set(fields).issubset(set(adata.var.columns)): | |
| 157 raise ValueError( | |
| 158 f"Some of the following fields {legend} are not present in adata.var: {fields}. Possible fields are: {list(set(adata.var.columns))}" | |
| 159 ) | |
| 160 | |
| 161 | |
| 162 def main(args): | |
| 163 # Load AnnData object from file | |
| 164 adata = anndata.read_h5ad(args.adata_file) | |
| 165 | |
| 166 # Merge adata.obs fields specified in args.adata_obs_fields_to_merge | |
| 167 if args.adata_obs_fields_to_merge: | |
| 168 # first split potential groups by ":" and iterate over them | |
| 169 for group in args.adata_obs_fields_to_merge.split(":"): | |
| 170 fields = group.split(",") | |
| 171 check_fields(fields, adata) | |
| 172 adata = merge_adata_obs_fields(fields, adata) | |
| 173 | |
| 174 check_fields([args.groupby, args.sample_key], adata) | |
| 175 | |
| 176 factor_fields = None | |
| 177 if args.factor_fields: | |
| 178 factor_fields = args.factor_fields.split(",") | |
| 179 check_fields(factor_fields, adata) | |
| 180 | |
| 181 print(f"Using mode: {args.mode}") | |
| 182 # Perform pseudobulk analysis | |
| 183 pseudobulk_data = get_pseudobulk( | |
| 184 adata, | |
| 185 sample_col=args.sample_key, | |
| 186 groups_col=args.groupby, | |
| 187 layer=args.layer, | |
| 188 mode=args.mode, | |
| 189 use_raw=args.use_raw, | |
| 190 min_cells=args.min_cells, | |
| 191 min_counts=args.min_counts, | |
| 192 ) | |
| 193 | |
| 194 # Plot pseudobulk samples | |
| 195 plot_pseudobulk_samples( | |
| 196 pseudobulk_data, | |
| 197 args.groupby, | |
| 198 save_path=args.save_path, | |
| 199 figsize=args.plot_samples_figsize, | |
| 200 ) | |
| 201 | |
| 202 plot_filter_by_expr( | |
| 203 pseudobulk_data, | |
| 204 group=args.groupby, | |
| 205 min_count=args.min_counts, | |
| 206 min_total_count=args.min_total_counts, | |
| 207 save_path=args.save_path, | |
| 208 ) | |
| 209 | |
| 210 # Filter by expression if enabled | |
| 211 if args.filter_expr: | |
| 212 filtered_adata = filter_by_expr( | |
| 213 pseudobulk_data, | |
| 214 min_count=args.min_counts, | |
| 215 min_total_count=args.min_total_counts, | |
| 216 ) | |
| 217 | |
| 218 pseudobulk_data = filtered_adata | |
| 219 | |
| 220 # Save the pseudobulk data | |
| 221 if args.anndata_output_path: | |
| 222 pseudobulk_data.write_h5ad(args.anndata_output_path, compression="gzip") | |
| 223 | |
| 224 write_DESeq2_inputs( | |
| 225 pseudobulk_data, output_dir=args.deseq2_output_path, factor_fields=factor_fields | |
| 226 ) | |
| 227 | |
| 228 | |
| 229 def merge_adata_obs_fields(obs_fields_to_merge, adata): | |
| 230 """ | |
| 231 Merge adata.obs fields specified in args.adata_obs_fields_to_merge | |
| 232 | |
| 233 Parameters | |
| 234 ---------- | |
| 235 obs_fields_to_merge : str | |
| 236 Fields in adata.obs to merge, comma separated | |
| 237 adata : anndata.AnnData | |
| 238 The AnnData object | |
| 239 | |
| 240 Returns | |
| 241 ------- | |
| 242 anndata.AnnData | |
| 243 The merged AnnData object | |
| 244 | |
| 245 docstring tests: | |
| 246 >>> import scanpy as sc | |
| 247 >>> ad = sc.datasets.pbmc68k_reduced() | |
| 248 >>> ad = merge_adata_obs_fields(["bulk_labels","louvain"], ad) | |
| 249 >>> ad.obs.columns | |
| 250 Index(['bulk_labels', 'n_genes', 'percent_mito', 'n_counts', 'S_score', | |
| 251 'G2M_score', 'phase', 'louvain', 'bulk_labels_louvain'], | |
| 252 dtype='object') | |
| 253 """ | |
| 254 field_name = "_".join(obs_fields_to_merge) | |
| 255 for field in obs_fields_to_merge: | |
| 256 if field not in adata.obs.columns: | |
| 257 raise ValueError(f"The '{field}' column is not present in adata.obs.") | |
| 258 if field_name not in adata.obs.columns: | |
| 259 adata.obs[field_name] = adata.obs[field].astype(str) | |
| 260 else: | |
| 261 adata.obs[field_name] = ( | |
| 262 adata.obs[field_name] + "_" + adata.obs[field].astype(str) | |
| 263 ) | |
| 264 return adata | |
| 265 | |
| 266 | |
| 267 if __name__ == "__main__": | |
| 268 # Create argument parser | |
| 269 parser = argparse.ArgumentParser( | |
| 270 description="Perform pseudobulk analysis on an AnnData object" | |
| 271 ) | |
| 272 | |
| 273 # Add arguments | |
| 274 parser.add_argument("adata_file", type=str, help="Path to the AnnData file") | |
| 275 parser.add_argument( | |
| 276 "-m", | |
| 277 "--adata_obs_fields_to_merge", | |
| 278 type=str, | |
| 279 help="Fields in adata.obs to merge, comma separated. You can have more than one set of fields, separated by semi-colon ;", | |
| 280 ) | |
| 281 parser.add_argument( | |
| 282 "--groupby", | |
| 283 type=str, | |
| 284 required=True, | |
| 285 help="The column in adata.obs that defines the groups", | |
| 286 ) | |
| 287 parser.add_argument( | |
| 288 "--sample_key", | |
| 289 required=True, | |
| 290 type=str, | |
| 291 help="The column in adata.obs that defines the samples", | |
| 292 ) | |
| 293 # add argument for layer | |
| 294 parser.add_argument( | |
| 295 "--layer", | |
| 296 type=str, | |
| 297 default=None, | |
| 298 help="The name of the layer of the AnnData object to use", | |
| 299 ) | |
| 300 # add argument for mode | |
| 301 parser.add_argument( | |
| 302 "--mode", | |
| 303 type=str, | |
| 304 default="sum", | |
| 305 help="The mode for Decoupler pseudobulk analysis", | |
| 306 choices=["sum", "mean", "median"], | |
| 307 ) | |
| 308 # add boolean argument for use_raw | |
| 309 parser.add_argument( | |
| 310 "--use_raw", | |
| 311 action="store_true", | |
| 312 default=False, | |
| 313 help="Whether to use the raw part of the AnnData object", | |
| 314 ) | |
| 315 # add argument for min_cells | |
| 316 parser.add_argument( | |
| 317 "--min_cells", | |
| 318 type=int, | |
| 319 default=10, | |
| 320 help="Minimum number of cells for pseudobulk analysis", | |
| 321 ) | |
| 322 parser.add_argument( | |
| 323 "--save_path", type=str, help="Path to save the plot (optional)" | |
| 324 ) | |
| 325 parser.add_argument( | |
| 326 "--min_counts", | |
| 327 type=int, | |
| 328 help="Minimum count threshold for filtering by expression", | |
| 329 ) | |
| 330 parser.add_argument( | |
| 331 "--min_total_counts", | |
| 332 type=int, | |
| 333 help="Minimum total count threshold for filtering by expression", | |
| 334 ) | |
| 335 parser.add_argument( | |
| 336 "--anndata_output_path", | |
| 337 type=str, | |
| 338 help="Path to save the filtered AnnData object or pseudobulk data", | |
| 339 ) | |
| 340 parser.add_argument( | |
| 341 "--filter_expr", action="store_true", help="Enable filtering by expression" | |
| 342 ) | |
| 343 parser.add_argument( | |
| 344 "--factor_fields", | |
| 345 type=str, | |
| 346 help="Comma separated list of fields for the factors", | |
| 347 ) | |
| 348 parser.add_argument( | |
| 349 "--deseq2_output_path", | |
| 350 type=str, | |
| 351 help="Path to save the DESeq2 inputs", | |
| 352 required=True, | |
| 353 ) | |
| 354 parser.add_argument( | |
| 355 "--plot_samples_figsize", | |
| 356 type=int, | |
| 357 default=[10, 10], | |
| 358 nargs=2, | |
| 359 help="Size of the samples plot as a tuple (two arguments)", | |
| 360 ) | |
| 361 parser.add_argument("--plot_filtering_figsize", type=int, default=[10, 10], nargs=2) | |
| 362 | |
| 363 # Parse the command line arguments | |
| 364 args = parser.parse_args() | |
| 365 | |
| 366 # Call the main function | |
| 367 main(args) |
