Geant4 Cross Reference

Cross-Referencing   Geant4
Geant4/examples/extended/parameterisations/Par04/training/utils/plotters.py

Version: [ ReleaseNotes ] [ 1.0 ] [ 1.1 ] [ 2.0 ] [ 3.0 ] [ 3.1 ] [ 3.2 ] [ 4.0 ] [ 4.0.p1 ] [ 4.0.p2 ] [ 4.1 ] [ 4.1.p1 ] [ 5.0 ] [ 5.0.p1 ] [ 5.1 ] [ 5.1.p1 ] [ 5.2 ] [ 5.2.p1 ] [ 5.2.p2 ] [ 6.0 ] [ 6.0.p1 ] [ 6.1 ] [ 6.2 ] [ 6.2.p1 ] [ 6.2.p2 ] [ 7.0 ] [ 7.0.p1 ] [ 7.1 ] [ 7.1.p1 ] [ 8.0 ] [ 8.0.p1 ] [ 8.1 ] [ 8.1.p1 ] [ 8.1.p2 ] [ 8.2 ] [ 8.2.p1 ] [ 8.3 ] [ 8.3.p1 ] [ 8.3.p2 ] [ 9.0 ] [ 9.0.p1 ] [ 9.0.p2 ] [ 9.1 ] [ 9.1.p1 ] [ 9.1.p2 ] [ 9.1.p3 ] [ 9.2 ] [ 9.2.p1 ] [ 9.2.p2 ] [ 9.2.p3 ] [ 9.2.p4 ] [ 9.3 ] [ 9.3.p1 ] [ 9.3.p2 ] [ 9.4 ] [ 9.4.p1 ] [ 9.4.p2 ] [ 9.4.p3 ] [ 9.4.p4 ] [ 9.5 ] [ 9.5.p1 ] [ 9.5.p2 ] [ 9.6 ] [ 9.6.p1 ] [ 9.6.p2 ] [ 9.6.p3 ] [ 9.6.p4 ] [ 10.0 ] [ 10.0.p1 ] [ 10.0.p2 ] [ 10.0.p3 ] [ 10.0.p4 ] [ 10.1 ] [ 10.1.p1 ] [ 10.1.p2 ] [ 10.1.p3 ] [ 10.2 ] [ 10.2.p1 ] [ 10.2.p2 ] [ 10.2.p3 ] [ 10.3 ] [ 10.3.p1 ] [ 10.3.p2 ] [ 10.3.p3 ] [ 10.4 ] [ 10.4.p1 ] [ 10.4.p2 ] [ 10.4.p3 ] [ 10.5 ] [ 10.5.p1 ] [ 10.6 ] [ 10.6.p1 ] [ 10.6.p2 ] [ 10.6.p3 ] [ 10.7 ] [ 10.7.p1 ] [ 10.7.p2 ] [ 10.7.p3 ] [ 10.7.p4 ] [ 11.0 ] [ 11.0.p1 ] [ 11.0.p2 ] [ 11.0.p3, ] [ 11.0.p4 ] [ 11.1 ] [ 11.1.1 ] [ 11.1.2 ] [ 11.1.3 ] [ 11.2 ] [ 11.2.1 ] [ 11.2.2 ] [ 11.3.0 ]

Diff markup

