changeset 3:c6787c2aee46 draft default tip

planemo upload for repository https://github.com/ebi-gene-expression-group/container-galaxy-sc-tertiary/ commit eea5c13f9e6e070a2359c59400773b01f9cd7567
author ebi-gxa
date Mon, 15 Jul 2024 10:56:37 +0000
parents 82b7cd3e1bbd
children
files decoupler_aucell_score.py decoupler_pathway_inference.py decoupler_pseudobulk.py
diffstat 3 files changed, 354 insertions(+), 82 deletions(-) [+]
line wrap: on
line diff
--- a/decoupler_aucell_score.py	Tue Apr 16 11:49:19 2024 +0000
+++ b/decoupler_aucell_score.py	Mon Jul 15 10:56:37 2024 +0000
@@ -1,16 +1,15 @@
 import argparse
-import os
-import tempfile
 
 import anndata
 import decoupler as dc
+import numba as nb
 import pandas as pd
-import numba as nb
 
 
 def read_gmt_long(gmt_file):
-    """
-    Reads a GMT file and produce a Pandas DataFrame in long format, ready to be passed to the AUCell method.
+    r"""
+    Reads a GMT file and produce a Pandas DataFrame in long format, ready to
+    be passed to the AUCell method.
 
     Parameters
     ----------
@@ -20,9 +19,29 @@
     Returns
     -------
     pd.DataFrame
-        A DataFrame with the gene sets. Each row represents a gene set to gene assignment, and the columns are "gene_set_name" and "genes".
-    >>> line = "HALLMARK_NOTCH_SIGNALING\\thttp://www.gsea-msigdb.org/gsea/msigdb/human/geneset/HALLMARK_NOTCH_SIGNALING\\tJAG1\\tNOTCH3\\tNOTCH2\\tAPH1A\\tHES1\\tCCND1\\tFZD1\\tPSEN2\\tFZD7\\tDTX1\\tDLL1\\tFZD5\\tMAML2\\tNOTCH1\\tPSENEN\\tWNT5A\\tCUL1\\tWNT2\\tDTX4\\tSAP30\\tPPARD\\tKAT2A\\tHEYL\\tSKP1\\tRBX1\\tTCF7L2\\tARRB1\\tLFNG\\tPRKCA\\tDTX2\\tST3GAL6\\tFBXW11\\n"
-    >>> line2 = "HALLMARK_APICAL_SURFACE\\thttp://www.gsea-msigdb.org/gsea/msigdb/human/geneset/HALLMARK_APICAL_SURFACE\\tB4GALT1\\tRHCG\\tMAL\\tLYPD3\\tPKHD1\\tATP6V0A4\\tCRYBG1\\tSHROOM2\\tSRPX\\tMDGA1\\tTMEM8B\\tTHY1\\tPCSK9\\tEPHB4\\tDCBLD2\\tGHRL\\tLYN\\tGAS1\\tFLOT2\\tPLAUR\\tAKAP7\\tATP8B1\\tEFNA5\\tSLC34A3\\tAPP\\tGSTM3\\tHSPB1\\tSLC2A4\\tIL2RB\\tRTN4RL1\\tNCOA6\\tSULF2\\tADAM10\\tBRCA1\\tGATA3\\tAFAP1L2\\tIL2RG\\tCD160\\tADIPOR2\\tSLC22A12\\tNTNG1\\tSCUBE1\\tCX3CL1\\tCROCC\\n"
+        A DataFrame with the gene sets. Each row represents a gene set to gene
+        assignment, and the columns are "gene_set_name" and "genes".
+    >>> import os
+    >>> import tempfile
+    >>> line = "HALLMARK_NOTCH_SIGNALING\
+    ... \thttp://www.gsea-msigdb.org/\
+    ... gsea/msigdb/human/geneset/HALLMARK_NOTCH_SIGNALING\
+    ... \tJAG1\tNOTCH3\tNOTCH2\tAPH1A\tHES1\tCCND1\
+    ... \tFZD1\tPSEN2\tFZD7\tDTX1\tDLL1\tFZD5\tMAML2\
+    ... \tNOTCH1\tPSENEN\tWNT5A\tCUL1\tWNT2\tDTX4\
+    ... \tSAP30\tPPARD\tKAT2A\tHEYL\tSKP1\tRBX1\tTCF7L2\
+    ... \tARRB1\tLFNG\tPRKCA\tDTX2\tST3GAL6\tFBXW11\n"
+    >>> line2 = "HALLMARK_APICAL_SURFACE\
+    ... \thttp://www.gsea-msigdb.org/\
+    ... gsea/msigdb/human/geneset/HALLMARK_APICAL_SURFACE\
+    ... \tB4GALT1\tRHCG\tMAL\tLYPD3\tPKHD1\tATP6V0A4\
+    ... \tCRYBG1\tSHROOM2\tSRPX\tMDGA1\tTMEM8B\tTHY1\
+    ... \tPCSK9\tEPHB4\tDCBLD2\tGHRL\tLYN\tGAS1\tFLOT2\
+    ... \tPLAUR\tAKAP7\tATP8B1\tEFNA5\tSLC34A3\tAPP\
+    ... \tGSTM3\tHSPB1\tSLC2A4\tIL2RB\tRTN4RL1\tNCOA6\
+    ... \tSULF2\tADAM10\tBRCA1\tGATA3\tAFAP1L2\tIL2RG\
+    ... \tCD160\tADIPOR2\tSLC22A12\tNTNG1\tSCUBE1\tCX3CL1\
+    ... \tCROCC\n"
     >>> temp_dir = tempfile.gettempdir()
     >>> temp_gmt = os.path.join(temp_dir, "temp_file.gmt")
     >>> with open(temp_gmt, "w") as f:
