changeset 0:77d680b36e23 draft

planemo upload for repository https://github.com/ebi-gene-expression-group/container-galaxy-sc-tertiary/ commit 1034a450c97dcbb77871050cf0c6d3da90dac823
author ebi-gxa
date Fri, 15 Mar 2024 12:17:49 +0000
parents
children e9b06a8fb73a
files decoupler_aucell_score.py decoupler_pathway_inference.py decoupler_pathway_inference.xml decoupler_pseudobulk.py get_test_data.sh test-data/mouse_hallmark_ss.gmt test-data/progeny_test.tsv test-data/progeny_test_2.tsv
diffstat 8 files changed, 995 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/decoupler_aucell_score.py	Fri Mar 15 12:17:49 2024 +0000
@@ -0,0 +1,186 @@
+import argparse
+import os
+import tempfile
+
+import anndata
+import decoupler as dc
+import pandas as pd
+
+
+def read_gmt(gmt_file):
+    """
+    Reads a GMT file into a Pandas DataFrame.
+
+    Parameters
+    ----------
+    gmt_file : str
+        Path to the GMT file.
+
+    Returns
+    -------
+    pd.DataFrame
+        A DataFrame with the gene sets. Each row represents a gene set, and the columns are "gene_set_name", "gene_set_url", and "genes".
+    >>> line = "HALLMARK_NOTCH_SIGNALING\\thttp://www.gsea-msigdb.org/gsea/msigdb/human/geneset/HALLMARK_NOTCH_SIGNALING\\tJAG1\\tNOTCH3\\tNOTCH2\\tAPH1A\\tHES1\\tCCND1\\tFZD1\\tPSEN2\\tFZD7\\tDTX1\\tDLL1\\tFZD5\\tMAML2\\tNOTCH1\\tPSENEN\\tWNT5A\\tCUL1\\tWNT2\\tDTX4\\tSAP30\\tPPARD\\tKAT2A\\tHEYL\\tSKP1\\tRBX1\\tTCF7L2\\tARRB1\\tLFNG\\tPRKCA\\tDTX2\\tST3GAL6\\tFBXW11\\n"
+    >>> line2 = "HALLMARK_APICAL_SURFACE\\thttp://www.gsea-msigdb.org/gsea/msigdb/human/geneset/HALLMARK_APICAL_SURFACE\\tB4GALT1\\tRHCG\\tMAL\\tLYPD3\\tPKHD1\\tATP6V0A4\\tCRYBG1\\tSHROOM2\\tSRPX\\tMDGA1\\tTMEM8B\\tTHY1\\tPCSK9\\tEPHB4\\tDCBLD2\\tGHRL\\tLYN\\tGAS1\\tFLOT2\\tPLAUR\\tAKAP7\\tATP8B1\\tEFNA5\\tSLC34A3\\tAPP\\tGSTM3\\tHSPB1\\tSLC2A4\\tIL2RB\\tRTN4RL1\\tNCOA6\\tSULF2\\tADAM10\\tBRCA1\\tGATA3\\tAFAP1L2\\tIL2RG\\tCD160\\tADIPOR2\\tSLC22A12\\tNTNG1\\tSCUBE1\\tCX3CL1\\tCROCC\\n"
+    >>> temp_dir = tempfile.gettempdir()
+    >>> temp_gmt = os.path.join(temp_dir, "temp_file.gmt")
+    >>> with open(temp_gmt, "w") as f:
+    ...   f.write(line)
+    ...   f.write(line2)
+    288
+    380
+    >>> df = read_gmt(temp_gmt)
+    >>> df.shape[0]
+    2
+    >>> df.columns == ["gene_set_name", "genes"]
+    array([ True,  True])
+    >>> df.loc[df["gene_set_name"] == "HALLMARK_APICAL_SURFACE"].genes.tolist()[0].startswith("B4GALT1")
+    True
+    """
+    # Read the GMT file into a list of lines
+    with open(gmt_file, "r") as f:
+        lines = f.readlines()
+
+    # Create a list of dictionaries, where each dictionary represents a gene set
+    gene_sets = []
+    for line in lines:
+        fields = line.strip().split("\t")
+        gene_set = {"gene_set_name": fields[0], "genes": ",".join(fields[2:])}
+        gene_sets.append(gene_set)
+
+    # Convert the list of dictionaries to a DataFrame
+    return pd.DataFrame(gene_sets)
+
+
+def score_genes_aucell(
+    adata: anndata.AnnData, gene_list: list, score_name: str, use_raw=False
+):
+    """Score genes using Aucell.
+
+    Parameters
+    ----------
+    adata : anndata.AnnData
+    gene_list : list
+    score_names : str
+    use_raw : bool, optional
+
+    >>> import scanpy as sc
+    >>> import decoupler as dc
+    >>> adata = sc.datasets.pbmc68k_reduced()
+    >>> gene_list = adata.var[adata.var.index.str.startswith("RP")].index.tolist()
+    >>> score_genes_aucell(adata, gene_list, "ribosomal_aucell", use_raw=False)
+    >>> "ribosomal_aucell" in adata.obs.columns
+    True
+    """
+    # make a data.frame with two columns, geneset and gene_id, geneset filled with score_names and gene_id with gene_list, one row per element
+    geneset_df = pd.DataFrame(
+        {
+            "gene_id": gene_list,
+            "geneset": score_name,
+        }
+    )
+    # run decoupler's run_aucell
+    dc.run_aucell(
+        adata, net=geneset_df, source="geneset", target="gene_id", use_raw=use_raw
+    )
+    # copy .obsm['aucell_estimate'] matrix columns to adata.obs using the column names
+    adata.obs[score_name] = adata.obsm["aucell_estimate"][score_name]
+
+
+def run_for_genelists(
+    adata, gene_lists, score_names, use_raw=False, gene_symbols_field="gene_symbols"
+):
+    if len(gene_lists) == len(score_names):
+        for gene_list, score_names in zip(gene_lists, score_names):
+            genes = gene_list.split(",")
+            ens_gene_ids = adata.var[adata.var[gene_symbols_field].isin(genes)].index
+            score_genes_aucell(
+                adata,
+                ens_gene_ids,
+                f"AUCell_{score_names}",
+                use_raw,
+            )
+    else:
+        raise ValueError(
+            "The number of gene lists (separated by :) and score names (separated by :) must be the same"
+        )
+
+
+if __name__ == "__main__":
+    # Create command-line arguments parser
+    parser = argparse.ArgumentParser(description="Score genes using Aucell")
+    parser.add_argument(
+        "--input_file", type=str, help="Path to input AnnData file", required=True
+    )
+    parser.add_argument(
+        "--output_file", type=str, help="Path to output file", required=True
+    )
+    parser.add_argument("--gmt_file", type=str, help="Path to GMT file", required=False)
+    # add argument for gene sets to score
+    parser.add_argument(
+        "--gene_sets_to_score",
+        type=str,
+        required=False,
+        help="Optional comma separated list of gene sets to score (the need to be in the gmt file)",
+    )
+    # add argument for gene list (comma separated) to score
+    parser.add_argument(
+        "--gene_lists_to_score",
+        type=str,
+        required=False,
+        help="Comma separated list of genes to score. You can have more than one set of genes, separated by colon :",
+    )
+    # argument for the score name when using the gene list
+    parser.add_argument(
+        "--score_names",
+        type=str,
+        required=False,
+        help="Name of the score column when using the gene list. You can have more than one set of score names, separated by colon :. It should be the same length as the number of gene lists.",
+    )
+    parser.add_argument(
+        "--gene_symbols_field",
+        type=str,
+        help="Name of the gene symbols field in the AnnData object",
+        required=True,
+    )
+    parser.add_argument("--use_raw", action="store_true", help="Use raw data")
+    parser.add_argument(
+        "--write_anndata", action="store_true", help="Write the modified AnnData object"
+    )
+
+    # Parse command-line arguments
+    args = parser.parse_args()
+
+    # Load input AnnData object
+    adata = anndata.read_h5ad(args.input_file)
+
+    if args.gmt_file is not None:
+        # Load MSigDB file in GMT format
+        msigdb = read_gmt(args.gmt_file)
+
+        gene_sets_to_score = args.gene_sets_to_score.split(",") if args.gene_sets_to_score else []
+        # Score genes by their ensembl ids using the score_genes_aucell function
+        for _, row in msigdb.iterrows():
+            gene_set_name = row["gene_set_name"]
+            if not gene_sets_to_score or gene_set_name in gene_sets_to_score:
+                genes = row["genes"].split(",")
+                # Convert gene symbols to ensembl ids by using the columns gene_symbols and index in adata.var specific to the gene set
+                ens_gene_ids = adata.var[
+                    adata.var[args.gene_symbols_field].isin(genes)
+                ].index
+                score_genes_aucell(
+                    adata, ens_gene_ids, f"AUCell_{gene_set_name}", args.use_raw
+                )
+    elif args.gene_lists_to_score is not None and args.score_names is not None:
+        gene_lists = args.gene_lists_to_score.split(":")
+        score_names = args.score_names.split(",")
+        run_for_genelists(
+            adata, gene_lists, score_names, args.use_raw, args.gene_symbols_field
+        )
+
+    # Save the modified AnnData object or generate a file with cells as rows and the new score_names columns
+    if args.write_anndata:
+        adata.write_h5ad(args.output_file)
+    else:
+        new_columns = [col for col in adata.obs.columns if col.startswith("AUCell_")]
+        adata.obs[new_columns].to_csv(args.output_file, sep="\t", index=True)
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/decoupler_pathway_inference.py	Fri Mar 15 12:17:49 2024 +0000
@@ -0,0 +1,132 @@
+# import the necessary packages
+import argparse
+
+import anndata as ad
+import decoupler as dc
+import pandas as pd
+
+# define arguments for the script
+parser = argparse.ArgumentParser()
+
+# add AnnData input file option
+parser.add_argument(
+    "-i", "--input_anndata", help="AnnData input file", required=True
+)
+
+# add network input file option
+parser.add_argument(
+    "-n", "--input_network", help="Network input file", required=True
+)
+
+# output file prefix
+parser.add_argument(
+    "-o", "--output",
+    help="output files prefix",
+    default=None,
+)
+
+# path to save Activities AnnData file
+parser.add_argument(
+    "-a", "--activities_path", help="Path to save Activities AnnData file", default=None
+)
+
+# Column name in net with source nodes
+parser.add_argument(
+    "-s", "--source", help="Column name in net with source nodes.", default="source"
+)
+
+# Column name in net with target nodes
+parser.add_argument(
+    "-t", "--target", help="Column name in net with target nodes.", default="target"
+)
+
+# Column name in net with weights.
+parser.add_argument(
+    "-w", "--weight", help="Column name in net with weights.", default="weight"
+)
+
+# add boolean argument for use_raw
+parser.add_argument(
+    "--use_raw", action="store_true", default=False, help="Whether to use the raw part of the AnnData object"
+)
+
+# add argument for min_cells
+parser.add_argument(
+    "--min_n", help="Minimum of targets per source. If less, sources are removed.", default=5, type=int
+)
+
+# add activity inference method option
+parser.add_argument(
+    "-m", "--method", help="Activity inference method", default="mlm", required=True
+)
+args = parser.parse_args()
+
+# check that either -o or --output is specified
+if args.output is None:
+    raise ValueError("Please specify either -o or --output")
+
+# read in the AnnData input file
+adata = ad.read_h5ad(args.input_anndata)
+
+# read in the input file network input file
+network = pd.read_csv(args.input_network, sep='\t')
+
+if (
+    args.source not in network.columns
+    or args.target not in network.columns
+    or args.weight not in network.columns
+):
+    raise ValueError(
+        "Source, target, and weight columns are not present in the network"
+    )
+
+
+print(type(args.min_n))
+
+if args.method == "mlm":
+    dc.run_mlm(
+        mat=adata,
+        net=network,
+        source=args.source,
+        target=args.target,
+        weight=args.weight,
+        verbose=True,
+        min_n=args.min_n,
+        use_raw=args.use_raw 
+    )
+
+    if args.output is not None:
+        # write adata.obsm[mlm_key] and adata.obsm[mlm_pvals_key] to the output network files
+        combined_df = pd.concat([adata.obsm["mlm_estimate"], adata.obsm["mlm_pvals"]], axis=1)
+
+        # Save the combined dataframe to a file
+        combined_df.to_csv(args.output + ".tsv", sep="\t")
+
+    # if args.activities_path is specified, generate the activities AnnData and save the AnnData object to the specified path
+    if args.activities_path is not None:
+        acts = dc.get_acts(adata, obsm_key="mlm_estimate")
+        acts.write_h5ad(args.activities_path)
+
+elif args.method == "ulm":
+    dc.run_ulm(
+        mat=adata,
+        net=network,
+        source=args.source,
+        target=args.target,
+        weight=args.weight,
+        verbose=True,
+        min_n=args.min_n,
+        use_raw=args.use_raw 
+    )
+
+    if args.output is not None:
+        # write adata.obsm[mlm_key] and adata.obsm[mlm_pvals_key] to the output network files
+        combined_df = pd.concat([adata.obsm["ulm_estimate"], adata.obsm["ulm_pvals"]], axis=1)
+
+        # Save the combined dataframe to a file
+        combined_df.to_csv(args.output + ".tsv", sep="\t")
+
+    # if args.activities_path is specified, generate the activities AnnData and save the AnnData object to the specified path
+    if args.activities_path is not None:
+        acts = dc.get_acts(adata, obsm_key="ulm_estimate")
+        acts.write_h5ad(args.activities_path)
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/decoupler_pathway_inference.xml	Fri Mar 15 12:17:49 2024 +0000
@@ -0,0 +1,129 @@
+<tool id="decoupler_pathway_inference" name="Decoupler Pathway Inference" version="1.4.0+galaxy0" profile="20.05" license="MIT">
+    <description>
+        of functional genesets/pathways for scRNA-seq data.
+    </description>
+    <requirements>
+        <requirement type="package" version="1.4.0">decoupler</requirement>
+    </requirements>
+    <command>
+        python '$__tool_directory__/decoupler_pathway_inference.py'
+            -i '$input_anndata'
+            -n '$input_network_file'
+            --min_n "$min_n"
+            --method '$method'
+            $use_raw
+            --source $source
+            --target $target
+            --weight $weight
+            --output "inference"
+            $write_activities_path
+    </command>
+    <inputs>
+        <param name="input_anndata" type="data" format="h5ad" label="Input AnnData file" />
+        <param name="input_network_file" type="data" format="tabular" label="Input Network file" help="Tabular file with columns Source, Target and Weight. A source gene/pathway regulates/contains a target gene, weights can be either positive or negative. The source element needs to be part of the network, the target is a gene in the network and in the dataset" />
+        <param name="min_n" type="integer" min="0" value="5" label="Minimum targets per source." help="If targets are less than minimum, sources are removed" />
+        <param name="method" type="select" label="Activity inference method">
+            <option value="mlm" selected="true">Multivariate linear model (MLM)</option>
+            <option value="ulm">Univariate linear model (ULM)</option>
+        </param>
+        <param name="use_raw" type="boolean" truevalue="--use_raw" falsevalue="" checked="false" label="Use the raw part of the AnnData object" />
+        <param name="write_activities_path" type="boolean" truevalue="--activities_path anndata_activities_path.h5ad" falsevalue="" checked="true" label="Write the activities AnnData object (contains the MLM/ULM activity results for each pathway and each cell in the main matrix, it is not a replacement of the original AnnData provided as input)." />
+        <param name="source" type="text" value='source' label="Column name in network with source nodes." help="If empty then default is 'source' is used." />
+        <param name="target" type="text" value='target' label="Column name in network with target nodes." help="If empty then default is 'target' is used." />
+        <param name="weight" type="text" value='weight' label="Column name in network with weight." help="If empty then default is 'weight' is used." />
+    </inputs>
+    <outputs>
+        <data name="output_ad" format="h5ad" from_work_dir="anndata_activities_path.h5ad" label="${tool.name} on ${on_string}: Regulators/Pathways activity AnnData file">
+            <filter>write_activities_path</filter>
+        </data>
+        <data name="output_table" format="tabular" from_work_dir="inference.tsv" label="${tool.name} on ${on_string}: Output estimate table" />
+    </outputs>
+    <tests>
+        <!-- Hint: You can use [ctrl+alt+t] after defining the inputs/outputs to auto-scaffold some basic test cases. -->
+
+    <test expect_num_outputs="2">
+        <param name="input_anndata" value="pbmc3k_processed.h5ad"/>
+        <param name="input_network_file" value="progeny_test.tsv"/>
+        <param name="min_n" value="0"/>
+        <param name="method" value="mlm"/>
+        <param name="use_raw" value="false"/>
+        <param name="write_activities_path" value="true"/>
+        <param name="source" value="source"/>
+        <param name="target" value="target"/>
+        <param name="weight" value="weight"/>
+        <output name="output_ad">
+            <assert_contents>
+                <has_h5_keys keys="obsm/mlm_estimate"/>
+            </assert_contents>
+        </output>
+        <output name="output_table">
+            <assert_contents>
+                <has_n_columns n="5"/>
+            </assert_contents>
+        </output>
+    </test>
+    <test>
+        <param name="input_anndata" value="pbmc3k_processed.h5ad"/>
+        <param name="input_network_file" value="progeny_test_2.tsv"/>
+        <param name="min_n" value="0"/>
+        <param name="method" value="ulm"/>
+        <param name="use_raw" value="false"/>
+        <param name="write_activities_path" value="true"/>
+        <param name="source" value="source"/>
+        <param name="target" value="target"/>
+        <param name="weight" value="weight"/>
+        <output name="output_ad">
+            <assert_contents>
+                <has_h5_keys keys="obsm/ulm_estimate"/>
+            </assert_contents>
+        </output>
+        <output name="output_table">
+            <assert_contents>
+                <has_n_columns n="5"/>
+            </assert_contents>
+        </output>
+    </test>
+    </tests>
+    <help>
+**What it does**
+
+Usage
+.....
+
+
+**Description**
+
+This tool extracts pathway activity inference using decoupler.
+
+**Input** 
+
+The input file should be an AnnData object in H5AD format. The tool accepts an H5AD file containing raw or normalized data.
+
+The tool also takes network file containing a collection of pathways and their target genes, with weights for each interaction.
+        Example:
+        ```
+                source    target    weight
+            0    T1    G01    1.0
+            1    T1    G02    1.0
+            2    T1    G03    0.7
+            3    T2    G04    1.0
+            4    T2    G06    -0.5
+        ```
+
+You can also specify whether to use the raw data in the AnnData object instead of the X matrix using the "use_raw" parameter and Minimum of targets per source using "min_n".
+
+
+**Output**
+
+The tool outputs an AnnData object containing the scores in the "obs" field, and tab-separated text files containing the scores for each cell.
+
+If the "write_activities_path" parameter is set to "true", the tool will write the modified AnnData object to an H5AD file. 
+If the "write_inference" parameter is set to "true", the tool will output a tab-separated text file containing the scores for each cell.
+
+
+
+    </help>
+    <citations>
+        <citation type="doi">10.1093/bioadv/vbac016 </citation>
+    </citations>
+</tool>
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/decoupler_pseudobulk.py	Fri Mar 15 12:17:49 2024 +0000
@@ -0,0 +1,367 @@
+import argparse
+
+import anndata
+import decoupler
+import pandas as pd
+
+
+def get_pseudobulk(
+    adata,
+    sample_col,
+    groups_col,
+    layer=None,
+    mode="sum",
+    min_cells=10,
+    min_counts=1000,
+    use_raw=False,
+):
+    """
+    >>> import scanpy as sc
+    >>> adata = sc.datasets.pbmc68k_reduced()
+    >>> adata.X = abs(adata.X).astype(int)
+    >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain")
+    """
+
+    return decoupler.get_pseudobulk(
+        adata,
+        sample_col=sample_col,
+        groups_col=groups_col,
+        layer=layer,
+        mode=mode,
+        use_raw=use_raw,
+        min_cells=min_cells,
+        min_counts=min_counts,
+    )
+
+
+def prepend_c_to_index(index_value):
+    if index_value and index_value[0].isdigit():
+        return "C" + index_value
+    return index_value
+
+
+# write results for loading into DESeq2
+def write_DESeq2_inputs(pdata, layer=None, output_dir="", factor_fields=None):
+    """
+    >>> import scanpy as sc
+    >>> adata = sc.datasets.pbmc68k_reduced()
+    >>> adata.X = abs(adata.X).astype(int)
+    >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain")
+    >>> write_DESeq2_inputs(pseudobulk)
+    """
+    # add / to output_dir if is not empty or if it doesn't end with /
+    if output_dir != "" and not output_dir.endswith("/"):
+        output_dir = output_dir + "/"
+    obs_for_deseq = pdata.obs.copy()
+    # replace any index starting with digits to start with C instead.
+    obs_for_deseq.rename(index=prepend_c_to_index, inplace=True)
+    # avoid dash that is read as point on R colnames.
+    obs_for_deseq.index = obs_for_deseq.index.str.replace("-", "_")
+    obs_for_deseq.index = obs_for_deseq.index.str.replace(" ", "_")
+    col_metadata_file = f"{output_dir}col_metadata.tsv"
+    # write obs to a col_metadata file
+    if factor_fields:
+        # only output the index plus the columns in factor_fields in that order
+        obs_for_deseq[factor_fields].to_csv(col_metadata_file, sep="\t", index=True)
+    else:
+        obs_for_deseq.to_csv(col_metadata_file, sep="\t", index=True)
+    # write var to a gene_metadata file
+    pdata.var.to_csv(f"{output_dir}gene_metadata.tsv", sep="\t", index=True)
+    # write the counts matrix of a specified layer to file
+    if layer is None:
+        # write the X numpy matrix transposed to file
+        df = pd.DataFrame(pdata.X.T, index=pdata.var.index, columns=obs_for_deseq.index)
+    else:
+        df = pd.DataFrame(
+            pdata.layers[layer].T, index=pdata.var.index, columns=obs_for_deseq.index
+        )
+    df.to_csv(f"{output_dir}counts_matrix.tsv", sep="\t", index_label="")
+
+
+def plot_pseudobulk_samples(
+    pseudobulk_data,
+    groupby,
+    figsize=(10, 10),
+    save_path=None,
+):
+    """
+    >>> import scanpy as sc
+    >>> adata = sc.datasets.pbmc68k_reduced()
+    >>> adata.X = abs(adata.X).astype(int)
+    >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain")
+    >>> plot_pseudobulk_samples(pseudobulk, groupby=["bulk_labels", "louvain"], figsize=(10, 10))
+    """
+    fig = decoupler.plot_psbulk_samples(
+        pseudobulk_data, groupby=groupby, figsize=figsize, return_fig=True
+    )
+    if save_path:
+        fig.savefig(f"{save_path}/pseudobulk_samples.png")
+    else:
+        fig.show()
+
+
+def plot_filter_by_expr(
+    pseudobulk_data, group, min_count=None, min_total_count=None, save_path=None
+):
+    """
+    >>> import scanpy as sc
+    >>> adata = sc.datasets.pbmc68k_reduced()
+    >>> adata.X = abs(adata.X).astype(int)
+    >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain")
+    >>> plot_filter_by_expr(pseudobulk, group="bulk_labels", min_count=10, min_total_count=200)
+    """
+    fig = decoupler.plot_filter_by_expr(
+        pseudobulk_data,
+        group=group,
+        min_count=min_count,
+        min_total_count=min_total_count,
+        return_fig=True,
+    )
+    if save_path:
+        fig.savefig(f"{save_path}/filter_by_expr.png")
+    else:
+        fig.show()
+
+
+def filter_by_expr(pdata, min_count=None, min_total_count=None):
+    """
+    >>> import scanpy as sc
+    >>> adata = sc.datasets.pbmc68k_reduced()
+    >>> adata.X = abs(adata.X).astype(int)
+    >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain")
+    >>> pdata_filt = filter_by_expr(pseudobulk, min_count=10, min_total_count=200)
+    """
+    genes = decoupler.filter_by_expr(
+        pdata, min_count=min_count, min_total_count=min_total_count
+    )
+    return pdata[:, genes].copy()
+
+
+def check_fields(fields, adata, obs=True, context=None):
+    """
+    >>> import scanpy as sc
+    >>> adata = sc.datasets.pbmc68k_reduced()
+    >>> check_fields(["bulk_labels", "louvain"], adata, obs=True)
+    """
+
+    legend = ""
+    if context:
+        legend = f", passed in {context},"
+    if obs:
+        if not set(fields).issubset(set(adata.obs.columns)):
+            raise ValueError(
+                f"Some of the following fields {legend} are not present in adata.obs: {fields}. Possible fields are: {list(set(adata.obs.columns))}"
+            )
+    else:
+        if not set(fields).issubset(set(adata.var.columns)):
+            raise ValueError(
+                f"Some of the following fields {legend} are not present in adata.var: {fields}. Possible fields are: {list(set(adata.var.columns))}"
+            )
+
+
+def main(args):
+    # Load AnnData object from file
+    adata = anndata.read_h5ad(args.adata_file)
+
+    # Merge adata.obs fields specified in args.adata_obs_fields_to_merge
+    if args.adata_obs_fields_to_merge:
+        # first split potential groups by ":" and iterate over them
+        for group in args.adata_obs_fields_to_merge.split(":"):
+            fields = group.split(",")
+            check_fields(fields, adata)
+            adata = merge_adata_obs_fields(fields, adata)
+
+    check_fields([args.groupby, args.sample_key], adata)
+
+    factor_fields = None
+    if args.factor_fields:
+        factor_fields = args.factor_fields.split(",")
+        check_fields(factor_fields, adata)
+
+    print(f"Using mode: {args.mode}")
+    # Perform pseudobulk analysis
+    pseudobulk_data = get_pseudobulk(
+        adata,
+        sample_col=args.sample_key,
+        groups_col=args.groupby,
+        layer=args.layer,
+        mode=args.mode,
+        use_raw=args.use_raw,
+        min_cells=args.min_cells,
+        min_counts=args.min_counts,
+    )
+
+    # Plot pseudobulk samples
+    plot_pseudobulk_samples(
+        pseudobulk_data,
+        args.groupby,
+        save_path=args.save_path,
+        figsize=args.plot_samples_figsize,
+    )
+
+    plot_filter_by_expr(
+        pseudobulk_data,
+        group=args.groupby,
+        min_count=args.min_counts,
+        min_total_count=args.min_total_counts,
+        save_path=args.save_path,
+    )
+
+    # Filter by expression if enabled
+    if args.filter_expr:
+        filtered_adata = filter_by_expr(
+            pseudobulk_data,
+            min_count=args.min_counts,
+            min_total_count=args.min_total_counts,
+        )
+
+        pseudobulk_data = filtered_adata
+
+    # Save the pseudobulk data
+    if args.anndata_output_path:
+        pseudobulk_data.write_h5ad(args.anndata_output_path, compression="gzip")
+
+    write_DESeq2_inputs(
+        pseudobulk_data, output_dir=args.deseq2_output_path, factor_fields=factor_fields
+    )
+
+
+def merge_adata_obs_fields(obs_fields_to_merge, adata):
+    """
+    Merge adata.obs fields specified in args.adata_obs_fields_to_merge
+
+    Parameters
+    ----------
+    obs_fields_to_merge : str
+        Fields in adata.obs to merge, comma separated
+    adata : anndata.AnnData
+        The AnnData object
+
+    Returns
+    -------
+    anndata.AnnData
+        The merged AnnData object
+
+    docstring tests:
+    >>> import scanpy as sc
+    >>> ad = sc.datasets.pbmc68k_reduced()
+    >>> ad = merge_adata_obs_fields(["bulk_labels","louvain"], ad)
+    >>> ad.obs.columns
+    Index(['bulk_labels', 'n_genes', 'percent_mito', 'n_counts', 'S_score',
+           'G2M_score', 'phase', 'louvain', 'bulk_labels_louvain'],
+          dtype='object')
+    """
+    field_name = "_".join(obs_fields_to_merge)
+    for field in obs_fields_to_merge:
+        if field not in adata.obs.columns:
+            raise ValueError(f"The '{field}' column is not present in adata.obs.")
+        if field_name not in adata.obs.columns:
+            adata.obs[field_name] = adata.obs[field].astype(str)
+        else:
+            adata.obs[field_name] = (
+                adata.obs[field_name] + "_" + adata.obs[field].astype(str)
+            )
+    return adata
+
+
+if __name__ == "__main__":
+    # Create argument parser
+    parser = argparse.ArgumentParser(
+        description="Perform pseudobulk analysis on an AnnData object"
+    )
+
+    # Add arguments
+    parser.add_argument("adata_file", type=str, help="Path to the AnnData file")
+    parser.add_argument(
+        "-m",
+        "--adata_obs_fields_to_merge",
+        type=str,
+        help="Fields in adata.obs to merge, comma separated. You can have more than one set of fields, separated by semi-colon ;",
+    )
+    parser.add_argument(
+        "--groupby",
+        type=str,
+        required=True,
+        help="The column in adata.obs that defines the groups",
+    )
+    parser.add_argument(
+        "--sample_key",
+        required=True,
+        type=str,
+        help="The column in adata.obs that defines the samples",
+    )
+    # add argument for layer
+    parser.add_argument(
+        "--layer",
+        type=str,
+        default=None,
+        help="The name of the layer of the AnnData object to use",
+    )
+    # add argument for mode
+    parser.add_argument(
+        "--mode",
+        type=str,
+        default="sum",
+        help="The mode for Decoupler pseudobulk analysis",
+        choices=["sum", "mean", "median"],
+    )
+    # add boolean argument for use_raw
+    parser.add_argument(
+        "--use_raw",
+        action="store_true",
+        default=False,
+        help="Whether to use the raw part of the AnnData object",
+    )
+    # add argument for min_cells
+    parser.add_argument(
+        "--min_cells",
+        type=int,
+        default=10,
+        help="Minimum number of cells for pseudobulk analysis",
+    )
+    parser.add_argument(
+        "--save_path", type=str, help="Path to save the plot (optional)"
+    )
+    parser.add_argument(
+        "--min_counts",
+        type=int,
+        help="Minimum count threshold for filtering by expression",
+    )
+    parser.add_argument(
+        "--min_total_counts",
+        type=int,
+        help="Minimum total count threshold for filtering by expression",
+    )
+    parser.add_argument(
+        "--anndata_output_path",
+        type=str,
+        help="Path to save the filtered AnnData object or pseudobulk data",
+    )
+    parser.add_argument(
+        "--filter_expr", action="store_true", help="Enable filtering by expression"
+    )
+    parser.add_argument(
+        "--factor_fields",
+        type=str,
+        help="Comma separated list of fields for the factors",
+    )
+    parser.add_argument(
+        "--deseq2_output_path",
+        type=str,
+        help="Path to save the DESeq2 inputs",
+        required=True,
+    )
+    parser.add_argument(
+        "--plot_samples_figsize",
+        type=int,
+        default=[10, 10],
+        nargs=2,
+        help="Size of the samples plot as a tuple (two arguments)",
+    )
+    parser.add_argument("--plot_filtering_figsize", type=int, default=[10, 10], nargs=2)
+
+    # Parse the command line arguments
+    args = parser.parse_args()
+
+    # Call the main function
+    main(args)
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/get_test_data.sh	Fri Mar 15 12:17:49 2024 +0000
@@ -0,0 +1,37 @@
+#!/usr/bin/env bash
+
+BASENAME_FILE='mito_counted_anndata.h5ad'
+
+MTX_LINK='https://zenodo.org/record/7053673/files/Mito-counted_AnnData'
+
+# convenience for getting data
+function get_data {
+  local link=$1
+  local fname=$2
+
+  if [ ! -f $fname ]; then
+    echo "$fname not available locally, downloading.."
+    wget -O $fname --retry-connrefused --waitretry=1 --read-timeout=20 --timeout=15 -t 3 $link
+  fi
+}
+
+# get matrix data
+mkdir -p test-data
+pushd test-data
+get_data $MTX_LINK $BASENAME_FILE
+
+
+# Download input anndata for decoupler-pathway_inference
+BASENAME_FILE='pbmc3k_processed.h5ad'
+
+MTX_LINK='https://zenodo.org/records/3752813/files/pbmc3k_processed.h5ad'
+
+get_data $MTX_LINK $BASENAME_FILE
+
+# Download output anndata for decoupler-pathway_inference
+BASENAME_FILE='test.h5ad'
+
+MTX_LINK='https://zenodo.org/records/10401958/files/test.h5ad'
+
+get_data $MTX_LINK $BASENAME_FILE
+
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/test-data/mouse_hallmark_ss.gmt	Fri Mar 15 12:17:49 2024 +0000
@@ -0,0 +1,2 @@
+HALLMARK_NOTCH_SIGNALING	http://www.gsea-msigdb.org/gsea/msigdb/mouse/geneset/HALLMARK_NOTCH_SIGNALING	Jag1	Notch1	Notch2	Notch3	Ccnd1	Tcf7l2	Wnt5a	Lfng	Psenen	Psen2	Heyl	Fzd1	Rbx1	Hes1	Arrb1	Ppard	Prkca	Wnt2	Fzd5	Dtx1	Sap30	Dtx2	Kat2a	Dll1	Fzd7	St3gal6	Fbxw11	Cul1	Aph1a	Dtx4	Skp1	Maml2
+HALLMARK_APICAL_SURFACE	http://www.gsea-msigdb.org/gsea/msigdb/mouse/geneset/HALLMARK_APICAL_SURFACE	Adam10	Gata3	B4galt1	Hspb1	App	Il2rg	Atp8b1	Brca1	Slc34a3	Atp6v0a4	Ghrl	Ncoa6	Rhcg	Scube1	Efna5	Crybg1	Ephb4	Flot2	Gas1	Gstm5	Il2rb	Shroom2	Lyn	Mal	Plaur	Slc2a4	Thy1	Akap7	Srpx	Cx3cl1	Dcbld2	Lypd3	Pkhd1	Sulf2	Pcsk9	Tmem8b	Rtn4rl1	Ntng1	Slc22a12	Adipor2	Afap1l2	Mdga1	Cd160	Crocc
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/test-data/progeny_test.tsv	Fri Mar 15 12:17:49 2024 +0000
@@ -0,0 +1,71 @@
+	source	target	weight	p_value
+0	Androgen	TMPRSS2	11.490631	0.0
+1	Androgen	NKX3-1	10.622551	2.2e-44
+2	Androgen	MBOAT2	10.472733	4.6e-44
+3	Androgen	KLK2	10.176186	1.94441e-40
+4	Androgen	SARG	11.386852	2.79021e-40
+5	EGFR	LZTFL1	-1.8738769	2.0809955e-18
+6	EGFR	PHLDA2	3.5051384	2.0530624e-17
+7	EGFR	DUSP6	12.6293125	6.537324e-17
+8	EGFR	DUSP5	7.9430394	6.86669e-17
+9	EGFR	PHLDA1	6.619626	3.4106933e-16
+10	Estrogen	GREB1	17.240173	0.0
+11	Estrogen	RET	10.718027	0.0
+12	Estrogen	TFF1	14.430255	0.0
+13	Estrogen	HEY2	11.482369	3.1e-44
+14	Estrogen	RAPGEFL1	10.544896	5.2e-43
+15	Hypoxia	FAM162A	8.335551	0.0
+16	Hypoxia	NDRG1	22.08712	0.0
+17	Hypoxia	ENO2	14.32694	0.0
+18	Hypoxia	PDK1	13.120449	0.0
+19	Hypoxia	ANKRD37	8.484976	0.0
+20	JAK-STAT	OAS1	15.028714	1.058e-41
+21	JAK-STAT	HERC6	8.769676	1.3450407e-38
+22	JAK-STAT	OAS3	10.618842	1.2143582e-37
+23	JAK-STAT	PLSCR1	8.481604	8.955206e-37
+24	JAK-STAT	DDX60	12.198234	9.150971e-36
+25	MAPK	DUSP6	16.859016	0.0
+26	MAPK	SPRED2	3.5018346	0.0
+27	MAPK	SPRY2	9.481585	9.19e-43
+28	MAPK	ETV5	5.9887094	6.7425e-41
+29	MAPK	EPHA2	6.3140125	3.7492e-40
+30	NFkB	NFKB1	9.513637	0.0
+31	NFkB	CXCL3	22.946114	0.0
+32	NFkB	NFKB2	5.5155754	0.0
+33	NFkB	NFKBIA	11.444533	0.0
+34	NFkB	BCL2A1	14.416924	0.0
+35	PI3K	MLANA	-9.985743	1.84e-43
+36	PI3K	PMEL	-6.5903482	6.8747866e-36
+37	PI3K	FAXDC2	-12.421274	3.297515e-34
+38	PI3K	HSD17B8	-8.601571	9.948224e-34
+39	PI3K	CTSF	-9.172143	1.0235212e-31
+40	TGFb	LINC00312	4.428987	2.0074443e-17
+41	TGFb	TSPAN2	5.502326	3.1451768e-16
+42	TGFb	SMAD7	7.6311436	7.3087106e-16
+43	TGFb	NOX4	5.913813	3.8292238e-15
+44	TGFb	COL4A1	6.3374896	9.052501e-15
+45	TNFa	CSF2	8.35548	0.0
+46	TNFa	CXCL5	10.0813675	0.0
+47	TNFa	NFKBIE	10.356205	0.0
+48	TNFa	TNFAIP3	35.40072	0.0
+49	TNFa	EFNA1	18.63111	0.0
+50	Trail	FRMPD1	-2.2346141	9.378505e-07
+51	Trail	WT1-AS	2.2251053	2.0316747e-06
+52	Trail	WNT8A	-1.8469616	3.795469e-05
+53	Trail	GPR18	3.240805	6.1090715e-05
+54	Trail	TEC	2.0513217	6.32898e-05
+55	VEGF	CRACD	-4.87119	6.7185365e-25
+56	VEGF	VWA8	-3.6068044	1.4495265e-18
+57	VEGF	NLGN1	-5.618075	2.6587072e-18
+58	VEGF	NRG3	-5.823747	1.0848074e-16
+59	VEGF	KCNK10	2.8833063	1.8129868e-16
+60	WNT	BMP4	5.936831	2.511717e-10
+61	WNT	SIGLEC6	2.0207362	2.347858e-09
+62	WNT	NPY2R	1.3872339	8.666917e-09
+63	WNT	CSF3R	1.9323153	3.0219417e-07
+64	WNT	KRT23	4.1216116	5.463989e-07
+65	p53	GLS2	6.452465	7.444302e-37
+66	p53	MDM2	8.193488	2.1194304e-35
+67	p53	ZNF79	4.020263	4.5987433e-34
+68	p53	FDXR	11.994496	5.589482e-32
+69	p53	LCE1B	11.813737	7.8095406e-30
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/test-data/progeny_test_2.tsv	Fri Mar 15 12:17:49 2024 +0000
@@ -0,0 +1,71 @@
+source	target	weight	p_value
+Androgen	TMPRSS2	11.490631	0.0
+Androgen	NKX3-1	10.622551	2.2e-44
+Androgen	MBOAT2	10.472733	4.6e-44
+Androgen	KLK2	10.176186	1.94441e-40
+Androgen	SARG	11.386852	2.79021e-40
+EGFR	LZTFL1	-1.8738769	2.0809955e-18
+EGFR	PHLDA2	3.5051384	2.0530624e-17
+EGFR	DUSP6	12.6293125	6.537324e-17
+EGFR	DUSP5	7.9430394	6.86669e-17
+EGFR	PHLDA1	6.619626	3.4106933e-16
+Estrogen	GREB1	17.240173	0.0
+Estrogen	RET	10.718027	0.0
+Estrogen	TFF1	14.430255	0.0
+Estrogen	HEY2	11.482369	3.1e-44
+Estrogen	RAPGEFL1	10.544896	5.2e-43
+Hypoxia	FAM162A	8.335551	0.0
+Hypoxia	NDRG1	22.08712	0.0
+Hypoxia	ENO2	14.32694	0.0
+Hypoxia	PDK1	13.120449	0.0
+Hypoxia	ANKRD37	8.484976	0.0
+JAK-STAT	OAS1	15.028714	1.058e-41
+JAK-STAT	HERC6	8.769676	1.3450407e-38
+JAK-STAT	OAS3	10.618842	1.2143582e-37
+JAK-STAT	PLSCR1	8.481604	8.955206e-37
+JAK-STAT	DDX60	12.198234	9.150971e-36
+MAPK	DUSP6	16.859016	0.0
+MAPK	SPRED2	3.5018346	0.0
+MAPK	SPRY2	9.481585	9.19e-43
+MAPK	ETV5	5.9887094	6.7425e-41
+MAPK	EPHA2	6.3140125	3.7492e-40
+NFkB	NFKB1	9.513637	0.0
+NFkB	CXCL3	22.946114	0.0
+NFkB	NFKB2	5.5155754	0.0
+NFkB	NFKBIA	11.444533	0.0
+NFkB	BCL2A1	14.416924	0.0
+PI3K	MLANA	-9.985743	1.84e-43
+PI3K	PMEL	-6.5903482	6.8747866e-36
+PI3K	FAXDC2	-12.421274	3.297515e-34
+PI3K	HSD17B8	-8.601571	9.948224e-34
+PI3K	CTSF	-9.172143	1.0235212e-31
+TGFb	LINC00312	4.428987	2.0074443e-17
+TGFb	TSPAN2	5.502326	3.1451768e-16
+TGFb	SMAD7	7.6311436	7.3087106e-16
+TGFb	NOX4	5.913813	3.8292238e-15
+TGFb	COL4A1	6.3374896	9.052501e-15
+TNFa	CSF2	8.35548	0.0
+TNFa	CXCL5	10.0813675	0.0
+TNFa	NFKBIE	10.356205	0.0
+TNFa	TNFAIP3	35.40072	0.0
+TNFa	EFNA1	18.63111	0.0
+Trail	FRMPD1	-2.2346141	9.378505e-07
+Trail	WT1-AS	2.2251053	2.0316747e-06
+Trail	WNT8A	-1.8469616	3.795469e-05
+Trail	GPR18	3.240805	6.1090715e-05
+Trail	TEC	2.0513217	6.32898e-05
+VEGF	CRACD	-4.87119	6.7185365e-25
+VEGF	VWA8	-3.6068044	1.4495265e-18
+VEGF	NLGN1	-5.618075	2.6587072e-18
+VEGF	NRG3	-5.823747	1.0848074e-16
+VEGF	KCNK10	2.8833063	1.8129868e-16
+WNT	BMP4	5.936831	2.511717e-10
+WNT	SIGLEC6	2.0207362	2.347858e-09
+WNT	NPY2R	1.3872339	8.666917e-09
+WNT	CSF3R	1.9323153	3.0219417e-07
+WNT	KRT23	4.1216116	5.463989e-07
+p53	GLS2	6.452465	7.444302e-37
+p53	MDM2	8.193488	2.1194304e-35
+p53	ZNF79	4.020263	4.5987433e-34
+p53	FDXR	11.994496	5.589482e-32
+p53	LCE1B	11.813737	7.8095406e-30