Geant4 Cross Reference

Cross-Referencing   Geant4
Geant4/examples/extended/parameterisations/Par04/training/utils/plotters.py

Version: [ ReleaseNotes ] [ 1.0 ] [ 1.1 ] [ 2.0 ] [ 3.0 ] [ 3.1 ] [ 3.2 ] [ 4.0 ] [ 4.0.p1 ] [ 4.0.p2 ] [ 4.1 ] [ 4.1.p1 ] [ 5.0 ] [ 5.0.p1 ] [ 5.1 ] [ 5.1.p1 ] [ 5.2 ] [ 5.2.p1 ] [ 5.2.p2 ] [ 6.0 ] [ 6.0.p1 ] [ 6.1 ] [ 6.2 ] [ 6.2.p1 ] [ 6.2.p2 ] [ 7.0 ] [ 7.0.p1 ] [ 7.1 ] [ 7.1.p1 ] [ 8.0 ] [ 8.0.p1 ] [ 8.1 ] [ 8.1.p1 ] [ 8.1.p2 ] [ 8.2 ] [ 8.2.p1 ] [ 8.3 ] [ 8.3.p1 ] [ 8.3.p2 ] [ 9.0 ] [ 9.0.p1 ] [ 9.0.p2 ] [ 9.1 ] [ 9.1.p1 ] [ 9.1.p2 ] [ 9.1.p3 ] [ 9.2 ] [ 9.2.p1 ] [ 9.2.p2 ] [ 9.2.p3 ] [ 9.2.p4 ] [ 9.3 ] [ 9.3.p1 ] [ 9.3.p2 ] [ 9.4 ] [ 9.4.p1 ] [ 9.4.p2 ] [ 9.4.p3 ] [ 9.4.p4 ] [ 9.5 ] [ 9.5.p1 ] [ 9.5.p2 ] [ 9.6 ] [ 9.6.p1 ] [ 9.6.p2 ] [ 9.6.p3 ] [ 9.6.p4 ] [ 10.0 ] [ 10.0.p1 ] [ 10.0.p2 ] [ 10.0.p3 ] [ 10.0.p4 ] [ 10.1 ] [ 10.1.p1 ] [ 10.1.p2 ] [ 10.1.p3 ] [ 10.2 ] [ 10.2.p1 ] [ 10.2.p2 ] [ 10.2.p3 ] [ 10.3 ] [ 10.3.p1 ] [ 10.3.p2 ] [ 10.3.p3 ] [ 10.4 ] [ 10.4.p1 ] [ 10.4.p2 ] [ 10.4.p3 ] [ 10.5 ] [ 10.5.p1 ] [ 10.6 ] [ 10.6.p1 ] [ 10.6.p2 ] [ 10.6.p3 ] [ 10.7 ] [ 10.7.p1 ] [ 10.7.p2 ] [ 10.7.p3 ] [ 10.7.p4 ] [ 11.0 ] [ 11.0.p1 ] [ 11.0.p2 ] [ 11.0.p3, ] [ 11.0.p4 ] [ 11.1 ] [ 11.1.1 ] [ 11.1.2 ] [ 11.1.3 ] [ 11.2 ] [ 11.2.1 ] [ 11.2.2 ] [ 11.3.0 ]

Diff markup

