Geant4 Cross Reference

Cross-Referencing   Geant4
Geant4/examples/extended/parameterisations/Par04/training/core/model.py

Version: [ ReleaseNotes ] [ 1.0 ] [ 1.1 ] [ 2.0 ] [ 3.0 ] [ 3.1 ] [ 3.2 ] [ 4.0 ] [ 4.0.p1 ] [ 4.0.p2 ] [ 4.1 ] [ 4.1.p1 ] [ 5.0 ] [ 5.0.p1 ] [ 5.1 ] [ 5.1.p1 ] [ 5.2 ] [ 5.2.p1 ] [ 5.2.p2 ] [ 6.0 ] [ 6.0.p1 ] [ 6.1 ] [ 6.2 ] [ 6.2.p1 ] [ 6.2.p2 ] [ 7.0 ] [ 7.0.p1 ] [ 7.1 ] [ 7.1.p1 ] [ 8.0 ] [ 8.0.p1 ] [ 8.1 ] [ 8.1.p1 ] [ 8.1.p2 ] [ 8.2 ] [ 8.2.p1 ] [ 8.3 ] [ 8.3.p1 ] [ 8.3.p2 ] [ 9.0 ] [ 9.0.p1 ] [ 9.0.p2 ] [ 9.1 ] [ 9.1.p1 ] [ 9.1.p2 ] [ 9.1.p3 ] [ 9.2 ] [ 9.2.p1 ] [ 9.2.p2 ] [ 9.2.p3 ] [ 9.2.p4 ] [ 9.3 ] [ 9.3.p1 ] [ 9.3.p2 ] [ 9.4 ] [ 9.4.p1 ] [ 9.4.p2 ] [ 9.4.p3 ] [ 9.4.p4 ] [ 9.5 ] [ 9.5.p1 ] [ 9.5.p2 ] [ 9.6 ] [ 9.6.p1 ] [ 9.6.p2 ] [ 9.6.p3 ] [ 9.6.p4 ] [ 10.0 ] [ 10.0.p1 ] [ 10.0.p2 ] [ 10.0.p3 ] [ 10.0.p4 ] [ 10.1 ] [ 10.1.p1 ] [ 10.1.p2 ] [ 10.1.p3 ] [ 10.2 ] [ 10.2.p1 ] [ 10.2.p2 ] [ 10.2.p3 ] [ 10.3 ] [ 10.3.p1 ] [ 10.3.p2 ] [ 10.3.p3 ] [ 10.4 ] [ 10.4.p1 ] [ 10.4.p2 ] [ 10.4.p3 ] [ 10.5 ] [ 10.5.p1 ] [ 10.6 ] [ 10.6.p1 ] [ 10.6.p2 ] [ 10.6.p3 ] [ 10.7 ] [ 10.7.p1 ] [ 10.7.p2 ] [ 10.7.p3 ] [ 10.7.p4 ] [ 11.0 ] [ 11.0.p1 ] [ 11.0.p2 ] [ 11.0.p3, ] [ 11.0.p4 ] [ 11.1 ] [ 11.1.1 ] [ 11.1.2 ] [ 11.1.3 ] [ 11.2 ] [ 11.2.1 ] [ 11.2.2 ] [ 11.3.0 ]

  1 import gc
  2 from dataclasses import dataclass, field
  3 from typing import List, Tuple
  4 
  5 import numpy as np
  6 import tensorflow as tf
  7 import wandb
  8 from sklearn.model_selection import KFold
  9 from tensorflow.keras import backend as K
 10 from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, History, Callback
 11 from tensorflow.keras.layers import BatchNormalization, Input, Dense, Layer, concatenate
 12 from tensorflow.keras.losses import BinaryCrossentropy, Reduction
 13 from tensorflow.keras.models import Model
 14 from tensorflow.python.data import Dataset
 15 from tensorflow.python.distribute.distribute_lib import Strategy
 16 from tensorflow.python.distribute.mirrored_strategy import MirroredStrategy
 17 from wandb.keras import WandbCallback
 18 
 19 from core.constants import ORIGINAL_DIM, LATENT_DIM, BATCH_SIZE_PER_REPLICA, EPOCHS, LEARNING_RATE, ACTIVATION, \
 20     OUT_ACTIVATION, OPTIMIZER_TYPE, KERNEL_INITIALIZER, GLOBAL_CHECKPOINT_DIR, EARLY_STOP, BIAS_INITIALIZER, \
 21     INTERMEDIATE_DIMS, SAVE_MODEL_EVERY_EPOCH, SAVE_BEST_MODEL, PATIENCE, MIN_DELTA, BEST_MODEL_FILENAME, \
 22     NUMBER_OF_K_FOLD_SPLITS, VALIDATION_SPLIT, WANDB_ENTITY
 23 from utils.optimizer import OptimizerFactory, OptimizerType
 24 
 25 
 26 class _Sampling(Layer):
 27     """ Custom layer to do the reparameterization trick: sample random latent vectors z from the latent Gaussian
 28     distribution.
 29 
 30     The sampled vector z is given by sampled_z = mean + std * epsilon
 31     """
 32 
 33     def __call__(self, inputs, **kwargs):
 34         z_mean, z_log_var, epsilon = inputs
 35         z_sigma = K.exp(0.5 * z_log_var)
 36         return z_mean + z_sigma * epsilon
 37 
 38 
 39 # KL divergence computation
 40 class _KLDivergenceLayer(Layer):
 41 
 42     def call(self, inputs, **kwargs):
 43         mu, log_var = inputs
 44         kl_loss = -0.5 * (1 + log_var - K.square(mu) - K.exp(log_var))
 45         kl_loss = K.mean(K.sum(kl_loss, axis=-1))
 46         self.add_loss(kl_loss)
 47         return inputs
 48 
 49 
 50 class VAE(Model):
 51     def get_config(self):
 52         config = super().get_config()
 53         config["encoder"] = self.encoder
 54         config["decoder"] = self.decoder
 55         return config
 56 
 57     def call(self, inputs, training=None, mask=None):
 58         _, e_input, angle_input, geo_input, _ = inputs
 59         z = self.encoder(inputs)
 60         return self.decoder([z, e_input, angle_input, geo_input])
 61 
 62     def __init__(self, encoder, decoder, **kwargs):
 63         super(VAE, self).__init__(**kwargs)
 64         self.encoder = encoder
 65         self.decoder = decoder
 66         self._set_inputs(inputs=self.encoder.inputs, outputs=self(self.encoder.inputs))
 67 
 68 
 69 @dataclass
 70 class VAEHandler:
 71     """
 72     Class to handle building and training VAE models.
 73     """
 74     _wandb_project_name: str = None
 75     _wandb_tags: List[str] = field(default_factory=list)
 76     _original_dim: int = ORIGINAL_DIM
 77     latent_dim: int = LATENT_DIM
 78     _batch_size_per_replica: int = BATCH_SIZE_PER_REPLICA
 79     _intermediate_dims: List[int] = field(default_factory=lambda: INTERMEDIATE_DIMS)
 80     _learning_rate: float = LEARNING_RATE
 81     _epochs: int = EPOCHS
 82     _activation: str = ACTIVATION
 83     _out_activation: str = OUT_ACTIVATION
 84     _number_of_k_fold_splits: float = NUMBER_OF_K_FOLD_SPLITS
 85     _optimizer_type: OptimizerType = OPTIMIZER_TYPE
 86     _kernel_initializer: str = KERNEL_INITIALIZER
 87     _bias_initializer: str = BIAS_INITIALIZER
 88     _checkpoint_dir: str = GLOBAL_CHECKPOINT_DIR
 89     _early_stop: bool = EARLY_STOP
 90     _save_model_every_epoch: bool = SAVE_MODEL_EVERY_EPOCH
 91     _save_best_model: bool = SAVE_BEST_MODEL
 92     _patience: int = PATIENCE
 93     _min_delta: float = MIN_DELTA
 94     _best_model_filename: str = BEST_MODEL_FILENAME
 95     _validation_split: float = VALIDATION_SPLIT
 96     _strategy: Strategy = MirroredStrategy()
 97 
 98     def __post_init__(self) -> None:
 99         # Calculate true batch size.
