comparison light_curve.py @ 0:f40d05521dca draft default tip

planemo upload for repository https://github.com/esg-epfl-apc/tools-astro/tree/main/tools commit de01e3c02a26cd6353a6b9b6f8d1be44de8ccd54
author astroteam
date Fri, 25 Apr 2025 19:33:20 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:f40d05521dca
1 import os
2 import sys
3
4 import matplotlib.pyplot as plt
5 import numpy as np
6
7 Mpc_in_h = 2.8590868063e10
8
9
10 def weighted_quantile(
11 values, quantiles, sample_weight=None, values_sorted=False, old_style=False
12 ):
13 """Very close to numpy.percentile, but supports weights.
14 NOTE: quantiles should be in [0, 1]!
15 :param values: numpy.array with data
16 :param quantiles: array-like with many quantiles needed
17 :param sample_weight: array-like of the same length as `array`
18 :param values_sorted: bool, if True, then will avoid sorting of initial array
19 :param old_style: if True, will correct output to be consistent with numpy.percentile.
20 :return: numpy.array with computed quantiles.
21 """
22 values = np.array(values)
23 quantiles = np.array(quantiles)
24 if sample_weight is None:
25 sample_weight = np.ones(len(values))
26 sample_weight = np.array(sample_weight)
27 assert np.all(quantiles >= 0) and np.all(
28 quantiles <= 1
29 ), "quantiles should be in [0, 1]"
30
31 if not values_sorted:
32 sorter = np.argsort(values)
33 values = values[sorter]
34 sample_weight = sample_weight[sorter]
35
36 weighted_quantiles = np.cumsum(sample_weight) - 0.5 * sample_weight
37 if old_style:
38 # To be convenient with np.percentile
39 weighted_quantiles -= weighted_quantiles[0]
40 weighted_quantiles /= weighted_quantiles[-1]
41 else:
42 weighted_quantiles /= np.sum(sample_weight)
43 return np.interp(quantiles, weighted_quantiles, values)
44
45
46 def get_distance_Mpc(mc_file):
47 # reading the source distance in Mpc from the data file
48 with open(mc_file) as f_mc_lines:
49 for line in f_mc_lines:
50 if line.startswith("#"):
51 cols = list(line.split())
52 idx = cols.index("T/Mpc")
53 if idx >= 0:
54 return float(cols[idx + 2])
55 raise ValueError('Unexpected mc file format: "T/Mpc" not found')
56
57
58 default_params = {
59 "weight_col": 2,
60 "E_src_col": 8,
61 "delay_col": 6,
62 "comoving_distance": None,
63 "n_points": 100,
64 "logscale": False,
65 "suffix": "",
66 "Emin": 1e6,
67 "Emax": 1e20,
68 "psf": 1.0,
69 "cut_0": False,
70 "rounding_error": 0.0007, # 1./600,
71 "min_n_particles": 100,
72 "out_ext": "png",
73 "max_t": 0.05, # use negative for percentile
74 "show": False,
75 "add_alpha": 0.0, # multiply weights be E_src^{-add_alpha}
76 # 'frac_time': 2000/3600, # calculate fraction of flux coming within this time in hours
77 "verbose": 2,
78 "format": ".2g", # float format
79 }
80
81
82 def filter_data(data, **kwargs):
83 params = default_params.copy()
84 params.update(kwargs)
85
86 Emax = params["Emax"]
87 Emin = params["Emin"]
88 data = data[data[:, 0] <= Emax]
89 data = data[data[:, 0] >= Emin]
90 psf = params["psf"]
91 data = data[data[:, 2] <= psf]
92 cut0 = params["cut_0"]
93 if cut0:
94 col = params["delay_col"] - 1
95 data = data[data[:, col] != 0.0]
96 return data
97
98
99 def get_counts(rotated_mc_file, **kwargs):
100 params = default_params.copy()
101 params.update(kwargs)
102 data = np.loadtxt(rotated_mc_file)
103 data_filtered = filter_data(data, **kwargs)
104 verbose = params["verbose"]
105 if verbose > 1:
106 print(len(data_filtered), "of", len(data), "has passed the filter")
107
108 weight_col = params["weight_col"] - 1
109 col = params["delay_col"] - 1
110 comoving_distance = params["comoving_distance"]
111 if not comoving_distance:
112 comoving_distance = get_distance_Mpc(rotated_mc_file)
113
114 x_scale = comoving_distance * Mpc_in_h # convert to hours
115
116 delay = data_filtered[:, col] * x_scale
117
118 idxs = np.argsort(delay)
119 delay = delay[idxs]
120
121 if weight_col >= 0:
122 assert weight_col < data_filtered.shape[1]
123 weights = data_filtered[idxs, weight_col]
124 else:
125 weights = np.ones(len(idxs))
126
127 add_alpha = params["add_alpha"]
128 if add_alpha != 0:
129 E_src = data_filtered[:, params["E_src_col"]]
130 av_Esrc = np.exp(np.log(E_src).mean())
131 weights *= np.power(E_src / av_Esrc, -add_alpha)
132
133 return delay, weights
134
135
136 def light_curve(delay, weights, **kwargs):
137 params = default_params.copy()
138 params.update(kwargs)
139 min_n_particles = params["min_n_particles"]
140 min_bin_size = params["rounding_error"]
141 max_t = params["max_t"]
142
143 if max_t < 0:
144 max_t = weighted_quantile(
145 delay, [-0.01 * max_t], sample_weight=weights)[0]
146
147 f = []
148 t = []
149 N = []
150
151 bin_idx = 0
152 if delay[0] < min_bin_size:
153 bin_idx = np.where(delay < min_bin_size)[0][-1]
154 if bin_idx + 1 < min_n_particles:
155 bin_idx = min_n_particles
156 wsum = np.sum(weights[:bin_idx])
157 _t = np.sum(delay[:bin_idx] * weights[:bin_idx]) / wsum
158 _t = max(_t, 0.5 * min_bin_size)
159 t.append(_t)
160 bin_size = max(delay[bin_idx] - delay[0], min_bin_size)
161 f.append(wsum / bin_size)
162 N.append(bin_idx)
163
164 while True:
165 xmin = (
166 0.5 * (delay[bin_idx - 1] + delay[bin_idx])
167 if bin_idx > 0
168 else delay[bin_idx]
169 )
170 if xmin > max_t:
171 break
172 bin_idx2 = np.where(
173 delay[bin_idx + min_n_particles:] > xmin + min_bin_size)[0]
174 if len(bin_idx2) == 0:
175 break
176 bin_idx2 = bin_idx2[0] + bin_idx + min_n_particles
177 _delay = delay[bin_idx:bin_idx2]
178 _weights = weights[bin_idx:bin_idx2]
179 wsum = _weights.sum()
180 t.append(np.sum(_delay * _weights) / wsum)
181 xmax = 0.5 * (delay[bin_idx2 - 1] + delay[bin_idx2])
182 f.append(wsum / (xmax - xmin))
183 N.append(bin_idx2 - bin_idx)
184 bin_idx = bin_idx2
185
186 return [np.array(x) for x in (t, f, N)]
187
188
189 def make_plot(path, label="verbose", **kwargs):
190 params = default_params.copy()
191 params.update(kwargs)
192 path = os.path.expanduser(path)
193 fmt = params["format"]
194 col = params["delay_col"] - 1
195 logscale = params["logscale"]
196 verbose = params["verbose"]
197 Emax = params["Emax"]
198 Emin = params["Emin"]
199 psf = params["psf"]
200 cut0 = params["cut_0"]
201
202 delay, weights = get_counts(path, **kwargs)
203 t, f, _ = light_curve(delay, weights, **kwargs)
204
205 suffix = params["suffix"]
206
207 if cut0:
208 suffix += "_cut0"
209 if logscale:
210 suffix += "_log"
211
212 max_t = params["max_t"]
213 out_name = (
214 f"{path}_f{col + 1}_E{Emin:{fmt}}-{Emax:{fmt}}TeV_th{psf}_r{max_t}{suffix}."
215 )
216
217 x_label = "t, [h]"
218 y_label = "dN/dt [a.u.]"
219
220 if verbose > 1:
221 for key, val in params.items():
222 print(f"\t{key} = {val}")
223
224 if max_t < 0:
225 median, max_t = weighted_quantile(
226 delay, [0.5, -0.01 * max_t], sample_weight=weights
227 )
228 else:
229 median = weighted_quantile(delay, [0.5], sample_weight=weights)[0]
230
231 # the histogram of the data
232 plt.xlabel(x_label)
233 plt.ylabel(y_label)
234
235 if label == "verbose":
236 label = f"50% with delay<{median:{fmt}} h"
237 plot_args = {}
238 if label:
239 plot_args["label"] = label
240
241 plt.plot(t, f, "g-", **plot_args)
242
243 if logscale:
244 plt.xscale("log")
245 plt.yscale("log")
246
247 if label:
248 plt.legend(loc="upper right")
249
250 file_name = out_name + params["out_ext"].strip()
251 plt.savefig(file_name)
252 if verbose > 0:
253 print("saved to", file_name)
254 if params["show"]:
255 plt.show()
256
257 return file_name
258
259
260 if __name__ == "__main__":
261
262 def usage_exit(reason=""):
263 print(reason, file=sys.stderr)
264 print(
265 "usage: python",
266 sys.argv[0],
267 "data_file",
268 *["{}={}".format(k, v) for k, v in default_params.items()],
269 file=sys.stderr,
270 )
271 exit(1)
272
273 if len(sys.argv) < 2:
274 usage_exit()
275
276 kwargs = {}
277 for par in sys.argv[2:]:
278 kv = par.split("=")
279 if len(kv) != 2:
280 usage_exit("invalid argument " + par)
281 if kv[0] not in default_params:
282 usage_exit("unknown parameter " + kv[0])
283 constr = type(default_params[kv[0]])
284 value = kv[1]
285 if constr == bool:
286 value = value.lower() == "false" or value == "0"
287 value = not value
288 else:
289 value = constr(value)
290 kwargs[kv[0]] = value
291
292 make_plot(sys.argv[1], **kwargs)