Mercurial > repos > ebi-gxa > decoupler_pathway_inference
comparison decoupler_aucell_score.py @ 0:77d680b36e23 draft
planemo upload for repository https://github.com/ebi-gene-expression-group/container-galaxy-sc-tertiary/ commit 1034a450c97dcbb77871050cf0c6d3da90dac823
author | ebi-gxa |
---|---|
date | Fri, 15 Mar 2024 12:17:49 +0000 |
parents | |
children | e9b06a8fb73a |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:77d680b36e23 |
---|---|
1 import argparse | |
2 import os | |
3 import tempfile | |
4 | |
5 import anndata | |
6 import decoupler as dc | |
7 import pandas as pd | |
8 | |
9 | |
10 def read_gmt(gmt_file): | |
11 """ | |
12 Reads a GMT file into a Pandas DataFrame. | |
13 | |
14 Parameters | |
15 ---------- | |
16 gmt_file : str | |
17 Path to the GMT file. | |
18 | |
19 Returns | |
20 ------- | |
21 pd.DataFrame | |
22 A DataFrame with the gene sets. Each row represents a gene set, and the columns are "gene_set_name", "gene_set_url", and "genes". | |
23 >>> line = "HALLMARK_NOTCH_SIGNALING\\thttp://www.gsea-msigdb.org/gsea/msigdb/human/geneset/HALLMARK_NOTCH_SIGNALING\\tJAG1\\tNOTCH3\\tNOTCH2\\tAPH1A\\tHES1\\tCCND1\\tFZD1\\tPSEN2\\tFZD7\\tDTX1\\tDLL1\\tFZD5\\tMAML2\\tNOTCH1\\tPSENEN\\tWNT5A\\tCUL1\\tWNT2\\tDTX4\\tSAP30\\tPPARD\\tKAT2A\\tHEYL\\tSKP1\\tRBX1\\tTCF7L2\\tARRB1\\tLFNG\\tPRKCA\\tDTX2\\tST3GAL6\\tFBXW11\\n" | |
24 >>> line2 = "HALLMARK_APICAL_SURFACE\\thttp://www.gsea-msigdb.org/gsea/msigdb/human/geneset/HALLMARK_APICAL_SURFACE\\tB4GALT1\\tRHCG\\tMAL\\tLYPD3\\tPKHD1\\tATP6V0A4\\tCRYBG1\\tSHROOM2\\tSRPX\\tMDGA1\\tTMEM8B\\tTHY1\\tPCSK9\\tEPHB4\\tDCBLD2\\tGHRL\\tLYN\\tGAS1\\tFLOT2\\tPLAUR\\tAKAP7\\tATP8B1\\tEFNA5\\tSLC34A3\\tAPP\\tGSTM3\\tHSPB1\\tSLC2A4\\tIL2RB\\tRTN4RL1\\tNCOA6\\tSULF2\\tADAM10\\tBRCA1\\tGATA3\\tAFAP1L2\\tIL2RG\\tCD160\\tADIPOR2\\tSLC22A12\\tNTNG1\\tSCUBE1\\tCX3CL1\\tCROCC\\n" | |
25 >>> temp_dir = tempfile.gettempdir() | |
26 >>> temp_gmt = os.path.join(temp_dir, "temp_file.gmt") | |
27 >>> with open(temp_gmt, "w") as f: | |
28 ... f.write(line) | |
29 ... f.write(line2) | |
30 288 | |
31 380 | |
32 >>> df = read_gmt(temp_gmt) | |
33 >>> df.shape[0] | |
34 2 | |
35 >>> df.columns == ["gene_set_name", "genes"] | |
36 array([ True, True]) | |
37 >>> df.loc[df["gene_set_name"] == "HALLMARK_APICAL_SURFACE"].genes.tolist()[0].startswith("B4GALT1") | |
38 True | |
39 """ | |
40 # Read the GMT file into a list of lines | |
41 with open(gmt_file, "r") as f: | |
42 lines = f.readlines() | |
43 | |
44 # Create a list of dictionaries, where each dictionary represents a gene set | |
45 gene_sets = [] | |
46 for line in lines: | |
47 fields = line.strip().split("\t") | |
48 gene_set = {"gene_set_name": fields[0], "genes": ",".join(fields[2:])} | |
49 gene_sets.append(gene_set) | |
50 | |
51 # Convert the list of dictionaries to a DataFrame | |
52 return pd.DataFrame(gene_sets) | |
53 | |
54 | |
55 def score_genes_aucell( | |
56 adata: anndata.AnnData, gene_list: list, score_name: str, use_raw=False | |
57 ): | |
58 """Score genes using Aucell. | |
59 | |
60 Parameters | |
61 ---------- | |
62 adata : anndata.AnnData | |
63 gene_list : list | |
64 score_names : str | |
65 use_raw : bool, optional | |
66 | |
67 >>> import scanpy as sc | |
68 >>> import decoupler as dc | |
69 >>> adata = sc.datasets.pbmc68k_reduced() | |
70 >>> gene_list = adata.var[adata.var.index.str.startswith("RP")].index.tolist() | |
71 >>> score_genes_aucell(adata, gene_list, "ribosomal_aucell", use_raw=False) | |
72 >>> "ribosomal_aucell" in adata.obs.columns | |
73 True | |
74 """ | |
75 # make a data.frame with two columns, geneset and gene_id, geneset filled with score_names and gene_id with gene_list, one row per element | |
76 geneset_df = pd.DataFrame( | |
77 { | |
78 "gene_id": gene_list, | |
79 "geneset": score_name, | |
80 } | |
81 ) | |
82 # run decoupler's run_aucell | |
83 dc.run_aucell( | |
84 adata, net=geneset_df, source="geneset", target="gene_id", use_raw=use_raw | |
85 ) | |
86 # copy .obsm['aucell_estimate'] matrix columns to adata.obs using the column names | |
87 adata.obs[score_name] = adata.obsm["aucell_estimate"][score_name] | |
88 | |
89 | |
90 def run_for_genelists( | |
91 adata, gene_lists, score_names, use_raw=False, gene_symbols_field="gene_symbols" | |
92 ): | |
93 if len(gene_lists) == len(score_names): | |
94 for gene_list, score_names in zip(gene_lists, score_names): | |
95 genes = gene_list.split(",") | |
96 ens_gene_ids = adata.var[adata.var[gene_symbols_field].isin(genes)].index | |
97 score_genes_aucell( | |
98 adata, | |
99 ens_gene_ids, | |
100 f"AUCell_{score_names}", | |
101 use_raw, | |
102 ) | |
103 else: | |
104 raise ValueError( | |
105 "The number of gene lists (separated by :) and score names (separated by :) must be the same" | |
106 ) | |
107 | |
108 | |
109 if __name__ == "__main__": | |
110 # Create command-line arguments parser | |
111 parser = argparse.ArgumentParser(description="Score genes using Aucell") | |
112 parser.add_argument( | |
113 "--input_file", type=str, help="Path to input AnnData file", required=True | |
114 ) | |
115 parser.add_argument( | |
116 "--output_file", type=str, help="Path to output file", required=True | |
117 ) | |
118 parser.add_argument("--gmt_file", type=str, help="Path to GMT file", required=False) | |
119 # add argument for gene sets to score | |
120 parser.add_argument( | |
121 "--gene_sets_to_score", | |
122 type=str, | |
123 required=False, | |
124 help="Optional comma separated list of gene sets to score (the need to be in the gmt file)", | |
125 ) | |
126 # add argument for gene list (comma separated) to score | |
127 parser.add_argument( | |
128 "--gene_lists_to_score", | |
129 type=str, | |
130 required=False, | |
131 help="Comma separated list of genes to score. You can have more than one set of genes, separated by colon :", | |
132 ) | |
133 # argument for the score name when using the gene list | |
134 parser.add_argument( | |
135 "--score_names", | |
136 type=str, | |
137 required=False, | |
138 help="Name of the score column when using the gene list. You can have more than one set of score names, separated by colon :. It should be the same length as the number of gene lists.", | |
139 ) | |
140 parser.add_argument( | |
141 "--gene_symbols_field", | |
142 type=str, | |
143 help="Name of the gene symbols field in the AnnData object", | |
144 required=True, | |
145 ) | |
146 parser.add_argument("--use_raw", action="store_true", help="Use raw data") | |
147 parser.add_argument( | |
148 "--write_anndata", action="store_true", help="Write the modified AnnData object" | |
149 ) | |
150 | |
151 # Parse command-line arguments | |
152 args = parser.parse_args() | |
153 | |
154 # Load input AnnData object | |
155 adata = anndata.read_h5ad(args.input_file) | |
156 | |
157 if args.gmt_file is not None: | |
158 # Load MSigDB file in GMT format | |
159 msigdb = read_gmt(args.gmt_file) | |
160 | |
161 gene_sets_to_score = args.gene_sets_to_score.split(",") if args.gene_sets_to_score else [] | |
162 # Score genes by their ensembl ids using the score_genes_aucell function | |
163 for _, row in msigdb.iterrows(): | |
164 gene_set_name = row["gene_set_name"] | |
165 if not gene_sets_to_score or gene_set_name in gene_sets_to_score: | |
166 genes = row["genes"].split(",") | |
167 # Convert gene symbols to ensembl ids by using the columns gene_symbols and index in adata.var specific to the gene set | |
168 ens_gene_ids = adata.var[ | |
169 adata.var[args.gene_symbols_field].isin(genes) | |
170 ].index | |
171 score_genes_aucell( | |
172 adata, ens_gene_ids, f"AUCell_{gene_set_name}", args.use_raw | |
173 ) | |
174 elif args.gene_lists_to_score is not None and args.score_names is not None: | |
175 gene_lists = args.gene_lists_to_score.split(":") | |
176 score_names = args.score_names.split(",") | |
177 run_for_genelists( | |
178 adata, gene_lists, score_names, args.use_raw, args.gene_symbols_field | |
179 ) | |
180 | |
181 # Save the modified AnnData object or generate a file with cells as rows and the new score_names columns | |
182 if args.write_anndata: | |
183 adata.write_h5ad(args.output_file) | |
184 else: | |
185 new_columns = [col for col in adata.obs.columns if col.startswith("AUCell_")] | |
186 adata.obs[new_columns].to_csv(args.output_file, sep="\t", index=True) |