Geant4 Cross Reference

Cross-Referencing   Geant4
Geant4/examples/extended/parameterisations/Par04/training/generate.py

Version: [ ReleaseNotes ] [ 1.0 ] [ 1.1 ] [ 2.0 ] [ 3.0 ] [ 3.1 ] [ 3.2 ] [ 4.0 ] [ 4.0.p1 ] [ 4.0.p2 ] [ 4.1 ] [ 4.1.p1 ] [ 5.0 ] [ 5.0.p1 ] [ 5.1 ] [ 5.1.p1 ] [ 5.2 ] [ 5.2.p1 ] [ 5.2.p2 ] [ 6.0 ] [ 6.0.p1 ] [ 6.1 ] [ 6.2 ] [ 6.2.p1 ] [ 6.2.p2 ] [ 7.0 ] [ 7.0.p1 ] [ 7.1 ] [ 7.1.p1 ] [ 8.0 ] [ 8.0.p1 ] [ 8.1 ] [ 8.1.p1 ] [ 8.1.p2 ] [ 8.2 ] [ 8.2.p1 ] [ 8.3 ] [ 8.3.p1 ] [ 8.3.p2 ] [ 9.0 ] [ 9.0.p1 ] [ 9.0.p2 ] [ 9.1 ] [ 9.1.p1 ] [ 9.1.p2 ] [ 9.1.p3 ] [ 9.2 ] [ 9.2.p1 ] [ 9.2.p2 ] [ 9.2.p3 ] [ 9.2.p4 ] [ 9.3 ] [ 9.3.p1 ] [ 9.3.p2 ] [ 9.4 ] [ 9.4.p1 ] [ 9.4.p2 ] [ 9.4.p3 ] [ 9.4.p4 ] [ 9.5 ] [ 9.5.p1 ] [ 9.5.p2 ] [ 9.6 ] [ 9.6.p1 ] [ 9.6.p2 ] [ 9.6.p3 ] [ 9.6.p4 ] [ 10.0 ] [ 10.0.p1 ] [ 10.0.p2 ] [ 10.0.p3 ] [ 10.0.p4 ] [ 10.1 ] [ 10.1.p1 ] [ 10.1.p2 ] [ 10.1.p3 ] [ 10.2 ] [ 10.2.p1 ] [ 10.2.p2 ] [ 10.2.p3 ] [ 10.3 ] [ 10.3.p1 ] [ 10.3.p2 ] [ 10.3.p3 ] [ 10.4 ] [ 10.4.p1 ] [ 10.4.p2 ] [ 10.4.p3 ] [ 10.5 ] [ 10.5.p1 ] [ 10.6 ] [ 10.6.p1 ] [ 10.6.p2 ] [ 10.6.p3 ] [ 10.7 ] [ 10.7.p1 ] [ 10.7.p2 ] [ 10.7.p3 ] [ 10.7.p4 ] [ 11.0 ] [ 11.0.p1 ] [ 11.0.p2 ] [ 11.0.p3, ] [ 11.0.p4 ] [ 11.1 ] [ 11.1.1 ] [ 11.1.2 ] [ 11.1.3 ] [ 11.2 ] [ 11.2.1 ] [ 11.2.2 ] [ 11.3.0 ]

Diff markup

