Geant4 Cross Reference

Cross-Referencing   Geant4
Geant4/examples/extended/parameterisations/Par04/training/utils/hyperparameter_tuner.py

Version: [ ReleaseNotes ] [ 1.0 ] [ 1.1 ] [ 2.0 ] [ 3.0 ] [ 3.1 ] [ 3.2 ] [ 4.0 ] [ 4.0.p1 ] [ 4.0.p2 ] [ 4.1 ] [ 4.1.p1 ] [ 5.0 ] [ 5.0.p1 ] [ 5.1 ] [ 5.1.p1 ] [ 5.2 ] [ 5.2.p1 ] [ 5.2.p2 ] [ 6.0 ] [ 6.0.p1 ] [ 6.1 ] [ 6.2 ] [ 6.2.p1 ] [ 6.2.p2 ] [ 7.0 ] [ 7.0.p1 ] [ 7.1 ] [ 7.1.p1 ] [ 8.0 ] [ 8.0.p1 ] [ 8.1 ] [ 8.1.p1 ] [ 8.1.p2 ] [ 8.2 ] [ 8.2.p1 ] [ 8.3 ] [ 8.3.p1 ] [ 8.3.p2 ] [ 9.0 ] [ 9.0.p1 ] [ 9.0.p2 ] [ 9.1 ] [ 9.1.p1 ] [ 9.1.p2 ] [ 9.1.p3 ] [ 9.2 ] [ 9.2.p1 ] [ 9.2.p2 ] [ 9.2.p3 ] [ 9.2.p4 ] [ 9.3 ] [ 9.3.p1 ] [ 9.3.p2 ] [ 9.4 ] [ 9.4.p1 ] [ 9.4.p2 ] [ 9.4.p3 ] [ 9.4.p4 ] [ 9.5 ] [ 9.5.p1 ] [ 9.5.p2 ] [ 9.6 ] [ 9.6.p1 ] [ 9.6.p2 ] [ 9.6.p3 ] [ 9.6.p4 ] [ 10.0 ] [ 10.0.p1 ] [ 10.0.p2 ] [ 10.0.p3 ] [ 10.0.p4 ] [ 10.1 ] [ 10.1.p1 ] [ 10.1.p2 ] [ 10.1.p3 ] [ 10.2 ] [ 10.2.p1 ] [ 10.2.p2 ] [ 10.2.p3 ] [ 10.3 ] [ 10.3.p1 ] [ 10.3.p2 ] [ 10.3.p3 ] [ 10.4 ] [ 10.4.p1 ] [ 10.4.p2 ] [ 10.4.p3 ] [ 10.5 ] [ 10.5.p1 ] [ 10.6 ] [ 10.6.p1 ] [ 10.6.p2 ] [ 10.6.p3 ] [ 10.7 ] [ 10.7.p1 ] [ 10.7.p2 ] [ 10.7.p3 ] [ 10.7.p4 ] [ 11.0 ] [ 11.0.p1 ] [ 11.0.p2 ] [ 11.0.p3, ] [ 11.0.p4 ] [ 11.1 ] [ 11.1.1 ] [ 11.1.2 ] [ 11.1.3 ] [ 11.2 ] [ 11.2.1 ] [ 11.2.2 ] [ 11.3.0 ]

Diff markup

