Geant4 Cross Reference

Cross-Referencing   Geant4
Geant4/examples/extended/parameterisations/Par04/training/validate.py

Version: [ ReleaseNotes ] [ 1.0 ] [ 1.1 ] [ 2.0 ] [ 3.0 ] [ 3.1 ] [ 3.2 ] [ 4.0 ] [ 4.0.p1 ] [ 4.0.p2 ] [ 4.1 ] [ 4.1.p1 ] [ 5.0 ] [ 5.0.p1 ] [ 5.1 ] [ 5.1.p1 ] [ 5.2 ] [ 5.2.p1 ] [ 5.2.p2 ] [ 6.0 ] [ 6.0.p1 ] [ 6.1 ] [ 6.2 ] [ 6.2.p1 ] [ 6.2.p2 ] [ 7.0 ] [ 7.0.p1 ] [ 7.1 ] [ 7.1.p1 ] [ 8.0 ] [ 8.0.p1 ] [ 8.1 ] [ 8.1.p1 ] [ 8.1.p2 ] [ 8.2 ] [ 8.2.p1 ] [ 8.3 ] [ 8.3.p1 ] [ 8.3.p2 ] [ 9.0 ] [ 9.0.p1 ] [ 9.0.p2 ] [ 9.1 ] [ 9.1.p1 ] [ 9.1.p2 ] [ 9.1.p3 ] [ 9.2 ] [ 9.2.p1 ] [ 9.2.p2 ] [ 9.2.p3 ] [ 9.2.p4 ] [ 9.3 ] [ 9.3.p1 ] [ 9.3.p2 ] [ 9.4 ] [ 9.4.p1 ] [ 9.4.p2 ] [ 9.4.p3 ] [ 9.4.p4 ] [ 9.5 ] [ 9.5.p1 ] [ 9.5.p2 ] [ 9.6 ] [ 9.6.p1 ] [ 9.6.p2 ] [ 9.6.p3 ] [ 9.6.p4 ] [ 10.0 ] [ 10.0.p1 ] [ 10.0.p2 ] [ 10.0.p3 ] [ 10.0.p4 ] [ 10.1 ] [ 10.1.p1 ] [ 10.1.p2 ] [ 10.1.p3 ] [ 10.2 ] [ 10.2.p1 ] [ 10.2.p2 ] [ 10.2.p3 ] [ 10.3 ] [ 10.3.p1 ] [ 10.3.p2 ] [ 10.3.p3 ] [ 10.4 ] [ 10.4.p1 ] [ 10.4.p2 ] [ 10.4.p3 ] [ 10.5 ] [ 10.5.p1 ] [ 10.6 ] [ 10.6.p1 ] [ 10.6.p2 ] [ 10.6.p3 ] [ 10.7 ] [ 10.7.p1 ] [ 10.7.p2 ] [ 10.7.p3 ] [ 10.7.p4 ] [ 11.0 ] [ 11.0.p1 ] [ 11.0.p2 ] [ 11.0.p3, ] [ 11.0.p4 ] [ 11.1 ] [ 11.1.1 ] [ 11.1.2 ] [ 11.1.3 ] [ 11.2 ] [ 11.2.1 ] [ 11.2.2 ] [ 11.3.0 ]

Diff markup

Differences between /examples/extended/parameterisations/Par04/training/validate.py (Version 11.3.0) and /examples/extended/parameterisations/Par04/training/validate.py (Version 9.6.p3)


  1 import argparse                                   
  2                                                   
  3 import numpy as np                                
  4                                                   
  5 from core.constants import INIT_DIR, GEN_DIR,     
  6 from utils.observables import LongitudinalProf    
  7 from utils.plotters import ProfilePlotter, Ene    
  8 from utils.preprocess import load_showers         
  9                                                   
 10                                                   
 11 def parse_args():                                 
 12     p = argparse.ArgumentParser()                 
 13     p.add_argument("--geometry", type=str, def    
 14     p.add_argument("--energy", type=int, defau    
 15     p.add_argument("--angle", type=int, defaul    
 16     args = p.parse_args()                         
 17     return args                                   
 18                                                   
 19                                                   
 20 # main function                                   
 21 def main():                                       
 22     # Parse commandline arguments                 
 23     args = parse_args()                           
 24     particle_energy = args.energy                 
 25     particle_angle = args.angle                   
 26     geometry = args.geometry                      
 27     # 1. Full simulation data loading             
 28     # Load energy of showers from a single geo    
 29     e_layer_g4 = load_showers(INIT_DIR, geomet    
 30                               particle_angle)     
 31     # 2. Fast simulation data loading, scaling    
 32     vae_energies = np.load(f"{GEN_DIR}/VAE_Gen    
 33     # Reshape the events into 3D                  
 34     e_layer_vae = vae_energies.reshape((len(va    
 35                                                   
 36     print("Data has been loaded.")                
 37                                                   
 38     # 3. Create observables from raw data.        
 39     full_sim_long = LongitudinalProfile(_input    
 40     full_sim_lat = LateralProfile(_input=e_lay    
 41     full_sim_energy = Energy(_input=e_layer_g4    
 42     ml_sim_long = LongitudinalProfile(_input=e    
 43     ml_sim_lat = LateralProfile(_input=e_layer    
 44     ml_sim_energy = Energy(_input=e_layer_vae)    
 45                                                   
 46     print("Created observables.")                 
 47                                                   
 48     # 4. Plot observables                         
 49     longitudinal_profile_plotter = ProfilePlot    
 50                                                   
 51     lateral_profile_plotter = ProfilePlotter(p    
 52                                              g    
 53     energy_plotter = EnergyPlotter(particle_en    
 54                                                   
 55     longitudinal_profile_plotter.plot_and_save    
 56     lateral_profile_plotter.plot_and_save()       
 57     energy_plotter.plot_and_save()                
 58     print("Done.")                                
 59                                                   
 60                                                   
 61 if __name__ == "__main__":                        
 62     exit(main())