Differences between /examples/extended/parameterisations/Par04/training/generate.py (Version 11.3.0) and /examples/extended/parameterisations/Par04/training/generate.py (Version 11.2.2)


  1 """                                                 1 """
  2 ** generate **                                      2 ** generate **
  3 generate showers using a saved VAE model            3 generate showers using a saved VAE model 
  4 """                                                 4 """
  5 import argparse                                     5 import argparse
  6                                                     6 
  7 import numpy as np                                  7 import numpy as np
  8 import tensorflow as tf                             8 import tensorflow as tf
  9 from tensorflow.python.data import Dataset          9 from tensorflow.python.data import Dataset
 10                                                    10 
 11 from core.constants import GLOBAL_CHECKPOINT_D     11 from core.constants import GLOBAL_CHECKPOINT_DIR, GEN_DIR, BATCH_SIZE_PER_REPLICA, MAX_GPU_MEMORY_ALLOCATION, GPU_IDS
 12 from utils.gpu_limiter import GPULimiter           12 from utils.gpu_limiter import GPULimiter
 13 from utils.preprocess import get_condition_arr     13 from utils.preprocess import get_condition_arrays
 14                                                    14 
 15                                                    15 
 16 def parse_args():                                  16 def parse_args():
 17     argument_parser = argparse.ArgumentParser(     17     argument_parser = argparse.ArgumentParser()
 18     argument_parser.add_argument("--geometry",     18     argument_parser.add_argument("--geometry", type=str, default="")
 19     argument_parser.add_argument("--energy", t     19     argument_parser.add_argument("--energy", type=int, default="")
 20     argument_parser.add_argument("--angle", ty     20     argument_parser.add_argument("--angle", type=int, default="")
 21     argument_parser.add_argument("--events", t     21     argument_parser.add_argument("--events", type=int, default=10000)
 22     argument_parser.add_argument("--epoch", ty     22     argument_parser.add_argument("--epoch", type=int, default=None)
 23     argument_parser.add_argument("--study-name     23     argument_parser.add_argument("--study-name", type=str, default="default_study_name")
 24     argument_parser.add_argument("--max-gpu-me     24     argument_parser.add_argument("--max-gpu-memory-allocation", type=int, default=MAX_GPU_MEMORY_ALLOCATION)
 25     argument_parser.add_argument("--gpu-ids",      25     argument_parser.add_argument("--gpu-ids", type=str, default=GPU_IDS)
 26     args = argument_parser.parse_args()            26     args = argument_parser.parse_args()
 27     return args                                    27     return args
 28                                                    28 
 29                                                    29 
 30 # main function                                    30 # main function
 31 def main():                                        31 def main():
 32     # 0. Parse arguments.                          32     # 0. Parse arguments.
 33     args = parse_args()                            33     args = parse_args()
 34     energy = args.energy                           34     energy = args.energy
 35     angle = args.angle                             35     angle = args.angle
 36     geometry = args.geometry                       36     geometry = args.geometry
 37     events = args.events                           37     events = args.events
 38     epoch = args.epoch                             38     epoch = args.epoch
 39     study_name = args.study_name                   39     study_name = args.study_name
 40     max_gpu_memory_allocation = args.max_gpu_m     40     max_gpu_memory_allocation = args.max_gpu_memory_allocation
 41     gpu_ids = args.gpu_ids                         41     gpu_ids = args.gpu_ids
 42                                                    42 
 43     # 1. Set GPU memory limits.                    43     # 1. Set GPU memory limits.
 44     GPULimiter(_gpu_ids=gpu_ids, _max_gpu_memo     44     GPULimiter(_gpu_ids=gpu_ids, _max_gpu_memory_allocation=max_gpu_memory_allocation)()
 45                                                    45 
 46     # 2. Load a saved model.                       46     # 2. Load a saved model.
 47                                                    47 
 48     # Create a handler and build model.            48     # Create a handler and build model.
 49     # This import must be local because otherw     49     # This import must be local because otherwise it is impossible to call GPULimiter.
 50     from core.model import VAEHandler              50     from core.model import VAEHandler
 51     vae = VAEHandler()                             51     vae = VAEHandler()
 52                                                    52 
 53     # Load the saved weights                       53     # Load the saved weights
 54     weights_dir = f"VAE_epoch_{epoch:03}" if e     54     weights_dir = f"VAE_epoch_{epoch:03}" if epoch is not None else "VAE_best"
 55     vae.model.load_weights(f"{GLOBAL_CHECKPOIN     55     vae.model.load_weights(f"{GLOBAL_CHECKPOINT_DIR}/{study_name}/{weights_dir}/model_weights").expect_partial()
 56                                                    56 
 57     # The generator is defined as the decoder      57     # The generator is defined as the decoder part only
 58     generator = vae.model.decoder                  58     generator = vae.model.decoder
 59                                                    59 
 60     # 3. Prepare data. Get condition values. S     60     # 3. Prepare data. Get condition values. Sample from the prior (normal distribution) in d dimension (d=latent_dim,
 61     # latent space dimension). Gather them int     61     # latent space dimension). Gather them into tuples. Wrap data in Dataset objects. The batch size must now be set
 62     # on the Dataset objects. Disable AutoShar     62     # on the Dataset objects. Disable AutoShard.
 63     e_cond, angle_cond, geo_cond = get_conditi     63     e_cond, angle_cond, geo_cond = get_condition_arrays(geometry, energy, events)
 64                                                    64 
 65     z_r = np.random.normal(loc=0, scale=1, siz     65     z_r = np.random.normal(loc=0, scale=1, size=(events, vae.latent_dim))
 66                                                    66 
 67     data = ((z_r, e_cond, angle_cond, geo_cond     67     data = ((z_r, e_cond, angle_cond, geo_cond),)
 68                                                    68 
 69     data = Dataset.from_tensor_slices(data)        69     data = Dataset.from_tensor_slices(data)
 70                                                    70 
 71     batch_size = BATCH_SIZE_PER_REPLICA            71     batch_size = BATCH_SIZE_PER_REPLICA
 72                                                    72 
 73     data = data.batch(batch_size)                  73     data = data.batch(batch_size)
 74                                                    74 
 75     options = tf.data.Options()                    75     options = tf.data.Options()
 76     options.experimental_distribute.auto_shard     76     options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
 77     data = data.with_options(options)              77     data = data.with_options(options)
 78                                                    78 
 79     # 4. Generate showers using the VAE model.     79     # 4. Generate showers using the VAE model.
 80     generated_events = generator.predict(data)     80     generated_events = generator.predict(data) * (energy * 1000)
 81                                                    81 
 82     # 5. Save the generated showers.               82     # 5. Save the generated showers.
 83     np.save(f"{GEN_DIR}/VAE_Generated_Geo_{geo     83     np.save(f"{GEN_DIR}/VAE_Generated_Geo_{geometry}_E_{energy}_Angle_{angle}.npy", generated_events)
 84                                                    84 
 85                                                    85 
 86 if __name__ == "__main__":                         86 if __name__ == "__main__":
 87     exit(main())                                   87     exit(main())