Mercurial > repos > chemteam > biomd_rmsd_clustering
comparison rmsd_clustering.py @ 0:ee1f38eb220e draft
"planemo upload for repository https://github.com/galaxycomputationalchemistry/galaxy-tools-compchem/ commit 1b23e024af45cc0999d9142d07de6897d4189ec2"
author | chemteam |
---|---|
date | Mon, 24 Aug 2020 06:08:17 -0400 |
parents | |
children | b9c46dbe9605 |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:ee1f38eb220e |
---|---|
1 import argparse | |
2 import json | |
3 | |
4 import matplotlib.pyplot as plt | |
5 | |
6 import numpy as np | |
7 | |
8 from scipy.cluster.hierarchy import cophenet, dendrogram, linkage | |
9 from scipy.spatial.distance import pdist | |
10 | |
11 | |
12 def json_to_np(fname, start=None, end=None): | |
13 """ | |
14 Load json file and convert to numpy array | |
15 """ | |
16 with open(fname) as f: | |
17 k = json.load(f) | |
18 print(np.array(k)[:, :, start:end].shape) | |
19 return np.array(k)[:, :, start:end] | |
20 | |
21 | |
22 def flatten_tensor(tensor, normalize=True): | |
23 """ | |
24 Flatten tensor to a 2D matrix along the time axis | |
25 """ | |
26 av = np.mean(tensor, axis=(0, 1)) if normalize else 1 | |
27 return np.mean(tensor/av, axis=2) | |
28 | |
29 | |
30 def get_cluster_linkage_array(mat, clustering_method='average'): | |
31 Z = linkage(mat, clustering_method) | |
32 c, coph_dists = cophenet(Z, pdist(mat)) | |
33 print('Cophenetic correlation coefficient: {}'.format(c)) | |
34 return Z | |
35 | |
36 | |
37 def plot_dist_mat(mat, output, cmap='plasma'): | |
38 """ | |
39 Plot distance matrix as heatmap | |
40 """ | |
41 fig, ax = plt.subplots(1) | |
42 p = ax.pcolormesh(mat, cmap=cmap) | |
43 plt.xlabel('Trajectory number') | |
44 plt.ylabel('Trajectory number') | |
45 plt.colorbar(p) | |
46 plt.draw() | |
47 plt.savefig(output, format='png') | |
48 | |
49 | |
50 def plot_dendrogram(Z, output): | |
51 plt.figure(figsize=(25, 10)) | |
52 plt.title('Hierarchical Clustering Dendrogram') | |
53 plt.xlabel('Trajectory index') | |
54 plt.ylabel('distance') | |
55 dendrogram( | |
56 Z, | |
57 leaf_rotation=90., # rotates the x axis labels | |
58 leaf_font_size=8., # font size for the x axis labels | |
59 ) | |
60 plt.savefig(output, format='png') | |
61 | |
62 | |
63 def main(): | |
64 parser = argparse.ArgumentParser() | |
65 parser.add_argument('--json', help='JSON input file (for 3D tensor).') | |
66 parser.add_argument('--mat', help='Input tabular file (for 2D matrix).') | |
67 parser.add_argument('--outp-mat', help='Tabular output file.') | |
68 parser.add_argument('--Z', required=True, | |
69 help='File for cluster linkage array.') | |
70 parser.add_argument('--dendrogram', | |
71 help="Path to the output dendrogram file") | |
72 parser.add_argument('--heatmap', | |
73 help="Path to the output distance matrix file") | |
74 parser.add_argument('--clustering-method', default='average', | |
75 choices=['single', 'complete', 'average', | |
76 'centroid', 'median', 'ward', 'weighted'], | |
77 help="Method to use for clustering.") | |
78 parser.add_argument('--cmap', type=str, default='plasma', | |
79 help="Matplotlib colormap to use" | |
80 "for plotting distance matrix.") | |
81 parser.add_argument('--start', type=int, | |
82 help="First trajectory frame to" | |
83 "calculate distance matrix") | |
84 parser.add_argument('--end', type=int, | |
85 help="Last trajectory frame to" | |
86 "calculate distance matrix") | |
87 parser.add_argument('--normalize', action="store_true", | |
88 help="Normalize the RMSD variation over" | |
89 "the trajectories before averaging.") | |
90 args = parser.parse_args() | |
91 | |
92 print(args) | |
93 if args.json: | |
94 tensor = json_to_np(args.json, args.start, args.end) | |
95 mat = flatten_tensor(tensor, args.normalize) | |
96 np.savetxt(args.outp_mat, mat) | |
97 elif args.mat: | |
98 mat = np.loadtxt(args.mat) | |
99 else: | |
100 print("Either --json or --mat must be specified.") | |
101 exit(1) | |
102 | |
103 Z = get_cluster_linkage_array(mat, args.clustering_method) | |
104 np.savetxt(args.Z, Z) | |
105 | |
106 if args.heatmap: | |
107 plot_dist_mat(mat, args.heatmap, args.cmap) | |
108 | |
109 if args.dendrogram: | |
110 plot_dendrogram(Z, args.dendrogram) | |
111 | |
112 | |
113 if __name__ == "__main__": | |
114 main() |