Differences between /examples/extended/parameterisations/Par04/training/utils/plotters.py (Version 11.3.0) and /examples/extended/parameterisations/Par04/training/utils/plotters.py (Version 9.0)


  1 from dataclasses import dataclass                 
  2 from typing import Tuple                          
  3                                                   
  4 import numpy as np                                
  5 from matplotlib import pyplot as plt              
  6 from scipy.optimize import curve_fit              
  7                                                   
  8 from core.constants import N_CELLS_Z, N_CELLS_    
  9     ML_SIM_HISTOGRAM_COLOR, FULL_SIM_GAUSSIAN_    
 10 from utils.observables import LongitudinalProf    
 11                                                   
 12 plt.rcParams.update({"font.size": 22})            
 13                                                   
 14                                                   
 15 @dataclass                                        
 16 class Plotter:                                    
 17     """ An abstract class defining interface o    
 18                                                   
 19     Do not use this class directly. Use Profil    
 20                                                   
 21     Attributes:                                   
 22         _particle_energy: An integer which is     
 23         _particle_angle: An integer which is a    
 24         _geometry: A string which is a name of    
 25                                                   
 26     """                                           
 27     _particle_energy: int                         
 28     _particle_angle: int                          
 29     _geometry: str                                
 30                                                   
 31     def plot_and_save(self):                      
 32         pass                                      
 33                                                   
 34                                                   
 35 def _gaussian(x: np.ndarray, a: float, mu: flo    
 36     """ Computes a value of a Gaussian.           
 37                                                   
 38     Args:                                         
 39         x: An argument of a function.             
 40         a: A scaling parameter.                   
 41         mu: A mean.                               
 42         sigma: A variance.                        
 43                                                   
 44     Returns:                                      
 45         A value of a function for given argume    
 46                                                   
 47     """                                           
 48     return a * np.exp(-((x - mu)**2 / (2 * sig    
 49                                                   
 50                                                   
 51 def _best_fit(data: np.ndarray,                   
 52               bins: np.ndarray,                   
 53               hist: bool = False) -> Tuple[np.    
 54     """ Finds estimated shape of a Gaussian us    
 55                                                   
 56     Args:                                         
 57         data: A numpy array with values of obs    
 58         bins: A numpy array specifying histogr    
 59         hist: If histogram is calculated. Then    
 60                                                   
 61     Returns:                                      
 62         A tuple of two lists. Xs and Ys of pre    
 63                                                   
 64     """                                           
 65     # Calculate histogram.                        
 66     if not hist:                                  
 67         hist, _ = np.histogram(data, bins)        
 68     else:                                         
 69         hist = data                               
 70                                                   
 71     # Choose only those bins which are nonzero    
 72     # hence we are interested in its first ele    
 73     indices = hist.nonzero()[0]                   
 74                                                   
 75     # Based on previously chosen nonzero bin,     
 76     # fitting procedure. Len(bins) == len(hist    
 77     bins_middles = (bins[:-1] + bins[1:]) / 2     
 78     xs = bins_middles[indices]                    
 79     ys_bar = hist[indices]                        
 80                                                   
 81     # Set initial parameters for curve fitter.    
 82     a0 = np.max(ys_bar)                           
 83     mu0 = np.mean(xs)                             
 84     sigma0 = np.var(xs)                           
 85                                                   
 86     # Fit a Gaussian to the prepared data.        
 87     (a, mu, sigma), _ = curve_fit(f=_gaussian,    
 88                                   xdata=xs,       
 89                                   ydata=ys_bar    
 90                                   p0=[a0, mu0,    
 91                                   method="trf"    
 92                                   maxfev=1000)    
 93                                                   
 94     # Calculate values of an approximation in     
 95     ys = _gaussian(xs, a, mu, sigma)              
 96     return xs, ys                                 
 97                                                   
 98                                                   
 99 @dataclass                                        
100 class ProfilePlotter(Plotter):                    
101     """ Plotter responsible for preparing plot    
102                                                   
103     Attributes:                                   
104         _full_simulation: A numpy array repres    
105         _ml_simulation: A numpy array represen    
106         _plot_gaussian: A boolean. Decides whe    
107             a fitted gaussian.                    
108         _profile_type: An enum. A profile can     
109                                                   
110     """                                           
111     _full_simulation: Profile                     
112     _ml_simulation: Profile                       
113     _plot_gaussian: bool = False                  
114                                                   
115     def __post_init__(self):                      
116         # Check if profiles are either both lo    
117         full_simulation_type = type(self._full    
118         ml_generation_type = type(self._ml_sim    
119         assert full_simulation_type == ml_gene    
120                                                   
121                                                   
122         # Set an attribute with profile type.     
123         if full_simulation_type == Longitudina    
124             self._profile_type = ProfileType.L    
125         else:                                     
126             self._profile_type = ProfileType.L    
127                                                   
128     def _plot_and_save_customizable_histogram(    
129             self,                                 
130             full_simulation: np.ndarray,          
131             ml_simulation: np.ndarray,            
132             bins: np.ndarray,                     
133             xlabel: str,                          
134             observable_name: str,                 
135             plot_profile: bool = False,           
136             y_log_scale: bool = False) -> None    
137         """ Prepares and saves a histogram for    
138                                                   
139         Args:                                     
140             full_simulation: A numpy array of     
141             ml_simulation: A numpy array of ob    
142             bins: A numpy array specifying his    
143             xlabel: A string. Name of x-axis o    
144             observable_name: A string. Name of    
145             plot_profile: A boolean. If set to    
146                 defined by the number of layer    
147                 need to create a data repeatin    
148                 only while plotting profiles n    
149             y_log_scale: A boolean. Used log s    
150                                                   
151         Returns:                                  
152             None.                                 
153                                                   
154         """                                       
155         fig, axes = plt.subplots(2,               
156                                  1,               
157                                  figsize=(15,     
158                                  clear=True,      
159                                  sharex="all")    
160                                                   
161         # Plot histograms.                        
162         if plot_profile:                          
163             # We already have the bins (layers    
164             # therefore directly plotting a st    
165             axes[0].step(bins[:-1],               
166                          full_simulation,         
167                          label="FullSim",         
168                          color=FULL_SIM_HISTOG    
169             axes[0].step(bins[:-1],               
170                          ml_simulation,           
171                          label="MLSim",           
172                          color=ML_SIM_HISTOGRA    
173             axes[0].vlines(x=bins[0],             
174                            ymin=0,                
175                            ymax=full_simulatio    
176                            color=FULL_SIM_HIST    
177             axes[0].vlines(x=bins[-2],            
178                            ymin=0,                
179                            ymax=full_simulatio    
180                            color=FULL_SIM_HIST    
181             axes[0].vlines(x=bins[0],             
182                            ymin=0,                
183                            ymax=ml_simulation[    
184                            color=ML_SIM_HISTOG    
185             axes[0].vlines(x=bins[-2],            
186                            ymin=0,                
187                            ymax=ml_simulation[    
188                            color=ML_SIM_HISTOG    
189             axes[0].set_ylim(0, None)             
190                                                   
191             # For using it later for the ratio    
192             energy_full_sim, energy_ml_sim = f    
193         else:                                     
194             energy_full_sim, _, _ = axes[0].hi    
195                 x=full_simulation,                
196                 bins=bins,                        
197                 label="FullSim",                  
198                 histtype=HISTOGRAM_TYPE,          
199                 color=FULL_SIM_HISTOGRAM_COLOR    
200             energy_ml_sim, _, _ = axes[0].hist    
201                                                   
202                                                   
203                                                   
204                                                   
205                                                   
206         # Plot Gaussians if needed.               
207         if self._plot_gaussian:                   
208             if plot_profile:                      
209                 (xs_full_sim, ys_full_sim) = _    
210                                                   
211                                                   
212                 (xs_ml_sim, ys_ml_sim) = _best    
213                                                   
214                                                   
215             else:                                 
216                 (xs_full_sim, ys_full_sim) = _    
217                 (xs_ml_sim, ys_ml_sim) = _best    
218             axes[0].plot(xs_full_sim,             
219                          ys_full_sim,             
220                          color=FULL_SIM_GAUSSI    
221                          label="FullSim")         
222             axes[0].plot(xs_ml_sim,               
223                          ys_ml_sim,               
224                          color=ML_SIM_GAUSSIAN    
225                          label="MLSim")           
226                                                   
227         if y_log_scale:                           
228             axes[0].set_yscale("log")             
229         axes[0].legend(loc="best")                
230         axes[0].set_xlabel(xlabel)                
231         axes[0].set_ylabel("Energy [Mev]")        
232         axes[0].set_title(                        
233             f" $e^-$, {self._particle_energy}     
234         )                                         
235                                                   
236         # Calculate ratios.                       
237         ratio = np.divide(energy_ml_sim,          
238                           energy_full_sim,        
239                           out=np.ones_like(ene    
240                           where=(energy_full_s    
241         # Since len(bins) == 1 + data, we calc    
242         bins_middles = (bins[:-1] + bins[1:])     
243         axes[1].plot(bins_middles, ratio, "-o"    
244         axes[1].set_xlabel(xlabel)                
245         axes[1].set_ylabel("MLSim/FullSim")       
246         axes[1].axhline(y=1, color="black")       
247         plt.savefig(                              
248             f"{VALID_DIR}/{observable_name}_Ge    
249             + f"Angle_{self._particle_angle}.p    
250         plt.clf()                                 
251                                                   
252     def _plot_profile(self) -> None:              
253         """ Plots profile of an observable.       
254                                                   
255         Returns:                                  
256             None.                                 
257                                                   
258         """                                       
259         full_simulation_profile = self._full_s    
260         ml_simulation_profile = self._ml_simul    
261         if self._profile_type == ProfileType.L    
262             # matplotlib will include the righ    
263             # hence extending by 1.               
264             bins = np.linspace(0, N_CELLS_Z, N    
265             observable_name = "LongProf"          
266             xlabel = "Layer index"                
267         else:                                     
268             bins = np.linspace(0, N_CELLS_R, N    
269             observable_name = "LatProf"           
270             xlabel = "R index"                    
271         self._plot_and_save_customizable_histo    
272                                                   
273                                                   
274                                                   
275                                                   
276                                                   
277                                                   
278     def _plot_first_moment(self) -> None:         
279         """ Plots and saves a first moment of     
280                                                   
281         Returns:                                  
282             None.                                 
283                                                   
284         """                                       
285         full_simulation_first_moment = self._f    
286         )                                         
287         ml_simulation_first_moment = self._ml_    
288         if self._profile_type == ProfileType.L    
289             xlabel = "$<\lambda> [mm]$"           
290             observable_name = "LongFirstMoment    
291             bins = np.linspace(0, 0.4 * N_CELL    
292         else:                                     
293             xlabel = "$<r> [mm]$"                 
294             observable_name = "LatFirstMoment"    
295             bins = np.linspace(0, 0.75 * N_CEL    
296                                                   
297         self._plot_and_save_customizable_histo    
298             full_simulation_first_moment, ml_s    
299             xlabel, observable_name)              
300                                                   
301     def _plot_second_moment(self) -> None:        
302         """ Plots and saves a second moment of    
303                                                   
304         Returns:                                  
305             None.                                 
306                                                   
307         """                                       
308         full_simulation_second_moment = self._    
309         )                                         
310         ml_simulation_second_moment = self._ml    
311         if self._profile_type == ProfileType.L    
312             xlabel = "$<\lambda^{2}> [mm^{2}]$    
313             observable_name = "LongSecondMomen    
314             bins = np.linspace(0, pow(N_CELLS_    
315         else:                                     
316             xlabel = "$<r^{2}> [mm^{2}]$"         
317             observable_name = "LatSecondMoment    
318             bins = np.linspace(0, pow(N_CELLS_    
319                                                   
320         self._plot_and_save_customizable_histo    
321             full_simulation_second_moment, ml_    
322             xlabel, observable_name)              
323                                                   
324     def plot_and_save(self) -> None:              
325         """ Main plotting function.               
326                                                   
327         Calls private methods and prints the i    
328                                                   
329         Returns:                                  
330             None.                                 
331                                                   
332         """                                       
333         if self._profile_type == ProfileType.L    
334             profile_type_name = "longitudinal"    
335         else:                                     
336             profile_type_name = "lateral"         
337         print(f"Plotting the {profile_type_nam    
338         self._plot_profile()                      
339         print(f"Plotting the first moment of {    
340         self._plot_first_moment()                 
341         print(f"Plotting the second moment of     
342         self._plot_second_moment()                
343                                                   
344                                                   
345 @dataclass                                        
346 class EnergyPlotter(Plotter):                     
347     """ Plotter responsible for preparing plot    
348                                                   
349     Attributes:                                   
350         _full_simulation: A numpy array repres    
351         _ml_simulation: A numpy array represen    
352                                                   
353     """                                           
354     _full_simulation: Energy                      
355     _ml_simulation: Energy                        
356                                                   
357     def _plot_total_energy(self, y_log_scale=T    
358         """ Plots and saves a histogram with t    
359                                                   
360         Args:                                     
361             y_log_scale: A boolean. Used log s    
362                                                   
363         Returns:                                  
364             None.                                 
365                                                   
366         """                                       
367         full_simulation_total_energy = self._f    
368         )                                         
369         ml_simulation_total_energy = self._ml_    
370                                                   
371         plt.figure(figsize=(12, 8))               
372         bins = np.linspace(                       
373             np.min(full_simulation_total_energ    
374             np.min(full_simulation_total_energ    
375             np.max(full_simulation_total_energ    
376             np.max(full_simulation_total_energ    
377         plt.hist(x=full_simulation_total_energ    
378                  histtype=HISTOGRAM_TYPE,         
379                  label="FullSim",                 
380                  bins=bins,                       
381                  color=FULL_SIM_HISTOGRAM_COLO    
382         plt.hist(x=ml_simulation_total_energy,    
383                  histtype=HISTOGRAM_TYPE,         
384                  label="MLSim",                   
385                  bins=bins,                       
386                  color=ML_SIM_HISTOGRAM_COLOR)    
387         plt.legend(loc="upper left")              
388         if y_log_scale:                           
389             plt.yscale("log")                     
390         plt.xlabel("Energy [MeV]")                
391         plt.ylabel("# events")                    
392         plt.title(                                
393             f" $e^-$, {self._particle_energy}     
394         )                                         
395         plt.savefig(                              
396             f"{VALID_DIR}/E_tot_Geo_{self._geo    
397         )                                         
398         plt.clf()                                 
399                                                   
400     def _plot_cell_energy(self) -> None:          
401         """ Plots and saves a histogram with n    
402         calorimeter with particular energy det    
403                                                   
404         Returns:                                  
405             None.                                 
406                                                   
407         """                                       
408         full_simulation_cell_energy = self._fu    
409         ml_simulation_cell_energy = self._ml_s    
410                                                   
411         log_full_simulation_cell_energy = np.l    
412             full_simulation_cell_energy,          
413             out=np.zeros_like(full_simulation_    
414             where=(full_simulation_cell_energy    
415         log_ml_simulation_cell_energy = np.log    
416             ml_simulation_cell_energy,            
417             out=np.zeros_like(ml_simulation_ce    
418             where=(ml_simulation_cell_energy !    
419         plt.figure(figsize=(12, 8))               
420         bins = np.linspace(-4, 1, 1000)           
421         plt.hist(x=log_full_simulation_cell_en    
422                  bins=bins,                       
423                  histtype=HISTOGRAM_TYPE,         
424                  label="FullSim",                 
425                  color=FULL_SIM_HISTOGRAM_COLO    
426         plt.hist(x=log_ml_simulation_cell_ener    
427                  bins=bins,                       
428                  histtype=HISTOGRAM_TYPE,         
429                  label="MLSim",                   
430                  color=ML_SIM_HISTOGRAM_COLOR)    
431         plt.xlabel("log10(E/MeV)")                
432         plt.ylim(bottom=1)                        
433         plt.yscale("log")                         
434         plt.ylim(bottom=1)                        
435         plt.ylabel("# entries")                   
436         plt.title(                                
437             f" $e^-$, {self._particle_energy}     
438         )                                         
439         plt.grid(True)                            
440         plt.legend(loc="upper left")              
441         plt.savefig(                              
442             f"{VALID_DIR}/E_cell_Geo_{self._ge    
443         )                                         
444         plt.clf()                                 
445                                                   
446     def _plot_energy_per_layer(self):             
447         """ Plots and saves N_CELLS_Z histogra    
448                                                   
449         Returns:                                  
450             None.                                 
451                                                   
452         """                                       
453         full_simulation_energy_per_layer = sel    
454         )                                         
455         ml_simulation_energy_per_layer = self.    
456         )                                         
457                                                   
458         number_of_plots_in_row = 9                
459         number_of_plots_in_column = 5             
460                                                   
461         bins = np.linspace(np.min(full_simulat    
462                            np.max(full_simulat    
463                                                   
464         fig, ax = plt.subplots(number_of_plots    
465                                number_of_plots    
466                                figsize=(20, 15    
467                                sharex="all",      
468                                sharey="all",      
469                                constrained_lay    
470                                                   
471         for layer_nb in range(N_CELLS_Z):         
472             i = layer_nb // number_of_plots_in    
473             j = layer_nb % number_of_plots_in_    
474                                                   
475             ax[i][j].hist(full_simulation_ener    
476                           histtype=HISTOGRAM_T    
477                           label="FullSim",        
478                           bins=bins,              
479                           color=FULL_SIM_HISTO    
480             ax[i][j].hist(ml_simulation_energy    
481                           histtype=HISTOGRAM_T    
482                           label="MLSim",          
483                           bins=bins,              
484                           color=ML_SIM_HISTOGR    
485             ax[i][j].set_title(f"Layer {layer_    
486             ax[i][j].set_yscale("log")            
487             ax[i][j].tick_params(axis='both',     
488                                                   
489         fig.supxlabel("Energy [MeV]", fontsize    
490         fig.supylabel("# entries", fontsize=14    
491         fig.suptitle(                             
492             f" $e^-$, {self._particle_energy}     
493         )                                         
494                                                   
495         # Take legend from one plot and make i    
496         handles, labels = ax[0][0].get_legend_    
497         fig.legend(handles, labels, bbox_to_an    
498                                                   
499         plt.savefig(                              
500             f"{VALID_DIR}/E_layer_Geo_{self._g    
501             bbox_inches="tight")                  
502         plt.clf()                                 
503                                                   
504     def plot_and_save(self):                      
505         """ Main plotting function.               
506                                                   
507         Calls private methods and prints the i    
508                                                   
509         Returns:                                  
510             None.                                 
511                                                   
512         """                                       
513         print("Plotting total energy...")         
514         self._plot_total_energy()                 
515         print("Plotting cell energy...")          
516         self._plot_cell_energy()                  
517         print("Plotting energy per layer...")     
518         self._plot_energy_per_layer()