Mercurial > repos > muon-spectroscopy-computational-project > larch_athena
comparison larch_athena.py @ 3:82e9dd980916 draft
planemo upload for repository https://github.com/MaterialsGalaxy/larch-tools/tree/main/larch_athena commit d4c7e090dc5c94395d7e1574845ac2c76f2e4f5f
| author | muon-spectroscopy-computational-project |
|---|---|
| date | Fri, 22 Mar 2024 14:23:27 +0000 |
| parents | a1e26990131c |
| children | a0d3b0fe0fa3 |
comparison
equal
deleted
inserted
replaced
| 2:a1e26990131c | 3:82e9dd980916 |
|---|---|
| 3 import os | 3 import os |
| 4 import re | 4 import re |
| 5 import sys | 5 import sys |
| 6 | 6 |
| 7 from common import ( | 7 from common import ( |
| 8 pre_edge_with_defaults, read_all_groups, read_group, xftf_with_defaults | 8 pre_edge_with_defaults, |
| 9 read_all_groups, | |
| 10 read_group, | |
| 11 xftf_with_defaults, | |
| 9 ) | 12 ) |
| 10 | 13 |
| 11 from larch.io import ( | 14 from larch.io import ( |
| 12 create_athena, | 15 create_athena, |
| 13 h5group, | 16 h5group, |
| 43 merge_inputs: bool, | 46 merge_inputs: bool, |
| 44 is_zipped: bool, | 47 is_zipped: bool, |
| 45 ) -> "dict[str, Group]": | 48 ) -> "dict[str, Group]": |
| 46 if merge_inputs: | 49 if merge_inputs: |
| 47 out_group = self.merge_files( | 50 out_group = self.merge_files( |
| 48 dat_files=dat_file, is_zipped=is_zipped | 51 dat_files=dat_file, |
| 52 is_zipped=is_zipped, | |
| 49 ) | 53 ) |
| 50 return {"out": out_group} | 54 return {"out": out_group} |
| 51 else: | 55 else: |
| 52 return self.load_single_file( | 56 return self.load_single_file( |
| 53 filepath=dat_file, is_zipped=is_zipped | 57 filepath=dat_file, |
| 58 is_zipped=is_zipped, | |
| 54 ) | 59 ) |
| 55 | 60 |
| 56 def merge_files( | 61 def merge_files( |
| 57 self, | 62 self, |
| 58 dat_files: str, | 63 dat_files: str, |
| 256 do_rebin: bool, | 261 do_rebin: bool, |
| 257 do_pre_edge: bool, | 262 do_pre_edge: bool, |
| 258 pre_edge_settings: dict, | 263 pre_edge_settings: dict, |
| 259 do_xftf: bool, | 264 do_xftf: bool, |
| 260 xftf_settings: dict, | 265 xftf_settings: dict, |
| 261 plot_graph: bool, | 266 plot_graph: list, |
| 262 annotation: str, | 267 annotation: str, |
| 263 path_key: str = "out", | 268 path_key: str = "out", |
| 264 ): | 269 ): |
| 265 if do_calibrate: | 270 if do_calibrate: |
| 266 print(f"Calibrating energy with {calibrate_settings}") | 271 print(f"Calibrating energy with {calibrate_settings}") |
| 285 | 290 |
| 286 if do_xftf: | 291 if do_xftf: |
| 287 xftf_with_defaults(xas_data, xftf_settings) | 292 xftf_with_defaults(xas_data, xftf_settings) |
| 288 | 293 |
| 289 if plot_graph: | 294 if plot_graph: |
| 290 plot_edge_fits(f"edge/{path_key}.png", xas_data) | 295 plot_graphs( |
| 291 plot_flattened(f"flat/{path_key}.png", xas_data) | 296 plot_path=f"plot/{path_key}.png", |
| 292 plot_derivative(f"derivative/{path_key}.png", xas_data) | 297 xas_data=xas_data, |
| 298 plot_keys=plot_graph, | |
| 299 ) | |
| 293 | 300 |
| 294 xas_project = create_athena(f"prj/{path_key}.prj") | 301 xas_project = create_athena(f"prj/{path_key}.prj") |
| 295 xas_project.add_group(xas_data) | 302 xas_project.add_group(xas_data) |
| 296 if annotation: | 303 if annotation: |
| 297 group = next(iter(xas_project.groups.values())) | 304 group = next(iter(xas_project.groups.values())) |
| 300 | 307 |
| 301 # Ensure that we do not run out of memory when running on large zips | 308 # Ensure that we do not run out of memory when running on large zips |
| 302 gc.collect() | 309 gc.collect() |
| 303 | 310 |
| 304 | 311 |
| 305 def plot_derivative(plot_path: str, xafs_group: Group): | 312 def plot_graphs( |
| 306 plt.figure() | 313 plot_path: str, |
| 307 plt.plot(xafs_group.energy, xafs_group.dmude) | 314 xas_data: Group, |
| 308 plt.grid(color="r", linestyle=":", linewidth=1) | 315 plot_keys: list, |
| 309 plt.xlabel("Energy (eV)") | 316 ) -> None: |
| 310 plt.ylabel("Derivative normalised to x$\mu$(E)") # noqa: W605 | 317 nrows = len(plot_keys) |
| 311 plt.savefig(plot_path, format="png") | 318 index = 1 |
| 312 plt.close("all") | 319 plt.figure(figsize=(6.4, nrows * 4.8)) |
| 313 | 320 if "edge" in plot_keys: |
| 314 | 321 plt.subplot(nrows, 1, index) |
| 315 def plot_edge_fits(plot_path: str, xafs_group: Group): | 322 plt.plot(xas_data.energy, xas_data.pre_edge, "g", label="pre-edge") |
| 316 plt.figure() | 323 plt.plot(xas_data.energy, xas_data.post_edge, "r", label="post-edge") |
| 317 plt.plot(xafs_group.energy, xafs_group.pre_edge, "g", label="pre-edge") | 324 plt.plot(xas_data.energy, xas_data.mu, "b", label="fit data") |
| 318 plt.plot(xafs_group.energy, xafs_group.post_edge, "r", label="post-edge") | 325 plt.grid(color="r", linestyle=":", linewidth=1) |
| 319 plt.plot(xafs_group.energy, xafs_group.mu, "b", label="fit data") | 326 plt.xlabel("Energy (eV)") |
| 320 plt.grid(color="r", linestyle=":", linewidth=1) | 327 plt.ylabel("x$\mu$(E)") # noqa: W605 |
| 321 plt.xlabel("Energy (eV)") | 328 plt.title("Pre-edge and post_edge fitting to $\mu$") # noqa: W605 |
| 322 plt.ylabel("x$\mu$(E)") # noqa: W605 | 329 plt.legend() |
| 323 plt.title("pre-edge and post_edge fitting to $\mu$") # noqa: W605 | 330 index += 1 |
| 324 plt.legend() | 331 |
| 325 plt.savefig(plot_path, format="png") | 332 if "flat" in plot_keys: |
| 326 plt.close("all") | 333 plt.subplot(nrows, 1, index) |
| 327 | 334 plt.plot(xas_data.energy, xas_data.flat) |
| 328 | 335 plt.grid(color="r", linestyle=":", linewidth=1) |
| 329 def plot_flattened(plot_path: str, xafs_group: Group): | 336 plt.xlabel("Energy (eV)") |
| 330 plt.figure() | 337 plt.ylabel("Flattened x$\mu$(E)") # noqa: W605 |
| 331 plt.plot(xafs_group.energy, xafs_group.flat) | 338 index += 1 |
| 332 plt.grid(color="r", linestyle=":", linewidth=1) | 339 |
| 333 plt.xlabel("Energy (eV)") | 340 if "dmude" in plot_keys: |
| 334 plt.ylabel("normalised x$\mu$(E)") # noqa: W605 | 341 plt.subplot(nrows, 1, index) |
| 342 plt.plot(xas_data.energy, xas_data.dmude) | |
| 343 plt.grid(color="r", linestyle=":", linewidth=1) | |
| 344 plt.xlabel("Energy (eV)") | |
| 345 plt.ylabel("Derivative normalised to x$\mu$(E)") # noqa: W605 | |
| 346 index += 1 | |
| 347 | |
| 348 plt.tight_layout(rect=(0, 0, 0.88, 1)) | |
| 335 plt.savefig(plot_path, format="png") | 349 plt.savefig(plot_path, format="png") |
| 336 plt.close("all") | 350 plt.close("all") |
| 337 | 351 |
| 338 | 352 |
| 339 if __name__ == "__main__": | 353 if __name__ == "__main__": |
