Mercurial > repos > ebi-gxa > decoupler_pseudobulk
comparison decoupler_pseudobulk.py @ 0:59a7f3f83aec draft
planemo upload for repository https://github.com/ebi-gene-expression-group/container-galaxy-sc-tertiary/ commit 20f4a739092bd05106d5de170523ad61d66e41fc
| author | ebi-gxa |
|---|---|
| date | Sun, 24 Sep 2023 08:44:24 +0000 |
| parents | |
| children | 046d8ff974ff |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:59a7f3f83aec |
|---|---|
| 1 import argparse | |
| 2 | |
| 3 import anndata | |
| 4 import decoupler | |
| 5 import pandas as pd | |
| 6 | |
| 7 | |
| 8 def get_pseudobulk( | |
| 9 adata, | |
| 10 sample_col, | |
| 11 groups_col, | |
| 12 layer=None, | |
| 13 mode="sum", | |
| 14 min_cells=10, | |
| 15 min_counts=1000, | |
| 16 use_raw=False, | |
| 17 ): | |
| 18 """ | |
| 19 >>> import scanpy as sc | |
| 20 >>> adata = sc.datasets.pbmc68k_reduced() | |
| 21 >>> adata.X = abs(adata.X).astype(int) | |
| 22 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") | |
| 23 """ | |
| 24 | |
| 25 return decoupler.get_pseudobulk( | |
| 26 adata, | |
| 27 sample_col=sample_col, | |
| 28 groups_col=groups_col, | |
| 29 layer=layer, | |
| 30 mode=mode, | |
| 31 use_raw=use_raw, | |
| 32 min_cells=min_cells, | |
| 33 min_counts=min_counts, | |
| 34 ) | |
| 35 | |
| 36 | |
| 37 def prepend_c_to_index(index_value): | |
| 38 if index_value and index_value[0].isdigit(): | |
| 39 return "C" + index_value | |
| 40 return index_value | |
| 41 | |
| 42 | |
| 43 # write results for loading into DESeq2 | |
| 44 def write_DESeq2_inputs(pdata, layer=None, output_dir="", factor_fields=None): | |
| 45 """ | |
| 46 >>> import scanpy as sc | |
| 47 >>> adata = sc.datasets.pbmc68k_reduced() | |
| 48 >>> adata.X = abs(adata.X).astype(int) | |
| 49 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") | |
| 50 >>> write_DESeq2_inputs(pseudobulk) | |
| 51 """ | |
| 52 # add / to output_dir if is not empty or if it doesn't end with / | |
| 53 if output_dir != "" and not output_dir.endswith("/"): | |
| 54 output_dir = output_dir + "/" | |
| 55 obs_for_deseq = pdata.obs.copy() | |
| 56 # replace any index starting with digits to start with C instead. | |
| 57 obs_for_deseq.rename(index=prepend_c_to_index, inplace=True) | |
| 58 # avoid dash that is read as point on R colnames. | |
| 59 obs_for_deseq.index = obs_for_deseq.index.str.replace("-", "_") | |
| 60 obs_for_deseq.index = obs_for_deseq.index.str.replace(" ", "_") | |
| 61 col_metadata_file = f"{output_dir}col_metadata.csv" | |
| 62 # write obs to a col_metadata file | |
| 63 if factor_fields: | |
| 64 # only output the index plus the columns in factor_fields in that order | |
| 65 obs_for_deseq[factor_fields].to_csv(col_metadata_file, sep=",", index=True) | |
| 66 else: | |
| 67 obs_for_deseq.to_csv(col_metadata_file, sep=",", index=True) | |
| 68 # write var to a gene_metadata file | |
| 69 pdata.var.to_csv(f"{output_dir}gene_metadata.csv", sep=",", index=True) | |
| 70 # write the counts matrix of a specified layer to file | |
| 71 if layer is None: | |
| 72 # write the X numpy matrix transposed to file | |
| 73 df = pd.DataFrame(pdata.X.T, index=pdata.var.index, columns=obs_for_deseq.index) | |
| 74 else: | |
| 75 df = pd.DataFrame( | |
| 76 pdata.layers[layer].T, index=pdata.var.index, columns=obs_for_deseq.index | |
| 77 ) | |
| 78 df.to_csv(f"{output_dir}counts_matrix.csv", sep=",", index_label="") | |
| 79 | |
| 80 | |
| 81 def plot_pseudobulk_samples( | |
| 82 pseudobulk_data, | |
| 83 groupby, | |
| 84 figsize=(10, 10), | |
| 85 save_path=None, | |
| 86 ): | |
| 87 """ | |
| 88 >>> import scanpy as sc | |
| 89 >>> adata = sc.datasets.pbmc68k_reduced() | |
| 90 >>> adata.X = abs(adata.X).astype(int) | |
| 91 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") | |
| 92 >>> plot_pseudobulk_samples(pseudobulk, groupby=["bulk_labels", "louvain"], figsize=(10, 10)) | |
| 93 """ | |
| 94 fig = decoupler.plot_psbulk_samples( | |
| 95 pseudobulk_data, groupby=groupby, figsize=figsize, return_fig=True | |
| 96 ) | |
| 97 if save_path: | |
| 98 fig.savefig(f"{save_path}/pseudobulk_samples.png") | |
| 99 else: | |
| 100 fig.show() | |
| 101 | |
| 102 | |
| 103 def plot_filter_by_expr( | |
| 104 pseudobulk_data, group, min_count=None, min_total_count=None, save_path=None | |
| 105 ): | |
| 106 """ | |
| 107 >>> import scanpy as sc | |
| 108 >>> adata = sc.datasets.pbmc68k_reduced() | |
| 109 >>> adata.X = abs(adata.X).astype(int) | |
| 110 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") | |
| 111 >>> plot_filter_by_expr(pseudobulk, group="bulk_labels", min_count=10, min_total_count=200) | |
| 112 """ | |
| 113 fig = decoupler.plot_filter_by_expr( | |
| 114 pseudobulk_data, | |
| 115 group=group, | |
| 116 min_count=min_count, | |
| 117 min_total_count=min_total_count, | |
| 118 return_fig=True, | |
| 119 ) | |
| 120 if save_path: | |
| 121 fig.savefig(f"{save_path}/filter_by_expr.png") | |
| 122 else: | |
| 123 fig.show() | |
| 124 | |
| 125 | |
| 126 def filter_by_expr(pdata, min_count=None, min_total_count=None): | |
| 127 """ | |
| 128 >>> import scanpy as sc | |
| 129 >>> adata = sc.datasets.pbmc68k_reduced() | |
| 130 >>> adata.X = abs(adata.X).astype(int) | |
| 131 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") | |
| 132 >>> pdata_filt = filter_by_expr(pseudobulk, min_count=10, min_total_count=200) | |
| 133 """ | |
| 134 genes = decoupler.filter_by_expr( | |
| 135 pdata, min_count=min_count, min_total_count=min_total_count | |
| 136 ) | |
| 137 return pdata[:, genes].copy() | |
| 138 | |
| 139 | |
| 140 def check_fields(fields, adata, obs=True, context=None): | |
| 141 """ | |
| 142 >>> import scanpy as sc | |
| 143 >>> adata = sc.datasets.pbmc68k_reduced() | |
| 144 >>> check_fields(["bulk_labels", "louvain"], adata, obs=True) | |
| 145 """ | |
| 146 | |
| 147 legend = "" | |
| 148 if context: | |
| 149 legend = f", passed in {context}," | |
| 150 if obs: | |
| 151 if not set(fields).issubset(set(adata.obs.columns)): | |
| 152 raise ValueError( | |
| 153 f"Some of the following fields {legend} are not present in adata.obs: {fields}. Possible fields are: {list(set(adata.obs.columns))}" | |
| 154 ) | |
| 155 else: | |
| 156 if not set(fields).issubset(set(adata.var.columns)): | |
| 157 raise ValueError( | |
| 158 f"Some of the following fields {legend} are not present in adata.var: {fields}. Possible fields are: {list(set(adata.var.columns))}" | |
| 159 ) | |
| 160 | |
| 161 | |
| 162 def main(args): | |
| 163 # Load AnnData object from file | |
| 164 adata = anndata.read_h5ad(args.adata_file) | |
| 165 | |
| 166 # Merge adata.obs fields specified in args.adata_obs_fields_to_merge | |
| 167 if args.adata_obs_fields_to_merge: | |
| 168 fields = args.adata_obs_fields_to_merge.split(",") | |
| 169 check_fields(fields, adata) | |
| 170 adata = merge_adata_obs_fields(fields, adata) | |
| 171 | |
| 172 check_fields([args.groupby, args.sample_key], adata) | |
| 173 | |
| 174 factor_fields = None | |
| 175 if args.factor_fields: | |
| 176 factor_fields = args.factor_fields.split(",") | |
| 177 check_fields(factor_fields, adata) | |
| 178 | |
| 179 print(f"Using mode: {args.mode}") | |
| 180 # Perform pseudobulk analysis | |
| 181 pseudobulk_data = get_pseudobulk( | |
| 182 adata, | |
| 183 sample_col=args.sample_key, | |
| 184 groups_col=args.groupby, | |
| 185 layer=args.layer, | |
| 186 mode=args.mode, | |
| 187 use_raw=args.use_raw, | |
| 188 min_cells=args.min_cells, | |
| 189 min_counts=args.min_counts, | |
| 190 ) | |
| 191 | |
| 192 # Plot pseudobulk samples | |
| 193 plot_pseudobulk_samples( | |
| 194 pseudobulk_data, | |
| 195 args.groupby, | |
| 196 save_path=args.save_path, | |
| 197 figsize=args.plot_samples_figsize, | |
| 198 ) | |
| 199 | |
| 200 plot_filter_by_expr( | |
| 201 pseudobulk_data, | |
| 202 group=args.groupby, | |
| 203 min_count=args.min_counts, | |
| 204 min_total_count=args.min_total_counts, | |
| 205 save_path=args.save_path, | |
| 206 ) | |
| 207 | |
| 208 # Filter by expression if enabled | |
| 209 if args.filter_expr: | |
| 210 filtered_adata = filter_by_expr( | |
| 211 pseudobulk_data, | |
| 212 min_count=args.min_counts, | |
| 213 min_total_count=args.min_total_counts, | |
| 214 ) | |
| 215 | |
| 216 pseudobulk_data = filtered_adata | |
| 217 | |
| 218 # Save the pseudobulk data | |
| 219 if args.anndata_output_path: | |
| 220 pseudobulk_data.write_h5ad(args.anndata_output_path, compression="gzip") | |
| 221 | |
| 222 write_DESeq2_inputs( | |
| 223 pseudobulk_data, output_dir=args.deseq2_output_path, factor_fields=factor_fields | |
| 224 ) | |
| 225 | |
| 226 | |
| 227 def merge_adata_obs_fields(obs_fields_to_merge, adata): | |
| 228 """ | |
| 229 Merge adata.obs fields specified in args.adata_obs_fields_to_merge | |
| 230 | |
| 231 Parameters | |
| 232 ---------- | |
| 233 obs_fields_to_merge : str | |
| 234 Fields in adata.obs to merge, comma separated | |
| 235 adata : anndata.AnnData | |
| 236 The AnnData object | |
| 237 | |
| 238 Returns | |
| 239 ------- | |
| 240 anndata.AnnData | |
| 241 The merged AnnData object | |
| 242 | |
| 243 docstring tests: | |
| 244 >>> import scanpy as sc | |
| 245 >>> ad = sc.datasets.pbmc68k_reduced() | |
| 246 >>> ad = merge_adata_obs_fields(["bulk_labels","louvain"], ad) | |
| 247 >>> ad.obs.columns | |
| 248 Index(['bulk_labels', 'n_genes', 'percent_mito', 'n_counts', 'S_score', | |
| 249 'G2M_score', 'phase', 'louvain', 'bulk_labels_louvain'], | |
| 250 dtype='object') | |
| 251 """ | |
| 252 field_name = "_".join(obs_fields_to_merge) | |
| 253 for field in obs_fields_to_merge: | |
| 254 if field not in adata.obs.columns: | |
| 255 raise ValueError(f"The '{field}' column is not present in adata.obs.") | |
| 256 if field_name not in adata.obs.columns: | |
| 257 adata.obs[field_name] = adata.obs[field].astype(str) | |
| 258 else: | |
| 259 adata.obs[field_name] = ( | |
| 260 adata.obs[field_name] + "_" + adata.obs[field].astype(str) | |
| 261 ) | |
| 262 return adata | |
| 263 | |
| 264 | |
| 265 if __name__ == "__main__": | |
| 266 # Create argument parser | |
| 267 parser = argparse.ArgumentParser( | |
| 268 description="Perform pseudobulk analysis on an AnnData object" | |
| 269 ) | |
| 270 | |
| 271 # Add arguments | |
| 272 parser.add_argument("adata_file", type=str, help="Path to the AnnData file") | |
| 273 parser.add_argument( | |
| 274 "-m", | |
| 275 "--adata_obs_fields_to_merge", | |
| 276 type=str, | |
| 277 help="Fields in adata.obs to merge, comma separated", | |
| 278 ) | |
| 279 parser.add_argument( | |
| 280 "--groupby", | |
| 281 type=str, | |
| 282 required=True, | |
| 283 help="The column in adata.obs that defines the groups", | |
| 284 ) | |
| 285 parser.add_argument( | |
| 286 "--sample_key", | |
| 287 required=True, | |
| 288 type=str, | |
| 289 help="The column in adata.obs that defines the samples", | |
| 290 ) | |
| 291 # add argument for layer | |
| 292 parser.add_argument( | |
| 293 "--layer", | |
| 294 type=str, | |
| 295 default=None, | |
| 296 help="The name of the layer of the AnnData object to use", | |
| 297 ) | |
| 298 # add argument for mode | |
| 299 parser.add_argument( | |
| 300 "--mode", | |
| 301 type=str, | |
| 302 default="sum", | |
| 303 help="The mode for Decoupler pseudobulk analysis", | |
| 304 choices=["sum", "mean", "median"], | |
| 305 ) | |
| 306 # add boolean argument for use_raw | |
| 307 parser.add_argument( | |
| 308 "--use_raw", | |
| 309 action="store_true", | |
| 310 default=False, | |
| 311 help="Whether to use the raw part of the AnnData object", | |
| 312 ) | |
| 313 # add argument for min_cells | |
| 314 parser.add_argument( | |
| 315 "--min_cells", | |
| 316 type=int, | |
| 317 default=10, | |
| 318 help="Minimum number of cells for pseudobulk analysis", | |
| 319 ) | |
| 320 parser.add_argument( | |
| 321 "--save_path", type=str, help="Path to save the plot (optional)" | |
| 322 ) | |
| 323 parser.add_argument( | |
| 324 "--min_counts", | |
| 325 type=int, | |
| 326 help="Minimum count threshold for filtering by expression", | |
| 327 ) | |
| 328 parser.add_argument( | |
| 329 "--min_total_counts", | |
| 330 type=int, | |
| 331 help="Minimum total count threshold for filtering by expression", | |
| 332 ) | |
| 333 parser.add_argument( | |
| 334 "--anndata_output_path", | |
| 335 type=str, | |
| 336 help="Path to save the filtered AnnData object or pseudobulk data", | |
| 337 ) | |
| 338 parser.add_argument( | |
| 339 "--filter_expr", action="store_true", help="Enable filtering by expression" | |
| 340 ) | |
| 341 parser.add_argument( | |
| 342 "--factor_fields", | |
| 343 type=str, | |
| 344 help="Comma separated list of fields for the factors", | |
| 345 ) | |
| 346 parser.add_argument( | |
| 347 "--deseq2_output_path", | |
| 348 type=str, | |
| 349 help="Path to save the DESeq2 inputs", | |
| 350 required=True, | |
| 351 ) | |
| 352 parser.add_argument( | |
| 353 "--plot_samples_figsize", | |
| 354 type=int, | |
| 355 default=[10, 10], | |
| 356 nargs=2, | |
| 357 help="Size of the samples plot as a tuple (two arguments)", | |
| 358 ) | |
| 359 parser.add_argument("--plot_filtering_figsize", type=int, default=[10, 10], nargs=2) | |
| 360 | |
| 361 # Parse the command line arguments | |
| 362 args = parser.parse_args() | |
| 363 | |
| 364 # Call the main function | |
| 365 main(args) |