Differences between /examples/extended/parameterisations/Par04/training/utils/plotters.py (Version 11.3.0) and /examples/extended/parameterisations/Par04/training/utils/plotters.py (Version 11.1)


  1 from dataclasses import dataclass                   1 from dataclasses import dataclass
  2 from typing import Tuple                            2 from typing import Tuple
  3                                                     3 
  4 import numpy as np                                  4 import numpy as np
  5 from matplotlib import pyplot as plt                5 from matplotlib import pyplot as plt
  6 from scipy.optimize import curve_fit                6 from scipy.optimize import curve_fit
  7                                                     7 
  8 from core.constants import N_CELLS_Z, N_CELLS_      8 from core.constants import N_CELLS_Z, N_CELLS_R, VALID_DIR, SIZE_Z, SIZE_R, HISTOGRAM_TYPE, FULL_SIM_HISTOGRAM_COLOR, \
  9     ML_SIM_HISTOGRAM_COLOR, FULL_SIM_GAUSSIAN_      9     ML_SIM_HISTOGRAM_COLOR, FULL_SIM_GAUSSIAN_COLOR, ML_SIM_GAUSSIAN_COLOR
 10 from utils.observables import LongitudinalProf     10 from utils.observables import LongitudinalProfile, ProfileType, Profile, Energy
 11                                                    11 
 12 plt.rcParams.update({"font.size": 22})             12 plt.rcParams.update({"font.size": 22})
 13                                                    13 
 14                                                    14 
 15 @dataclass                                         15 @dataclass
 16 class Plotter:                                     16 class Plotter:
 17     """ An abstract class defining interface o     17     """ An abstract class defining interface of all plotters.
 18                                                    18 
 19     Do not use this class directly. Use Profil     19     Do not use this class directly. Use ProfilePlotter or EnergyPlotter instead.
 20                                                    20 
 21     Attributes:                                    21     Attributes:
 22         _particle_energy: An integer which is      22         _particle_energy: An integer which is energy of the primary particle in GeV units.
 23         _particle_angle: An integer which is a     23         _particle_angle: An integer which is an angle of the primary particle in degrees.
 24         _geometry: A string which is a name of     24         _geometry: A string which is a name of the calorimeter geometry (e.g. SiW, SciPb).
 25                                                    25 
 26     """                                            26     """
 27     _particle_energy: int                          27     _particle_energy: int
 28     _particle_angle: int                           28     _particle_angle: int
 29     _geometry: str                                 29     _geometry: str
 30                                                    30 
 31     def plot_and_save(self):                       31     def plot_and_save(self):
 32         pass                                       32         pass
 33                                                    33 
 34                                                    34 
 35 def _gaussian(x: np.ndarray, a: float, mu: flo     35 def _gaussian(x: np.ndarray, a: float, mu: float, sigma: float) -> np.ndarray:
 36     """ Computes a value of a Gaussian.            36     """ Computes a value of a Gaussian.
 37                                                    37 
 38     Args:                                          38     Args:
 39         x: An argument of a function.              39         x: An argument of a function.
 40         a: A scaling parameter.                    40         a: A scaling parameter.
 41         mu: A mean.                                41         mu: A mean.
 42         sigma: A variance.                         42         sigma: A variance.
 43                                                    43 
 44     Returns:                                       44     Returns:
 45         A value of a function for given argume     45         A value of a function for given arguments.
 46                                                    46 
 47     """                                            47     """
 48     return a * np.exp(-((x - mu)**2 / (2 * sig     48     return a * np.exp(-((x - mu)**2 / (2 * sigma**2)))
 49                                                    49 
 50                                                    50 
 51 def _best_fit(data: np.ndarray,                    51 def _best_fit(data: np.ndarray,
 52               bins: np.ndarray,                    52               bins: np.ndarray,
 53               hist: bool = False) -> Tuple[np.     53               hist: bool = False) -> Tuple[np.ndarray, np.ndarray]:
 54     """ Finds estimated shape of a Gaussian us     54     """ Finds estimated shape of a Gaussian using Use non-linear least squares.
 55                                                    55 
 56     Args:                                          56     Args:
 57         data: A numpy array with values of obs     57         data: A numpy array with values of observables from multiple events.
 58         bins: A numpy array specifying histogr     58         bins: A numpy array specifying histogram bins.
 59         hist: If histogram is calculated. Then     59         hist: If histogram is calculated. Then data is the frequencies.
 60                                                    60 
 61     Returns:                                       61     Returns:
 62         A tuple of two lists. Xs and Ys of pre     62         A tuple of two lists. Xs and Ys of predicted curve.
 63                                                    63 
 64     """                                            64     """
 65     # Calculate histogram.                         65     # Calculate histogram.
 66     if not hist:                                   66     if not hist:
 67         hist, _ = np.histogram(data, bins)         67         hist, _ = np.histogram(data, bins)
 68     else:                                          68     else:
 69         hist = data                                69         hist = data
 70                                                    70 
 71     # Choose only those bins which are nonzero     71     # Choose only those bins which are nonzero. Nonzero() return a tuple of arrays. In this case it has a length = 1,
 72     # hence we are interested in its first ele     72     # hence we are interested in its first element.
 73     indices = hist.nonzero()[0]                    73     indices = hist.nonzero()[0]
 74                                                    74 
 75     # Based on previously chosen nonzero bin,      75     # Based on previously chosen nonzero bin, calculate position of xs and ys_bar (true values) which will be used in
 76     # fitting procedure. Len(bins) == len(hist     76     # fitting procedure. Len(bins) == len(hist + 1), so we choose middles of bins as xs.
 77     bins_middles = (bins[:-1] + bins[1:]) / 2      77     bins_middles = (bins[:-1] + bins[1:]) / 2
 78     xs = bins_middles[indices]                     78     xs = bins_middles[indices]
 79     ys_bar = hist[indices]                         79     ys_bar = hist[indices]
 80                                                    80 
 81     # Set initial parameters for curve fitter.     81     # Set initial parameters for curve fitter.
 82     a0 = np.max(ys_bar)                            82     a0 = np.max(ys_bar)
 83     mu0 = np.mean(xs)                              83     mu0 = np.mean(xs)
 84     sigma0 = np.var(xs)                            84     sigma0 = np.var(xs)
 85                                                    85 
 86     # Fit a Gaussian to the prepared data.         86     # Fit a Gaussian to the prepared data.
 87     (a, mu, sigma), _ = curve_fit(f=_gaussian,     87     (a, mu, sigma), _ = curve_fit(f=_gaussian,
 88                                   xdata=xs,        88                                   xdata=xs,
 89                                   ydata=ys_bar     89                                   ydata=ys_bar,
 90                                   p0=[a0, mu0,     90                                   p0=[a0, mu0, sigma0],
 91                                   method="trf"     91                                   method="trf",
 92                                   maxfev=1000)     92                                   maxfev=1000)
 93                                                    93 
 94     # Calculate values of an approximation in      94     # Calculate values of an approximation in given points and return values.
 95     ys = _gaussian(xs, a, mu, sigma)               95     ys = _gaussian(xs, a, mu, sigma)
 96     return xs, ys                                  96     return xs, ys
 97                                                    97 
 98                                                    98 
 99 @dataclass                                         99 @dataclass
100 class ProfilePlotter(Plotter):                    100 class ProfilePlotter(Plotter):
101     """ Plotter responsible for preparing plot    101     """ Plotter responsible for preparing plots of profiles and their first and second moments.
102                                                   102 
103     Attributes:                                   103     Attributes:
104         _full_simulation: A numpy array repres    104         _full_simulation: A numpy array representing a profile of data generated by Geant4.
105         _ml_simulation: A numpy array represen    105         _ml_simulation: A numpy array representing a profile of data generated by ML model.
106         _plot_gaussian: A boolean. Decides whe    106         _plot_gaussian: A boolean. Decides whether first and second moment should be plotted as a histogram or
107             a fitted gaussian.                    107             a fitted gaussian.
108         _profile_type: An enum. A profile can     108         _profile_type: An enum. A profile can be either lateral or longitudinal.
109                                                   109 
110     """                                           110     """
111     _full_simulation: Profile                     111     _full_simulation: Profile
112     _ml_simulation: Profile                       112     _ml_simulation: Profile
113     _plot_gaussian: bool = False                  113     _plot_gaussian: bool = False
114                                                   114 
115     def __post_init__(self):                      115     def __post_init__(self):
116         # Check if profiles are either both lo    116         # Check if profiles are either both longitudinal or lateral.
117         full_simulation_type = type(self._full    117         full_simulation_type = type(self._full_simulation)
118         ml_generation_type = type(self._ml_sim    118         ml_generation_type = type(self._ml_simulation)
119         assert full_simulation_type == ml_gene    119         assert full_simulation_type == ml_generation_type, "Both profiles within a ProfilePlotter must be the same " \
120                                                   120                                                            "type."
121                                                   121 
122         # Set an attribute with profile type.     122         # Set an attribute with profile type.
123         if full_simulation_type == Longitudina    123         if full_simulation_type == LongitudinalProfile:
124             self._profile_type = ProfileType.L    124             self._profile_type = ProfileType.LONGITUDINAL
125         else:                                     125         else:
126             self._profile_type = ProfileType.L    126             self._profile_type = ProfileType.LATERAL
127                                                   127 
128     def _plot_and_save_customizable_histogram(    128     def _plot_and_save_customizable_histogram(
129             self,                                 129             self,
130             full_simulation: np.ndarray,          130             full_simulation: np.ndarray,
131             ml_simulation: np.ndarray,            131             ml_simulation: np.ndarray,
132             bins: np.ndarray,                     132             bins: np.ndarray,
133             xlabel: str,                          133             xlabel: str,
134             observable_name: str,                 134             observable_name: str,
135             plot_profile: bool = False,           135             plot_profile: bool = False,
136             y_log_scale: bool = False) -> None    136             y_log_scale: bool = False) -> None:
137         """ Prepares and saves a histogram for    137         """ Prepares and saves a histogram for a given pair of observables.
138                                                   138 
139         Args:                                     139         Args:
140             full_simulation: A numpy array of     140             full_simulation: A numpy array of observables coming from full simulation.
141             ml_simulation: A numpy array of ob    141             ml_simulation: A numpy array of observables coming from ML simulation.
142             bins: A numpy array specifying his    142             bins: A numpy array specifying histogram bins.
143             xlabel: A string. Name of x-axis o    143             xlabel: A string. Name of x-axis on the plot.
144             observable_name: A string. Name of    144             observable_name: A string. Name of plotted observable.
145             plot_profile: A boolean. If set to    145             plot_profile: A boolean. If set to True, full_simulation and ml_simulation are histogram weights while x is
146                 defined by the number of layer    146                 defined by the number of layers. This means that in order to plot histogram (and gaussian), one first
147                 need to create a data repeatin    147                 need to create a data repeating each layer or R index appropriate number of times. Should be set to True
148                 only while plotting profiles n    148                 only while plotting profiles not first or second moments.
149             y_log_scale: A boolean. Used log s    149             y_log_scale: A boolean. Used log scale on y-axis is set to True.
150                                                   150 
151         Returns:                                  151         Returns:
152             None.                                 152             None.
153                                                   153 
154         """                                       154         """
155         fig, axes = plt.subplots(2,               155         fig, axes = plt.subplots(2,
156                                  1,               156                                  1,
157                                  figsize=(15,     157                                  figsize=(15, 10),
158                                  clear=True,      158                                  clear=True,
159                                  sharex="all")    159                                  sharex="all")
160                                                   160 
161         # Plot histograms.                        161         # Plot histograms.
162         if plot_profile:                          162         if plot_profile:
163             # We already have the bins (layers    163             # We already have the bins (layers) and freqencies (energies),
164             # therefore directly plotting a st    164             # therefore directly plotting a step plot + lines instead of a hist plot.
165             axes[0].step(bins[:-1],               165             axes[0].step(bins[:-1],
166                          full_simulation,         166                          full_simulation,
167                          label="FullSim",         167                          label="FullSim",
168                          color=FULL_SIM_HISTOG    168                          color=FULL_SIM_HISTOGRAM_COLOR)
169             axes[0].step(bins[:-1],               169             axes[0].step(bins[:-1],
170                          ml_simulation,           170                          ml_simulation,
171                          label="MLSim",           171                          label="MLSim",
172                          color=ML_SIM_HISTOGRA    172                          color=ML_SIM_HISTOGRAM_COLOR)
173             axes[0].vlines(x=bins[0],             173             axes[0].vlines(x=bins[0],
174                            ymin=0,                174                            ymin=0,
175                            ymax=full_simulatio    175                            ymax=full_simulation[0],
176                            color=FULL_SIM_HIST    176                            color=FULL_SIM_HISTOGRAM_COLOR)
177             axes[0].vlines(x=bins[-2],            177             axes[0].vlines(x=bins[-2],
178                            ymin=0,                178                            ymin=0,
179                            ymax=full_simulatio    179                            ymax=full_simulation[-1],
180                            color=FULL_SIM_HIST    180                            color=FULL_SIM_HISTOGRAM_COLOR)
181             axes[0].vlines(x=bins[0],             181             axes[0].vlines(x=bins[0],
182                            ymin=0,                182                            ymin=0,
183                            ymax=ml_simulation[    183                            ymax=ml_simulation[0],
184                            color=ML_SIM_HISTOG    184                            color=ML_SIM_HISTOGRAM_COLOR)
185             axes[0].vlines(x=bins[-2],            185             axes[0].vlines(x=bins[-2],
186                            ymin=0,                186                            ymin=0,
187                            ymax=ml_simulation[    187                            ymax=ml_simulation[-1],
188                            color=ML_SIM_HISTOG    188                            color=ML_SIM_HISTOGRAM_COLOR)
189             axes[0].set_ylim(0, None)             189             axes[0].set_ylim(0, None)
190                                                   190 
191             # For using it later for the ratio    191             # For using it later for the ratios.
192             energy_full_sim, energy_ml_sim = f    192             energy_full_sim, energy_ml_sim = full_simulation, ml_simulation
193         else:                                     193         else:
194             energy_full_sim, _, _ = axes[0].hi    194             energy_full_sim, _, _ = axes[0].hist(
195                 x=full_simulation,                195                 x=full_simulation,
196                 bins=bins,                        196                 bins=bins,
197                 label="FullSim",                  197                 label="FullSim",
198                 histtype=HISTOGRAM_TYPE,          198                 histtype=HISTOGRAM_TYPE,
199                 color=FULL_SIM_HISTOGRAM_COLOR    199                 color=FULL_SIM_HISTOGRAM_COLOR)
200             energy_ml_sim, _, _ = axes[0].hist    200             energy_ml_sim, _, _ = axes[0].hist(x=ml_simulation,
201                                                   201                                                bins=bins,
202                                                   202                                                label="MLSim",
203                                                   203                                                histtype=HISTOGRAM_TYPE,
204                                                   204                                                color=ML_SIM_HISTOGRAM_COLOR)
205                                                   205 
206         # Plot Gaussians if needed.               206         # Plot Gaussians if needed.
207         if self._plot_gaussian:                   207         if self._plot_gaussian:
208             if plot_profile:                      208             if plot_profile:
209                 (xs_full_sim, ys_full_sim) = _    209                 (xs_full_sim, ys_full_sim) = _best_fit(full_simulation,
210                                                   210                                                        bins,
211                                                   211                                                        hist=True)
212                 (xs_ml_sim, ys_ml_sim) = _best    212                 (xs_ml_sim, ys_ml_sim) = _best_fit(ml_simulation,
213                                                   213                                                    bins,
214                                                   214                                                    hist=True)
215             else:                                 215             else:
216                 (xs_full_sim, ys_full_sim) = _    216                 (xs_full_sim, ys_full_sim) = _best_fit(full_simulation, bins)
217                 (xs_ml_sim, ys_ml_sim) = _best    217                 (xs_ml_sim, ys_ml_sim) = _best_fit(ml_simulation, bins)
218             axes[0].plot(xs_full_sim,             218             axes[0].plot(xs_full_sim,
219                          ys_full_sim,             219                          ys_full_sim,
220                          color=FULL_SIM_GAUSSI    220                          color=FULL_SIM_GAUSSIAN_COLOR,
221                          label="FullSim")         221                          label="FullSim")
222             axes[0].plot(xs_ml_sim,               222             axes[0].plot(xs_ml_sim,
223                          ys_ml_sim,               223                          ys_ml_sim,
224                          color=ML_SIM_GAUSSIAN    224                          color=ML_SIM_GAUSSIAN_COLOR,
225                          label="MLSim")           225                          label="MLSim")
226                                                   226 
227         if y_log_scale:                           227         if y_log_scale:
228             axes[0].set_yscale("log")             228             axes[0].set_yscale("log")
229         axes[0].legend(loc="best")                229         axes[0].legend(loc="best")
230         axes[0].set_xlabel(xlabel)                230         axes[0].set_xlabel(xlabel)
231         axes[0].set_ylabel("Energy [Mev]")        231         axes[0].set_ylabel("Energy [Mev]")
232         axes[0].set_title(                        232         axes[0].set_title(
233             f" $e^-$, {self._particle_energy}     233             f" $e^-$, {self._particle_energy} [GeV], {self._particle_angle}$^{{\circ}}$, {self._geometry}"
234         )                                         234         )
235                                                   235 
236         # Calculate ratios.                       236         # Calculate ratios.
237         ratio = np.divide(energy_ml_sim,          237         ratio = np.divide(energy_ml_sim,
238                           energy_full_sim,        238                           energy_full_sim,
239                           out=np.ones_like(ene    239                           out=np.ones_like(energy_ml_sim),
240                           where=(energy_full_s    240                           where=(energy_full_sim != 0))
241         # Since len(bins) == 1 + data, we calc    241         # Since len(bins) == 1 + data, we calculate middles of bins as xs.
242         bins_middles = (bins[:-1] + bins[1:])     242         bins_middles = (bins[:-1] + bins[1:]) / 2
243         axes[1].plot(bins_middles, ratio, "-o"    243         axes[1].plot(bins_middles, ratio, "-o")
244         axes[1].set_xlabel(xlabel)                244         axes[1].set_xlabel(xlabel)
245         axes[1].set_ylabel("MLSim/FullSim")       245         axes[1].set_ylabel("MLSim/FullSim")
246         axes[1].axhline(y=1, color="black")       246         axes[1].axhline(y=1, color="black")
247         plt.savefig(                              247         plt.savefig(
248             f"{VALID_DIR}/{observable_name}_Ge    248             f"{VALID_DIR}/{observable_name}_Geo_{self._geometry}_E_{self._particle_energy}_"
249             + f"Angle_{self._particle_angle}.p    249             + f"Angle_{self._particle_angle}.png")
250         plt.clf()                                 250         plt.clf()
251                                                   251 
252     def _plot_profile(self) -> None:              252     def _plot_profile(self) -> None:
253         """ Plots profile of an observable.       253         """ Plots profile of an observable.
254                                                   254 
255         Returns:                                  255         Returns:
256             None.                                 256             None.
257                                                   257 
258         """                                       258         """
259         full_simulation_profile = self._full_s    259         full_simulation_profile = self._full_simulation.calc_profile()
260         ml_simulation_profile = self._ml_simul    260         ml_simulation_profile = self._ml_simulation.calc_profile()
261         if self._profile_type == ProfileType.L    261         if self._profile_type == ProfileType.LONGITUDINAL:
262             # matplotlib will include the righ    262             # matplotlib will include the right-limit for the last bar,
263             # hence extending by 1.               263             # hence extending by 1.
264             bins = np.linspace(0, N_CELLS_Z, N    264             bins = np.linspace(0, N_CELLS_Z, N_CELLS_Z + 1)
265             observable_name = "LongProf"          265             observable_name = "LongProf"
266             xlabel = "Layer index"                266             xlabel = "Layer index"
267         else:                                     267         else:
268             bins = np.linspace(0, N_CELLS_R, N    268             bins = np.linspace(0, N_CELLS_R, N_CELLS_R + 1)
269             observable_name = "LatProf"           269             observable_name = "LatProf"
270             xlabel = "R index"                    270             xlabel = "R index"
271         self._plot_and_save_customizable_histo    271         self._plot_and_save_customizable_histogram(full_simulation_profile,
272                                                   272                                                    ml_simulation_profile,
273                                                   273                                                    bins,
274                                                   274                                                    xlabel,
275                                                   275                                                    observable_name,
276                                                   276                                                    plot_profile=True)
277                                                   277 
278     def _plot_first_moment(self) -> None:         278     def _plot_first_moment(self) -> None:
279         """ Plots and saves a first moment of     279         """ Plots and saves a first moment of an observable's profile.
280                                                   280 
281         Returns:                                  281         Returns:
282             None.                                 282             None.
283                                                   283 
284         """                                       284         """
285         full_simulation_first_moment = self._f    285         full_simulation_first_moment = self._full_simulation.calc_first_moment(
286         )                                         286         )
287         ml_simulation_first_moment = self._ml_    287         ml_simulation_first_moment = self._ml_simulation.calc_first_moment()
288         if self._profile_type == ProfileType.L    288         if self._profile_type == ProfileType.LONGITUDINAL:
289             xlabel = "$<\lambda> [mm]$"           289             xlabel = "$<\lambda> [mm]$"
290             observable_name = "LongFirstMoment    290             observable_name = "LongFirstMoment"
291             bins = np.linspace(0, 0.4 * N_CELL    291             bins = np.linspace(0, 0.4 * N_CELLS_Z * SIZE_Z, 128)
292         else:                                     292         else:
293             xlabel = "$<r> [mm]$"                 293             xlabel = "$<r> [mm]$"
294             observable_name = "LatFirstMoment"    294             observable_name = "LatFirstMoment"
295             bins = np.linspace(0, 0.75 * N_CEL    295             bins = np.linspace(0, 0.75 * N_CELLS_R * SIZE_R, 128)
296                                                   296 
297         self._plot_and_save_customizable_histo    297         self._plot_and_save_customizable_histogram(
298             full_simulation_first_moment, ml_s    298             full_simulation_first_moment, ml_simulation_first_moment, bins,
299             xlabel, observable_name)              299             xlabel, observable_name)
300                                                   300 
301     def _plot_second_moment(self) -> None:        301     def _plot_second_moment(self) -> None:
302         """ Plots and saves a second moment of    302         """ Plots and saves a second moment of an observable's profile.
303                                                   303 
304         Returns:                                  304         Returns:
305             None.                                 305             None.
306                                                   306 
307         """                                       307         """
308         full_simulation_second_moment = self._    308         full_simulation_second_moment = self._full_simulation.calc_second_moment(
309         )                                         309         )
310         ml_simulation_second_moment = self._ml    310         ml_simulation_second_moment = self._ml_simulation.calc_second_moment()
311         if self._profile_type == ProfileType.L    311         if self._profile_type == ProfileType.LONGITUDINAL:
312             xlabel = "$<\lambda^{2}> [mm^{2}]$    312             xlabel = "$<\lambda^{2}> [mm^{2}]$"
313             observable_name = "LongSecondMomen    313             observable_name = "LongSecondMoment"
314             bins = np.linspace(0, pow(N_CELLS_    314             bins = np.linspace(0, pow(N_CELLS_Z * SIZE_Z, 2) / 35., 128)
315         else:                                     315         else:
316             xlabel = "$<r^{2}> [mm^{2}]$"         316             xlabel = "$<r^{2}> [mm^{2}]$"
317             observable_name = "LatSecondMoment    317             observable_name = "LatSecondMoment"
318             bins = np.linspace(0, pow(N_CELLS_    318             bins = np.linspace(0, pow(N_CELLS_R * SIZE_R, 2) / 8., 128)
319                                                   319 
320         self._plot_and_save_customizable_histo    320         self._plot_and_save_customizable_histogram(
321             full_simulation_second_moment, ml_    321             full_simulation_second_moment, ml_simulation_second_moment, bins,
322             xlabel, observable_name)              322             xlabel, observable_name)
323                                                   323 
324     def plot_and_save(self) -> None:              324     def plot_and_save(self) -> None:
325         """ Main plotting function.               325         """ Main plotting function.
326                                                   326 
327         Calls private methods and prints the i    327         Calls private methods and prints the information about progress.
328                                                   328 
329         Returns:                                  329         Returns:
330             None.                                 330             None.
331                                                   331 
332         """                                       332         """
333         if self._profile_type == ProfileType.L    333         if self._profile_type == ProfileType.LONGITUDINAL:
334             profile_type_name = "longitudinal"    334             profile_type_name = "longitudinal"
335         else:                                     335         else:
336             profile_type_name = "lateral"         336             profile_type_name = "lateral"
337         print(f"Plotting the {profile_type_nam    337         print(f"Plotting the {profile_type_name} profile...")
338         self._plot_profile()                      338         self._plot_profile()
339         print(f"Plotting the first moment of {    339         print(f"Plotting the first moment of {profile_type_name} profile...")
340         self._plot_first_moment()                 340         self._plot_first_moment()
341         print(f"Plotting the second moment of     341         print(f"Plotting the second moment of {profile_type_name} profile...")
342         self._plot_second_moment()                342         self._plot_second_moment()
343                                                   343 
344                                                   344 
345 @dataclass                                        345 @dataclass
346 class EnergyPlotter(Plotter):                     346 class EnergyPlotter(Plotter):
347     """ Plotter responsible for preparing plot    347     """ Plotter responsible for preparing plots of profiles and their first and second moments.
348                                                   348 
349     Attributes:                                   349     Attributes:
350         _full_simulation: A numpy array repres    350         _full_simulation: A numpy array representing a profile of data generated by Geant4.
351         _ml_simulation: A numpy array represen    351         _ml_simulation: A numpy array representing a profile of data generated by ML model.
352                                                   352 
353     """                                           353     """
354     _full_simulation: Energy                      354     _full_simulation: Energy
355     _ml_simulation: Energy                        355     _ml_simulation: Energy
356                                                   356 
357     def _plot_total_energy(self, y_log_scale=T    357     def _plot_total_energy(self, y_log_scale=True) -> None:
358         """ Plots and saves a histogram with t    358         """ Plots and saves a histogram with total energy detected in an event.
359                                                   359 
360         Args:                                     360         Args:
361             y_log_scale: A boolean. Used log s    361             y_log_scale: A boolean. Used log scale on y-axis is set to True.
362                                                   362 
363         Returns:                                  363         Returns:
364             None.                                 364             None.
365                                                   365 
366         """                                       366         """
367         full_simulation_total_energy = self._f    367         full_simulation_total_energy = self._full_simulation.calc_total_energy(
368         )                                         368         )
369         ml_simulation_total_energy = self._ml_    369         ml_simulation_total_energy = self._ml_simulation.calc_total_energy()
370                                                   370 
371         plt.figure(figsize=(12, 8))               371         plt.figure(figsize=(12, 8))
372         bins = np.linspace(                       372         bins = np.linspace(
373             np.min(full_simulation_total_energ    373             np.min(full_simulation_total_energy) -
374             np.min(full_simulation_total_energ    374             np.min(full_simulation_total_energy) * 0.05,
375             np.max(full_simulation_total_energ    375             np.max(full_simulation_total_energy) +
376             np.max(full_simulation_total_energ    376             np.max(full_simulation_total_energy) * 0.05, 50)
377         plt.hist(x=full_simulation_total_energ    377         plt.hist(x=full_simulation_total_energy,
378                  histtype=HISTOGRAM_TYPE,         378                  histtype=HISTOGRAM_TYPE,
379                  label="FullSim",                 379                  label="FullSim",
380                  bins=bins,                       380                  bins=bins,
381                  color=FULL_SIM_HISTOGRAM_COLO    381                  color=FULL_SIM_HISTOGRAM_COLOR)
382         plt.hist(x=ml_simulation_total_energy,    382         plt.hist(x=ml_simulation_total_energy,
383                  histtype=HISTOGRAM_TYPE,         383                  histtype=HISTOGRAM_TYPE,
384                  label="MLSim",                   384                  label="MLSim",
385                  bins=bins,                       385                  bins=bins,
386                  color=ML_SIM_HISTOGRAM_COLOR)    386                  color=ML_SIM_HISTOGRAM_COLOR)
387         plt.legend(loc="upper left")              387         plt.legend(loc="upper left")
388         if y_log_scale:                           388         if y_log_scale:
389             plt.yscale("log")                     389             plt.yscale("log")
390         plt.xlabel("Energy [MeV]")                390         plt.xlabel("Energy [MeV]")
391         plt.ylabel("# events")                    391         plt.ylabel("# events")
392         plt.title(                                392         plt.title(
393             f" $e^-$, {self._particle_energy}     393             f" $e^-$, {self._particle_energy} [GeV], {self._particle_angle}$^{{\circ}}$, {self._geometry} "
394         )                                         394         )
395         plt.savefig(                              395         plt.savefig(
396             f"{VALID_DIR}/E_tot_Geo_{self._geo    396             f"{VALID_DIR}/E_tot_Geo_{self._geometry}_E_{self._particle_energy}_Angle_{self._particle_angle}.png"
397         )                                         397         )
398         plt.clf()                                 398         plt.clf()
399                                                   399 
400     def _plot_cell_energy(self) -> None:          400     def _plot_cell_energy(self) -> None:
401         """ Plots and saves a histogram with n    401         """ Plots and saves a histogram with number of detector's cells across whole
402         calorimeter with particular energy det    402         calorimeter with particular energy detected.
403                                                   403 
404         Returns:                                  404         Returns:
405             None.                                 405             None.
406                                                   406 
407         """                                       407         """
408         full_simulation_cell_energy = self._fu    408         full_simulation_cell_energy = self._full_simulation.calc_cell_energy()
409         ml_simulation_cell_energy = self._ml_s    409         ml_simulation_cell_energy = self._ml_simulation.calc_cell_energy()
410                                                   410 
411         log_full_simulation_cell_energy = np.l    411         log_full_simulation_cell_energy = np.log10(
412             full_simulation_cell_energy,          412             full_simulation_cell_energy,
413             out=np.zeros_like(full_simulation_    413             out=np.zeros_like(full_simulation_cell_energy),
414             where=(full_simulation_cell_energy    414             where=(full_simulation_cell_energy != 0))
415         log_ml_simulation_cell_energy = np.log    415         log_ml_simulation_cell_energy = np.log10(
416             ml_simulation_cell_energy,            416             ml_simulation_cell_energy,
417             out=np.zeros_like(ml_simulation_ce    417             out=np.zeros_like(ml_simulation_cell_energy),
418             where=(ml_simulation_cell_energy !    418             where=(ml_simulation_cell_energy != 0))
419         plt.figure(figsize=(12, 8))               419         plt.figure(figsize=(12, 8))
420         bins = np.linspace(-4, 1, 1000)           420         bins = np.linspace(-4, 1, 1000)
421         plt.hist(x=log_full_simulation_cell_en    421         plt.hist(x=log_full_simulation_cell_energy,
422                  bins=bins,                       422                  bins=bins,
423                  histtype=HISTOGRAM_TYPE,         423                  histtype=HISTOGRAM_TYPE,
424                  label="FullSim",                 424                  label="FullSim",
425                  color=FULL_SIM_HISTOGRAM_COLO    425                  color=FULL_SIM_HISTOGRAM_COLOR)
426         plt.hist(x=log_ml_simulation_cell_ener    426         plt.hist(x=log_ml_simulation_cell_energy,
427                  bins=bins,                       427                  bins=bins,
428                  histtype=HISTOGRAM_TYPE,         428                  histtype=HISTOGRAM_TYPE,
429                  label="MLSim",                   429                  label="MLSim",
430                  color=ML_SIM_HISTOGRAM_COLOR)    430                  color=ML_SIM_HISTOGRAM_COLOR)
431         plt.xlabel("log10(E/MeV)")                431         plt.xlabel("log10(E/MeV)")
432         plt.ylim(bottom=1)                        432         plt.ylim(bottom=1)
433         plt.yscale("log")                         433         plt.yscale("log")
434         plt.ylim(bottom=1)                        434         plt.ylim(bottom=1)
435         plt.ylabel("# entries")                   435         plt.ylabel("# entries")
436         plt.title(                                436         plt.title(
437             f" $e^-$, {self._particle_energy}     437             f" $e^-$, {self._particle_energy} [GeV], {self._particle_angle}$^{{\circ}}$, {self._geometry} "
438         )                                         438         )
439         plt.grid(True)                            439         plt.grid(True)
440         plt.legend(loc="upper left")              440         plt.legend(loc="upper left")
441         plt.savefig(                              441         plt.savefig(
442             f"{VALID_DIR}/E_cell_Geo_{self._ge    442             f"{VALID_DIR}/E_cell_Geo_{self._geometry}_E_{self._particle_energy}_Angle_{self._particle_angle}.png"
443         )                                         443         )
444         plt.clf()                                 444         plt.clf()
445                                                   445 
446     def _plot_energy_per_layer(self):             446     def _plot_energy_per_layer(self):
447         """ Plots and saves N_CELLS_Z histogra    447         """ Plots and saves N_CELLS_Z histograms with total energy detected in particular layers.
448                                                   448 
449         Returns:                                  449         Returns:
450             None.                                 450             None.
451                                                   451 
452         """                                       452         """
453         full_simulation_energy_per_layer = sel    453         full_simulation_energy_per_layer = self._full_simulation.calc_energy_per_layer(
454         )                                         454         )
455         ml_simulation_energy_per_layer = self.    455         ml_simulation_energy_per_layer = self._ml_simulation.calc_energy_per_layer(
456         )                                         456         )
457                                                   457 
458         number_of_plots_in_row = 9                458         number_of_plots_in_row = 9
459         number_of_plots_in_column = 5             459         number_of_plots_in_column = 5
460                                                   460 
461         bins = np.linspace(np.min(full_simulat    461         bins = np.linspace(np.min(full_simulation_energy_per_layer - 10),
462                            np.max(full_simulat    462                            np.max(full_simulation_energy_per_layer + 10), 25)
463                                                   463 
464         fig, ax = plt.subplots(number_of_plots    464         fig, ax = plt.subplots(number_of_plots_in_column,
465                                number_of_plots    465                                number_of_plots_in_row,
466                                figsize=(20, 15    466                                figsize=(20, 15),
467                                sharex="all",      467                                sharex="all",
468                                sharey="all",      468                                sharey="all",
469                                constrained_lay    469                                constrained_layout=True)
470                                                   470 
471         for layer_nb in range(N_CELLS_Z):         471         for layer_nb in range(N_CELLS_Z):
472             i = layer_nb // number_of_plots_in    472             i = layer_nb // number_of_plots_in_row
473             j = layer_nb % number_of_plots_in_    473             j = layer_nb % number_of_plots_in_row
474                                                   474 
475             ax[i][j].hist(full_simulation_ener    475             ax[i][j].hist(full_simulation_energy_per_layer[:, layer_nb],
476                           histtype=HISTOGRAM_T    476                           histtype=HISTOGRAM_TYPE,
477                           label="FullSim",        477                           label="FullSim",
478                           bins=bins,              478                           bins=bins,
479                           color=FULL_SIM_HISTO    479                           color=FULL_SIM_HISTOGRAM_COLOR)
480             ax[i][j].hist(ml_simulation_energy    480             ax[i][j].hist(ml_simulation_energy_per_layer[:, layer_nb],
481                           histtype=HISTOGRAM_T    481                           histtype=HISTOGRAM_TYPE,
482                           label="MLSim",          482                           label="MLSim",
483                           bins=bins,              483                           bins=bins,
484                           color=ML_SIM_HISTOGR    484                           color=ML_SIM_HISTOGRAM_COLOR)
485             ax[i][j].set_title(f"Layer {layer_    485             ax[i][j].set_title(f"Layer {layer_nb}", fontsize=13)
486             ax[i][j].set_yscale("log")            486             ax[i][j].set_yscale("log")
487             ax[i][j].tick_params(axis='both',     487             ax[i][j].tick_params(axis='both', which='major', labelsize=10)
488                                                   488 
489         fig.supxlabel("Energy [MeV]", fontsize    489         fig.supxlabel("Energy [MeV]", fontsize=14)
490         fig.supylabel("# entries", fontsize=14    490         fig.supylabel("# entries", fontsize=14)
491         fig.suptitle(                             491         fig.suptitle(
492             f" $e^-$, {self._particle_energy}     492             f" $e^-$, {self._particle_energy} [GeV], {self._particle_angle}$^{{\circ}}$, {self._geometry} "
493         )                                         493         )
494                                                   494 
495         # Take legend from one plot and make i    495         # Take legend from one plot and make it a global legend.
496         handles, labels = ax[0][0].get_legend_    496         handles, labels = ax[0][0].get_legend_handles_labels()
497         fig.legend(handles, labels, bbox_to_an    497         fig.legend(handles, labels, bbox_to_anchor=(1.15, 0.5))
498                                                   498 
499         plt.savefig(                              499         plt.savefig(
500             f"{VALID_DIR}/E_layer_Geo_{self._g    500             f"{VALID_DIR}/E_layer_Geo_{self._geometry}_E_{self._particle_energy}_Angle_{self._particle_angle}.png",
501             bbox_inches="tight")                  501             bbox_inches="tight")
502         plt.clf()                                 502         plt.clf()
503                                                   503 
504     def plot_and_save(self):                      504     def plot_and_save(self):
505         """ Main plotting function.               505         """ Main plotting function.
506                                                   506 
507         Calls private methods and prints the i    507         Calls private methods and prints the information about progress.
508                                                   508 
509         Returns:                                  509         Returns:
510             None.                                 510             None.
511                                                   511 
512         """                                       512         """
513         print("Plotting total energy...")         513         print("Plotting total energy...")
514         self._plot_total_energy()                 514         self._plot_total_energy()
515         print("Plotting cell energy...")          515         print("Plotting cell energy...")
516         self._plot_cell_energy()                  516         self._plot_cell_energy()
517         print("Plotting energy per layer...")     517         print("Plotting energy per layer...")
518         self._plot_energy_per_layer()             518         self._plot_energy_per_layer()