diff light_curve.py @ 0:f40d05521dca draft default tip

planemo upload for repository https://github.com/esg-epfl-apc/tools-astro/tree/main/tools commit de01e3c02a26cd6353a6b9b6f8d1be44de8ccd54
author astroteam
date Fri, 25 Apr 2025 19:33:20 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/light_curve.py	Fri Apr 25 19:33:20 2025 +0000
@@ -0,0 +1,292 @@
+import os
+import sys
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+Mpc_in_h = 2.8590868063e10
+
+
+def weighted_quantile(
+    values, quantiles, sample_weight=None, values_sorted=False, old_style=False
+):
+    """Very close to numpy.percentile, but supports weights.
+    NOTE: quantiles should be in [0, 1]!
+    :param values: numpy.array with data
+    :param quantiles: array-like with many quantiles needed
+    :param sample_weight: array-like of the same length as `array`
+    :param values_sorted: bool, if True, then will avoid sorting of initial array
+    :param old_style: if True, will correct output to be consistent with numpy.percentile.
+    :return: numpy.array with computed quantiles.
+    """
+    values = np.array(values)
+    quantiles = np.array(quantiles)
+    if sample_weight is None:
+        sample_weight = np.ones(len(values))
+    sample_weight = np.array(sample_weight)
+    assert np.all(quantiles >= 0) and np.all(
+        quantiles <= 1
+    ), "quantiles should be in [0, 1]"
+
+    if not values_sorted:
+        sorter = np.argsort(values)
+        values = values[sorter]
+        sample_weight = sample_weight[sorter]
+
+    weighted_quantiles = np.cumsum(sample_weight) - 0.5 * sample_weight
+    if old_style:
+        # To be convenient with np.percentile
+        weighted_quantiles -= weighted_quantiles[0]
+        weighted_quantiles /= weighted_quantiles[-1]
+    else:
+        weighted_quantiles /= np.sum(sample_weight)
+    return np.interp(quantiles, weighted_quantiles, values)
+
+
+def get_distance_Mpc(mc_file):
+    # reading the source distance in Mpc from the data file
+    with open(mc_file) as f_mc_lines:
+        for line in f_mc_lines:
+            if line.startswith("#"):
+                cols = list(line.split())
+                idx = cols.index("T/Mpc")
+                if idx >= 0:
+                    return float(cols[idx + 2])
+    raise ValueError('Unexpected mc file format: "T/Mpc" not found')
+
+
+default_params = {
+    "weight_col": 2,
+    "E_src_col": 8,
+    "delay_col": 6,
+    "comoving_distance": None,
+    "n_points": 100,
+    "logscale": False,
+    "suffix": "",
+    "Emin": 1e6,
+    "Emax": 1e20,
+    "psf": 1.0,
+    "cut_0": False,
+    "rounding_error": 0.0007,  # 1./600,
+    "min_n_particles": 100,
+    "out_ext": "png",
+    "max_t": 0.05,  # use negative for percentile
+    "show": False,
+    "add_alpha": 0.0,  # multiply weights be E_src^{-add_alpha}
+    # 'frac_time': 2000/3600,  # calculate fraction of flux coming within this time in hours
+    "verbose": 2,
+    "format": ".2g",  # float format
+}
+
+
+def filter_data(data, **kwargs):
+    params = default_params.copy()
+    params.update(kwargs)
+
+    Emax = params["Emax"]
+    Emin = params["Emin"]
+    data = data[data[:, 0] <= Emax]
+    data = data[data[:, 0] >= Emin]
+    psf = params["psf"]
+    data = data[data[:, 2] <= psf]
+    cut0 = params["cut_0"]
+    if cut0:
+        col = params["delay_col"] - 1
+        data = data[data[:, col] != 0.0]
+    return data
+
+
+def get_counts(rotated_mc_file, **kwargs):
+    params = default_params.copy()
+    params.update(kwargs)
+    data = np.loadtxt(rotated_mc_file)
+    data_filtered = filter_data(data, **kwargs)
+    verbose = params["verbose"]
+    if verbose > 1:
+        print(len(data_filtered), "of", len(data), "has passed the filter")
+
+    weight_col = params["weight_col"] - 1
+    col = params["delay_col"] - 1
+    comoving_distance = params["comoving_distance"]
+    if not comoving_distance:
+        comoving_distance = get_distance_Mpc(rotated_mc_file)
+
+    x_scale = comoving_distance * Mpc_in_h  # convert to hours
+
+    delay = data_filtered[:, col] * x_scale
+
+    idxs = np.argsort(delay)
+    delay = delay[idxs]
+
+    if weight_col >= 0:
+        assert weight_col < data_filtered.shape[1]
+        weights = data_filtered[idxs, weight_col]
+    else:
+        weights = np.ones(len(idxs))
+
+    add_alpha = params["add_alpha"]
+    if add_alpha != 0:
+        E_src = data_filtered[:, params["E_src_col"]]
+        av_Esrc = np.exp(np.log(E_src).mean())
+        weights *= np.power(E_src / av_Esrc, -add_alpha)
+
+    return delay, weights
+
+
+def light_curve(delay, weights, **kwargs):
+    params = default_params.copy()
+    params.update(kwargs)
+    min_n_particles = params["min_n_particles"]
+    min_bin_size = params["rounding_error"]
+    max_t = params["max_t"]
+
+    if max_t < 0:
+        max_t = weighted_quantile(
+            delay, [-0.01 * max_t], sample_weight=weights)[0]
+
+    f = []
+    t = []
+    N = []
+
+    bin_idx = 0
+    if delay[0] < min_bin_size:
+        bin_idx = np.where(delay < min_bin_size)[0][-1]
+        if bin_idx + 1 < min_n_particles:
+            bin_idx = min_n_particles
+        wsum = np.sum(weights[:bin_idx])
+        _t = np.sum(delay[:bin_idx] * weights[:bin_idx]) / wsum
+        _t = max(_t, 0.5 * min_bin_size)
+        t.append(_t)
+        bin_size = max(delay[bin_idx] - delay[0], min_bin_size)
+        f.append(wsum / bin_size)
+        N.append(bin_idx)
+
+    while True:
+        xmin = (
+            0.5 * (delay[bin_idx - 1] + delay[bin_idx])
+            if bin_idx > 0
+            else delay[bin_idx]
+        )
+        if xmin > max_t:
+            break
+        bin_idx2 = np.where(
+            delay[bin_idx + min_n_particles:] > xmin + min_bin_size)[0]
+        if len(bin_idx2) == 0:
+            break
+        bin_idx2 = bin_idx2[0] + bin_idx + min_n_particles
+        _delay = delay[bin_idx:bin_idx2]
+        _weights = weights[bin_idx:bin_idx2]
+        wsum = _weights.sum()
+        t.append(np.sum(_delay * _weights) / wsum)
+        xmax = 0.5 * (delay[bin_idx2 - 1] + delay[bin_idx2])
+        f.append(wsum / (xmax - xmin))
+        N.append(bin_idx2 - bin_idx)
+        bin_idx = bin_idx2
+
+    return [np.array(x) for x in (t, f, N)]
+
+
+def make_plot(path, label="verbose", **kwargs):
+    params = default_params.copy()
+    params.update(kwargs)
+    path = os.path.expanduser(path)
+    fmt = params["format"]
+    col = params["delay_col"] - 1
+    logscale = params["logscale"]
+    verbose = params["verbose"]
+    Emax = params["Emax"]
+    Emin = params["Emin"]
+    psf = params["psf"]
+    cut0 = params["cut_0"]
+
+    delay, weights = get_counts(path, **kwargs)
+    t, f, _ = light_curve(delay, weights, **kwargs)
+
+    suffix = params["suffix"]
+
+    if cut0:
+        suffix += "_cut0"
+    if logscale:
+        suffix += "_log"
+
+    max_t = params["max_t"]
+    out_name = (
+        f"{path}_f{col + 1}_E{Emin:{fmt}}-{Emax:{fmt}}TeV_th{psf}_r{max_t}{suffix}."
+    )
+
+    x_label = "t, [h]"
+    y_label = "dN/dt [a.u.]"
+
+    if verbose > 1:
+        for key, val in params.items():
+            print(f"\t{key} = {val}")
+
+    if max_t < 0:
+        median, max_t = weighted_quantile(
+            delay, [0.5, -0.01 * max_t], sample_weight=weights
+        )
+    else:
+        median = weighted_quantile(delay, [0.5], sample_weight=weights)[0]
+
+    # the histogram of the data
+    plt.xlabel(x_label)
+    plt.ylabel(y_label)
+
+    if label == "verbose":
+        label = f"50% with delay<{median:{fmt}} h"
+    plot_args = {}
+    if label:
+        plot_args["label"] = label
+
+    plt.plot(t, f, "g-", **plot_args)
+
+    if logscale:
+        plt.xscale("log")
+        plt.yscale("log")
+
+    if label:
+        plt.legend(loc="upper right")
+
+    file_name = out_name + params["out_ext"].strip()
+    plt.savefig(file_name)
+    if verbose > 0:
+        print("saved to", file_name)
+    if params["show"]:
+        plt.show()
+
+    return file_name
+
+
+if __name__ == "__main__":
+
+    def usage_exit(reason=""):
+        print(reason, file=sys.stderr)
+        print(
+            "usage: python",
+            sys.argv[0],
+            "data_file",
+            *["{}={}".format(k, v) for k, v in default_params.items()],
+            file=sys.stderr,
+        )
+        exit(1)
+
+    if len(sys.argv) < 2:
+        usage_exit()
+
+    kwargs = {}
+    for par in sys.argv[2:]:
+        kv = par.split("=")
+        if len(kv) != 2:
+            usage_exit("invalid argument " + par)
+        if kv[0] not in default_params:
+            usage_exit("unknown parameter " + kv[0])
+        constr = type(default_params[kv[0]])
+        value = kv[1]
+        if constr == bool:
+            value = value.lower() == "false" or value == "0"
+            value = not value
+        else:
+            value = constr(value)
+        kwargs[kv[0]] = value
+
+    make_plot(sys.argv[1], **kwargs)