@@ -36,7 +55,8 @@
     >>> len(df.loc[df["gene_set"] == "HALLMARK_APICAL_SURFACE"].gene.tolist())
     44
     """
-    # Create a list of dictionaries, where each dictionary represents a gene set
+    # Create a list of dictionaries, where each dictionary represents a
+    # gene set
     gene_sets = {}
 
     # Read the GMT file into a list of lines
@@ -46,12 +66,20 @@
             if not line:
                 break
             fields = line.strip().split("\t")
-            gene_sets[fields[0]]= fields[2:]
+            gene_sets[fields[0]] = fields[2:]
 
-    return pd.concat(pd.DataFrame({'gene_set':k, 'gene':v}) for k, v in gene_sets.items())
+    return pd.concat(
+        pd.DataFrame({"gene_set": k, "gene": v}) for k, v in gene_sets.items()
+    )
 
 
-def score_genes_aucell_mt(adata: anndata.AnnData, gene_set_gene: pd.DataFrame, use_raw=False, min_n_genes=5, var_gene_symbols_field=None):
+def score_genes_aucell_mt(
+    adata: anndata.AnnData,
+    gene_set_gene: pd.DataFrame,
+    use_raw=False,
+    min_n_genes=5,
+    var_gene_symbols_field=None,
+):
     """Score genes using Aucell.
 
     Parameters
@@ -60,17 +88,23 @@
     gene_set_gene: pd.DataFrame with columns gene_set and gene
     use_raw : bool, optional, False by default.
     min_n_genes : int, optional, 5 by default.
-    var_gene_symbols_field : str, optional, None by default. The field in var where gene symbols are stored
+    var_gene_symbols_field : str, optional, None by default. The field in var
+    where gene symbols are stored
 
     >>> import scanpy as sc
     >>> import decoupler as dc
     >>> adata = sc.datasets.pbmc68k_reduced()
-    >>> r_gene_list = adata.var[adata.var.index.str.startswith("RP")].index.tolist()
-    >>> m_gene_list = adata.var[adata.var.index.str.startswith("M")].index.tolist()
+    >>> r_gene_list = adata.var[
+    ...                  adata.var.index.str.startswith("RP")].index.tolist()
+    >>> m_gene_list = adata.var[
+    ...                  adata.var.index.str.startswith("M")].index.tolist()
     >>> gene_set = {}
     >>> gene_set["m"] = m_gene_list
     >>> gene_set["r"] = r_gene_list
-    >>> gene_set_df = pd.concat(pd.DataFrame({'gene_set':k, 'gene':v}) for k, v in gene_set.items())
+    >>> gene_set_df = pd.concat(
+    ...                  pd.DataFrame(
+    ...                     {'gene_set':k, 'gene':v}
+    ...     ) for k, v in gene_set.items())
     >>> score_genes_aucell_mt(adata, gene_set_df, use_raw=False)
     >>> "AUCell_m" in adata.obs.columns
     True
