view celltypist_CLI.py @ 1:df005630100e draft default tip

planemo upload for repository https://github.com/ebi-gene-expression-group/container-galaxy-sc-tertiary/ commit 4dd95e3e6c059cc9df47f5d08e4f3f8b618830f1-dirty
author ebi-gxa
date Wed, 19 Feb 2025 12:10:17 +0000
parents a7d6985ba791
children
line wrap: on
line source

import argparse
import os
import pickle
import sys
import textwrap

import celltypist
import scanpy as sc


def train(
    input_path,
    labels,
    genes,
    model_name,
    output_path,
    transpose_input,
    normalize,
    raw_counts_layer,
    gene_symbols_field,
    epochs,
    max_iter,
    C,
    n_jobs,
    use_SGD,
    alpha,
    use_GPU,
    mini_batch,
    batch_number,
    batch_size,
    balance_cell_type,
    feature_selection,
    top_genes,
):
    """
    Method used to train a CellTypist model and output the model to a desired
    location. Please refer to the CLI help to gain a better understanding of
    the required and optional parameters.
    """
    if input_path.endswith(".h5ad"):
        print("Detected an AnnData object as input.")
        data = sc.read_h5ad(input_path)
    elif input_path.endswith(".mtx"):
        print("Detected .mtx file as input.")
        data = input_path
    else:
        raise ValueError("Invalid input file type detected.")
    if input_path.endswith(".mtx") and genes is None:
        raise ValueError("Missing a gene file for the provided data.")
    if (
        raw_counts_layer
        and input_path.endswith(".h5ad")
        and raw_counts_layer in data.layers.keys()
        and raw_counts_layer != "X"
    ):
        print(f"Using raw counts layer: {raw_counts_layer}")
        data.X = data.layers[raw_counts_layer]
    elif raw_counts_layer and input_path.endswith(".h5ad"):
        raise ValueError(
            f"Raw counts layer {raw_counts_layer} \
                         should be either different to 'X' or an \
                         existing layer in AnnData"
        )
    elif raw_counts_layer:
        raise ValueError(
            f"Raw counts layer {raw_counts_layer} provided but \
                         the data format provided is .mtx. Please provide an \
                         AnnData object."
        )
    if (
        gene_symbols_field
        and input_path.endswith(".h5ad")
        and gene_symbols_field in data.var.columns
    ):
        data.var["old_index"] = data.var.index
        data.var.set_index(gene_symbols_field, inplace=True)
    if normalize:
        sc.pp.normalize_total(data, target_sum=1e4)
        sc.pp.log1p(data)

    model = celltypist.train(
        data,
        labels=labels,
        genes=genes,
        transpose_input=transpose_input,
        C=C,
        epochs=epochs,
        max_iter=max_iter,
        n_jobs=n_jobs,
        use_SGD=use_SGD,
        alpha=alpha,
        use_GPU=use_GPU,
        mini_batch=mini_batch,
        batch_number=batch_number,
        batch_size=batch_size,
        balance_cell_type=balance_cell_type,
        feature_selection=feature_selection,
        top_genes=top_genes,
    )

    model_save_location = output_path + "/" + model_name + ".pkl"

    return pickle.dump(model, open(model_save_location, "wb"))


def predict(
    input_path,
    model,
    output_path,
    gene_file,
    cell_file,
    raw_counts_layer,
    gene_symbols_field,
    normalize,
    transpose_input,
    mode,
    p_thres,
    majority_voting,
    over_clustering,
    use_GPU,
):
    """
    Method used to obtain cell type prediction on a new dataset using an \
        existing CellTypist model.
    Please refer to the CLI help to gain a better understanding of the \
        required and optional
    parameters.
    """
    if input_path.endswith(".h5ad"):
        print("Detected an AnnData object as input.")
        data = sc.read_h5ad(input_path)
    elif input_path.endswith(".mtx"):
        print("Detected .mtx file as input.")
        data = input_path
    else:
        raise ValueError("Invalid input file type detected.")
    if input_path.endswith(".mtx") and gene_file is None:
        raise ValueError("Missing a gene file for the provided .mtx data.")
    if input_path.endswith(".mtx") and cell_file is None:
        raise ValueError("Missing a cell file for the provided .mtx data.")
    if (
        raw_counts_layer
        and input_path.endswith(".h5ad")
        and raw_counts_layer in data.layers.keys()
        and raw_counts_layer != "X"
    ):
        print(f"Using raw counts layer: {raw_counts_layer}")
        data.X = data.layers[raw_counts_layer]
    elif raw_counts_layer and input_path.endswith(".h5ad"):
        raise ValueError(
            f"Raw counts layer {raw_counts_layer} should be \
                         either different to 'X' or an existing layer in \
                         AnnData"
        )
    elif raw_counts_layer:
        raise ValueError(
            f"Raw counts layer {raw_counts_layer} provided but \
                         the data format is .mtx. Please provide an AnnData \
                         object."
        )
    if (
        gene_symbols_field
        and input_path.endswith(".h5ad")
        and gene_symbols_field in data.var.columns
    ):
        data.var["old_index"] = data.var.index
        data.var.set_index(gene_symbols_field, inplace=True)
    if normalize:
        sc.pp.normalize_total(data, target_sum=1e4)
        sc.pp.log1p(data)
    if mode:
        mode = mode.replace("_", " ")
    if majority_voting:
        if over_clustering:
            over_clustering = over_clustering
        else:
            over_clustering = None

    predictions = celltypist.annotate(
        data,
        model,
        gene_file=gene_file,
        cell_file=cell_file,
        mode=mode,
        p_thres=p_thres,
        transpose_input=transpose_input,
        majority_voting=majority_voting,
        over_clustering=over_clustering,
        use_GPU=use_GPU,
    )

    return predictions.predicted_labels.to_csv(output_path, sep="\t")


