Geant4 Cross Reference

Cross-Referencing   Geant4
Geant4/examples/extended/parameterisations/Par04/training/train.py

Version: [ ReleaseNotes ] [ 1.0 ] [ 1.1 ] [ 2.0 ] [ 3.0 ] [ 3.1 ] [ 3.2 ] [ 4.0 ] [ 4.0.p1 ] [ 4.0.p2 ] [ 4.1 ] [ 4.1.p1 ] [ 5.0 ] [ 5.0.p1 ] [ 5.1 ] [ 5.1.p1 ] [ 5.2 ] [ 5.2.p1 ] [ 5.2.p2 ] [ 6.0 ] [ 6.0.p1 ] [ 6.1 ] [ 6.2 ] [ 6.2.p1 ] [ 6.2.p2 ] [ 7.0 ] [ 7.0.p1 ] [ 7.1 ] [ 7.1.p1 ] [ 8.0 ] [ 8.0.p1 ] [ 8.1 ] [ 8.1.p1 ] [ 8.1.p2 ] [ 8.2 ] [ 8.2.p1 ] [ 8.3 ] [ 8.3.p1 ] [ 8.3.p2 ] [ 9.0 ] [ 9.0.p1 ] [ 9.0.p2 ] [ 9.1 ] [ 9.1.p1 ] [ 9.1.p2 ] [ 9.1.p3 ] [ 9.2 ] [ 9.2.p1 ] [ 9.2.p2 ] [ 9.2.p3 ] [ 9.2.p4 ] [ 9.3 ] [ 9.3.p1 ] [ 9.3.p2 ] [ 9.4 ] [ 9.4.p1 ] [ 9.4.p2 ] [ 9.4.p3 ] [ 9.4.p4 ] [ 9.5 ] [ 9.5.p1 ] [ 9.5.p2 ] [ 9.6 ] [ 9.6.p1 ] [ 9.6.p2 ] [ 9.6.p3 ] [ 9.6.p4 ] [ 10.0 ] [ 10.0.p1 ] [ 10.0.p2 ] [ 10.0.p3 ] [ 10.0.p4 ] [ 10.1 ] [ 10.1.p1 ] [ 10.1.p2 ] [ 10.1.p3 ] [ 10.2 ] [ 10.2.p1 ] [ 10.2.p2 ] [ 10.2.p3 ] [ 10.3 ] [ 10.3.p1 ] [ 10.3.p2 ] [ 10.3.p3 ] [ 10.4 ] [ 10.4.p1 ] [ 10.4.p2 ] [ 10.4.p3 ] [ 10.5 ] [ 10.5.p1 ] [ 10.6 ] [ 10.6.p1 ] [ 10.6.p2 ] [ 10.6.p3 ] [ 10.7 ] [ 10.7.p1 ] [ 10.7.p2 ] [ 10.7.p3 ] [ 10.7.p4 ] [ 11.0 ] [ 11.0.p1 ] [ 11.0.p2 ] [ 11.0.p3, ] [ 11.0.p4 ] [ 11.1 ] [ 11.1.1 ] [ 11.1.2 ] [ 11.1.3 ] [ 11.2 ] [ 11.2.1 ] [ 11.2.2 ] [ 11.3.0 ]

Diff markup

Differences between /examples/extended/parameterisations/Par04/training/train.py (Version 11.3.0) and /examples/extended/parameterisations/Par04/training/train.py (Version 11.2.2)


  1 from argparse import ArgumentParser                 1 from argparse import ArgumentParser
  2                                                     2 
  3 from core.constants import GPU_IDS, MAX_GPU_ME      3 from core.constants import GPU_IDS, MAX_GPU_MEMORY_ALLOCATION, GLOBAL_CHECKPOINT_DIR
  4 from utils.gpu_limiter import GPULimiter            4 from utils.gpu_limiter import GPULimiter
  5 from utils.preprocess import preprocess             5 from utils.preprocess import preprocess
  6                                                     6 
  7                                                     7 
  8 def parse_args():                                   8 def parse_args():
  9     argument_parser = ArgumentParser()              9     argument_parser = ArgumentParser()
 10     argument_parser.add_argument("--max-gpu-me     10     argument_parser.add_argument("--max-gpu-memory-allocation", type=int, default=MAX_GPU_MEMORY_ALLOCATION)
 11     argument_parser.add_argument("--gpu-ids",      11     argument_parser.add_argument("--gpu-ids", type=str, default=GPU_IDS)
 12     argument_parser.add_argument("--study-name     12     argument_parser.add_argument("--study-name", type=str, default="default_study_name")
 13     args = argument_parser.parse_args()            13     args = argument_parser.parse_args()
 14     return args                                    14     return args
 15                                                    15 
 16                                                    16 
 17 def main():                                        17 def main():
 18     # 0. Parse arguments.                          18     # 0. Parse arguments.
 19     args = parse_args()                            19     args = parse_args()
 20     max_gpu_memory_allocation = args.max_gpu_m     20     max_gpu_memory_allocation = args.max_gpu_memory_allocation
 21     gpu_ids = args.gpu_ids                         21     gpu_ids = args.gpu_ids
 22     study_name = args.study_name                   22     study_name = args.study_name
 23     checkpoint_dir = f"{GLOBAL_CHECKPOINT_DIR}     23     checkpoint_dir = f"{GLOBAL_CHECKPOINT_DIR}/{study_name}"
 24                                                    24 
 25     # 1. Set GPU memory limits.                    25     # 1. Set GPU memory limits.
 26     GPULimiter(_gpu_ids=gpu_ids, _max_gpu_memo     26     GPULimiter(_gpu_ids=gpu_ids, _max_gpu_memory_allocation=max_gpu_memory_allocation)()
 27                                                    27 
 28     # 2. Data loading/preprocessing                28     # 2. Data loading/preprocessing
 29                                                    29 
 30     # The preprocess function reads the data a     30     # The preprocess function reads the data and performs preprocessing and encoding for the values of energy,
 31     # angle and geometry                           31     # angle and geometry
 32     energies_train, cond_e_train, cond_angle_t     32     energies_train, cond_e_train, cond_angle_train, cond_geo_train = preprocess()
 33                                                    33 
 34     # 3. Manufacture model handler.                34     # 3. Manufacture model handler.
 35                                                    35 
 36     # This import must be local because otherw     36     # This import must be local because otherwise it is impossible to call GPULimiter.
 37     from core.model import VAEHandler              37     from core.model import VAEHandler
 38     vae = VAEHandler(_wandb_project_name=study     38     vae = VAEHandler(_wandb_project_name=study_name, _wandb_tags=["single training"], _checkpoint_dir=checkpoint_dir)
 39                                                    39 
 40     # 4. Train model.                              40     # 4. Train model.
 41     histories = vae.train(energies_train,          41     histories = vae.train(energies_train,
 42                           cond_e_train,            42                           cond_e_train,
 43                           cond_angle_train,        43                           cond_angle_train,
 44                           cond_geo_train           44                           cond_geo_train
 45                           )                        45                           )
 46                                                    46 
 47     # Note : One history object can be used to     47     # Note : One history object can be used to plot the loss evaluation as function of the epochs. Remember that the
 48     # function returns a list of those objects     48     # function returns a list of those objects. Each of them represents a different fold of cross validation.
 49                                                    49 
 50                                                    50 
 51 if __name__ == "__main__":                         51 if __name__ == "__main__":
 52     exit(main())                                   52     exit(main())