@@ -78,47 +112,72 @@
     True
     """
 
-    # if var_gene_symbols_fiels is provided, transform gene_set_gene df so that gene contains gene ids instead of gene symbols
+    # if var_gene_symbols_fiels is provided, transform gene_set_gene df so
+    #  that gene contains gene ids instead of gene symbols
     if var_gene_symbols_field:
-        # merge the index of var to gene_set_gene df based on var_gene_symbols_field
+        # merge the index of var to gene_set_gene df based on
+        # var_gene_symbols_field
         var_id_symbols = adata.var[[var_gene_symbols_field]]
-        var_id_symbols['gene_id'] = var_id_symbols.index
+        var_id_symbols["gene_id"] = var_id_symbols.index
 
-        gene_set_gene = gene_set_gene.merge(var_id_symbols, left_on='gene', right_on=var_gene_symbols_field, how='left')
-        # this will still produce some empty gene_ids (genes in the gene_set_gene df that are not in the var df), fill those
-        # with the original gene symbol from the gene_set to avoid deforming the AUCell calculation
-        gene_set_gene['gene_id'] = gene_set_gene['gene_id'].fillna(gene_set_gene['gene'])
-        gene_set_gene['gene'] = gene_set_gene['gene_id']
-    
+        gene_set_gene = gene_set_gene.merge(
+            var_id_symbols,
+            left_on="gene",
+            right_on=var_gene_symbols_field,
+            how="left",
+        )
+        # this will still produce some empty gene_ids (genes in the
+        # gene_set_gene df that are not in the var df), fill those
+        # with the original gene symbol from the gene_set to avoid
+        # deforming the AUCell calculation
+        gene_set_gene["gene_id"] = gene_set_gene["gene_id"].fillna(
+            gene_set_gene["gene"]
+        )
+        gene_set_gene["gene"] = gene_set_gene["gene_id"]
+
     # run decoupler's run_aucell
     dc.run_aucell(
-            adata, net=gene_set_gene, source="gene_set", target="gene", use_raw=use_raw, min_n=min_n_genes
-        )
+        adata,
+        net=gene_set_gene,
+        source="gene_set",
+        target="gene",
+        use_raw=use_raw,
+        min_n=min_n_genes,
+    )
     for gs in gene_set_gene.gene_set.unique():
-        if gs in adata.obsm['aucell_estimate'].keys():
+        if gs in adata.obsm["aucell_estimate"].keys():
             adata.obs[f"AUCell_{gs}"] = adata.obsm["aucell_estimate"][gs]
 
 
 def run_for_genelists(
-    adata, gene_lists, score_names, use_raw=False, gene_symbols_field=None, min_n_genes=5
+    adata,
+    gene_lists,
+    score_names,
+    use_raw=False,
+    gene_symbols_field=None,
+    min_n_genes=5,
 ):
     if len(gene_lists) == len(score_names):
         for gene_list, score_names in zip(gene_lists, score_names):
             genes = gene_list.split(",")
             gene_sets = {}
             gene_sets[score_names] = genes
-            gene_set_gene_df = pd.concat(pd.DataFrame({'gene_set':k, 'gene':v}) for k, v in gene_sets.items())
-            
+            gene_set_gene_df = pd.concat(
+                pd.DataFrame({"gene_set": k, "gene": v})
+                for k, v in gene_sets.items()
+            )
+
             score_genes_aucell_mt(
                 adata,
                 gene_set_gene_df,
                 use_raw,
                 min_n_genes,
-                var_gene_symbols_field=gene_symbols_field
+                var_gene_symbols_field=gene_symbols_field,
             )
     else:
         raise ValueError(
-            "The number of gene lists (separated by :) and score names (separated by :) must be the same"
+            "The number of gene lists (separated by :) and score names \
+                (separated by :) must be the same"
         )
 
 
@@ -126,32 +185,41 @@
     # Create command-line arguments parser
     parser = argparse.ArgumentParser(description="Score genes using Aucell")
     parser.add_argument(
-        "--input_file", type=str, help="Path to input AnnData file", required=True
+        "--input_file",
+        type=str,
+        help="Path to input AnnData file",
+        required=True,
     )
     parser.add_argument(
         "--output_file", type=str, help="Path to output file", required=True
     )
-    parser.add_argument("--gmt_file", type=str, help="Path to GMT file", required=False)
+    parser.add_argument(
+        "--gmt_file", type=str, help="Path to GMT file", required=False
+    )
     # add argument for gene sets to score
     parser.add_argument(
         "--gene_sets_to_score",
         type=str,
         required=False,
-        help="Optional comma separated list of gene sets to score (the need to be in the gmt file)",
+        help="Optional comma separated list of gene sets to score \
+            (the need to be in the gmt file)",
     )
     # add argument for gene list (comma separated) to score
     parser.add_argument(
         "--gene_lists_to_score",
         type=str,
         required=False,
-        help="Comma separated list of genes to score. You can have more than one set of genes, separated by colon :",
+        help="Comma separated list of genes to score. You can have more \
+            than one set of genes, separated by colon :",
     )
     # argument for the score name when using the gene list
     parser.add_argument(
         "--score_names",
         type=str,
         required=False,
-        help="Name of the score column when using the gene list. You can have more than one set of score names, separated by colon :. It should be the same length as the number of gene lists.",
+        help="Name of the score column when using the gene list. You can \
+            have more than one set of score names, separated by colon :. \
+                It should be the same length as the number of gene lists.",
     )
     parser.add_argument(
         "--gene_symbols_field",
@@ -159,7 +227,8 @@
         help="Name of the gene symbols field in the AnnData object",
         required=True,
     )
-    # argument for min_n Minimum of targets per source. If less, sources are removed.
+    # argument for min_n Minimum of targets per source. If less, sources
+    # are removed.
     parser.add_argument(
         "--min_n",
         type=int,
@@ -169,11 +238,18 @@
     )
     parser.add_argument("--use_raw", action="store_true", help="Use raw data")
     parser.add_argument(
-        "--write_anndata", action="store_true", help="Write the modified AnnData object"
+        "--write_anndata",
+        action="store_true",
+        help="Write the modified AnnData object",
     )
     # argument for number of max concurrent processes
-    parser.add_argument("--max_threads", type=int, required=False, default=1, help="Number of max concurrent threads")
-
+    parser.add_argument(
+        "--max_threads",
+        type=int,
+        required=False,
+        default=1,
+        help="Number of max concurrent threads",
+    )
 
     # Parse command-line arguments
     args = parser.parse_args()
@@ -189,23 +265,40 @@
         msigdb = read_gmt_long(args.gmt_file)
 
         gene_sets_to_score = (
-            args.gene_sets_to_score.split(",") if args.gene_sets_to_score else []
+            args.gene_sets_to_score.split(",")
+            if args.gene_sets_to_score
+            else []
         )
         if gene_sets_to_score:
-            # we limit the GMT file read to the genesets specified in the gene_sets_to_score argument
+            # we limit the GMT file read to the genesets specified in the
+            # gene_sets_to_score argument
             msigdb = msigdb[msigdb["gene_set"].isin(gene_sets_to_score)]
-        
-        score_genes_aucell_mt(adata, msigdb, args.use_raw, args.min_n, var_gene_symbols_field=args.gene_symbols_field)
+
+        score_genes_aucell_mt(
+            adata,
+            msigdb,
+            args.use_raw,
+            args.min_n,
+            var_gene_symbols_field=args.gene_symbols_field,
+        )
     elif args.gene_lists_to_score is not None and args.score_names is not None:
         gene_lists = args.gene_lists_to_score.split(":")
         score_names = args.score_names.split(",")
         run_for_genelists(
-            adata, gene_lists, score_names, args.use_raw, args.gene_symbols_field, args.min_n
+            adata,
+            gene_lists,
+            score_names,
+            args.use_raw,
+            args.gene_symbols_field,
+            args.min_n,
         )
 
-    # Save the modified AnnData object or generate a file with cells as rows and the new score_names columns
+    # Save the modified AnnData object or generate a file with cells as rows
+    # and the new score_names columns
     if args.write_anndata:
         adata.write_h5ad(args.output_file)
     else:
-        new_columns = [col for col in adata.obs.columns if col.startswith("AUCell_")]
+        new_columns = [
+            col for col in adata.obs.columns if col.startswith("AUCell_")
+        ]
         adata.obs[new_columns].to_csv(args.output_file, sep="\t", index=True)
--- a/decoupler_pathway_inference.py	Tue Apr 16 11:49:19 2024 +0000
+++ b/decoupler_pathway_inference.py	Mon Jul 15 10:56:37 2024 +0000
@@ -20,24 +20,34 @@
 
 # output file prefix
 parser.add_argument(
-    "-o", "--output",
+    "-o",
+    "--output",
     help="output files prefix",
     default=None,
 )
 
 # path to save Activities AnnData file
 parser.add_argument(
-    "-a", "--activities_path", help="Path to save Activities AnnData file", default=None
+    "-a",
+    "--activities_path",
+    help="Path to save Activities AnnData file",
+    default=None,
 )
 
 # Column name in net with source nodes
 parser.add_argument(
-    "-s", "--source", help="Column name in net with source nodes.", default="source"
+    "-s",
+    "--source",
+    help="Column name in net with source nodes.",
+    default="source",
 )
 
 # Column name in net with target nodes
 parser.add_argument(
-    "-t", "--target", help="Column name in net with target nodes.", default="target"
+    "-t",
+    "--target",
+    help="Column name in net with target nodes.",
+    default="target",
 )
 
 # Column name in net with weights.
@@ -47,17 +57,27 @@
 
 # add boolean argument for use_raw
 parser.add_argument(
-    "--use_raw", action="store_true", default=False, help="Whether to use the raw part of the AnnData object"
+    "--use_raw",
+    action="store_true",
+    default=False,
+    help="Whether to use the raw part of the AnnData object",
 )
 
 # add argument for min_cells
 parser.add_argument(
-    "--min_n", help="Minimum of targets per source. If less, sources are removed.", default=5, type=int
+    "--min_n",
+    help="Minimum of targets per source. If less, sources are removed.",
+    default=5,
+    type=int,
 )
 
 # add activity inference method option
 parser.add_argument(
-    "-m", "--method", help="Activity inference method", default="mlm", required=True
+    "-m",
+    "--method",
+    help="Activity inference method",
+    default="mlm",
+    required=True,
 )
 args = parser.parse_args()
 
@@ -69,7 +89,7 @@
 adata = ad.read_h5ad(args.input_anndata)
 
 # read in the input file network input file
-network = pd.read_csv(args.input_network, sep='\t')
+network = pd.read_csv(args.input_network, sep="\t")
 
 if (
     args.source not in network.columns
@@ -92,17 +112,21 @@
         weight=args.weight,
         verbose=True,
         min_n=args.min_n,
-        use_raw=args.use_raw 
+        use_raw=args.use_raw,
     )
 
     if args.output is not None:
-        # write adata.obsm[mlm_key] and adata.obsm[mlm_pvals_key] to the output network files
-        combined_df = pd.concat([adata.obsm["mlm_estimate"], adata.obsm["mlm_pvals"]], axis=1)
+        # write adata.obsm[mlm_key] and adata.obsm[mlm_pvals_key] to the
+        # output network files
+        combined_df = pd.concat(
+            [adata.obsm["mlm_estimate"], adata.obsm["mlm_pvals"]], axis=1
+        )
 
         # Save the combined dataframe to a file
         combined_df.to_csv(args.output + ".tsv", sep="\t")
 
-    # if args.activities_path is specified, generate the activities AnnData and save the AnnData object to the specified path
+    # if args.activities_path is specified, generate the activities AnnData
+    # and save the AnnData object to the specified path
     if args.activities_path is not None:
         acts = dc.get_acts(adata, obsm_key="mlm_estimate")
         acts.write_h5ad(args.activities_path)
@@ -116,17 +140,21 @@
         weight=args.weight,
         verbose=True,
         min_n=args.min_n,
-        use_raw=args.use_raw 
+        use_raw=args.use_raw,
     )
 
     if args.output is not None:
-        # write adata.obsm[mlm_key] and adata.obsm[mlm_pvals_key] to the output network files
-        combined_df = pd.concat([adata.obsm["ulm_estimate"], adata.obsm["ulm_pvals"]], axis=1)
+        # write adata.obsm[mlm_key] and adata.obsm[mlm_pvals_key] to the
+        # output network files
+        combined_df = pd.concat(
+            [adata.obsm["ulm_estimate"], adata.obsm["ulm_pvals"]], axis=1
+        )
 
         # Save the combined dataframe to a file
         combined_df.to_csv(args.output + ".tsv", sep="\t")
 
-    # if args.activities_path is specified, generate the activities AnnData and save the AnnData object to the specified path
+    # if args.activities_path is specified, generate the activities AnnData
+    # and save the AnnData object to the specified path
     if args.activities_path is not None:
         acts = dc.get_acts(adata, obsm_key="ulm_estimate")
         acts.write_h5ad(args.activities_path)
--- a/decoupler_pseudobulk.py	Tue Apr 16 11:49:19 2024 +0000
+++ b/decoupler_pseudobulk.py	Mon Jul 15 10:56:37 2024 +0000
@@ -40,8 +40,108 @@
     return index_value
 
 
+def genes_to_ignore_per_contrast_field(
+    count_matrix_df,
+    samples_metadata,
+    sample_metadata_col_contrasts,
+    min_counts_per_sample=5,
+    use_cpms=False,
+):
+    """
+    # This function calculates the genes to ignore per contrast field
+    # (e.g., bulk_labels, louvain).
+    # It does this by first getting the count matrix for each group,
+    # then identifying genes with a count below a specified threshold.
+    # The genes to ignore are those that are present in more than a specified
+    # number of groups.
+
+    >>> import pandas as pd
+    >>> samples_metadata = pd.DataFrame({'sample':
+    ...                                    ['S1', 'S2', 'S3',
+    ...                                     'S4', 'S5', 'S6'],
+    ...                                  'contrast_field':
+    ...                                    ['A', 'A', 'A', 'B', 'B', 'B']})
+    >>> count_matrix_df = pd.DataFrame(
+    ...                       {'S1':
+    ...                          [30, 1, 40, 50, 30],
+    ...                        'S2':
+    ...                          [40, 2, 60, 50, 80],
+    ...                        'S3':
+    ...                          [80, 1, 60, 50, 50],
+    ...                        'S4': [1, 50, 50, 50, 2],
+    ...                        'S5': [3, 40, 40, 40, 2],
+    ...                        'S6': [0, 50, 50, 50, 1]})
+    >>> count_matrix_df.index = ['Gene1', 'Gene2', 'Gene3', 'Gene4', 'Gene5']
+    >>> df = genes_to_ignore_per_contrast_field(count_matrix_df,
+    ...             samples_metadata, min_counts_per_sample=5,
+    ...             sample_metadata_col_contrasts='contrast_field')
+    >>> df[df['contrast_field'] == 'A'].genes_to_ignore.tolist()[0]
+    'Gene2'
+    >>> df[df['contrast_field'] == 'B'].genes_to_ignore.tolist()[0]
+    'Gene1'
+    >>> df[df['contrast_field'] == 'B'].genes_to_ignore.tolist()[1]
+    'Gene5'
+    """
+
+    # Initialize a dictionary to store the genes to ignore per contrast field
+    contrast_fields = []
+    genes_to_ignore = []
+
+    # Iterate over the contrast fields
+    for contrast_field in samples_metadata[
+        sample_metadata_col_contrasts
+    ].unique():
+        # Get the count matrix for the current contrast field
+        count_matrix_field = count_matrix_df.loc[
+            :,
+            (
+                samples_metadata[sample_metadata_col_contrasts]
+                == contrast_field
+            ).tolist(),
+        ]
+
+        # We derive min_counts from the number of samples with that
+        # contrast_field value
+        min_counts = count_matrix_field.shape[1] * min_counts_per_sample
+
+        if use_cpms:
+            # Convert counts to counts per million (CPM)
+            count_matrix_field = (
+                count_matrix_field.div(count_matrix_field.sum(axis=1), axis=0)
+                * 1e6
+            )
+            min_counts = 1  # use 1 CPM
+
+        # Calculate the total number of cells in the current contrast field
+        # (this produces a vector of counts per gene)
+        total_counts_per_gene = count_matrix_field.sum(axis=1)
+
+        # Identify genes with a count below the specified threshold
+        genes = total_counts_per_gene[
+            total_counts_per_gene < min_counts
+        ].index.tolist()
+        if len(genes) > 0:
+            # genes_to_ignore[contrast_field] = " ".join(genes)
+            for gene in genes:
+                genes_to_ignore.append(gene)
+                contrast_fields.append(contrast_field)
+    # transform gene_to_ignore to a DataFrame
+    # genes_to_ignore_df = pd.DataFrame(genes_to_ignore.items(),
+    #                           columns=["contrast_field", "genes_to_ignore"])
+    genes_to_ignore_df = pd.DataFrame(
+        {"contrast_field": contrast_fields, "genes_to_ignore": genes_to_ignore}
+    )
+    return genes_to_ignore_df
+
+
 # write results for loading into DESeq2
-def write_DESeq2_inputs(pdata, layer=None, output_dir="", factor_fields=None):
+def write_DESeq2_inputs(
+    pdata,
+    layer=None,
+    output_dir="",
+    factor_fields=None,
+    min_counts_per_sample_marking=20,
+):
     """
     >>> import scanpy as sc
     >>> adata = sc.datasets.pbmc68k_reduced()
