Geant4 Cross Reference |
1 from argparse import ArgumentParser 1 from argparse import ArgumentParser 2 2 3 from core.constants import GPU_IDS, MAX_GPU_ME 3 from core.constants import GPU_IDS, MAX_GPU_MEMORY_ALLOCATION, GLOBAL_CHECKPOINT_DIR 4 from utils.gpu_limiter import GPULimiter 4 from utils.gpu_limiter import GPULimiter 5 from utils.preprocess import preprocess 5 from utils.preprocess import preprocess 6 6 7 7 8 def parse_args(): 8 def parse_args(): 9 argument_parser = ArgumentParser() 9 argument_parser = ArgumentParser() 10 argument_parser.add_argument("--max-gpu-me 10 argument_parser.add_argument("--max-gpu-memory-allocation", type=int, default=MAX_GPU_MEMORY_ALLOCATION) 11 argument_parser.add_argument("--gpu-ids", 11 argument_parser.add_argument("--gpu-ids", type=str, default=GPU_IDS) 12 argument_parser.add_argument("--study-name 12 argument_parser.add_argument("--study-name", type=str, default="default_study_name") 13 args = argument_parser.parse_args() 13 args = argument_parser.parse_args() 14 return args 14 return args 15 15 16 16 17 def main(): 17 def main(): 18 # 0. Parse arguments. 18 # 0. Parse arguments. 19 args = parse_args() 19 args = parse_args() 20 max_gpu_memory_allocation = args.max_gpu_m 20 max_gpu_memory_allocation = args.max_gpu_memory_allocation 21 gpu_ids = args.gpu_ids 21 gpu_ids = args.gpu_ids 22 study_name = args.study_name 22 study_name = args.study_name 23 checkpoint_dir = f"{GLOBAL_CHECKPOINT_DIR} 23 checkpoint_dir = f"{GLOBAL_CHECKPOINT_DIR}/{study_name}" 24 24 25 # 1. Set GPU memory limits. 25 # 1. Set GPU memory limits. 26 GPULimiter(_gpu_ids=gpu_ids, _max_gpu_memo 26 GPULimiter(_gpu_ids=gpu_ids, _max_gpu_memory_allocation=max_gpu_memory_allocation)() 27 27 28 # 2. Data loading/preprocessing 28 # 2. Data loading/preprocessing 29 29 30 # The preprocess function reads the data a 30 # The preprocess function reads the data and performs preprocessing and encoding for the values of energy, 31 # angle and geometry 31 # angle and geometry 32 energies_train, cond_e_train, cond_angle_t 32 energies_train, cond_e_train, cond_angle_train, cond_geo_train = preprocess() 33 33 34 # 3. Manufacture model handler. 34 # 3. Manufacture model handler. 35 35 36 # This import must be local because otherw 36 # This import must be local because otherwise it is impossible to call GPULimiter. 37 from core.model import VAEHandler 37 from core.model import VAEHandler 38 vae = VAEHandler(_wandb_project_name=study 38 vae = VAEHandler(_wandb_project_name=study_name, _wandb_tags=["single training"], _checkpoint_dir=checkpoint_dir) 39 39 40 # 4. Train model. 40 # 4. Train model. 41 histories = vae.train(energies_train, 41 histories = vae.train(energies_train, 42 cond_e_train, 42 cond_e_train, 43 cond_angle_train, 43 cond_angle_train, 44 cond_geo_train 44 cond_geo_train 45 ) 45 ) 46 46 47 # Note : One history object can be used to 47 # Note : One history object can be used to plot the loss evaluation as function of the epochs. Remember that the 48 # function returns a list of those objects 48 # function returns a list of those objects. Each of them represents a different fold of cross validation. 49 49 50 50 51 if __name__ == "__main__": 51 if __name__ == "__main__": 52 exit(main()) 52 exit(main())