comparison kmd_hmdb_plot_generator.py @ 0:59c8bad5f6bc draft default tip

planemo upload for repository https://github.com/workflow4metabolomics/tools-metabolomics/blob/master/tools/kmd_hmdb_data_plot/ commit 7fa454b6a4268b89fe18043e8dd10f30a7b4c7ca
author workflow4metabolomics
date Tue, 29 Aug 2023 09:45:16 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:59c8bad5f6bc
1 #!/usr/bin/env python3
2
3 import csv
4 import itertools
5 import os
6
7 import click
8
9 import plotly.express
10 import plotly.graph_objects
11
12 __version__ = "1.0.0"
13
14
15 @click.group()
16 def cli():
17 pass
18
19
20 @cli.command(help="")
21 @click.option(
22 "--version",
23 is_flag=True,
24 default=False,
25 )
26 @click.option(
27 "--input",
28 default="./test.tsv",
29 help="Provide the mz-ratio."
30 )
31 @click.option(
32 "--output",
33 default="./test.html",
34 help="Provide the database."
35 )
36 @click.option(
37 "--x-column",
38 default=["nominal_mass"],
39 multiple=True,
40 help="Provide the column names for the X axis.",
41 )
42 @click.option(
43 "--y-column",
44 default=["kendricks_mass_defect"],
45 multiple=True,
46 help="Provide the column names for the Y axis.",
47 )
48 @click.option(
49 "--annotation-column",
50 multiple=True,
51 default=[
52 "metabolite_name",
53 "chemical_formula",
54 ],
55 help="Provide the columns name for the annotation."
56 )
57 def plot(*args, **kwargs):
58
59 if kwargs.pop("version"):
60 print(__version__)
61 exit(0)
62
63 input_path = kwargs.pop("input")
64 data = read_input(input_path, kwargs)
65 fig = build_fig(*data)
66 build_html_plot(fig, kwargs.get("output"))
67
68
69 def read_input(path: str, kwargs: {}):
70 if not os.path.exists(path):
71 raise ValueError(f"The path '{path}' does not exist.")
72 sep = detect_sep(path)
73 with open(path) as csv_file:
74 line_generator = csv.reader(csv_file, delimiter=sep)
75 first_line = next(line_generator)
76 all_lines = list(line_generator)
77 hover_names = (
78 "metabolite_name",
79 "chemical_formula",
80 )
81 annotation_indexes = get_index_of(first_line, hover_names)
82 (
83 x_index,
84 y_index,
85 x_column,
86 y_column,
87 ) = get_indexes_names(
88 first_line,
89 list(kwargs.get("x_column")),
90 list(kwargs.get("y_column")),
91 )
92 x_lists = [[] for i in range(len(x_index))]
93 y_lists = [[] for i in range(len(y_index))]
94 x_column = list(map(first_line.__getitem__, x_index))
95 y_column = list(map(first_line.__getitem__, y_index))
96 trace_names = [
97 f"f({x_column[i]}) = {y_column[i]}"
98 for i in range(len(x_index))
99 ]
100 hover_names = kwargs["annotation_column"]
101 annotation_indexes = [
102 get_index_of(first_line, column)[0]
103 for column in hover_names
104 ]
105 hover_names = list(map(first_line.__getitem__, annotation_indexes))
106 annotations = list()
107 for line in all_lines:
108 for i in range(len(x_index)):
109 x_lists[i].append(float(line[x_index[i]]))
110 y_lists[i].append(float(line[y_index[i]]))
111 annotations.append("<br>".join(
112 f"{hover_names[hover_index]}: {line[index]}"
113 for hover_index, index in enumerate(annotation_indexes)
114 ))
115 return x_lists, y_lists, annotations, trace_names
116
117
118 def get_indexes_names(first_line, x_column, y_column):
119 x_column, y_column = map(list, zip(*itertools.product(x_column, y_column)))
120 x_index = get_index_of(first_line, x_column)
121 y_index = get_index_of(first_line, y_column)
122 for i in range(len(x_index))[::-1]:
123 if x_index[i] == y_index[i]:
124 del x_index[i], x_column[i], y_index[i], y_column[i],
125 return (
126 x_index,
127 y_index,
128 x_column,
129 y_column,
130 )
131
132
133 def get_index_of(first_line, column):
134 if isinstance(column, (tuple, list)):
135 return [get_index_of(first_line, x)[0] for x in list(column)]
136 try:
137 return [int(column) - 1]
138 except ValueError:
139 return [first_line.index(column)]
140
141
142 def build_fig(x_lists, y_lists, annotations, trace_names):
143 fig = plotly.express.scatter()
144 for i in range(len(x_lists)):
145 fig.add_trace(
146 plotly.graph_objects.Scatter(
147 name=trace_names[i],
148 x=x_lists[i],
149 y=y_lists[i],
150 hovertext=annotations,
151 mode="markers",
152 )
153 )
154 return fig
155
156
157 def detect_sep(tabular_file: str) -> str:
158 with open(tabular_file, "r") as file:
159 first_line = file.readline()
160 if len(first_line.split(',')) > len(first_line.split('\t')):
161 return ','
162 return '\t'
163
164
165 def build_html_plot(fig, output: str):
166 return plotly.offline.plot(
167 fig,
168 filename=output,
169 auto_open=False,
170 )
171
172
173 if __name__ == "__main__":
174 cli()