@@ -62,7 +162,9 @@
     # write obs to a col_metadata file
     if factor_fields:
         # only output the index plus the columns in factor_fields in that order
-        obs_for_deseq[factor_fields].to_csv(col_metadata_file, sep="\t", index=True)
+        obs_for_deseq[factor_fields].to_csv(
+            col_metadata_file, sep="\t", index=True
+        )
     else:
         obs_for_deseq.to_csv(col_metadata_file, sep="\t", index=True)
     # write var to a gene_metadata file
@@ -70,13 +172,28 @@
     # write the counts matrix of a specified layer to file
     if layer is None:
         # write the X numpy matrix transposed to file
-        df = pd.DataFrame(pdata.X.T, index=pdata.var.index, columns=obs_for_deseq.index)
+        df = pd.DataFrame(
+            pdata.X.T, index=pdata.var.index, columns=obs_for_deseq.index
+        )
     else:
         df = pd.DataFrame(
-            pdata.layers[layer].T, index=pdata.var.index, columns=obs_for_deseq.index
+            pdata.layers[layer].T,
+            index=pdata.var.index,
+            columns=obs_for_deseq.index,
         )
     df.to_csv(f"{output_dir}counts_matrix.tsv", sep="\t", index_label="")
 
+    if factor_fields:
+        df_genes_ignore = genes_to_ignore_per_contrast_field(
+            count_matrix_df=df,
+            samples_metadata=obs_for_deseq,
+            sample_metadata_col_contrasts=factor_fields[0],
+            min_counts_per_sample=min_counts_per_sample_marking,
+        )
+        df_genes_ignore.to_csv(
+            f"{output_dir}genes_to_ignore_per_contrast_field.tsv", sep="\t"
+        )
+
 
 def plot_pseudobulk_samples(
     pseudobulk_data,
@@ -89,7 +206,9 @@
     >>> adata = sc.datasets.pbmc68k_reduced()
     >>> adata.X = abs(adata.X).astype(int)
     >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain")
-    >>> plot_pseudobulk_samples(pseudobulk, groupby=["bulk_labels", "louvain"], figsize=(10, 10))
+    >>> plot_pseudobulk_samples(pseudobulk,
+    ...                         groupby=["bulk_labels", "louvain"],
+    ...                         figsize=(10, 10))
     """
     fig = decoupler.plot_psbulk_samples(
         pseudobulk_data, groupby=groupby, figsize=figsize, return_fig=True
@@ -101,14 +220,19 @@
 
 
 def plot_filter_by_expr(
-    pseudobulk_data, group, min_count=None, min_total_count=None, save_path=None
+    pseudobulk_data,
+    group,
+    min_count=None,
+    min_total_count=None,
+    save_path=None,
 ):
     """
     >>> import scanpy as sc
     >>> adata = sc.datasets.pbmc68k_reduced()
     >>> adata.X = abs(adata.X).astype(int)
     >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain")
-    >>> plot_filter_by_expr(pseudobulk, group="bulk_labels", min_count=10, min_total_count=200)
+    >>> plot_filter_by_expr(pseudobulk, group="bulk_labels",
+    ...                     min_count=10, min_total_count=200)
     """
     fig = decoupler.plot_filter_by_expr(
         pseudobulk_data,
@@ -129,7 +253,8 @@
     >>> adata = sc.datasets.pbmc68k_reduced()
     >>> adata.X = abs(adata.X).astype(int)
     >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain")
-    >>> pdata_filt = filter_by_expr(pseudobulk, min_count=10, min_total_count=200)
+    >>> pdata_filt = filter_by_expr(pseudobulk,
+    ...                             min_count=10, min_total_count=200)
     """
     genes = decoupler.filter_by_expr(
         pdata, min_count=min_count, min_total_count=min_total_count
@@ -150,12 +275,16 @@
     if obs:
         if not set(fields).issubset(set(adata.obs.columns)):
             raise ValueError(
-                f"Some of the following fields {legend} are not present in adata.obs: {fields}. Possible fields are: {list(set(adata.obs.columns))}"
+                f"Some of the following fields {legend} are not present \
+                    in adata.obs: {fields}. \
+                        Possible fields are: {list(set(adata.obs.columns))}"
             )
     else:
         if not set(fields).issubset(set(adata.var.columns)):
             raise ValueError(
-                f"Some of the following fields {legend} are not present in adata.var: {fields}. Possible fields are: {list(set(adata.var.columns))}"
+                f"Some of the following fields {legend} are not present \
+                    in adata.var: {fields}. \
+                        Possible fields are: {list(set(adata.var.columns))}"
             )
 
 
@@ -219,10 +348,15 @@
 
     # Save the pseudobulk data
     if args.anndata_output_path:
-        pseudobulk_data.write_h5ad(args.anndata_output_path, compression="gzip")
+        pseudobulk_data.write_h5ad(
+            args.anndata_output_path, compression="gzip"
+        )
 
     write_DESeq2_inputs(
-        pseudobulk_data, output_dir=args.deseq2_output_path, factor_fields=factor_fields
+        pseudobulk_data,
+        output_dir=args.deseq2_output_path,
+        factor_fields=factor_fields,
+        min_counts_per_sample_marking=args.min_counts_per_sample_marking,
     )
 
 
@@ -254,7 +388,9 @@
     field_name = "_".join(obs_fields_to_merge)
     for field in obs_fields_to_merge:
         if field not in adata.obs.columns:
-            raise ValueError(f"The '{field}' column is not present in adata.obs.")
+            raise ValueError(
+                f"The '{field}' column is not present in adata.obs."
+            )
         if field_name not in adata.obs.columns:
             adata.obs[field_name] = adata.obs[field].astype(str)
         else:
@@ -271,12 +407,16 @@
     )
 
     # Add arguments
-    parser.add_argument("adata_file", type=str, help="Path to the AnnData file")
+    parser.add_argument(
+        "adata_file", type=str, help="Path to the AnnData file"
+    )
     parser.add_argument(
         "-m",
         "--adata_obs_fields_to_merge",
         type=str,
-        help="Fields in adata.obs to merge, comma separated. You can have more than one set of fields, separated by semi-colon ;",
+        help="Fields in adata.obs to merge, comma separated. \
+            You can have more than one set of fields, \
+                separated by semi-colon ;",
     )
     parser.add_argument(
         "--groupby",
@@ -328,6 +468,13 @@
         help="Minimum count threshold for filtering by expression",
     )
     parser.add_argument(
+        "--min_counts_per_sample_marking",
+        type=int,
+        default=20,
+        help="Minimum count threshold per sample for \
+            marking genes to be ignored after DE",
+    )
+    parser.add_argument(
         "--min_total_counts",
         type=int,
         help="Minimum total count threshold for filtering by expression",
@@ -338,7 +485,9 @@
         help="Path to save the filtered AnnData object or pseudobulk data",
     )
     parser.add_argument(
-        "--filter_expr", action="store_true", help="Enable filtering by expression"
+        "--filter_expr",
+        action="store_true",
+        help="Enable filtering by expression",
     )
     parser.add_argument(
         "--factor_fields",
@@ -358,7 +507,9 @@
         nargs=2,
         help="Size of the samples plot as a tuple (two arguments)",
     )
-    parser.add_argument("--plot_filtering_figsize", type=int, default=[10, 10], nargs=2)
+    parser.add_argument(
+        "--plot_filtering_figsize", type=int, default=[10, 10], nargs=2
+    )
 
     # Parse the command line arguments
     args = parser.parse_args()