def main():
    parser = argparse.ArgumentParser(
        add_help=False,
        description=textwrap.dedent(
            """
        CellTypist for automatic cell type annotation.
        ----------------------------------------------
        Welcome to the ODS CLI tool for cell type
        annotation using CellTypist.
        The CLI provides the user with the opportunity
        to train a new model or predict using an
        existing model.
        Either of the methods can be activated by
        providing the appropriate action variable.
        For details of required/optional parameters
        for each method, please refer to the
        appropriate help function (-h).
                                """
        ),
        formatter_class=argparse.RawTextHelpFormatter,
    )

    choices_to = {
        "train": "⚖️  Train a CellTypist model using a reference \
                  dataset.",
        "predict": "🖋️  Use an existing CellTypist model to predict \
                    on a new dataset.",
    }

    parser.add_argument(
        "--action",
        type=str,
        help="\n".join(
            "{}: {}".format(key, value) for key, value in choices_to.items()
        ),
        choices=choices_to,
    )

    parser.add_argument(
        "--help",
        action="store_true",
        help="List available options.",
        required=False,
    )

    args, sub_args = parser.parse_known_args()
    if args.help:
        if args.action is None:
            print(parser.format_help())
            sys.exit(1)
        sub_args.append("--help")

    action = "predict" if args.action is None else args.action

    parser = argparse.ArgumentParser(
        prog="%s %s" % (os.path.basename(sys.argv[0]), action)
    )

    if action == "predict":
        parser.add_argument(
            "--input_path",
            type=str,
            help="Path to the input file. \
                  Can be an AnnData object or a .mtx file.",
            required=True,
        )

        parser.add_argument(
            "--output_path",
            type=str,
            help="Path to output file.",
            required=True,
        )

        parser.add_argument(
            "--model",
            type=str,
            help="Path to model file, stored in .pkl format.",
            required=True,
        )

        parser.add_argument(
            "--gene_file",
            type=str,
            help=(
                "Path to the file which stores each gene \
                                  per line corresponding to the genes used \
                                  in the provided .mtx file."
            ),
            required=False,
            default=None,
        )

        parser.add_argument(
            "--cell_file",
            type=str,
            help="Path to the file which stores each cell \
                                  per line corresponding to the cells used \
                                  in the provided .mtx file.",
            required=False,
            default=None,
        )

        parser.add_argument(
            "--raw_counts_layer",
            type=str,
            help="The name of the layer that stores the raw \
                                  counts. Uses default matrix if not present.",
            required=False,
            default=None,
        )

        parser.add_argument(
            "--gene_symbols_field",
            type=str,
            help="The field in AnnData where the gene symbols \
                                  are stored, if not in index.",
            required=False,
            default=None,
        )

        parser.add_argument(
            "--normalize",
            action="store_true",
            help="If raw counts are provided in the AnnData \
                                  object, they need to be normalized.",
            required=False,
        )

        parser.add_argument(
            "--transpose_input",
            action="store_true",
            help="If the provided matrix is in the \
                                  gene-by-cell format, please transpose the \
                                  input to cell-by-gene format",
            required=False,
        )

        parser.add_argument(
            "--mode",
            type=str,
            choices=["best_match", "prob_match"],
            help="Choose the cell type with the largest score/probability \
                  as the final prediction (`best_match`), or enable a \
                  multi-label classification (`prob_match`), which assigns \
                  0 (i.e., unassigned), 1, or >=2 cell type labels to each \
                  query cell. [default: best_match]",
            required=False,
            default="best_match",
        )

        parser.add_argument(
            "--p_thres",
            type=float,
            help="Probability threshold for assigning a cell \
                                  type in a multiclass problem, defaults \
                                  to 0.5.",
            default=0.5,
            required=False,
        )

        parser.add_argument(
            "--majority_voting",
            action="store_true",
            help="Refine the predicted labels by running the \
                                  majority voting classifier after \
                                  over-clustering.",
            required=False,
        )

        parser.add_argument(
            "--over_clustering",
            type=str,
            help="If majority voting is set to True, specify \
                                  the type of over clustering that is to be \
                                  performend. This can be specified in the \
                                  AnnData or an input file specifying the \
                                  over-clustering per cell. If not present, \
                                  then the default heuristic over-clustring \
                                  based on input data will be used.",
            required=False,
            default=None,
        )

        parser.add_argument(
            "--use_GPU",
            action="store_true",
            help="Whether to use GPU for over clustering \
                  on the basis of `rapids-singlecell`.",
            required=False,
        )

        args = parser.parse_args(sub_args)

        predict(
            args.input_path,
            args.model,
            args.output_path,
            args.gene_file,
            args.cell_file,
            args.raw_counts_layer,
            args.gene_symbols_field,
            args.normalize,
            args.transpose_input,
            args.mode,
            args.p_thres,
            args.majority_voting,
            args.over_clustering,
            args.use_GPU,
        )
    else:
        parser.add_argument(
            "--input_path",
            type=str,
            help="Path to the input file. Can be an\
                  AnnData object or a .mtx file.",
            required=True,
        )

        parser.add_argument(
            "--labels",
            type=str,
            help="Path to the file that stores the per cell types\
                  or the layer in the AnnData object where the\
                      cell types are held.",
            required=True,
        )

        parser.add_argument(
            "--genes",
            type=str,
            help="Path to the file containing one gene per line corresponding \
                to the genes in X, required for .mtx data.",
            required=False,
            default=None,
        )

        parser.add_argument(
            "--model_name",
            type=str,
            help="The name of the trained model, used for saving.",
            required=True,
        )

        parser.add_argument(
            "--output_path",
            type=str,
            help="The location to where the trained model should be saved.",
            required=True,
        )

        parser.add_argument(
            "--transpose_input",
            action="store_true",
            help="If the provided matrix is in the gene-by-cell format, \
                please transpose the input to cell-by-gene format.",
            required=False,
        )

        parser.add_argument(
            "--normalize",
            action="store_true",
            help="If raw counts are provided in the AnnData object, \
                they need to be normalized.",
            required=False,
        )

        parser.add_argument(
            "--gene_symbols_field",
            type=str,
            help="The field in AnnData where the gene symbols are stored, \
                if not in index.",
            required=False,
            default=None,
        )

        parser.add_argument(
            "--raw_counts_layer",
            type=str,
            help="The name of the layer that stores the raw counts. \
                Uses default matrix if not present",
            required=False,
            default=None,
        )

        parser.add_argument(
            "--epochs",
            type=int,
            help="The number of epochs for which the model \
                needs to be trained.",
            required=True,
            default=10,
        )

        parser.add_argument(
            "--solver",
            type=str,
            help="Algorithm to use in the optimization problem for \
                traditional logistic classifier. Default is based \
                    on the size of the input data.",
            required=False,
        )

        parser.add_argument(
            "--max_iter",
            type=int,
            help="Maximum number of iterations before reaching \
                the minimum of the cost function.",
            required=True,
            default=100,
        )

        parser.add_argument(
            "--C",
            type=float,
            help="Inverse of L2 regularization strength for \
                traditional logistic classifier.",
            required=False,
            default=1.0,
        )

        parser.add_argument(
            "--n_jobs",
            type=int,
            help="Number of CPUs used.",
            required=False,
            default=1,
        )

        parser.add_argument(
            "--use_SGD",
            type=bool,
            help="Whether to implement SGD learning \
                for the logistic classifier.",
            required=False,
            default=False,
        )

        parser.add_argument(
            "--alpha",
            type=float,
            help="L2 regularization strength for SGD logistic classifier.",
            required=False,
            default=0.0001,
        )

        parser.add_argument(
            "--use_GPU",
            type=bool,
            help="Whether to use GPU for logistic classifier.",
            required=False,
            default=False,
        )

        parser.add_argument(
            "--mini_batch",
            type=bool,
            help="Whether to implement mini-batch training \
                for the SGD logistic classifier.",
            required=False,
            default=False,
        )

        parser.add_argument(
            "--batch_number",
            type=int,
            help="The number of batches used for training in each epoch. \
                Each batch contains batch_size cells.",
            required=False,
            default=100,
        )

        parser.add_argument(
            "--batch_size",
            type=int,
            help="The number of cells within each batch.",
            required=False,
            default=1000,
        )

        parser.add_argument(
            "--balance_cell_type",
            type=bool,
            help="Whether to balance the cell type frequencies \
                in mini-batches during each epoch.",
            required=False,
            default=False,
        )

        parser.add_argument(
            "--feature_selection",
            type=bool,
            help="Whether to perform two-pass data training where \
                the first round is used for selecting important \
                    features/genes using SGD learning.",
            required=False,
            default=False,
        )

        parser.add_argument(
            "--top_genes",
            type=int,
            help="The number of top genes selected from each \
                class/cell-type based on their absolute \
                    regression coefficients.",
            required=False,
            default=300,
        )

        args = parser.parse_args(sub_args)

        train(
            args.input_path,
            args.labels,
            args.genes,
            args.model_name,
            args.output_path,
            args.transpose_input,
            args.normalize,
            args.raw_counts_layer,
            args.gene_symbols_field,
            args.epochs,
            args.max_iter,
            args.C,
            args.n_jobs,
            args.use_SGD,
            args.alpha,
            args.use_GPU,
            args.mini_batch,
            args.batch_number,
            args.batch_size,
            args.balance_cell_type,
            args.feature_selection,
            args.top_genes,
        )


if __name__ == "__main__":
    main()