Mercurial > repos > recetox > matchms_formatter
comparison formatter.py @ 0:60f34912b3de draft
"planemo upload for repository https://github.com/RECETOX/galaxytools/tree/master/tools/matchms commit 4d2ac914c951166e386a94d8ebb8cb1becfac122"
author | recetox |
---|---|
date | Tue, 22 Mar 2022 16:08:45 +0000 |
parents | |
children | 574c6331e9db |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:60f34912b3de |
---|---|
1 import click | |
2 from pandas import DataFrame, read_csv | |
3 | |
4 | |
5 def create_long_table(data: DataFrame, value_id: str) -> DataFrame: | |
6 """Convert the table from compact into long format. | |
7 See DataFrame.melt(...). | |
8 | |
9 Args: | |
10 data (DataFrame): The data table to convert. | |
11 value_id (str): The name to assign to the added column through conversion to long format. | |
12 | |
13 Returns: | |
14 DataFrame: Table in long format. | |
15 """ | |
16 return data.transpose().melt(ignore_index=False, var_name='compound', value_name=value_id) | |
17 | |
18 | |
19 def join_df(x: DataFrame, y: DataFrame, on=[], how="inner") -> DataFrame: | |
20 """Shortcut functions to join to dataframes on columns and index | |
21 | |
22 Args: | |
23 x (DataFrame): Table X | |
24 y (DataFrame): Table Y | |
25 on (list, optional): Columns on which to join. Defaults to []. | |
26 how (str, optional): Join method, see DataFrame.join(...). Defaults to "inner". | |
27 | |
28 Returns: | |
29 DataFrame: Joined dataframe. | |
30 """ | |
31 df_x = x.set_index([x.index] + on) | |
32 df_y = y.set_index([y.index] + on) | |
33 combined = df_x.join(df_y, how=how) | |
34 return combined | |
35 | |
36 | |
37 def get_top_k_matches(data: DataFrame, k: int) -> DataFrame: | |
38 """Function to get top k matches from dataframe with scores. | |
39 | |
40 Args: | |
41 data (DataFrame): A table with score column. | |
42 k (int): Number of top scores to retrieve. | |
43 | |
44 Returns: | |
45 DataFrame: Table containing only the top k best matches for each compound. | |
46 """ | |
47 return data.groupby(level=0, group_keys=False).apply(DataFrame.nlargest, n=k, columns=['score']) | |
48 | |
49 | |
50 def filter_thresholds(data: DataFrame, t_score: float, t_matches: float) -> DataFrame: | |
51 """Filter a dataframe with scores and matches to only contain values above specified thresholds. | |
52 | |
53 Args: | |
54 data (DataFrame): Table to filter. | |
55 t_score (float): Score threshold. | |
56 t_matches (float): Matches threshold. | |
57 | |
58 Returns: | |
59 DataFrame: Filtered dataframe. | |
60 """ | |
61 filtered = data[data['score'] > t_score] | |
62 filtered = filtered[filtered['matches'] > t_matches] | |
63 return filtered | |
64 | |
65 | |
66 def load_data(scores_filename: str, matches_filename: str) -> DataFrame: | |
67 """Load data from filenames and join on compound id. | |
68 | |
69 Args: | |
70 scores_filename (str): Path to scores table. | |
71 matches_filename (str): Path to matches table. | |
72 | |
73 Returns: | |
74 DataFrame: Joined dataframe on compounds containing scores an matches in long format. | |
75 """ | |
76 matches = read_csv(matches_filename, sep=None, index_col=0) | |
77 scores = read_csv(scores_filename, sep=None, index_col=0) | |
78 | |
79 scores_long = create_long_table(scores, 'score') | |
80 matches_long = create_long_table(matches, 'matches') | |
81 | |
82 combined = join_df(matches_long, scores_long, on=['compound'], how='inner') | |
83 return combined | |
84 | |
85 | |
86 @click.group() | |
87 @click.option('--sf', 'scores_filename', type=click.Path(exists=True), required=True) | |
88 @click.option('--mf', 'matches_filename', type=click.Path(exists=True), required=True) | |
89 @click.option('--o', 'output_filename', type=click.Path(writable=True), required=True) | |
90 @click.pass_context | |
91 def cli(ctx, scores_filename, matches_filename, output_filename): | |
92 ctx.ensure_object(dict) | |
93 ctx.obj['data'] = load_data(scores_filename, matches_filename) | |
94 pass | |
95 | |
96 | |
97 @cli.command() | |
98 @click.option('--st', 'scores_threshold', type=float, required=True) | |
99 @click.option('--mt', 'matches_threshold', type=float, required=True) | |
100 @click.pass_context | |
101 def get_thresholded_data(ctx, scores_threshold, matches_threshold): | |
102 result = filter_thresholds(ctx.obj['data'], scores_threshold, matches_threshold) | |
103 return result | |
104 | |
105 | |
106 @cli.command() | |
107 @click.option('--k', 'k', type=int, required=True) | |
108 @click.pass_context | |
109 def get_top_k_data(ctx, k): | |
110 result = get_top_k_matches(ctx.obj['data'], k) | |
111 return result | |
112 | |
113 | |
114 @cli.resultcallback() | |
115 def write_output(result: DataFrame, scores_filename, matches_filename, output_filename): | |
116 input_file = read_csv(scores_filename, sep=None, iterator=True) | |
117 sep = input_file._engine.data.dialect.delimiter | |
118 | |
119 result = result.reset_index().rename(columns={'level_0': 'query', 'compound': 'reference'}) | |
120 result.to_csv(output_filename, sep=sep, index=False) | |
121 | |
122 | |
123 if __name__ == '__main__': | |
124 cli(obj={}) |