Geant4 Cross Reference |
1 from argparse import ArgumentParser 2 3 from core.constants import GPU_IDS, MAX_GPU_ME 4 from utils.gpu_limiter import GPULimiter 5 from utils.preprocess import preprocess 6 7 8 def parse_args(): 9 argument_parser = ArgumentParser() 10 argument_parser.add_argument("--max-gpu-me 11 argument_parser.add_argument("--gpu-ids", 12 argument_parser.add_argument("--study-name 13 args = argument_parser.parse_args() 14 return args 15 16 17 def main(): 18 # 0. Parse arguments. 19 args = parse_args() 20 max_gpu_memory_allocation = args.max_gpu_m 21 gpu_ids = args.gpu_ids 22 study_name = args.study_name 23 checkpoint_dir = f"{GLOBAL_CHECKPOINT_DIR} 24 25 # 1. Set GPU memory limits. 26 GPULimiter(_gpu_ids=gpu_ids, _max_gpu_memo 27 28 # 2. Data loading/preprocessing 29 30 # The preprocess function reads the data a 31 # angle and geometry 32 energies_train, cond_e_train, cond_angle_t 33 34 # 3. Manufacture model handler. 35 36 # This import must be local because otherw 37 from core.model import VAEHandler 38 vae = VAEHandler(_wandb_project_name=study 39 40 # 4. Train model. 41 histories = vae.train(energies_train, 42 cond_e_train, 43 cond_angle_train, 44 cond_geo_train 45 ) 46 47 # Note : One history object can be used to 48 # function returns a list of those objects 49 50 51 if __name__ == "__main__": 52 exit(main())