Geant4 Cross Reference |
1 import gc 2 from dataclasses import dataclass, field 3 from typing import List, Tuple 4 5 import numpy as np 6 import tensorflow as tf 7 import wandb 8 from sklearn.model_selection import KFold 9 from tensorflow.keras import backend as K 10 from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, History, Callback 11 from tensorflow.keras.layers import BatchNormalization, Input, Dense, Layer, concatenate 12 from tensorflow.keras.losses import BinaryCrossentropy, Reduction 13 from tensorflow.keras.models import Model 14 from tensorflow.python.data import Dataset 15 from tensorflow.python.distribute.distribute_lib import Strategy 16 from tensorflow.python.distribute.mirrored_strategy import MirroredStrategy 17 from wandb.keras import WandbCallback 18 19 from core.constants import ORIGINAL_DIM, LATENT_DIM, BATCH_SIZE_PER_REPLICA, EPOCHS, LEARNING_RATE, ACTIVATION, \ 20 OUT_ACTIVATION, OPTIMIZER_TYPE, KERNEL_INITIALIZER, GLOBAL_CHECKPOINT_DIR, EARLY_STOP, BIAS_INITIALIZER, \ 21 INTERMEDIATE_DIMS, SAVE_MODEL_EVERY_EPOCH, SAVE_BEST_MODEL, PATIENCE, MIN_DELTA, BEST_MODEL_FILENAME, \ 22 NUMBER_OF_K_FOLD_SPLITS, VALIDATION_SPLIT, WANDB_ENTITY 23 from utils.optimizer import OptimizerFactory, OptimizerType 24 25 26 class _Sampling(Layer): 27 """ Custom layer to do the reparameterization trick: sample random latent vectors z from the latent Gaussian 28 distribution. 29 30 The sampled vector z is given by sampled_z = mean + std * epsilon 31 """ 32 33 def __call__(self, inputs, **kwargs): 34 z_mean, z_log_var, epsilon = inputs 35 z_sigma = K.exp(0.5 * z_log_var) 36 return z_mean + z_sigma * epsilon 37 38 39 # KL divergence computation 40 class _KLDivergenceLayer(Layer): 41 42 def call(self, inputs, **kwargs): 43 mu, log_var = inputs 44 kl_loss = -0.5 * (1 + log_var - K.square(mu) - K.exp(log_var)) 45 kl_loss = K.mean(K.sum(kl_loss, axis=-1)) 46 self.add_loss(kl_loss) 47 return inputs 48 49 50 class VAE(Model): 51 def get_config(self): 52 config = super().get_config() 53 config["encoder"] = self.encoder 54 config["decoder"] = self.decoder 55 return config 56 57 def call(self, inputs, training=None, mask=None): 58 _, e_input, angle_input, geo_input, _ = inputs 59 z = self.encoder(inputs) 60 return self.decoder([z, e_input, angle_input, geo_input]) 61 62 def __init__(self, encoder, decoder, **kwargs): 63 super(VAE, self).__init__(**kwargs) 64 self.encoder = encoder 65 self.decoder = decoder 66 self._set_inputs(inputs=self.encoder.inputs, outputs=self(self.encoder.inputs)) 67 68 69 @dataclass 70 class VAEHandler: 71 """ 72 Class to handle building and training VAE models. 73 """ 74 _wandb_project_name: str = None 75 _wandb_tags: List[str] = field(default_factory=list) 76 _original_dim: int = ORIGINAL_DIM 77 latent_dim: int = LATENT_DIM 78 _batch_size_per_replica: int = BATCH_SIZE_PER_REPLICA 79 _intermediate_dims: List[int] = field(default_factory=lambda: INTERMEDIATE_DIMS) 80 _learning_rate: float = LEARNING_RATE 81 _epochs: int = EPOCHS 82 _activation: str = ACTIVATION 83 _out_activation: str = OUT_ACTIVATION 84 _number_of_k_fold_splits: float = NUMBER_OF_K_FOLD_SPLITS 85 _optimizer_type: OptimizerType = OPTIMIZER_TYPE 86 _kernel_initializer: str = KERNEL_INITIALIZER 87 _bias_initializer: str = BIAS_INITIALIZER 88 _checkpoint_dir: str = GLOBAL_CHECKPOINT_DIR 89 _early_stop: bool = EARLY_STOP 90 _save_model_every_epoch: bool = SAVE_MODEL_EVERY_EPOCH 91 _save_best_model: bool = SAVE_BEST_MODEL 92 _patience: int = PATIENCE 93 _min_delta: float = MIN_DELTA 94 _best_model_filename: str = BEST_MODEL_FILENAME 95 _validation_split: float = VALIDATION_SPLIT 96 _strategy: Strategy = MirroredStrategy() 97 98 def __post_init__(self) -> None: 99 # Calculate true batch size. 100 self._batch_size = self._batch_size_per_replica * self._strategy.num_replicas_in_sync 101 self._build_and_compile_new_model() 102 # Setup Wandb. 103 if self._wandb_project_name is not None: 104 self._setup_wandb() 105 106 def _setup_wandb(self) -> None: 107 config = { 108 "learning_rate": self._learning_rate, 109 "batch_size": self._batch_size, 110 "epochs": self._epochs, 111 "optimizer_type": self._optimizer_type, 112 "intermediate_dims": self._intermediate_dims, 113 "latent_dim": self.latent_dim 114 } 115 # Reinit flag is needed for hyperparameter tuning. Whenever new training is started, new Wandb run should be 116 # created. 117 wandb.init(project=self._wandb_project_name, entity=WANDB_ENTITY, reinit=True, config=config, 118 tags=self._wandb_tags) 119 120 def _build_and_compile_new_model(self) -> None: 121 """ Builds and compiles a new model. 122 123 VAEHandler keep a list of VAE instance. The reason is that while k-fold cross validation is performed, 124 each fold requires a new, clear instance of model. New model is always added at the end of the list of 125 existing ones. 126 127 Returns: None 128 129 """ 130 # Build encoder and decoder. 131 encoder = self._build_encoder() 132 decoder = self._build_decoder() 133 134 # Compile model within a distributed strategy. 135 with self._strategy.scope(): 136 # Build VAE. 137 self.model = VAE(encoder, decoder) 138 # Manufacture an optimizer and compile model with. 139 optimizer = OptimizerFactory.create_optimizer(self._optimizer_type, self._learning_rate) 140 reconstruction_loss = BinaryCrossentropy(reduction=Reduction.SUM) 141 self.model.compile(optimizer=optimizer, loss=[reconstruction_loss], loss_weights=[ORIGINAL_DIM]) 142 143 def _prepare_input_layers(self, for_encoder: bool) -> List[Input]: 144 """ 145 Create four Input layers. Each of them is responsible to take respectively: batch of showers/batch of latent 146 vectors, batch of energies, batch of angles, batch of geometries. 147 148 Args: 149 for_encoder: Boolean which decides whether an input is full dimensional shower or a latent vector. 150 151 Returns: 152 List of Input layers (five for encoder and four for decoder). 153 154 """ 155 e_input = Input(shape=(1,)) 156 angle_input = Input(shape=(1,)) 157 geo_input = Input(shape=(2,)) 158 if for_encoder: 159 x_input = Input(shape=self._original_dim) 160 eps_input = Input(shape=self.latent_dim) 161 return [x_input, e_input, angle_input, geo_input, eps_input] 162 else: 163 x_input = Input(shape=self.latent_dim) 164 return [x_input, e_input, angle_input, geo_input] 165 166 def _build_encoder(self) -> Model: 167 """ Based on a list of intermediate dimensions, activation function and initializers for kernel and bias builds 168 the encoder. 169 170 Returns: 171 Encoder is returned as a keras.Model. 172 173 """ 174 175 with self._strategy.scope(): 176 # Prepare input layer. 177 x_input, e_input, angle_input, geo_input, eps_input = self._prepare_input_layers(for_encoder=True) 178 x = concatenate([x_input, e_input, angle_input, geo_input]) 179 # Construct hidden layers (Dense and Batch Normalization). 180 for intermediate_dim in self._intermediate_dims: 181 x = Dense(units=intermediate_dim, activation=self._activation, 182 kernel_initializer=self._kernel_initializer, 183 bias_initializer=self._bias_initializer)(x) 184 x = BatchNormalization()(x) 185 # Add Dense layer to get description of multidimensional Gaussian distribution in terms of mean 186 # and log(variance). 187 z_mean = Dense(self.latent_dim, name="z_mean")(x) 188 z_log_var = Dense(self.latent_dim, name="z_log_var")(x) 189 # Add KLDivergenceLayer responsible for calculation of KL loss. 190 z_mean, z_log_var = _KLDivergenceLayer()([z_mean, z_log_var]) 191 # Sample a probe from the distribution. 192 encoder_output = _Sampling()([z_mean, z_log_var, eps_input]) 193 # Create model. 194 encoder = Model(inputs=[x_input, e_input, angle_input, geo_input, eps_input], outputs=encoder_output, 195 name="encoder") 196 return encoder 197 198 def _build_decoder(self) -> Model: 199 """ Based on a list of intermediate dimensions, activation function and initializers for kernel and bias builds 200 the decoder. 201 202 Returns: 203 Decoder is returned as a keras.Model. 204 205 """ 206 207 with self._strategy.scope(): 208 # Prepare input layer. 209 latent_input, e_input, angle_input, geo_input = self._prepare_input_layers(for_encoder=False) 210 x = concatenate([latent_input, e_input, angle_input, geo_input]) 211 # Construct hidden layers (Dense and Batch Normalization). 212 for intermediate_dim in reversed(self._intermediate_dims): 213 x = Dense(units=intermediate_dim, activation=self._activation, 214 kernel_initializer=self._kernel_initializer, 215 bias_initializer=self._bias_initializer)(x) 216 x = BatchNormalization()(x) 217 # Add Dense layer to get output which shape is compatible in an input's shape. 218 decoder_outputs = Dense(units=self._original_dim, activation=self._out_activation)(x) 219 # Create model. 220 decoder = Model(inputs=[latent_input, e_input, angle_input, geo_input], outputs=decoder_outputs, 221 name="decoder") 222 return decoder 223 224 def _manufacture_callbacks(self) -> List[Callback]: 225 """ 226 Based on parameters set by the user, manufacture callbacks required for training. 227 228 Returns: 229 A list of `Callback` objects. 230 231 """ 232 callbacks = [] 233 # If the early stopping flag is on then stop the training when a monitored metric (validation) has stopped 234 # improving after (patience) number of epochs. 235 if self._early_stop: 236 callbacks.append( 237 EarlyStopping(monitor="val_loss", 238 min_delta=self._min_delta, 239 patience=self._patience, 240 verbose=True, 241 restore_best_weights=True)) 242 # Save model after every epoch. 243 if self._save_model_every_epoch: 244 callbacks.append(ModelCheckpoint(filepath=f"{self._checkpoint_dir}/VAE_epoch_{{epoch:03}}/model_weights", 245 monitor="val_loss", 246 verbose=True, 247 save_weights_only=True, 248 mode="min", 249 save_freq="epoch")) 250 # Pass metadata to wandb. 251 callbacks.append(WandbCallback( 252 monitor="val_loss", verbose=0, mode="auto", save_model=False)) 253 return callbacks 254 255 def _get_train_and_val_data(self, dataset: np.array, e_cond: np.array, angle_cond: np.array, geo_cond: np.array, 256 noise: np.array, train_indexes: np.array, validation_indexes: np.array) \ 257 -> Tuple[Dataset, Dataset]: 258 """ 259 Splits data into train and validation set based on given lists of indexes. 260 261 """ 262 263 # Prepare training data. 264 train_dataset = dataset[train_indexes, :] 265 train_e_cond = e_cond[train_indexes] 266 train_angle_cond = angle_cond[train_indexes] 267 train_geo_cond = geo_cond[train_indexes, :] 268 train_noise = noise[train_indexes, :] 269 270 # Prepare validation data. 271 val_dataset = dataset[validation_indexes, :] 272 val_e_cond = e_cond[validation_indexes] 273 val_angle_cond = angle_cond[validation_indexes] 274 val_geo_cond = geo_cond[validation_indexes, :] 275 val_noise = noise[validation_indexes, :] 276 277 # Gather them into tuples. 278 train_x = (train_dataset, train_e_cond, train_angle_cond, train_geo_cond, train_noise) 279 train_y = train_dataset 280 val_x = (val_dataset, val_e_cond, val_angle_cond, val_geo_cond, val_noise) 281 val_y = val_dataset 282 283 # Wrap data in Dataset objects. 284 # TODO(@mdragula): This approach requires loading the whole data set to RAM. It 285 # would be better to read the data partially when needed. Also one should bare in mind that using tf.Dataset 286 # slows down training process. 287 train_data = Dataset.from_tensor_slices((train_x, train_y)) 288 val_data = Dataset.from_tensor_slices((val_x, val_y)) 289 290 # The batch size must now be set on the Dataset objects. 291 train_data = train_data.batch(self._batch_size) 292 val_data = val_data.batch(self._batch_size) 293 294 # Disable AutoShard. 295 options = tf.data.Options() 296 options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA 297 train_data = train_data.with_options(options) 298 val_data = val_data.with_options(options) 299 300 return train_data, val_data 301 302 def _k_fold_training(self, dataset: np.array, e_cond: np.array, angle_cond: np.array, geo_cond: np.array, 303 noise: np.array, callbacks: List[Callback], verbose: bool = True) -> List[History]: 304 """ 305 Performs K-fold cross validation training. 306 307 Number of fold is defined by (self._number_of_k_fold_splits). Always shuffle the dataset. 308 309 Args: 310 dataset: A matrix representing showers. Shape = 311 (number of samples, ORIGINAL_DIM = N_CELLS_Z * N_CELLS_R * N_CELLS_PHI). 312 e_cond: A matrix representing an energy for each sample. Shape = (number of samples, ). 313 angle_cond: A matrix representing an angle for each sample. Shape = (number of samples, ). 314 geo_cond: A matrix representing a geometry of the detector for each sample. Shape = (number of samples, 2). 315 noise: A matrix representing an additional noise needed to perform a reparametrization trick. 316 callbacks: A list of callback forwarded to the fitting function. 317 verbose: A boolean which says there the training should be performed in a verbose mode or not. 318 319 Returns: A list of `History` objects.`History.history` attribute is a record of training loss values and 320 metrics values at successive epochs, as well as validation loss values and validation metrics values (if 321 applicable). 322 323 """ 324 # TODO(@mdragula): KFold cross validation can be parallelized. Each fold is independent from each the others. 325 k_fold = KFold(n_splits=self._number_of_k_fold_splits, shuffle=True) 326 histories = [] 327 328 for i, (train_indexes, validation_indexes) in enumerate(k_fold.split(dataset)): 329 print(f"K-fold: {i + 1}/{self._number_of_k_fold_splits}...") 330 train_data, val_data = self._get_train_and_val_data(dataset, e_cond, angle_cond, geo_cond, noise, 331 train_indexes, validation_indexes) 332 333 self._build_and_compile_new_model() 334 335 history = self.model.fit(x=train_data, 336 shuffle=True, 337 epochs=self._epochs, 338 verbose=verbose, 339 validation_data=val_data, 340 callbacks=callbacks 341 ) 342 histories.append(history) 343 344 if self._save_best_model: 345 self.model.save_weights(f"{self._checkpoint_dir}/VAE_fold_{i + 1}/model_weights") 346 print(f"Best model from fold {i + 1} was saved.") 347 348 # Remove all unnecessary data from previous fold. 349 del self.model 350 del train_data 351 del val_data 352 tf.keras.backend.clear_session() 353 gc.collect() 354 355 return histories 356 357 def _single_training(self, dataset: np.array, e_cond: np.array, angle_cond: np.array, geo_cond: np.array, 358 noise: np.ndarray, callbacks: List[Callback], verbose: bool = True) -> List[History]: 359 """ 360 Performs a single training. 361 362 A fraction of dataset (self._validation_split) is used as a validation data. 363 364 Args: 365 dataset: A matrix representing showers. Shape = 366 (number of samples, ORIGINAL_DIM = N_CELLS_Z * N_CELLS_R * N_CELLS_PHI). 367 e_cond: A matrix representing an energy for each sample. Shape = (number of samples, ). 368 angle_cond: A matrix representing an angle for each sample. Shape = (number of samples, ). 369 geo_cond: A matrix representing a geometry of the detector for each sample. Shape = (number of samples, 2). 370 noise: A matrix representing an additional noise needed to perform a reparametrization trick. 371 callbacks: A list of callback forwarded to the fitting function. 372 verbose: A boolean which says there the training should be performed in a verbose mode or not. 373 374 Returns: A one-element list of `History` objects.`History.history` attribute is a record of training loss 375 values and metrics values at successive epochs, as well as validation loss values and validation metrics 376 values (if applicable). 377 378 """ 379 dataset_size, _ = dataset.shape 380 permutation = np.random.permutation(dataset_size) 381 split = int(dataset_size * self._validation_split) 382 train_indexes, validation_indexes = permutation[split:], permutation[:split] 383 384 train_data, val_data = self._get_train_and_val_data(dataset, e_cond, angle_cond, geo_cond, noise, train_indexes, 385 validation_indexes) 386 387 history = self.model.fit(x=train_data, 388 shuffle=True, 389 epochs=self._epochs, 390 verbose=verbose, 391 validation_data=val_data, 392 callbacks=callbacks 393 ) 394 if self._save_best_model: 395 self.model.save_weights(f"{self._checkpoint_dir}/VAE_best/model_weights") 396 print("Best model was saved.") 397 398 return [history] 399 400 def train(self, dataset: np.array, e_cond: np.array, angle_cond: np.array, geo_cond: np.array, 401 verbose: bool = True) -> List[History]: 402 """ 403 For a given input data trains and validates the model. 404 405 If the numer of K-fold splits > 1 then it runs K-fold cross validation, otherwise it runs a single training 406 which uses (self._validation_split * 100) % of dataset as a validation data. 407 408 Args: 409 dataset: A matrix representing showers. Shape = 410 (number of samples, ORIGINAL_DIM = N_CELLS_Z * N_CELLS_R * N_CELLS_PHI). 411 e_cond: A matrix representing an energy for each sample. Shape = (number of samples, ). 412 angle_cond: A matrix representing an angle for each sample. Shape = (number of samples, ). 413 geo_cond: A matrix representing a geometry of the detector for each sample. Shape = (number of samples, 2). 414 verbose: A boolean which says there the training should be performed in a verbose mode or not. 415 416 Returns: A list of `History` objects.`History.history` attribute is a record of training loss values and 417 metrics values at successive epochs, as well as validation loss values and validation metrics values (if 418 applicable). 419 420 """ 421 422 callbacks = self._manufacture_callbacks() 423 424 noise = np.random.normal(0, 1, size=(dataset.shape[0], self.latent_dim)) 425 426 if self._number_of_k_fold_splits > 1: 427 return self._k_fold_training(dataset, e_cond, angle_cond, geo_cond, noise, callbacks, verbose) 428 else: 429 return self._single_training(dataset, e_cond, angle_cond, geo_cond, noise, callbacks, verbose)