Mercurial > repos > astroteam > crbeam_astro_tool
comparison light_curve.py @ 0:f40d05521dca draft default tip
planemo upload for repository https://github.com/esg-epfl-apc/tools-astro/tree/main/tools commit de01e3c02a26cd6353a6b9b6f8d1be44de8ccd54
author | astroteam |
---|---|
date | Fri, 25 Apr 2025 19:33:20 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:f40d05521dca |
---|---|
1 import os | |
2 import sys | |
3 | |
4 import matplotlib.pyplot as plt | |
5 import numpy as np | |
6 | |
7 Mpc_in_h = 2.8590868063e10 | |
8 | |
9 | |
10 def weighted_quantile( | |
11 values, quantiles, sample_weight=None, values_sorted=False, old_style=False | |
12 ): | |
13 """Very close to numpy.percentile, but supports weights. | |
14 NOTE: quantiles should be in [0, 1]! | |
15 :param values: numpy.array with data | |
16 :param quantiles: array-like with many quantiles needed | |
17 :param sample_weight: array-like of the same length as `array` | |
18 :param values_sorted: bool, if True, then will avoid sorting of initial array | |
19 :param old_style: if True, will correct output to be consistent with numpy.percentile. | |
20 :return: numpy.array with computed quantiles. | |
21 """ | |
22 values = np.array(values) | |
23 quantiles = np.array(quantiles) | |
24 if sample_weight is None: | |
25 sample_weight = np.ones(len(values)) | |
26 sample_weight = np.array(sample_weight) | |
27 assert np.all(quantiles >= 0) and np.all( | |
28 quantiles <= 1 | |
29 ), "quantiles should be in [0, 1]" | |
30 | |
31 if not values_sorted: | |
32 sorter = np.argsort(values) | |
33 values = values[sorter] | |
34 sample_weight = sample_weight[sorter] | |
35 | |
36 weighted_quantiles = np.cumsum(sample_weight) - 0.5 * sample_weight | |
37 if old_style: | |
38 # To be convenient with np.percentile | |
39 weighted_quantiles -= weighted_quantiles[0] | |
40 weighted_quantiles /= weighted_quantiles[-1] | |
41 else: | |
42 weighted_quantiles /= np.sum(sample_weight) | |
43 return np.interp(quantiles, weighted_quantiles, values) | |
44 | |
45 | |
46 def get_distance_Mpc(mc_file): | |
47 # reading the source distance in Mpc from the data file | |
48 with open(mc_file) as f_mc_lines: | |
49 for line in f_mc_lines: | |
50 if line.startswith("#"): | |
51 cols = list(line.split()) | |
52 idx = cols.index("T/Mpc") | |
53 if idx >= 0: | |
54 return float(cols[idx + 2]) | |
55 raise ValueError('Unexpected mc file format: "T/Mpc" not found') | |
56 | |
57 | |
58 default_params = { | |
59 "weight_col": 2, | |
60 "E_src_col": 8, | |
61 "delay_col": 6, | |
62 "comoving_distance": None, | |
63 "n_points": 100, | |
64 "logscale": False, | |
65 "suffix": "", | |
66 "Emin": 1e6, | |
67 "Emax": 1e20, | |
68 "psf": 1.0, | |
69 "cut_0": False, | |
70 "rounding_error": 0.0007, # 1./600, | |
71 "min_n_particles": 100, | |
72 "out_ext": "png", | |
73 "max_t": 0.05, # use negative for percentile | |
74 "show": False, | |
75 "add_alpha": 0.0, # multiply weights be E_src^{-add_alpha} | |
76 # 'frac_time': 2000/3600, # calculate fraction of flux coming within this time in hours | |
77 "verbose": 2, | |
78 "format": ".2g", # float format | |
79 } | |
80 | |
81 | |
82 def filter_data(data, **kwargs): | |
83 params = default_params.copy() | |
84 params.update(kwargs) | |
85 | |
86 Emax = params["Emax"] | |
87 Emin = params["Emin"] | |
88 data = data[data[:, 0] <= Emax] | |
89 data = data[data[:, 0] >= Emin] | |
90 psf = params["psf"] | |
91 data = data[data[:, 2] <= psf] | |
92 cut0 = params["cut_0"] | |
93 if cut0: | |
94 col = params["delay_col"] - 1 | |
95 data = data[data[:, col] != 0.0] | |
96 return data | |
97 | |
98 | |
99 def get_counts(rotated_mc_file, **kwargs): | |
100 params = default_params.copy() | |
101 params.update(kwargs) | |
102 data = np.loadtxt(rotated_mc_file) | |
103 data_filtered = filter_data(data, **kwargs) | |
104 verbose = params["verbose"] | |
105 if verbose > 1: | |
106 print(len(data_filtered), "of", len(data), "has passed the filter") | |
107 | |
108 weight_col = params["weight_col"] - 1 | |
109 col = params["delay_col"] - 1 | |
110 comoving_distance = params["comoving_distance"] | |
111 if not comoving_distance: | |
112 comoving_distance = get_distance_Mpc(rotated_mc_file) | |
113 | |
114 x_scale = comoving_distance * Mpc_in_h # convert to hours | |
115 | |
116 delay = data_filtered[:, col] * x_scale | |
117 | |
118 idxs = np.argsort(delay) | |
119 delay = delay[idxs] | |
120 | |
121 if weight_col >= 0: | |
122 assert weight_col < data_filtered.shape[1] | |
123 weights = data_filtered[idxs, weight_col] | |
124 else: | |
125 weights = np.ones(len(idxs)) | |
126 | |
127 add_alpha = params["add_alpha"] | |
128 if add_alpha != 0: | |
129 E_src = data_filtered[:, params["E_src_col"]] | |
130 av_Esrc = np.exp(np.log(E_src).mean()) | |
131 weights *= np.power(E_src / av_Esrc, -add_alpha) | |
132 | |
133 return delay, weights | |
134 | |
135 | |
136 def light_curve(delay, weights, **kwargs): | |
137 params = default_params.copy() | |
138 params.update(kwargs) | |
139 min_n_particles = params["min_n_particles"] | |
140 min_bin_size = params["rounding_error"] | |
141 max_t = params["max_t"] | |
142 | |
143 if max_t < 0: | |
144 max_t = weighted_quantile( | |
145 delay, [-0.01 * max_t], sample_weight=weights)[0] | |
146 | |
147 f = [] | |
148 t = [] | |
149 N = [] | |
150 | |
151 bin_idx = 0 | |
152 if delay[0] < min_bin_size: | |
153 bin_idx = np.where(delay < min_bin_size)[0][-1] | |
154 if bin_idx + 1 < min_n_particles: | |
155 bin_idx = min_n_particles | |
156 wsum = np.sum(weights[:bin_idx]) | |
157 _t = np.sum(delay[:bin_idx] * weights[:bin_idx]) / wsum | |
158 _t = max(_t, 0.5 * min_bin_size) | |
159 t.append(_t) | |
160 bin_size = max(delay[bin_idx] - delay[0], min_bin_size) | |
161 f.append(wsum / bin_size) | |
162 N.append(bin_idx) | |
163 | |
164 while True: | |
165 xmin = ( | |
166 0.5 * (delay[bin_idx - 1] + delay[bin_idx]) | |
167 if bin_idx > 0 | |
168 else delay[bin_idx] | |
169 ) | |
170 if xmin > max_t: | |
171 break | |
172 bin_idx2 = np.where( | |
173 delay[bin_idx + min_n_particles:] > xmin + min_bin_size)[0] | |
174 if len(bin_idx2) == 0: | |
175 break | |
176 bin_idx2 = bin_idx2[0] + bin_idx + min_n_particles | |
177 _delay = delay[bin_idx:bin_idx2] | |
178 _weights = weights[bin_idx:bin_idx2] | |
179 wsum = _weights.sum() | |
180 t.append(np.sum(_delay * _weights) / wsum) | |
181 xmax = 0.5 * (delay[bin_idx2 - 1] + delay[bin_idx2]) | |
182 f.append(wsum / (xmax - xmin)) | |
183 N.append(bin_idx2 - bin_idx) | |
184 bin_idx = bin_idx2 | |
185 | |
186 return [np.array(x) for x in (t, f, N)] | |
187 | |
188 | |
189 def make_plot(path, label="verbose", **kwargs): | |
190 params = default_params.copy() | |
191 params.update(kwargs) | |
192 path = os.path.expanduser(path) | |
193 fmt = params["format"] | |
194 col = params["delay_col"] - 1 | |
195 logscale = params["logscale"] | |
196 verbose = params["verbose"] | |
197 Emax = params["Emax"] | |
198 Emin = params["Emin"] | |
199 psf = params["psf"] | |
200 cut0 = params["cut_0"] | |
201 | |
202 delay, weights = get_counts(path, **kwargs) | |
203 t, f, _ = light_curve(delay, weights, **kwargs) | |
204 | |
205 suffix = params["suffix"] | |
206 | |
207 if cut0: | |
208 suffix += "_cut0" | |
209 if logscale: | |
210 suffix += "_log" | |
211 | |
212 max_t = params["max_t"] | |
213 out_name = ( | |
214 f"{path}_f{col + 1}_E{Emin:{fmt}}-{Emax:{fmt}}TeV_th{psf}_r{max_t}{suffix}." | |
215 ) | |
216 | |
217 x_label = "t, [h]" | |
218 y_label = "dN/dt [a.u.]" | |
219 | |
220 if verbose > 1: | |
221 for key, val in params.items(): | |
222 print(f"\t{key} = {val}") | |
223 | |
224 if max_t < 0: | |
225 median, max_t = weighted_quantile( | |
226 delay, [0.5, -0.01 * max_t], sample_weight=weights | |
227 ) | |
228 else: | |
229 median = weighted_quantile(delay, [0.5], sample_weight=weights)[0] | |
230 | |
231 # the histogram of the data | |
232 plt.xlabel(x_label) | |
233 plt.ylabel(y_label) | |
234 | |
235 if label == "verbose": | |
236 label = f"50% with delay<{median:{fmt}} h" | |
237 plot_args = {} | |
238 if label: | |
239 plot_args["label"] = label | |
240 | |
241 plt.plot(t, f, "g-", **plot_args) | |
242 | |
243 if logscale: | |
244 plt.xscale("log") | |
245 plt.yscale("log") | |
246 | |
247 if label: | |
248 plt.legend(loc="upper right") | |
249 | |
250 file_name = out_name + params["out_ext"].strip() | |
251 plt.savefig(file_name) | |
252 if verbose > 0: | |
253 print("saved to", file_name) | |
254 if params["show"]: | |
255 plt.show() | |
256 | |
257 return file_name | |
258 | |
259 | |
260 if __name__ == "__main__": | |
261 | |
262 def usage_exit(reason=""): | |
263 print(reason, file=sys.stderr) | |
264 print( | |
265 "usage: python", | |
266 sys.argv[0], | |
267 "data_file", | |
268 *["{}={}".format(k, v) for k, v in default_params.items()], | |
269 file=sys.stderr, | |
270 ) | |
271 exit(1) | |
272 | |
273 if len(sys.argv) < 2: | |
274 usage_exit() | |
275 | |
276 kwargs = {} | |
277 for par in sys.argv[2:]: | |
278 kv = par.split("=") | |
279 if len(kv) != 2: | |
280 usage_exit("invalid argument " + par) | |
281 if kv[0] not in default_params: | |
282 usage_exit("unknown parameter " + kv[0]) | |
283 constr = type(default_params[kv[0]]) | |
284 value = kv[1] | |
285 if constr == bool: | |
286 value = value.lower() == "false" or value == "0" | |
287 value = not value | |
288 else: | |
289 value = constr(value) | |
290 kwargs[kv[0]] = value | |
291 | |
292 make_plot(sys.argv[1], **kwargs) |