diff kmd_hmdb_plot_generator.py @ 0:59c8bad5f6bc draft default tip

planemo upload for repository https://github.com/workflow4metabolomics/tools-metabolomics/blob/master/tools/kmd_hmdb_data_plot/ commit 7fa454b6a4268b89fe18043e8dd10f30a7b4c7ca
author workflow4metabolomics
date Tue, 29 Aug 2023 09:45:16 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/kmd_hmdb_plot_generator.py	Tue Aug 29 09:45:16 2023 +0000
@@ -0,0 +1,174 @@
+#!/usr/bin/env python3
+
+import csv
+import itertools
+import os
+
+import click
+
+import plotly.express
+import plotly.graph_objects
+
+__version__ = "1.0.0"
+
+
+@click.group()
+def cli():
+    pass
+
+
+@cli.command(help="")
+@click.option(
+    "--version",
+    is_flag=True,
+    default=False,
+)
+@click.option(
+    "--input",
+    default="./test.tsv",
+    help="Provide the mz-ratio."
+)
+@click.option(
+    "--output",
+    default="./test.html",
+    help="Provide the database."
+)
+@click.option(
+    "--x-column",
+    default=["nominal_mass"],
+    multiple=True,
+    help="Provide the column names for the X axis.",
+)
+@click.option(
+    "--y-column",
+    default=["kendricks_mass_defect"],
+    multiple=True,
+    help="Provide the column names for the Y axis.",
+)
+@click.option(
+    "--annotation-column",
+    multiple=True,
+    default=[
+        "metabolite_name",
+        "chemical_formula",
+    ],
+    help="Provide the columns name for the annotation."
+)
+def plot(*args, **kwargs):
+
+    if kwargs.pop("version"):
+        print(__version__)
+        exit(0)
+
+    input_path = kwargs.pop("input")
+    data = read_input(input_path, kwargs)
+    fig = build_fig(*data)
+    build_html_plot(fig, kwargs.get("output"))
+
+
+def read_input(path: str, kwargs: {}):
+    if not os.path.exists(path):
+        raise ValueError(f"The path '{path}' does not exist.")
+    sep = detect_sep(path)
+    with open(path) as csv_file:
+        line_generator = csv.reader(csv_file, delimiter=sep)
+        first_line = next(line_generator)
+        all_lines = list(line_generator)
+        hover_names = (
+            "metabolite_name",
+            "chemical_formula",
+        )
+        annotation_indexes = get_index_of(first_line, hover_names)
+        (
+            x_index,
+            y_index,
+            x_column,
+            y_column,
+        ) = get_indexes_names(
+            first_line,
+            list(kwargs.get("x_column")),
+            list(kwargs.get("y_column")),
+        )
+        x_lists = [[] for i in range(len(x_index))]
+        y_lists = [[] for i in range(len(y_index))]
+        x_column = list(map(first_line.__getitem__, x_index))
+        y_column = list(map(first_line.__getitem__, y_index))
+        trace_names = [
+            f"f({x_column[i]}) = {y_column[i]}"
+            for i in range(len(x_index))
+        ]
+        hover_names = kwargs["annotation_column"]
+        annotation_indexes = [
+            get_index_of(first_line, column)[0]
+            for column in hover_names
+        ]
+        hover_names = list(map(first_line.__getitem__, annotation_indexes))
+        annotations = list()
+        for line in all_lines:
+            for i in range(len(x_index)):
+                x_lists[i].append(float(line[x_index[i]]))
+                y_lists[i].append(float(line[y_index[i]]))
+            annotations.append("<br>".join(
+                f"{hover_names[hover_index]}: {line[index]}"
+                for hover_index, index in enumerate(annotation_indexes)
+            ))
+    return x_lists, y_lists, annotations, trace_names
+
+
+def get_indexes_names(first_line, x_column, y_column):
+    x_column, y_column = map(list, zip(*itertools.product(x_column, y_column)))
+    x_index = get_index_of(first_line, x_column)
+    y_index = get_index_of(first_line, y_column)
+    for i in range(len(x_index))[::-1]:
+        if x_index[i] == y_index[i]:
+            del x_index[i], x_column[i], y_index[i], y_column[i],
+    return (
+        x_index,
+        y_index,
+        x_column,
+        y_column,
+    )
+
+
+def get_index_of(first_line, column):
+    if isinstance(column, (tuple, list)):
+        return [get_index_of(first_line, x)[0] for x in list(column)]
+    try:
+        return [int(column) - 1]
+    except ValueError:
+        return [first_line.index(column)]
+
+
+def build_fig(x_lists, y_lists, annotations, trace_names):
+    fig = plotly.express.scatter()
+    for i in range(len(x_lists)):
+        fig.add_trace(
+            plotly.graph_objects.Scatter(
+                name=trace_names[i],
+                x=x_lists[i],
+                y=y_lists[i],
+                hovertext=annotations,
+                mode="markers",
+            )
+        )
+    return fig
+
+
+def detect_sep(tabular_file: str) -> str:
+    with open(tabular_file, "r") as file:
+        first_line = file.readline()
+    if len(first_line.split(',')) > len(first_line.split('\t')):
+        return ','
+    return '\t'
+
+
+def build_html_plot(fig, output: str):
+    return plotly.offline.plot(
+        fig,
+        filename=output,
+        auto_open=False,
+    )
+
+
+if __name__ == "__main__":
+    cli()