100         self._batch_size = self._batch_size_per_replica * self._strategy.num_replicas_in_sync
101         self._build_and_compile_new_model()
102         # Setup Wandb.
103         if self._wandb_project_name is not None:
104             self._setup_wandb()
105 
106     def _setup_wandb(self) -> None:
107         config = {
108             "learning_rate": self._learning_rate,
109             "batch_size": self._batch_size,
110             "epochs": self._epochs,
111             "optimizer_type": self._optimizer_type,
112             "intermediate_dims": self._intermediate_dims,
113             "latent_dim": self.latent_dim
114         }
115         # Reinit flag is needed for hyperparameter tuning. Whenever new training is started, new Wandb run should be
116         # created.
117         wandb.init(project=self._wandb_project_name, entity=WANDB_ENTITY, reinit=True, config=config,
118                    tags=self._wandb_tags)
119 
120     def _build_and_compile_new_model(self) -> None:
121         """ Builds and compiles a new model.
122 
123         VAEHandler keep a list of VAE instance. The reason is that while k-fold cross validation is performed,
124         each fold requires a new, clear instance of model. New model is always added at the end of the list of
125         existing ones.
126 
127         Returns: None
128 
129         """
130         # Build encoder and decoder.
131         encoder = self._build_encoder()
132         decoder = self._build_decoder()
133 
134         # Compile model within a distributed strategy.
135         with self._strategy.scope():
136             # Build VAE.
137             self.model = VAE(encoder, decoder)
138             # Manufacture an optimizer and compile model with.
139             optimizer = OptimizerFactory.create_optimizer(self._optimizer_type, self._learning_rate)
140             reconstruction_loss = BinaryCrossentropy(reduction=Reduction.SUM)
141             self.model.compile(optimizer=optimizer, loss=[reconstruction_loss], loss_weights=[ORIGINAL_DIM])
142 
143     def _prepare_input_layers(self, for_encoder: bool) -> List[Input]:
144         """
145         Create four Input layers. Each of them is responsible to take respectively: batch of showers/batch of latent
146         vectors, batch of energies, batch of angles, batch of geometries.
147 
148         Args:
149             for_encoder: Boolean which decides whether an input is full dimensional shower or a latent vector.
150 
151         Returns:
152             List of Input layers (five for encoder and four for decoder).
153 
154         """
155         e_input = Input(shape=(1,))
156         angle_input = Input(shape=(1,))
157         geo_input = Input(shape=(2,))
158         if for_encoder:
159             x_input = Input(shape=self._original_dim)
160             eps_input = Input(shape=self.latent_dim)
161             return [x_input, e_input, angle_input, geo_input, eps_input]
162         else:
163             x_input = Input(shape=self.latent_dim)
164             return [x_input, e_input, angle_input, geo_input]
165 
166     def _build_encoder(self) -> Model:
167         """ Based on a list of intermediate dimensions, activation function and initializers for kernel and bias builds
168         the encoder.
169 
170         Returns:
171              Encoder is returned as a keras.Model.
172 
173         """
174 
175         with self._strategy.scope():
176             # Prepare input layer.
177             x_input, e_input, angle_input, geo_input, eps_input = self._prepare_input_layers(for_encoder=True)
178             x = concatenate([x_input, e_input, angle_input, geo_input])
179             # Construct hidden layers (Dense and Batch Normalization).
180             for intermediate_dim in self._intermediate_dims:
181                 x = Dense(units=intermediate_dim, activation=self._activation,
182                           kernel_initializer=self._kernel_initializer,
183                           bias_initializer=self._bias_initializer)(x)
184                 x = BatchNormalization()(x)
185             # Add Dense layer to get description of multidimensional Gaussian distribution in terms of mean
186             # and log(variance).
187             z_mean = Dense(self.latent_dim, name="z_mean")(x)
188             z_log_var = Dense(self.latent_dim, name="z_log_var")(x)
189             # Add KLDivergenceLayer responsible for calculation of KL loss.
190             z_mean, z_log_var = _KLDivergenceLayer()([z_mean, z_log_var])
191             # Sample a probe from the distribution.
192             encoder_output = _Sampling()([z_mean, z_log_var, eps_input])
193             # Create model.
194             encoder = Model(inputs=[x_input, e_input, angle_input, geo_input, eps_input], outputs=encoder_output,
195                             name="encoder")
196         return encoder
197 
198     def _build_decoder(self) -> Model:
199         """ Based on a list of intermediate dimensions, activation function and initializers for kernel and bias builds
200         the decoder.
201 
202         Returns:
203              Decoder is returned as a keras.Model.
204 
205         """
206 
207         with self._strategy.scope():
208             # Prepare input layer.
209             latent_input, e_input, angle_input, geo_input = self._prepare_input_layers(for_encoder=False)
210             x = concatenate([latent_input, e_input, angle_input, geo_input])
211             # Construct hidden layers (Dense and Batch Normalization).
212             for intermediate_dim in reversed(self._intermediate_dims):
213                 x = Dense(units=intermediate_dim, activation=self._activation,
214                           kernel_initializer=self._kernel_initializer,
215                           bias_initializer=self._bias_initializer)(x)
216                 x = BatchNormalization()(x)
217             # Add Dense layer to get output which shape is compatible in an input's shape.
218             decoder_outputs = Dense(units=self._original_dim, activation=self._out_activation)(x)
219             # Create model.
220             decoder = Model(inputs=[latent_input, e_input, angle_input, geo_input], outputs=decoder_outputs,
221                             name="decoder")
222         return decoder
223 
224     def _manufacture_callbacks(self) -> List[Callback]:
225         """
226         Based on parameters set by the user, manufacture callbacks required for training.
227 
228         Returns:
229             A list of `Callback` objects.
230 
231         """
232         callbacks = []
233         # If the early stopping flag is on then stop the training when a monitored metric (validation) has stopped
234         # improving after (patience) number of epochs.
235         if self._early_stop:
236             callbacks.append(
237                 EarlyStopping(monitor="val_loss",
238                               min_delta=self._min_delta,
239                               patience=self._patience,
240                               verbose=True,
241                               restore_best_weights=True))
242         # Save model after every epoch.
243         if self._save_model_every_epoch:
244             callbacks.append(ModelCheckpoint(filepath=f"{self._checkpoint_dir}/VAE_epoch_{{epoch:03}}/model_weights",
245                                              monitor="val_loss",
246                                              verbose=True,
247                                              save_weights_only=True,
248                                              mode="min",
249                                              save_freq="epoch"))
250         # Pass metadata to wandb.
251         callbacks.append(WandbCallback(
252             monitor="val_loss", verbose=0, mode="auto", save_model=False))
253         return callbacks
254 
255     def _get_train_and_val_data(self, dataset: np.array, e_cond: np.array, angle_cond: np.array, geo_cond: np.array,
256                                 noise: np.array, train_indexes: np.array, validation_indexes: np.array) \
257             -> Tuple[Dataset, Dataset]:
258         """
259         Splits data into train and validation set based on given lists of indexes.
260 
261         """
262 
263         # Prepare training data.
264         train_dataset = dataset[train_indexes, :]
265         train_e_cond = e_cond[train_indexes]
266         train_angle_cond = angle_cond[train_indexes]
267         train_geo_cond = geo_cond[train_indexes, :]
268         train_noise = noise[train_indexes, :]
269 
270         # Prepare validation data.
271         val_dataset = dataset[validation_indexes, :]
272         val_e_cond = e_cond[validation_indexes]
273         val_angle_cond = angle_cond[validation_indexes]
274         val_geo_cond = geo_cond[validation_indexes, :]
275         val_noise = noise[validation_indexes, :]
276 
277         # Gather them into tuples.
278         train_x = (train_dataset, train_e_cond, train_angle_cond, train_geo_cond, train_noise)
279         train_y = train_dataset
280         val_x = (val_dataset, val_e_cond, val_angle_cond, val_geo_cond, val_noise)
281         val_y = val_dataset
282 
283         # Wrap data in Dataset objects.
284         # TODO(@mdragula): This approach requires loading the whole data set to RAM. It
285         #  would be better to read the data partially when needed. Also one should bare in mind that using tf.Dataset
286         #  slows down training process.
287         train_data = Dataset.from_tensor_slices((train_x, train_y))
288         val_data = Dataset.from_tensor_slices((val_x, val_y))
289 
290         # The batch size must now be set on the Dataset objects.
291         train_data = train_data.batch(self._batch_size)
292         val_data = val_data.batch(self._batch_size)
293 
294         # Disable AutoShard.
295         options = tf.data.Options()
296         options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
297         train_data = train_data.with_options(options)
298         val_data = val_data.with_options(options)
299 
300         return train_data, val_data
301 
302     def _k_fold_training(self, dataset: np.array, e_cond: np.array, angle_cond: np.array, geo_cond: np.array,
303                          noise: np.array, callbacks: List[Callback], verbose: bool = True) -> List[History]:
304         """
305         Performs K-fold cross validation training.
306 
307         Number of fold is defined by (self._number_of_k_fold_splits). Always shuffle the dataset.
308 
309         Args:
310             dataset: A matrix representing showers. Shape =
311                 (number of samples, ORIGINAL_DIM = N_CELLS_Z * N_CELLS_R * N_CELLS_PHI).
312             e_cond: A matrix representing an energy for each sample. Shape = (number of samples, ).
313             angle_cond: A matrix representing an angle for each sample. Shape = (number of samples, ).
314             geo_cond: A matrix representing a geometry of the detector for each sample. Shape = (number of samples, 2).
315             noise: A matrix representing an additional noise needed to perform a reparametrization trick.
316             callbacks: A list of callback forwarded to the fitting function.
317             verbose: A boolean which says there the training should be performed in a verbose mode or not.
318 
319         Returns: A list of `History` objects.`History.history` attribute is a record of training loss values and
320         metrics values at successive epochs, as well as validation loss values and validation metrics values (if
321         applicable).
322 
323         """
324         # TODO(@mdragula): KFold cross validation can be parallelized. Each fold is independent from each the others.
325         k_fold = KFold(n_splits=self._number_of_k_fold_splits, shuffle=True)
326         histories = []
327 
328         for i, (train_indexes, validation_indexes) in enumerate(k_fold.split(dataset)):
329             print(f"K-fold: {i + 1}/{self._number_of_k_fold_splits}...")
330             train_data, val_data = self._get_train_and_val_data(dataset, e_cond, angle_cond, geo_cond, noise,
331                                                                 train_indexes, validation_indexes)
332 
333             self._build_and_compile_new_model()
334 
335             history = self.model.fit(x=train_data,
336                                      shuffle=True,
337                                      epochs=self._epochs,
338                                      verbose=verbose,
339                                      validation_data=val_data,
340                                      callbacks=callbacks
341                                      )
342             histories.append(history)
343 
344             if self._save_best_model:
345                 self.model.save_weights(f"{self._checkpoint_dir}/VAE_fold_{i + 1}/model_weights")
346                 print(f"Best model from fold {i + 1} was saved.")
347 
348             # Remove all unnecessary data from previous fold.
349             del self.model
350             del train_data
351             del val_data
352             tf.keras.backend.clear_session()
353             gc.collect()
354 
355         return histories
356 
357     def _single_training(self, dataset: np.array, e_cond: np.array, angle_cond: np.array, geo_cond: np.array,
358                          noise: np.ndarray, callbacks: List[Callback], verbose: bool = True) -> List[History]:
359         """
360         Performs a single training.
361 
362         A fraction of dataset (self._validation_split) is used as a validation data.
363 
364         Args:
365             dataset: A matrix representing showers. Shape =
366                 (number of samples, ORIGINAL_DIM = N_CELLS_Z * N_CELLS_R * N_CELLS_PHI).
367             e_cond: A matrix representing an energy for each sample. Shape = (number of samples, ).
368             angle_cond: A matrix representing an angle for each sample. Shape = (number of samples, ).
369             geo_cond: A matrix representing a geometry of the detector for each sample. Shape = (number of samples, 2).
370             noise: A matrix representing an additional noise needed to perform a reparametrization trick.
371             callbacks: A list of callback forwarded to the fitting function.
372             verbose: A boolean which says there the training should be performed in a verbose mode or not.
373 
374         Returns: A one-element list of `History` objects.`History.history` attribute is a record of training loss
375         values and metrics values at successive epochs, as well as validation loss values and validation metrics
376         values (if applicable).
377 
378         """
379         dataset_size, _ = dataset.shape
380         permutation = np.random.permutation(dataset_size)
381         split = int(dataset_size * self._validation_split)
382         train_indexes, validation_indexes = permutation[split:], permutation[:split]
383 
384         train_data, val_data = self._get_train_and_val_data(dataset, e_cond, angle_cond, geo_cond, noise, train_indexes,
385                                                             validation_indexes)
386 
387         history = self.model.fit(x=train_data,
388                                  shuffle=True,
389                                  epochs=self._epochs,
390                                  verbose=verbose,
391                                  validation_data=val_data,
392                                  callbacks=callbacks
393                                  )
394         if self._save_best_model:
395             self.model.save_weights(f"{self._checkpoint_dir}/VAE_best/model_weights")
396             print("Best model was saved.")
397 
398         return [history]
399 
400     def train(self, dataset: np.array, e_cond: np.array, angle_cond: np.array, geo_cond: np.array,
401               verbose: bool = True) -> List[History]:
402         """
403         For a given input data trains and validates the model.
404 
405         If the numer of K-fold splits > 1 then it runs K-fold cross validation, otherwise it runs a single training
406         which uses (self._validation_split * 100) % of dataset as a validation data.
407 
408         Args:
409             dataset: A matrix representing showers. Shape =
410                 (number of samples, ORIGINAL_DIM = N_CELLS_Z * N_CELLS_R * N_CELLS_PHI).
411             e_cond: A matrix representing an energy for each sample. Shape = (number of samples, ).
412             angle_cond: A matrix representing an angle for each sample. Shape = (number of samples, ).
413             geo_cond: A matrix representing a geometry of the detector for each sample. Shape = (number of samples, 2).
414             verbose: A boolean which says there the training should be performed in a verbose mode or not.
415 
416         Returns: A list of `History` objects.`History.history` attribute is a record of training loss values and
417         metrics values at successive epochs, as well as validation loss values and validation metrics values (if
418         applicable).
419 
420         """
421 
422         callbacks = self._manufacture_callbacks()
423 
424         noise = np.random.normal(0, 1, size=(dataset.shape[0], self.latent_dim))
425 
426         if self._number_of_k_fold_splits > 1:
427             return self._k_fold_training(dataset, e_cond, angle_cond, geo_cond, noise, callbacks, verbose)
428         else:
429             return self._single_training(dataset, e_cond, angle_cond, geo_cond, noise, callbacks, verbose)