Geant4 Cross Reference

Cross-Referencing   Geant4
Geant4/examples/extended/parameterisations/Par04/training/tune_model.py

Version: [ ReleaseNotes ] [ 1.0 ] [ 1.1 ] [ 2.0 ] [ 3.0 ] [ 3.1 ] [ 3.2 ] [ 4.0 ] [ 4.0.p1 ] [ 4.0.p2 ] [ 4.1 ] [ 4.1.p1 ] [ 5.0 ] [ 5.0.p1 ] [ 5.1 ] [ 5.1.p1 ] [ 5.2 ] [ 5.2.p1 ] [ 5.2.p2 ] [ 6.0 ] [ 6.0.p1 ] [ 6.1 ] [ 6.2 ] [ 6.2.p1 ] [ 6.2.p2 ] [ 7.0 ] [ 7.0.p1 ] [ 7.1 ] [ 7.1.p1 ] [ 8.0 ] [ 8.0.p1 ] [ 8.1 ] [ 8.1.p1 ] [ 8.1.p2 ] [ 8.2 ] [ 8.2.p1 ] [ 8.3 ] [ 8.3.p1 ] [ 8.3.p2 ] [ 9.0 ] [ 9.0.p1 ] [ 9.0.p2 ] [ 9.1 ] [ 9.1.p1 ] [ 9.1.p2 ] [ 9.1.p3 ] [ 9.2 ] [ 9.2.p1 ] [ 9.2.p2 ] [ 9.2.p3 ] [ 9.2.p4 ] [ 9.3 ] [ 9.3.p1 ] [ 9.3.p2 ] [ 9.4 ] [ 9.4.p1 ] [ 9.4.p2 ] [ 9.4.p3 ] [ 9.4.p4 ] [ 9.5 ] [ 9.5.p1 ] [ 9.5.p2 ] [ 9.6 ] [ 9.6.p1 ] [ 9.6.p2 ] [ 9.6.p3 ] [ 9.6.p4 ] [ 10.0 ] [ 10.0.p1 ] [ 10.0.p2 ] [ 10.0.p3 ] [ 10.0.p4 ] [ 10.1 ] [ 10.1.p1 ] [ 10.1.p2 ] [ 10.1.p3 ] [ 10.2 ] [ 10.2.p1 ] [ 10.2.p2 ] [ 10.2.p3 ] [ 10.3 ] [ 10.3.p1 ] [ 10.3.p2 ] [ 10.3.p3 ] [ 10.4 ] [ 10.4.p1 ] [ 10.4.p2 ] [ 10.4.p3 ] [ 10.5 ] [ 10.5.p1 ] [ 10.6 ] [ 10.6.p1 ] [ 10.6.p2 ] [ 10.6.p3 ] [ 10.7 ] [ 10.7.p1 ] [ 10.7.p2 ] [ 10.7.p3 ] [ 10.7.p4 ] [ 11.0 ] [ 11.0.p1 ] [ 11.0.p2 ] [ 11.0.p3, ] [ 11.0.p4 ] [ 11.1 ] [ 11.1.1 ] [ 11.1.2 ] [ 11.1.3 ] [ 11.2 ] [ 11.2.1 ] [ 11.2.2 ] [ 11.3.0 ]

Diff markup

Differences between /examples/extended/parameterisations/Par04/training/tune_model.py (Version 11.3.0) and /examples/extended/parameterisations/Par04/training/tune_model.py (Version 11.1.2)


  1 from argparse import ArgumentParser                 1 from argparse import ArgumentParser
  2                                                     2 
  3 from core.constants import MAX_GPU_MEMORY_ALLO      3 from core.constants import MAX_GPU_MEMORY_ALLOCATION, GPU_IDS
  4 from utils.gpu_limiter import GPULimiter            4 from utils.gpu_limiter import GPULimiter
  5 from utils.optimizer import OptimizerType           5 from utils.optimizer import OptimizerType
  6                                                     6 
  7 # Hyperparemeters to be optimized.                  7 # Hyperparemeters to be optimized.
  8 discrete_parameters = {"nb_hidden_layers": (1,      8 discrete_parameters = {"nb_hidden_layers": (1, 6), "latent_dim": (15, 100)}
  9 continuous_parameters = {"learning_rate": (0.0      9 continuous_parameters = {"learning_rate": (0.0001, 0.005)}
 10 categorical_parameters = {"optimizer_type": [O     10 categorical_parameters = {"optimizer_type": [OptimizerType.ADAM, OptimizerType.RMSPROP]}
 11                                                    11 
 12                                                    12 
 13 def parse_args():                                  13 def parse_args():
 14     argument_parser = ArgumentParser()             14     argument_parser = ArgumentParser()
 15     argument_parser.add_argument("--study-name     15     argument_parser.add_argument("--study-name", type=str, default="default_study_name")
 16     argument_parser.add_argument("--storage",      16     argument_parser.add_argument("--storage", type=str)
 17     argument_parser.add_argument("--max-gpu-me     17     argument_parser.add_argument("--max-gpu-memory-allocation", type=int, default=MAX_GPU_MEMORY_ALLOCATION)
 18     argument_parser.add_argument("--gpu-ids",      18     argument_parser.add_argument("--gpu-ids", type=str, default=GPU_IDS)
 19     args = argument_parser.parse_args()            19     args = argument_parser.parse_args()
 20     return args                                    20     return args
 21                                                    21 
 22                                                    22 
 23 def main():                                        23 def main():
 24     # 0. Parse arguments.                          24     # 0. Parse arguments.
 25     args = parse_args()                            25     args = parse_args()
 26     study_name = args.study_name                   26     study_name = args.study_name
 27     storage = args.storage                         27     storage = args.storage
 28     max_gpu_memory_allocation = args.max_gpu_m     28     max_gpu_memory_allocation = args.max_gpu_memory_allocation
 29     gpu_ids = args.gpu_ids                         29     gpu_ids = args.gpu_ids
 30                                                    30 
 31     # 1. Set GPU memory limits.                    31     # 1. Set GPU memory limits.
 32     GPULimiter(_gpu_ids=gpu_ids, _max_gpu_memo     32     GPULimiter(_gpu_ids=gpu_ids, _max_gpu_memory_allocation=max_gpu_memory_allocation)()
 33                                                    33 
 34     # 2. Manufacture hyperparameter tuner.         34     # 2. Manufacture hyperparameter tuner.
 35                                                    35 
 36     # This import must be local because otherw     36     # This import must be local because otherwise it is impossible to call GPULimiter.
 37     from utils.hyperparameter_tuner import Hyp     37     from utils.hyperparameter_tuner import HyperparameterTuner
 38     hyperparameter_tuner = HyperparameterTuner     38     hyperparameter_tuner = HyperparameterTuner(discrete_parameters, continuous_parameters, categorical_parameters,
 39                                                    39                                                storage, study_name)
 40                                                    40 
 41     # 3. Run main tuning function.                 41     # 3. Run main tuning function.
 42     hyperparameter_tuner.tune()                    42     hyperparameter_tuner.tune()
 43     # Watch out! This script neither deletes t     43     # Watch out! This script neither deletes the study in DB nor deletes the database itself. If you are using
 44     # parallelized optimization, then you shou     44     # parallelized optimization, then you should care about deleting study in the database by yourself.
 45                                                    45 
 46                                                    46 
 47 if __name__ == "__main__":                         47 if __name__ == "__main__":
 48     exit(main())                                   48     exit(main())