diff celltypist_CLI.py @ 0:a7d6985ba791 draft

planemo upload for repository https://github.com/ebi-gene-expression-group/container-galaxy-sc-tertiary/ commit 53651c0be9daeebb0921f3b5e542323304dfdc98-dirty
author ebi-gxa
date Fri, 14 Feb 2025 11:56:48 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/celltypist_CLI.py	Fri Feb 14 11:56:48 2025 +0000
@@ -0,0 +1,633 @@
+import argparse
+import os
+import pickle
+import sys
+import textwrap
+
+import celltypist
+import scanpy as sc
+
+
+def train(
+    input_path,
+    labels,
+    genes,
+    model_name,
+    output_path,
+    transpose_input,
+    normalize,
+    raw_counts_layer,
+    gene_symbols_field,
+    epochs,
+    max_iter,
+    C,
+    n_jobs,
+    use_SGD,
+    alpha,
+    use_GPU,
+    mini_batch,
+    batch_number,
+    batch_size,
+    balance_cell_type,
+    feature_selection,
+    top_genes,
+):
+    """
+    Method used to train a CellTypist model and output the model to a desired
+    location. Please refer to the CLI help to gain a better understanding of
+    the required and optional parameters.
+    """
+    if input_path.endswith(".h5ad"):
+        print("Detected an AnnData object as input.")
+        data = sc.read_h5ad(input_path)
+    elif input_path.endswith(".mtx"):
+        print("Detected .mtx file as input.")
+        data = input_path
+    else:
+        raise ValueError("Invalid input file type detected.")
+    if input_path.endswith(".mtx") and genes is None:
+        raise ValueError("Missing a gene file for the provided data.")
+    if (
+        raw_counts_layer
+        and input_path.endswith(".h5ad")
+        and raw_counts_layer in data.layers.keys()
+        and raw_counts_layer != "X"
+    ):
+        print(f"Using raw counts layer: {raw_counts_layer}")
+        data.X = data.layers[raw_counts_layer]
+    elif raw_counts_layer and input_path.endswith(".h5ad"):
+        raise ValueError(
+            f"Raw counts layer {raw_counts_layer} \
+                         should be either different to 'X' or an \
+                         existing layer in AnnData"
+        )
+    elif raw_counts_layer:
+        raise ValueError(
+            f"Raw counts layer {raw_counts_layer} provided but \
+                         the data format provided is .mtx. Please provide an \
+                         AnnData object."
+        )
+    if (
+        gene_symbols_field
+        and input_path.endswith(".h5ad")
+        and gene_symbols_field in data.var.columns
+    ):
+        data.var["old_index"] = data.var.index
+        data.var.set_index(gene_symbols_field, inplace=True)
+    if normalize:
+        sc.pp.normalize_total(data, target_sum=1e4)
+        sc.pp.log1p(data)
+
+    model = celltypist.train(
+        data,
+        labels=labels,
+        genes=genes,
+        transpose_input=transpose_input,
+        C=C,
+        epochs=epochs,
+        max_iter=max_iter,
+        n_jobs=n_jobs,
+        use_SGD=use_SGD,
+        alpha=alpha,
+        use_GPU=use_GPU,
+        mini_batch=mini_batch,
+        batch_number=batch_number,
+        batch_size=batch_size,
+        balance_cell_type=balance_cell_type,
+        feature_selection=feature_selection,
+        top_genes=top_genes,
+    )
+
+    model_save_location = output_path + "/" + model_name + ".pkl"
+
+    return pickle.dump(model, open(model_save_location, "wb"))
+
+
+def predict(
+    input_path,
+    model,
+    output_path,
+    gene_file,
+    cell_file,
+    raw_counts_layer,
+    gene_symbols_field,
+    normalize,
+    transpose_input,
+    mode,
+    p_thres,
+    majority_voting,
+    over_clustering,
+    use_GPU,
+):
+    """
+    Method used to obtain cell type prediction on a new dataset using an \
+        existing CellTypist model.
+    Please refer to the CLI help to gain a better understanding of the \
+        required and optional
+    parameters.
+    """
+    if input_path.endswith(".h5ad"):
+        print("Detected an AnnData object as input.")
+        data = sc.read_h5ad(input_path)
+    elif input_path.endswith(".mtx"):
+        print("Detected .mtx file as input.")
+        data = input_path
+    else:
+        raise ValueError("Invalid input file type detected.")
+    if input_path.endswith(".mtx") and gene_file is None:
+        raise ValueError("Missing a gene file for the provided .mtx data.")
+    if input_path.endswith(".mtx") and cell_file is None:
+        raise ValueError("Missing a cell file for the provided .mtx data.")
+    if (
+        raw_counts_layer
+        and input_path.endswith(".h5ad")
+        and raw_counts_layer in data.layers.keys()
+        and raw_counts_layer != "X"
+    ):
+        print(f"Using raw counts layer: {raw_counts_layer}")
+        data.X = data.layers[raw_counts_layer]
+    elif raw_counts_layer and input_path.endswith(".h5ad"):
+        raise ValueError(
+            f"Raw counts layer {raw_counts_layer} should be \
+                         either different to 'X' or an existing layer in \
+                         AnnData"
+        )
+    elif raw_counts_layer:
+        raise ValueError(
+            f"Raw counts layer {raw_counts_layer} provided but \
+                         the data format is .mtx. Please provide an AnnData \
+                         object."
+        )
+    if (
+        gene_symbols_field
+        and input_path.endswith(".h5ad")
+        and gene_symbols_field in data.var.columns
+    ):
+        data.var["old_index"] = data.var.index
+        data.var.set_index(gene_symbols_field, inplace=True)
+    if normalize:
+        sc.pp.normalize_total(data, target_sum=1e4)
+        sc.pp.log1p(data)
+    if mode:
+        mode = mode.replace("_", " ")
+    if majority_voting:
+        if over_clustering:
+            over_clustering = over_clustering
+        else:
+            over_clustering = None
+
+    predictions = celltypist.annotate(
+        data,
+        model,
+        gene_file=gene_file,
+        cell_file=cell_file,
+        mode=mode,
+        p_thres=p_thres,
+        transpose_input=transpose_input,
+        majority_voting=majority_voting,
+        over_clustering=over_clustering,
+        use_GPU=use_GPU,
+    )
+
+    return predictions.predicted_labels.to_csv(output_path, sep="\t")
+
+
+def main():
+    parser = argparse.ArgumentParser(
+        add_help=False,
+        description=textwrap.dedent(
+            """
+        CellTypist for automatic cell type annotation.
+        ----------------------------------------------
+        Welcome to the ODS CLI tool for cell type
+        annotation using CellTypist.
+        The CLI provides the user with the opportunity
+        to train a new model or predict using an
+        existing model.
+        Either of the methods can be activated by
+        providing the appropriate action variable.
+        For details of required/optional parameters
+        for each method, please refer to the
+        appropriate help function (-h).
+                                """
+        ),
+        formatter_class=argparse.RawTextHelpFormatter,
+    )
+
+    choices_to = {
+        "train": "⚖️  Train a CellTypist model using a reference \
+                  dataset.",
+        "predict": "🖋️  Use an existing CellTypist model to predict \
+                    on a new dataset.",
+    }
+
+    parser.add_argument(
+        "--action",
+        type=str,
+        help="\n".join(
+            "{}: {}".format(key, value) for key, value in choices_to.items()
+        ),
+        choices=choices_to,
+    )
+
+    parser.add_argument(
+        "--help",
+        action="store_true",
+        help="List available options.",
+        required=False,
+    )
+
+    args, sub_args = parser.parse_known_args()
+    if args.help:
+        if args.action is None:
+            print(parser.format_help())
+            sys.exit(1)
+        sub_args.append("--help")
+
+    action = "predict" if args.action is None else args.action
+
+    parser = argparse.ArgumentParser(
+        prog="%s %s" % (os.path.basename(sys.argv[0]), action)
+    )
+
+    if action == "predict":
+        parser.add_argument(
+            "--input_path",
+            type=str,
+            help="Path to the input file. \
+                  Can be an AnnData object or a .mtx file.",
+            required=True,
+        )
+
+        parser.add_argument(
+            "--output_path",
+            type=str,
+            help="Path to output file.",
+            required=True,
+        )
+
+        parser.add_argument(
+            "--model",
+            type=str,
+            help="Path to model file, stored in .pkl format.",
+            required=True,
+        )
+
+        parser.add_argument(
+            "--gene_file",
+            type=str,
+            help=(
+                "Path to the file which stores each gene \
+                                  per line corresponding to the genes used \
+                                  in the provided .mtx file."
+            ),
+            required=False,
+            default=None,
+        )
+
+        parser.add_argument(
+            "--cell_file",
+            type=str,
+            help="Path to the file which stores each cell \
+                                  per line corresponding to the cells used \
+                                  in the provided .mtx file.",
+            required=False,
+            default=None,
+        )
+
+        parser.add_argument(
+            "--raw_counts_layer",
+            type=str,
+            help="The name of the layer that stores the raw \
+                                  counts. Uses default matrix if not present.",
+            required=False,
+            default=None,
+        )
+
+        parser.add_argument(
+            "--gene_symbols_field",
+            type=str,
+            help="The field in AnnData where the gene symbols \
+                                  are stored, if not in index.",
+            required=False,
+            default=None,
+        )
+
+        parser.add_argument(
+            "--normalize",
+            action="store_true",
+            help="If raw counts are provided in the AnnData \
+                                  object, they need to be normalized.",
+            required=False,
+        )
+
+        parser.add_argument(
+            "--transpose_input",
+            action="store_true",
+            help="If the provided matrix is in the \
+                                  gene-by-cell format, please transpose the \
+                                  input to cell-by-gene format",
+            required=False,
+        )
+
+        parser.add_argument(
+            "--mode",
+            type=str,
+            choices=["best_match", "prob_match"],
+            help="Choose the cell type with the largest score/probability \
+                  as the final prediction (`best_match`), or enable a \
+                  multi-label classification (`prob_match`), which assigns \
+                  0 (i.e., unassigned), 1, or >=2 cell type labels to each \
+                  query cell. [default: best_match]",
+            required=False,
+            default="best_match",
+        )
+
+        parser.add_argument(
+            "--p_thres",
+            type=float,
+            help="Probability threshold for assigning a cell \
+                                  type in a multiclass problem, defaults \
+                                  to 0.5.",
+            default=0.5,
+            required=False,
+        )
+
+        parser.add_argument(
+            "--majority_voting",
+            action="store_true",
+            help="Refine the predicted labels by running the \
+                                  majority voting classifier after \
+                                  over-clustering.",
+            required=False,
+        )
+
+        parser.add_argument(
+            "--over_clustering",
+            type=str,
+            help="If majority voting is set to True, specify \
+                                  the type of over clustering that is to be \
+                                  performend. This can be specified in the \
+                                  AnnData or an input file specifying the \
+                                  over-clustering per cell. If not present, \
+                                  then the default heuristic over-clustring \
+                                  based on input data will be used.",
+            required=False,
+            default=None,
+        )
+
+        parser.add_argument(
+            "--use_GPU",
+            action="store_true",
+            help="Whether to use GPU for over clustering \
+                  on the basis of `rapids-singlecell`.",
+            required=False,
+        )
+
+        args = parser.parse_args(sub_args)
+
+        predict(
+            args.input_path,
+            args.model,
+            args.output_path,
+            args.gene_file,
+            args.cell_file,
+            args.raw_counts_layer,
+            args.gene_symbols_field,
+            args.normalize,
+            args.transpose_input,
+            args.mode,
+            args.p_thres,
+            args.majority_voting,
+            args.over_clustering,
+            args.use_GPU,
+        )
+    else:
+        parser.add_argument(
+            "--input_path",
+            type=str,
+            help="Path to the input file. Can be an\
+                  AnnData object or a .mtx file.",
+            required=True,
+        )
+
+        parser.add_argument(
+            "--labels",
+            type=str,
+            help="Path to the file that stores the per cell types\
+                  or the layer in the AnnData object where the\
+                      cell types are held.",
+            required=True,
+        )
+
+        parser.add_argument(
+            "--genes",
+            type=str,
+            help="Path to the file containing one gene per line corresponding \
+                to the genes in X, required for .mtx data.",
+            required=False,
+            default=None,
+        )
+
+        parser.add_argument(
+            "--model_name",
+            type=str,
+            help="The name of the trained model, used for saving.",
+            required=True,
+        )
+
+        parser.add_argument(
+            "--output_path",
+            type=str,
+            help="The location to where the trained model should be saved.",
+            required=True,
+        )
+
+        parser.add_argument(
+            "--transpose_input",
+            action="store_true",
+            help="If the provided matrix is in the gene-by-cell format, \
+                please transpose the input to cell-by-gene format.",
+            required=False,
+        )
+
+        parser.add_argument(
+            "--normalize",
+            action="store_true",
+            help="If raw counts are provided in the AnnData object, \
+                they need to be normalized.",
+            required=False,
+        )
+
+        parser.add_argument(
+            "--gene_symbols_field",
+            type=str,
+            help="The field in AnnData where the gene symbols are stored, \
+                if not in index.",
+            required=False,
+            default=None,
+        )
+
+        parser.add_argument(
+            "--raw_counts_layer",
+            type=str,
+            help="The name of the layer that stores the raw counts. \
+                Uses default matrix if not present",
+            required=False,
+            default=None,
+        )
+
+        parser.add_argument(
+            "--epochs",
+            type=int,
+            help="The number of epochs for which the model \
+                needs to be trained.",
+            required=True,
+            default=10,
+        )
+
+        parser.add_argument(
+            "--solver",
+            type=str,
+            help="Algorithm to use in the optimization problem for \
+                traditional logistic classifier. Default is based \
+                    on the size of the input data.",
+            required=False,
+        )
+
+        parser.add_argument(
+            "--max_iter",
+            type=int,
+            help="Maximum number of iterations before reaching \
+                the minimum of the cost function.",
+            required=True,
+            default=100,
+        )
+
+        parser.add_argument(
+            "--C",
+            type=float,
+            help="Inverse of L2 regularization strength for \
+                traditional logistic classifier.",
+            required=False,
+            default=1.0,
+        )
+
+        parser.add_argument(
+            "--n_jobs",
+            type=int,
+            help="Number of CPUs used.",
+            required=False,
+            default=1,
+        )
+
+        parser.add_argument(
+            "--use_SGD",
+            type=bool,
+            help="Whether to implement SGD learning \
+                for the logistic classifier.",
+            required=False,
+            default=False,
+        )
+
+        parser.add_argument(
+            "--alpha",
+            type=float,
+            help="L2 regularization strength for SGD logistic classifier.",
+            required=False,
+            default=0.0001,
+        )
+
+        parser.add_argument(
+            "--use_GPU",
+            type=bool,
+            help="Whether to use GPU for logistic classifier.",
+            required=False,
+            default=False,
+        )
+
+        parser.add_argument(
+            "--mini_batch",
+            type=bool,
+            help="Whether to implement mini-batch training \
+                for the SGD logistic classifier.",
+            required=False,
+            default=False,
+        )
+
+        parser.add_argument(
+            "--batch_number",
+            type=int,
+            help="The number of batches used for training in each epoch. \
+                Each batch contains batch_size cells.",
+            required=False,
+            default=100,
+        )
+
+        parser.add_argument(
+            "--batch_size",
+            type=int,
+            help="The number of cells within each batch.",
+            required=False,
+            default=1000,
+        )
+
+        parser.add_argument(
+            "--balance_cell_type",
+            type=bool,
+            help="Whether to balance the cell type frequencies \
+                in mini-batches during each epoch.",
+            required=False,
+            default=False,
+        )
+
+        parser.add_argument(
+            "--feature_selection",
+            type=bool,
+            help="Whether to perform two-pass data training where \
+                the first round is used for selecting important \
+                    features/genes using SGD learning.",
+            required=False,
+            default=False,
+        )
+
+        parser.add_argument(
+            "--top_genes",
+            type=int,
+            help="The number of top genes selected from each \
+                class/cell-type based on their absolute \
+                    regression coefficients.",
+            required=False,
+            default=300,
+        )
+
+        args = parser.parse_args(sub_args)
+
+        train(
+            args.input_path,
+            args.labels,
+            args.genes,
+            args.model_name,
+            args.output_path,
+            args.transpose_input,
+            args.normalize,
+            args.raw_counts_layer,
+            args.gene_symbols_field,
+            args.epochs,
+            args.max_iter,
+            args.C,
+            args.n_jobs,
+            args.use_SGD,
+            args.alpha,
+            args.use_GPU,
+            args.mini_batch,
+            args.batch_number,
+            args.batch_size,
+            args.balance_cell_type,
+            args.feature_selection,
+            args.top_genes,
+        )
+
+
+if __name__ == "__main__":
+    main()