Differences between /examples/extended/parameterisations/Par04/training/utils/hyperparameter_tuner.py (Version 11.3.0) and /examples/extended/parameterisations/Par04/training/utils/hyperparameter_tuner.py (Version 4.0)


  1 from dataclasses import dataclass                 
  2 from typing import Tuple, Dict, Any, List         
  3                                                   
  4 import numpy as np                                
  5 from optuna import Trial, create_study, get_al    
  6 from optuna.pruners import MedianPruner           
  7 from optuna.samplers import TPESampler            
  8 from optuna.trial import TrialState               
  9                                                   
 10 from core.constants import LEARNING_RATE, BATC    
 11     OPTIMIZER_TYPE, KERNEL_INITIALIZER, BIAS_I    
 12     INTERMEDIATE_DIMS, MAX_HIDDEN_LAYER_DIM, G    
 13 from core.model import VAEHandler                 
 14 from utils.preprocess import preprocess           
 15                                                   
 16                                                   
 17 @dataclass                                        
 18 class HyperparameterTuner:                        
 19     """Tuner which looks for the best hyperpar    
 20                                                   
 21     Currently, supported hyperparameters are:     
 22     activation function, activation function a    
 23     bias initializer, batch size.                 
 24                                                   
 25     Attributes:                                   
 26         _discrete_parameters: A dictionary of     
 27         _continuous_parameters: A dictionary o    
 28         _categorical_parameters: A dictionary     
 29         _storage: A string representing URL to    
 30         _study_name: A string, a name of study    
 31                                                   
 32     """                                           
 33     _discrete_parameters: Dict[str, Tuple[int,    
 34     _continuous_parameters: Dict[str, Tuple[fl    
 35     _categorical_parameters: Dict[str, List[An    
 36     _storage: str = None                          
 37     _study_name: str = None                       
 38                                                   
 39     def _check_hyperparameters(self):             
 40         available_hyperparameters = ["latent_d    
 41                                      "optimize    
 42                                      "batch_si    
 43         hyperparameters_to_be_optimized = list    
 44             self._continuous_parameters.keys()    
 45         for hyperparameter_name in hyperparame    
 46             if hyperparameter_name not in avai    
 47                 raise Exception(f"Unknown hype    
 48                                                   
 49     def __post_init__(self):                      
 50         self._check_hyperparameters()             
 51         self._energies_train, self._cond_e_tra    
 52                                                   
 53         if self._storage is not None and self.    
 54             # Parallel optimization               
 55             study_summaries = get_all_study_su    
 56             if any(self._study_name == study_s    
 57                 # The study is already created    
 58                 self._study = load_study(self.    
 59             else:                                 
 60                 # The study does not exist in     
 61                 self._study = create_study(sto    
 62                                            stu    
 63         else:                                     
 64             # Single optimization                 
 65             self._study = create_study(sampler    
 66                                                   
 67     def _create_model_handler(self, trial: Tri    
 68         """For a given trail builds the model.    
 69                                                   
 70         Optuna suggests parameters like dimens    
 71                                                   
 72         Args:                                     
 73             trial: Optuna's trial                 
 74                                                   
 75         Returns:                                  
 76             Variational Autoencoder (VAE)         
 77         """                                       
 78                                                   
 79         # Discrete parameters                     
 80         if "latent_dim" in self._discrete_para    
 81             latent_dim = trial.suggest_int(nam    
 82                                            low    
 83                                            hig    
 84         else:                                     
 85             latent_dim = LATENT_DIM               
 86                                                   
 87         if "nb_hidden_layers" in self._discret    
 88             nb_hidden_layers = trial.suggest_i    
 89                                                   
 90                                                   
 91                                                   
 92             all_possible = np.arange(start=lat    
 93             chunks = np.array_split(all_possib    
 94             ranges = [(chunk[0], chunk[-1]) fo    
 95             ranges = reversed(ranges)             
 96                                                   
 97             # Cast from np.int to int allows t    
 98             intermediate_dims = [trial.suggest    
 99                                  i, (low, high    
100                                  in enumerate(    
101         else:                                     
102             intermediate_dims = INTERMEDIATE_D    
103                                                   
104         if "batch_size_per_replica" in self._d    
105             batch_size_per_replica = trial.sug    
106                                                   
107                                                   
108         else:                                     
109             batch_size_per_replica = BATCH_SIZ    
110                                                   
111         # Continuous parameters                   
112         if "learning_rate" in self._continuous    
113             learning_rate = trial.suggest_floa    
114                                                   
115                                                   
116         else:                                     
117             learning_rate = LEARNING_RATE         
118                                                   
119         # Categorical parameters                  
120         if "activation" in self._categorical_p    
121             activation = trial.suggest_categor    
122                                                   
123         else:                                     
124             activation = ACTIVATION               
125                                                   
126         if "out_activation" in self._categoric    
127             out_activation = trial.suggest_cat    
128                                                   
129         else:                                     
130             out_activation = OUT_ACTIVATION       
131                                                   
132         if "optimizer_type" in self._categoric    
133             optimizer_type = trial.suggest_cat    
134                                                   
135         else:                                     
136             optimizer_type = OPTIMIZER_TYPE       
137                                                   
138         if "kernel_initializer" in self._categ    
139             kernel_initializer = trial.suggest    
140                                                   
141         else:                                     
142             kernel_initializer = KERNEL_INITIA    
143                                                   
144         if "bias_initializer" in self._categor    
145             bias_initializer = trial.suggest_c    
146                                                   
147         else:                                     
148             bias_initializer = BIAS_INITIALIZE    
149                                                   
150         checkpoint_dir = f"{GLOBAL_CHECKPOINT_    
151                                                   
152         return VAEHandler(_wandb_project_name=    
153                           _wandb_tags=["hyperp    
154                           _batch_size_per_repl    
155                           _intermediate_dims=i    
156                           latent_dim=latent_di    
157                           _learning_rate=learn    
158                           _activation=activati    
159                           _out_activation=out_    
160                           _optimizer_type=opti    
161                           _kernel_initializer=    
162                           _bias_initializer=bi    
163                           _checkpoint_dir=chec    
164                           _early_stop=True,       
165                           _save_model_every_ep    
166                           _save_best_model=Tru    
167                           )                       
168                                                   
169     def _objective(self, trial: Trial) -> floa    
170         """For a given trial trains the model     
171                                                   
172         Args:                                     
173             trial: Optuna's trial                 
174                                                   
175         Returns: One float numer which is a va    
176         performed in cross validation mode or     
177         of the dataset.                           
178         """                                       
179                                                   
180         # Generate the trial model.               
181         model_handler = self._create_model_han    
182                                                   
183         # Train the model.                        
184         verbose = True                            
185         histories = model_handler.train(self._    
186                                         self._    
187                                                   
188         # Return validation loss (currently it    
189         # best model according to the validati    
190         final_validation_losses = [np.min(hist    
191         avg_validation_loss = np.mean(final_va    
192         return avg_validation_loss                
193                                                   
194     def tune(self) -> None:                       
195         """Main tuning function.                  
196                                                   
197         Based on a given study, tunes the mode    
198         objective function and adjusted parame    
199         """                                       
200                                                   
201         self._study.optimize(func=self._object    
202         pruned_trials = self._study.get_trials    
203         complete_trials = self._study.get_tria    
204         print("Study statistics: ")               
205         print("  Number of finished trials: ",    
206         print("  Number of pruned trials: ", l    
207         print("  Number of complete trials: ",    
208                                                   
209         print("Best trial:")                      
210         trial = self._study.best_trial            
211                                                   
212         print("  Value: ", trial.value)           
213                                                   
214         print("  Params: ")                       
215         for key, value in trial.params.items()    
216             print(f"    {key}: {value}")