Geant4 Cross Reference

Cross-Referencing   Geant4
Geant4/examples/extended/parameterisations/Par04/training/utils/hyperparameter_tuner.py

Version: [ ReleaseNotes ] [ 1.0 ] [ 1.1 ] [ 2.0 ] [ 3.0 ] [ 3.1 ] [ 3.2 ] [ 4.0 ] [ 4.0.p1 ] [ 4.0.p2 ] [ 4.1 ] [ 4.1.p1 ] [ 5.0 ] [ 5.0.p1 ] [ 5.1 ] [ 5.1.p1 ] [ 5.2 ] [ 5.2.p1 ] [ 5.2.p2 ] [ 6.0 ] [ 6.0.p1 ] [ 6.1 ] [ 6.2 ] [ 6.2.p1 ] [ 6.2.p2 ] [ 7.0 ] [ 7.0.p1 ] [ 7.1 ] [ 7.1.p1 ] [ 8.0 ] [ 8.0.p1 ] [ 8.1 ] [ 8.1.p1 ] [ 8.1.p2 ] [ 8.2 ] [ 8.2.p1 ] [ 8.3 ] [ 8.3.p1 ] [ 8.3.p2 ] [ 9.0 ] [ 9.0.p1 ] [ 9.0.p2 ] [ 9.1 ] [ 9.1.p1 ] [ 9.1.p2 ] [ 9.1.p3 ] [ 9.2 ] [ 9.2.p1 ] [ 9.2.p2 ] [ 9.2.p3 ] [ 9.2.p4 ] [ 9.3 ] [ 9.3.p1 ] [ 9.3.p2 ] [ 9.4 ] [ 9.4.p1 ] [ 9.4.p2 ] [ 9.4.p3 ] [ 9.4.p4 ] [ 9.5 ] [ 9.5.p1 ] [ 9.5.p2 ] [ 9.6 ] [ 9.6.p1 ] [ 9.6.p2 ] [ 9.6.p3 ] [ 9.6.p4 ] [ 10.0 ] [ 10.0.p1 ] [ 10.0.p2 ] [ 10.0.p3 ] [ 10.0.p4 ] [ 10.1 ] [ 10.1.p1 ] [ 10.1.p2 ] [ 10.1.p3 ] [ 10.2 ] [ 10.2.p1 ] [ 10.2.p2 ] [ 10.2.p3 ] [ 10.3 ] [ 10.3.p1 ] [ 10.3.p2 ] [ 10.3.p3 ] [ 10.4 ] [ 10.4.p1 ] [ 10.4.p2 ] [ 10.4.p3 ] [ 10.5 ] [ 10.5.p1 ] [ 10.6 ] [ 10.6.p1 ] [ 10.6.p2 ] [ 10.6.p3 ] [ 10.7 ] [ 10.7.p1 ] [ 10.7.p2 ] [ 10.7.p3 ] [ 10.7.p4 ] [ 11.0 ] [ 11.0.p1 ] [ 11.0.p2 ] [ 11.0.p3, ] [ 11.0.p4 ] [ 11.1 ] [ 11.1.1 ] [ 11.1.2 ] [ 11.1.3 ] [ 11.2 ] [ 11.2.1 ] [ 11.2.2 ] [ 11.3.0 ]

Diff markup

