Mercurial > repos > muon-spectroscopy-computational-project > larch_plot
comparison larch_plot.py @ 4:35d24102cefd draft
planemo upload for repository https://github.com/MaterialsGalaxy/larch-tools/tree/main/larch_plot commit 3fe6078868efd0fcea0fb5eea8dcd4b152d9c0a8
| author | muon-spectroscopy-computational-project |
|---|---|
| date | Thu, 11 Apr 2024 09:02:24 +0000 |
| parents | 5b993aff09e3 |
| children |
comparison
equal
deleted
inserted
replaced
| 3:5b993aff09e3 | 4:35d24102cefd |
|---|---|
| 1 import json | 1 import json |
| 2 import sys | 2 import sys |
| 3 | 3 |
| 4 from common import read_groups | 4 from common import read_groups |
| 5 | |
| 6 from larch.symboltable import Group | |
| 5 | 7 |
| 6 import matplotlib | 8 import matplotlib |
| 7 import matplotlib.pyplot as plt | 9 import matplotlib.pyplot as plt |
| 8 | 10 |
| 9 import numpy as np | 11 import numpy as np |
| 10 | 12 |
| 11 | 13 |
| 12 AXIS_LABELS = { | 14 AXIS_LABELS = { |
| 15 "energy": "Energy (eV)", | |
| 16 "distance": "r (ang)", | |
| 17 "sample": "Sample", | |
| 13 "flat": r"x$\mu$(E), flattened", | 18 "flat": r"x$\mu$(E), flattened", |
| 14 "dmude": r"d(x$\mu$(E))/dE, normalised", | 19 "dmude": r"d(x$\mu$(E))/dE, normalised", |
| 15 "chir_mag": r"|$\chi$(r)|", | 20 "chir_mag": r"|$\chi$(r)|", |
| 16 "energy": "Energy (eV)", | 21 "e0": "Edge Energy (eV)", |
| 17 "distance": "r (ang)", | |
| 18 } | 22 } |
| 23 | |
| 24 | |
| 25 def sample_plot(groups: "list[Group]", y_variable: str): | |
| 26 x = [get_label(group) for group in groups] | |
| 27 y = [getattr(group, y_variable) for group in groups] | |
| 28 plt.scatter(x, y) | |
| 19 | 29 |
| 20 | 30 |
| 21 def main(dat_files: "list[str]", plot_settings: "list[dict]"): | 31 def main(dat_files: "list[str]", plot_settings: "list[dict]"): |
| 22 groups = list(read_groups(dat_files)) | 32 groups = list(read_groups(dat_files)) |
| 23 | 33 |
| 24 for i, settings in enumerate(plot_settings): | 34 for i, settings in enumerate(plot_settings): |
| 25 data_list = [] | 35 data_list = [] |
| 26 x_variable = "energy" | |
| 27 y_variable = settings["variable"]["variable"] | 36 y_variable = settings["variable"]["variable"] |
| 28 x_min = settings["variable"]["x_limit_min"] | |
| 29 x_max = settings["variable"]["x_limit_max"] | |
| 30 y_min = settings["variable"]["y_limit_min"] | |
| 31 y_max = settings["variable"]["y_limit_max"] | |
| 32 plot_path = f"plots/{i}_{y_variable}.png" | 37 plot_path = f"plots/{i}_{y_variable}.png" |
| 33 plt.figure() | 38 plt.figure() |
| 34 | 39 |
| 35 for group in groups: | 40 if y_variable == "e0": |
| 36 params = group.athena_params | 41 x_variable = "sample" |
| 37 annotation = getattr(params, "annotation", None) | 42 sample_plot(groups, y_variable) |
| 38 file = getattr(params, "file", None) | 43 else: |
| 39 params_id = getattr(params, "id", None) | 44 x_variable = "energy" |
| 40 label = annotation or file or params_id | 45 x_min = settings["variable"]["x_limit_min"] |
| 41 if y_variable == "chir_mag": | 46 x_max = settings["variable"]["x_limit_max"] |
| 42 x_variable = "distance" | 47 y_min = settings["variable"]["y_limit_min"] |
| 43 x = group.r | 48 y_max = settings["variable"]["y_limit_max"] |
| 44 else: | 49 for group in groups: |
| 45 x = group.energy | 50 label = get_label(group) |
| 51 if y_variable == "chir_mag": | |
| 52 x_variable = "distance" | |
| 53 x = group.r | |
| 54 else: | |
| 55 x = group.energy | |
| 46 | 56 |
| 47 y = getattr(group, y_variable) | 57 y = getattr(group, y_variable) |
| 48 if x_min is None and x_max is None: | 58 if x_min is None and x_max is None: |
| 49 plt.plot(x, y, label=label) | 59 plt.plot(x, y, label=label) |
| 50 else: | 60 else: |
| 51 data_list.append({"x": x, "y": y, "label": label}) | 61 data_list.append({"x": x, "y": y, "label": label}) |
| 52 | 62 |
| 53 if x_min is not None or x_max is not None: | 63 if x_min is not None or x_max is not None: |
| 54 for data in data_list: | 64 for data in data_list: |
| 55 index_min = None | 65 index_min = None |
| 56 index_max = None | 66 index_max = None |
| 57 x = data["x"] | 67 x = data["x"] |
| 58 if x_min is not None: | 68 if x_min is not None: |
| 59 index_min = max(np.searchsorted(x, x_min) - 1, 0) | 69 index_min = max(np.searchsorted(x, x_min) - 1, 0) |
| 60 if x_max is not None: | 70 if x_max is not None: |
| 61 index_max = min(np.searchsorted(x, x_max) + 1, len(x)) | 71 index_max = min(np.searchsorted(x, x_max) + 1, len(x)) |
| 62 plt.plot( | 72 plt.plot( |
| 63 x[index_min:index_max], | 73 x[index_min:index_max], |
| 64 data["y"][index_min:index_max], | 74 data["y"][index_min:index_max], |
| 65 label=data["label"], | 75 label=data["label"], |
| 66 ) | 76 ) |
| 67 | 77 |
| 68 plt.xlim(x_min, x_max) | 78 plt.xlim(x_min, x_max) |
| 69 plt.ylim(y_min, y_max) | 79 plt.ylim(y_min, y_max) |
| 70 | 80 |
| 71 save_plot(x_variable, y_variable, plot_path) | 81 save_plot(x_variable, y_variable, plot_path) |
| 82 | |
| 83 | |
| 84 def get_label(group: Group) -> str: | |
| 85 params = group.athena_params | |
| 86 annotation = getattr(params, "annotation", None) | |
| 87 file = getattr(params, "file", None) | |
| 88 params_id = getattr(params, "id", None) | |
| 89 label = annotation or file or params_id | |
| 90 return label | |
| 72 | 91 |
| 73 | 92 |
| 74 def save_plot(x_type: str, y_type: str, plot_path: str): | 93 def save_plot(x_type: str, y_type: str, plot_path: str): |
| 75 plt.grid(color="r", linestyle=":", linewidth=1) | 94 plt.grid(color="r", linestyle=":", linewidth=1) |
| 76 plt.xlabel(AXIS_LABELS[x_type]) | 95 plt.xlabel(AXIS_LABELS[x_type]) |
