comparison rmsd_clustering.py @ 0:ee1f38eb220e draft

"planemo upload for repository https://github.com/galaxycomputationalchemistry/galaxy-tools-compchem/ commit 1b23e024af45cc0999d9142d07de6897d4189ec2"
author chemteam
date Mon, 24 Aug 2020 06:08:17 -0400
parents
children b9c46dbe9605
comparison
equal deleted inserted replaced
-1:000000000000 0:ee1f38eb220e
1 import argparse
2 import json
3
4 import matplotlib.pyplot as plt
5
6 import numpy as np
7
8 from scipy.cluster.hierarchy import cophenet, dendrogram, linkage
9 from scipy.spatial.distance import pdist
10
11
12 def json_to_np(fname, start=None, end=None):
13 """
14 Load json file and convert to numpy array
15 """
16 with open(fname) as f:
17 k = json.load(f)
18 print(np.array(k)[:, :, start:end].shape)
19 return np.array(k)[:, :, start:end]
20
21
22 def flatten_tensor(tensor, normalize=True):
23 """
24 Flatten tensor to a 2D matrix along the time axis
25 """
26 av = np.mean(tensor, axis=(0, 1)) if normalize else 1
27 return np.mean(tensor/av, axis=2)
28
29
30 def get_cluster_linkage_array(mat, clustering_method='average'):
31 Z = linkage(mat, clustering_method)
32 c, coph_dists = cophenet(Z, pdist(mat))
33 print('Cophenetic correlation coefficient: {}'.format(c))
34 return Z
35
36
37 def plot_dist_mat(mat, output, cmap='plasma'):
38 """
39 Plot distance matrix as heatmap
40 """
41 fig, ax = plt.subplots(1)
42 p = ax.pcolormesh(mat, cmap=cmap)
43 plt.xlabel('Trajectory number')
44 plt.ylabel('Trajectory number')
45 plt.colorbar(p)
46 plt.draw()
47 plt.savefig(output, format='png')
48
49
50 def plot_dendrogram(Z, output):
51 plt.figure(figsize=(25, 10))
52 plt.title('Hierarchical Clustering Dendrogram')
53 plt.xlabel('Trajectory index')
54 plt.ylabel('distance')
55 dendrogram(
56 Z,
57 leaf_rotation=90., # rotates the x axis labels
58 leaf_font_size=8., # font size for the x axis labels
59 )
60 plt.savefig(output, format='png')
61
62
63 def main():
64 parser = argparse.ArgumentParser()
65 parser.add_argument('--json', help='JSON input file (for 3D tensor).')
66 parser.add_argument('--mat', help='Input tabular file (for 2D matrix).')
67 parser.add_argument('--outp-mat', help='Tabular output file.')
68 parser.add_argument('--Z', required=True,
69 help='File for cluster linkage array.')
70 parser.add_argument('--dendrogram',
71 help="Path to the output dendrogram file")
72 parser.add_argument('--heatmap',
73 help="Path to the output distance matrix file")
74 parser.add_argument('--clustering-method', default='average',
75 choices=['single', 'complete', 'average',
76 'centroid', 'median', 'ward', 'weighted'],
77 help="Method to use for clustering.")
78 parser.add_argument('--cmap', type=str, default='plasma',
79 help="Matplotlib colormap to use"
80 "for plotting distance matrix.")
81 parser.add_argument('--start', type=int,
82 help="First trajectory frame to"
83 "calculate distance matrix")
84 parser.add_argument('--end', type=int,
85 help="Last trajectory frame to"
86 "calculate distance matrix")
87 parser.add_argument('--normalize', action="store_true",
88 help="Normalize the RMSD variation over"
89 "the trajectories before averaging.")
90 args = parser.parse_args()
91
92 print(args)
93 if args.json:
94 tensor = json_to_np(args.json, args.start, args.end)
95 mat = flatten_tensor(tensor, args.normalize)
96 np.savetxt(args.outp_mat, mat)
97 elif args.mat:
98 mat = np.loadtxt(args.mat)
99 else:
100 print("Either --json or --mat must be specified.")
101 exit(1)
102
103 Z = get_cluster_linkage_array(mat, args.clustering_method)
104 np.savetxt(args.Z, Z)
105
106 if args.heatmap:
107 plot_dist_mat(mat, args.heatmap, args.cmap)
108
109 if args.dendrogram:
110 plot_dendrogram(Z, args.dendrogram)
111
112
113 if __name__ == "__main__":
114 main()