Differences between /examples/extended/parameterisations/Par04/training/utils/hyperparameter_tuner.py (Version 11.3.0) and /examples/extended/parameterisations/Par04/training/utils/hyperparameter_tuner.py (Version 11.1)


  1 from dataclasses import dataclass                   1 from dataclasses import dataclass
  2 from typing import Tuple, Dict, Any, List           2 from typing import Tuple, Dict, Any, List
  3                                                     3 
  4 import numpy as np                                  4 import numpy as np
  5 from optuna import Trial, create_study, get_al      5 from optuna import Trial, create_study, get_all_study_summaries, load_study
  6 from optuna.pruners import MedianPruner             6 from optuna.pruners import MedianPruner
  7 from optuna.samplers import TPESampler              7 from optuna.samplers import TPESampler
  8 from optuna.trial import TrialState                 8 from optuna.trial import TrialState
  9                                                     9 
 10 from core.constants import LEARNING_RATE, BATC     10 from core.constants import LEARNING_RATE, BATCH_SIZE_PER_REPLICA, ACTIVATION, OUT_ACTIVATION, \
 11     OPTIMIZER_TYPE, KERNEL_INITIALIZER, BIAS_I     11     OPTIMIZER_TYPE, KERNEL_INITIALIZER, BIAS_INITIALIZER, N_TRIALS, LATENT_DIM, \
 12     INTERMEDIATE_DIMS, MAX_HIDDEN_LAYER_DIM, G     12     INTERMEDIATE_DIMS, MAX_HIDDEN_LAYER_DIM, GLOBAL_CHECKPOINT_DIR
 13 from core.model import VAEHandler                  13 from core.model import VAEHandler
 14 from utils.preprocess import preprocess            14 from utils.preprocess import preprocess
 15                                                    15 
 16                                                    16 
 17 @dataclass                                         17 @dataclass
 18 class HyperparameterTuner:                         18 class HyperparameterTuner:
 19     """Tuner which looks for the best hyperpar     19     """Tuner which looks for the best hyperparameters of a Variational Autoencoder specified in model.py.
 20                                                    20 
 21     Currently, supported hyperparameters are:      21     Currently, supported hyperparameters are: dimension of latent space, number of hidden layers, learning rate,
 22     activation function, activation function a     22     activation function, activation function after the final layer, optimizer type, kernel initializer,
 23     bias initializer, batch size.                  23     bias initializer, batch size.
 24                                                    24 
 25     Attributes:                                    25     Attributes:
 26         _discrete_parameters: A dictionary of      26         _discrete_parameters: A dictionary of hyperparameters taking discrete values in the range [low, high].
 27         _continuous_parameters: A dictionary o     27         _continuous_parameters: A dictionary of hyperparameters taking continuous values in the range [low, high].
 28         _categorical_parameters: A dictionary      28         _categorical_parameters: A dictionary of hyperparameters taking values specified by the list of them.
 29         _storage: A string representing URL to     29         _storage: A string representing URL to a database required for a distributed training
 30         _study_name: A string, a name of study     30         _study_name: A string, a name of study.
 31                                                    31 
 32     """                                            32     """
 33     _discrete_parameters: Dict[str, Tuple[int,     33     _discrete_parameters: Dict[str, Tuple[int, int]]
 34     _continuous_parameters: Dict[str, Tuple[fl     34     _continuous_parameters: Dict[str, Tuple[float, float]]
 35     _categorical_parameters: Dict[str, List[An     35     _categorical_parameters: Dict[str, List[Any]]
 36     _storage: str = None                           36     _storage: str = None
 37     _study_name: str = None                        37     _study_name: str = None
 38                                                    38 
 39     def _check_hyperparameters(self):              39     def _check_hyperparameters(self):
 40         available_hyperparameters = ["latent_d     40         available_hyperparameters = ["latent_dim", "nb_hidden_layers", "learning_rate", "activation", "out_activation",
 41                                      "optimize     41                                      "optimizer_type", "kernel_initializer", "bias_initializer",
 42                                      "batch_si     42                                      "batch_size_per_replica"]
 43         hyperparameters_to_be_optimized = list     43         hyperparameters_to_be_optimized = list(self._discrete_parameters.keys()) + list(
 44             self._continuous_parameters.keys()     44             self._continuous_parameters.keys()) + list(self._categorical_parameters.keys())
 45         for hyperparameter_name in hyperparame     45         for hyperparameter_name in hyperparameters_to_be_optimized:
 46             if hyperparameter_name not in avai     46             if hyperparameter_name not in available_hyperparameters:
 47                 raise Exception(f"Unknown hype     47                 raise Exception(f"Unknown hyperparameter: {hyperparameter_name}")
 48                                                    48 
 49     def __post_init__(self):                       49     def __post_init__(self):
 50         self._check_hyperparameters()              50         self._check_hyperparameters()
 51         self._energies_train, self._cond_e_tra     51         self._energies_train, self._cond_e_train, self._cond_angle_train, self._cond_geo_train = preprocess()
 52                                                    52 
 53         if self._storage is not None and self.     53         if self._storage is not None and self._study_name is not None:
 54             # Parallel optimization                54             # Parallel optimization
 55             study_summaries = get_all_study_su     55             study_summaries = get_all_study_summaries(self._storage)
 56             if any(self._study_name == study_s     56             if any(self._study_name == study_summary.study_name for study_summary in study_summaries):
 57                 # The study is already created     57                 # The study is already created in the database. Load it.
 58                 self._study = load_study(self.     58                 self._study = load_study(self._study_name, self._storage)
 59             else:                                  59             else:
 60                 # The study does not exist in      60                 # The study does not exist in the database. Create a new one.
 61                 self._study = create_study(sto     61                 self._study = create_study(storage=self._storage, sampler=TPESampler(), pruner=MedianPruner(),
 62                                            stu     62                                            study_name=self._study_name, direction="minimize")
 63         else:                                      63         else:
 64             # Single optimization                  64             # Single optimization
 65             self._study = create_study(sampler     65             self._study = create_study(sampler=TPESampler(), pruner=MedianPruner(), direction="minimize")
 66                                                    66 
 67     def _create_model_handler(self, trial: Tri     67     def _create_model_handler(self, trial: Trial) -> VAEHandler:
 68         """For a given trail builds the model.     68         """For a given trail builds the model.
 69                                                    69 
 70         Optuna suggests parameters like dimens     70         Optuna suggests parameters like dimensions of particular layers of the model, learning rate, optimizer, etc.
 71                                                    71 
 72         Args:                                      72         Args:
 73             trial: Optuna's trial                  73             trial: Optuna's trial
 74                                                    74 
 75         Returns:                                   75         Returns:
 76             Variational Autoencoder (VAE)          76             Variational Autoencoder (VAE)
 77         """                                        77         """
 78                                                    78 
 79         # Discrete parameters                      79         # Discrete parameters
 80         if "latent_dim" in self._discrete_para     80         if "latent_dim" in self._discrete_parameters.keys():
 81             latent_dim = trial.suggest_int(nam     81             latent_dim = trial.suggest_int(name="latent_dim",
 82                                            low     82                                            low=self._discrete_parameters["latent_dim"][0],
 83                                            hig     83                                            high=self._discrete_parameters["latent_dim"][1])
 84         else:                                      84         else:
 85             latent_dim = LATENT_DIM                85             latent_dim = LATENT_DIM
 86                                                    86 
 87         if "nb_hidden_layers" in self._discret     87         if "nb_hidden_layers" in self._discrete_parameters.keys():
 88             nb_hidden_layers = trial.suggest_i     88             nb_hidden_layers = trial.suggest_int(name="nb_hidden_layers",
 89                                                    89                                                  low=self._discrete_parameters["nb_hidden_layers"][0],
 90                                                    90                                                  high=self._discrete_parameters["nb_hidden_layers"][1])
 91                                                    91 
 92             all_possible = np.arange(start=lat     92             all_possible = np.arange(start=latent_dim + 5, stop=MAX_HIDDEN_LAYER_DIM)
 93             chunks = np.array_split(all_possib     93             chunks = np.array_split(all_possible, nb_hidden_layers)
 94             ranges = [(chunk[0], chunk[-1]) fo     94             ranges = [(chunk[0], chunk[-1]) for chunk in chunks]
 95             ranges = reversed(ranges)              95             ranges = reversed(ranges)
 96                                                    96 
 97             # Cast from np.int to int allows t     97             # Cast from np.int to int allows to become JSON serializable.
 98             intermediate_dims = [trial.suggest     98             intermediate_dims = [trial.suggest_int(name=f"intermediate_dim_{i}", low=int(low), high=int(high)) for
 99                                  i, (low, high     99                                  i, (low, high)
100                                  in enumerate(    100                                  in enumerate(ranges)]
101         else:                                     101         else:
102             intermediate_dims = INTERMEDIATE_D    102             intermediate_dims = INTERMEDIATE_DIMS
103                                                   103 
104         if "batch_size_per_replica" in self._d    104         if "batch_size_per_replica" in self._discrete_parameters.keys():
105             batch_size_per_replica = trial.sug    105             batch_size_per_replica = trial.suggest_int(name="batch_size_per_replica",
106                                                   106                                                        low=self._discrete_parameters["batch_size_per_replica"][0],
107                                                   107                                                        high=self._discrete_parameters["batch_size_per_replica"][1])
108         else:                                     108         else:
109             batch_size_per_replica = BATCH_SIZ    109             batch_size_per_replica = BATCH_SIZE_PER_REPLICA
110                                                   110 
111         # Continuous parameters                   111         # Continuous parameters
112         if "learning_rate" in self._continuous    112         if "learning_rate" in self._continuous_parameters.keys():
113             learning_rate = trial.suggest_floa    113             learning_rate = trial.suggest_float(name="learning_rate",
114                                                   114                                                 low=self._continuous_parameters["learning_rate"][0],
115                                                   115                                                 high=self._continuous_parameters["learning_rate"][1])
116         else:                                     116         else:
117             learning_rate = LEARNING_RATE         117             learning_rate = LEARNING_RATE
118                                                   118 
119         # Categorical parameters                  119         # Categorical parameters
120         if "activation" in self._categorical_p    120         if "activation" in self._categorical_parameters.keys():
121             activation = trial.suggest_categor    121             activation = trial.suggest_categorical(name="activation",
122                                                   122                                                    choices=self._categorical_parameters["activation"])
123         else:                                     123         else:
124             activation = ACTIVATION               124             activation = ACTIVATION
125                                                   125 
126         if "out_activation" in self._categoric    126         if "out_activation" in self._categorical_parameters.keys():
127             out_activation = trial.suggest_cat    127             out_activation = trial.suggest_categorical(name="out_activation",
128                                                   128                                                        choices=self._categorical_parameters["out_activation"])
129         else:                                     129         else:
130             out_activation = OUT_ACTIVATION       130             out_activation = OUT_ACTIVATION
131                                                   131 
132         if "optimizer_type" in self._categoric    132         if "optimizer_type" in self._categorical_parameters.keys():
133             optimizer_type = trial.suggest_cat    133             optimizer_type = trial.suggest_categorical(name="optimizer_type",
134                                                   134                                                        choices=self._categorical_parameters["optimizer_type"])
135         else:                                     135         else:
136             optimizer_type = OPTIMIZER_TYPE       136             optimizer_type = OPTIMIZER_TYPE
137                                                   137 
138         if "kernel_initializer" in self._categ    138         if "kernel_initializer" in self._categorical_parameters.keys():
139             kernel_initializer = trial.suggest    139             kernel_initializer = trial.suggest_categorical(name="kernel_initializer",
140                                                   140                                                            choices=self._categorical_parameters["kernel_initializer"])
141         else:                                     141         else:
142             kernel_initializer = KERNEL_INITIA    142             kernel_initializer = KERNEL_INITIALIZER
143                                                   143 
144         if "bias_initializer" in self._categor    144         if "bias_initializer" in self._categorical_parameters.keys():
145             bias_initializer = trial.suggest_c    145             bias_initializer = trial.suggest_categorical(name="bias_initializer",
146                                                   146                                                          choices=self._categorical_parameters["bias_initializer"])
147         else:                                     147         else:
148             bias_initializer = BIAS_INITIALIZE    148             bias_initializer = BIAS_INITIALIZER
149                                                   149 
150         checkpoint_dir = f"{GLOBAL_CHECKPOINT_    150         checkpoint_dir = f"{GLOBAL_CHECKPOINT_DIR}/{self._study_name}/trial_{trial.number:03d}"
151                                                   151 
152         return VAEHandler(_wandb_project_name=    152         return VAEHandler(_wandb_project_name=self._study_name,
153                           _wandb_tags=["hyperp    153                           _wandb_tags=["hyperparameter tuning", f"trial {trial.number}"],
154                           _batch_size_per_repl    154                           _batch_size_per_replica=batch_size_per_replica,
155                           _intermediate_dims=i    155                           _intermediate_dims=intermediate_dims,
156                           latent_dim=latent_di    156                           latent_dim=latent_dim,
157                           _learning_rate=learn    157                           _learning_rate=learning_rate,
158                           _activation=activati    158                           _activation=activation,
159                           _out_activation=out_    159                           _out_activation=out_activation,
160                           _optimizer_type=opti    160                           _optimizer_type=optimizer_type,
161                           _kernel_initializer=    161                           _kernel_initializer=kernel_initializer,
162                           _bias_initializer=bi    162                           _bias_initializer=bias_initializer,
163                           _checkpoint_dir=chec    163                           _checkpoint_dir=checkpoint_dir,
164                           _early_stop=True,       164                           _early_stop=True,
165                           _save_model_every_ep    165                           _save_model_every_epoch=False,
166                           _save_best_model=Tru    166                           _save_best_model=True,
167                           )                       167                           )
168                                                   168 
169     def _objective(self, trial: Trial) -> floa    169     def _objective(self, trial: Trial) -> float:
170         """For a given trial trains the model     170         """For a given trial trains the model and returns an average validation loss.
171                                                   171 
172         Args:                                     172         Args:
173             trial: Optuna's trial                 173             trial: Optuna's trial
174                                                   174 
175         Returns: One float numer which is a va    175         Returns: One float numer which is a validation loss. It can be either calculated as an average of k trainings
176         performed in cross validation mode or     176         performed in cross validation mode or is one number obtained from  validation on unseen before, some fraction
177         of the dataset.                           177         of the dataset.
178         """                                       178         """
179                                                   179 
180         # Generate the trial model.               180         # Generate the trial model.
181         model_handler = self._create_model_han    181         model_handler = self._create_model_handler(trial)
182                                                   182 
183         # Train the model.                        183         # Train the model.
184         verbose = True                            184         verbose = True
185         histories = model_handler.train(self._    185         histories = model_handler.train(self._energies_train, self._cond_e_train, self._cond_angle_train,
186                                         self._    186                                         self._cond_geo_train, verbose)
187                                                   187 
188         # Return validation loss (currently it    188         # Return validation loss (currently it is treated as an objective goal). Notice that we take into account the
189         # best model according to the validati    189         # best model according to the validation loss.
190         final_validation_losses = [np.min(hist    190         final_validation_losses = [np.min(history.history["val_loss"]) for history in histories]
191         avg_validation_loss = np.mean(final_va    191         avg_validation_loss = np.mean(final_validation_losses).item()
192         return avg_validation_loss                192         return avg_validation_loss
193                                                   193 
194     def tune(self) -> None:                       194     def tune(self) -> None:
195         """Main tuning function.                  195         """Main tuning function.
196                                                   196 
197         Based on a given study, tunes the mode    197         Based on a given study, tunes the model and prints detailed information about the best trial (value of the
198         objective function and adjusted parame    198         objective function and adjusted parameters).
199         """                                       199         """
200                                                   200 
201         self._study.optimize(func=self._object    201         self._study.optimize(func=self._objective, n_trials=N_TRIALS, gc_after_trial=True)
202         pruned_trials = self._study.get_trials    202         pruned_trials = self._study.get_trials(deepcopy=False, states=(TrialState.PRUNED,))
203         complete_trials = self._study.get_tria    203         complete_trials = self._study.get_trials(deepcopy=False, states=(TrialState.COMPLETE,))
204         print("Study statistics: ")               204         print("Study statistics: ")
205         print("  Number of finished trials: ",    205         print("  Number of finished trials: ", len(self._study.trials))
206         print("  Number of pruned trials: ", l    206         print("  Number of pruned trials: ", len(pruned_trials))
207         print("  Number of complete trials: ",    207         print("  Number of complete trials: ", len(complete_trials))
208                                                   208 
209         print("Best trial:")                      209         print("Best trial:")
210         trial = self._study.best_trial            210         trial = self._study.best_trial
211                                                   211 
212         print("  Value: ", trial.value)           212         print("  Value: ", trial.value)
213                                                   213 
214         print("  Params: ")                       214         print("  Params: ")
215         for key, value in trial.params.items()    215         for key, value in trial.params.items():
216             print(f"    {key}: {value}")          216             print(f"    {key}: {value}")