view main.py @ 4:e7b4afedc471 draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
author bgruening
date Tue, 11 Feb 2025 10:14:12 +0000
parents 33d53eb476fd
children 49b4ee0d0965
line wrap: on
line source

"""
Tabular data prediction using TabPFN
"""
import argparse
import time

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import (
    average_precision_score,
    precision_recall_curve,
    r2_score,
    root_mean_squared_error,
)
from sklearn.preprocessing import label_binarize
from tabpfn import TabPFNClassifier, TabPFNRegressor


def separate_features_labels(data):
    df = pd.read_csv(data, sep="\t")
    labels = df.iloc[:, -1]
    features = df.iloc[:, :-1]
    return features, labels


def classification_plot(y_true, y_scores):
    plt.figure(figsize=(8, 6))
    is_binary = len(np.unique(y_true)) == 2
    if is_binary:
        # Compute precision-recall curve
        precision, recall, _ = precision_recall_curve(y_true, y_scores[:, 1])
        average_precision = average_precision_score(y_true, y_scores[:, 1])
        plt.plot(
            recall,
            precision,
            label=f"Precision-Recall Curve (AP={average_precision:.2f})",
        )
        plt.title("Precision-Recall Curve (binary classification)")
    else:
        y_true_bin = label_binarize(y_true, classes=np.unique(y_true))
        n_classes = y_true_bin.shape[1]
        class_labels = [f"Class {i}" for i in range(n_classes)]
        # Plot PR curve for each class
        for i in range(n_classes):
            precision, recall, _ = precision_recall_curve(
                y_true_bin[:, i], y_scores[:, i]
            )
            ap_score = average_precision_score(y_true_bin[:, i], y_scores[:, i])
            plt.plot(
                recall, precision, label=f"{class_labels[i]} (AP = {ap_score:.2f})"
            )
        # Compute micro-average PR curve
        precision, recall, _ = precision_recall_curve(
            y_true_bin.ravel(), y_scores.ravel()
        )
        plt.plot(
            recall, precision, linestyle="--", color="black", label="Micro-average"
        )
        plt.title("Precision-Recall Curve (Multiclass Classification)")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.legend(loc="lower left")
    plt.grid(True)
    plt.savefig("output_plot.png")


def regression_plot(xval, yval, title, xlabel, ylabel):
    plt.figure(figsize=(8, 6))
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)
    plt.legend(loc="lower left")
    plt.grid(True)
    plt.scatter(xval, yval, alpha=0.8)
    xticks = np.arange(len(xval))
    plt.plot(xticks, xticks, color="red", linestyle="--", label="y = x")
    plt.savefig("output_plot.png")


def train_evaluate(args):
    """
    Train TabPFN and predict
    """
    # prepare train data
    tr_features, tr_labels = separate_features_labels(args["train_data"])
    # prepare test data
    if args["testhaslabels"] == "haslabels":
        te_features, te_labels = separate_features_labels(args["test_data"])
    else:
        te_features = pd.read_csv(args["test_data"], sep="\t")
        te_labels = []
    s_time = time.time()
    if args["selected_task"] == "Classification":
        classifier = TabPFNClassifier()
        classifier.fit(tr_features, tr_labels)
        y_eval = classifier.predict(te_features)
        pred_probas_test = classifier.predict_proba(te_features)
        if len(te_labels) > 0:
            classification_plot(te_labels, pred_probas_test)
    else:
        regressor = TabPFNRegressor()
        regressor.fit(tr_features, tr_labels)
        y_eval = regressor.predict(te_features)
        if len(te_labels) > 0:
            score = root_mean_squared_error(te_labels, y_eval)
            r2_metric_score = r2_score(te_labels, y_eval)
            regression_plot(
                te_labels,
                y_eval,
                f"Scatter plot: True vs predicted values. RMSE={score:.2f}, R2={r2_metric_score:.2f}",
                "True values",
                "Predicted values",
            )
    e_time = time.time()
    print(
        "Time taken by TabPFN for training and prediction: {} seconds".format(
            e_time - s_time
        )
    )
    te_features["predicted_labels"] = y_eval
    te_features.to_csv("output_predicted_data", sep="\t", index=None)


if __name__ == "__main__":
    arg_parser = argparse.ArgumentParser()
    arg_parser.add_argument("-trdata", "--train_data", required=True, help="Train data")
    arg_parser.add_argument("-tedata", "--test_data", required=True, help="Test data")
    arg_parser.add_argument(
        "-testhaslabels",
        "--testhaslabels",
        required=True,
        help="if test data contain labels",
    )
    arg_parser.add_argument(
        "-selectedtask",
        "--selected_task",
        required=True,
        help="Type of machine learning task",
    )
    # get argument values
    args = vars(arg_parser.parse_args())
    train_evaluate(args)