Geant4 Cross Reference |
1 """ 1 """ 2 ** generate ** 2 ** generate ** 3 generate showers using a saved VAE model 3 generate showers using a saved VAE model 4 """ 4 """ 5 import argparse 5 import argparse 6 6 7 import numpy as np 7 import numpy as np 8 import tensorflow as tf 8 import tensorflow as tf 9 from tensorflow.python.data import Dataset 9 from tensorflow.python.data import Dataset 10 10 11 from core.constants import GLOBAL_CHECKPOINT_D 11 from core.constants import GLOBAL_CHECKPOINT_DIR, GEN_DIR, BATCH_SIZE_PER_REPLICA, MAX_GPU_MEMORY_ALLOCATION, GPU_IDS 12 from utils.gpu_limiter import GPULimiter 12 from utils.gpu_limiter import GPULimiter 13 from utils.preprocess import get_condition_arr 13 from utils.preprocess import get_condition_arrays 14 14 15 15 16 def parse_args(): 16 def parse_args(): 17 argument_parser = argparse.ArgumentParser( 17 argument_parser = argparse.ArgumentParser() 18 argument_parser.add_argument("--geometry", 18 argument_parser.add_argument("--geometry", type=str, default="") 19 argument_parser.add_argument("--energy", t 19 argument_parser.add_argument("--energy", type=int, default="") 20 argument_parser.add_argument("--angle", ty 20 argument_parser.add_argument("--angle", type=int, default="") 21 argument_parser.add_argument("--events", t 21 argument_parser.add_argument("--events", type=int, default=10000) 22 argument_parser.add_argument("--epoch", ty 22 argument_parser.add_argument("--epoch", type=int, default=None) 23 argument_parser.add_argument("--study-name 23 argument_parser.add_argument("--study-name", type=str, default="default_study_name") 24 argument_parser.add_argument("--max-gpu-me 24 argument_parser.add_argument("--max-gpu-memory-allocation", type=int, default=MAX_GPU_MEMORY_ALLOCATION) 25 argument_parser.add_argument("--gpu-ids", 25 argument_parser.add_argument("--gpu-ids", type=str, default=GPU_IDS) 26 args = argument_parser.parse_args() 26 args = argument_parser.parse_args() 27 return args 27 return args 28 28 29 29 30 # main function 30 # main function 31 def main(): 31 def main(): 32 # 0. Parse arguments. 32 # 0. Parse arguments. 33 args = parse_args() 33 args = parse_args() 34 energy = args.energy 34 energy = args.energy 35 angle = args.angle 35 angle = args.angle 36 geometry = args.geometry 36 geometry = args.geometry 37 events = args.events 37 events = args.events 38 epoch = args.epoch 38 epoch = args.epoch 39 study_name = args.study_name 39 study_name = args.study_name 40 max_gpu_memory_allocation = args.max_gpu_m 40 max_gpu_memory_allocation = args.max_gpu_memory_allocation 41 gpu_ids = args.gpu_ids 41 gpu_ids = args.gpu_ids 42 42 43 # 1. Set GPU memory limits. 43 # 1. Set GPU memory limits. 44 GPULimiter(_gpu_ids=gpu_ids, _max_gpu_memo 44 GPULimiter(_gpu_ids=gpu_ids, _max_gpu_memory_allocation=max_gpu_memory_allocation)() 45 45 46 # 2. Load a saved model. 46 # 2. Load a saved model. 47 47 48 # Create a handler and build model. 48 # Create a handler and build model. 49 # This import must be local because otherw 49 # This import must be local because otherwise it is impossible to call GPULimiter. 50 from core.model import VAEHandler 50 from core.model import VAEHandler 51 vae = VAEHandler() 51 vae = VAEHandler() 52 52 53 # Load the saved weights 53 # Load the saved weights 54 weights_dir = f"VAE_epoch_{epoch:03}" if e 54 weights_dir = f"VAE_epoch_{epoch:03}" if epoch is not None else "VAE_best" 55 vae.model.load_weights(f"{GLOBAL_CHECKPOIN 55 vae.model.load_weights(f"{GLOBAL_CHECKPOINT_DIR}/{study_name}/{weights_dir}/model_weights").expect_partial() 56 56 57 # The generator is defined as the decoder 57 # The generator is defined as the decoder part only 58 generator = vae.model.decoder 58 generator = vae.model.decoder 59 59 60 # 3. Prepare data. Get condition values. S 60 # 3. Prepare data. Get condition values. Sample from the prior (normal distribution) in d dimension (d=latent_dim, 61 # latent space dimension). Gather them int 61 # latent space dimension). Gather them into tuples. Wrap data in Dataset objects. The batch size must now be set 62 # on the Dataset objects. Disable AutoShar 62 # on the Dataset objects. Disable AutoShard. 63 e_cond, angle_cond, geo_cond = get_conditi 63 e_cond, angle_cond, geo_cond = get_condition_arrays(geometry, energy, events) 64 64 65 z_r = np.random.normal(loc=0, scale=1, siz 65 z_r = np.random.normal(loc=0, scale=1, size=(events, vae.latent_dim)) 66 66 67 data = ((z_r, e_cond, angle_cond, geo_cond 67 data = ((z_r, e_cond, angle_cond, geo_cond),) 68 68 69 data = Dataset.from_tensor_slices(data) 69 data = Dataset.from_tensor_slices(data) 70 70 71 batch_size = BATCH_SIZE_PER_REPLICA 71 batch_size = BATCH_SIZE_PER_REPLICA 72 72 73 data = data.batch(batch_size) 73 data = data.batch(batch_size) 74 74 75 options = tf.data.Options() 75 options = tf.data.Options() 76 options.experimental_distribute.auto_shard 76 options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF 77 data = data.with_options(options) 77 data = data.with_options(options) 78 78 79 # 4. Generate showers using the VAE model. 79 # 4. Generate showers using the VAE model. 80 generated_events = generator.predict(data) 80 generated_events = generator.predict(data) * (energy * 1000) 81 81 82 # 5. Save the generated showers. 82 # 5. Save the generated showers. 83 np.save(f"{GEN_DIR}/VAE_Generated_Geo_{geo 83 np.save(f"{GEN_DIR}/VAE_Generated_Geo_{geometry}_E_{energy}_Angle_{angle}.npy", generated_events) 84 84 85 85 86 if __name__ == "__main__": 86 if __name__ == "__main__": 87 exit(main()) 87 exit(main())