Geant4 Cross Reference

Cross-Referencing   Geant4
Geant4/examples/extended/parameterisations/Par04/training/convert.py

Version: [ ReleaseNotes ] [ 1.0 ] [ 1.1 ] [ 2.0 ] [ 3.0 ] [ 3.1 ] [ 3.2 ] [ 4.0 ] [ 4.0.p1 ] [ 4.0.p2 ] [ 4.1 ] [ 4.1.p1 ] [ 5.0 ] [ 5.0.p1 ] [ 5.1 ] [ 5.1.p1 ] [ 5.2 ] [ 5.2.p1 ] [ 5.2.p2 ] [ 6.0 ] [ 6.0.p1 ] [ 6.1 ] [ 6.2 ] [ 6.2.p1 ] [ 6.2.p2 ] [ 7.0 ] [ 7.0.p1 ] [ 7.1 ] [ 7.1.p1 ] [ 8.0 ] [ 8.0.p1 ] [ 8.1 ] [ 8.1.p1 ] [ 8.1.p2 ] [ 8.2 ] [ 8.2.p1 ] [ 8.3 ] [ 8.3.p1 ] [ 8.3.p2 ] [ 9.0 ] [ 9.0.p1 ] [ 9.0.p2 ] [ 9.1 ] [ 9.1.p1 ] [ 9.1.p2 ] [ 9.1.p3 ] [ 9.2 ] [ 9.2.p1 ] [ 9.2.p2 ] [ 9.2.p3 ] [ 9.2.p4 ] [ 9.3 ] [ 9.3.p1 ] [ 9.3.p2 ] [ 9.4 ] [ 9.4.p1 ] [ 9.4.p2 ] [ 9.4.p3 ] [ 9.4.p4 ] [ 9.5 ] [ 9.5.p1 ] [ 9.5.p2 ] [ 9.6 ] [ 9.6.p1 ] [ 9.6.p2 ] [ 9.6.p3 ] [ 9.6.p4 ] [ 10.0 ] [ 10.0.p1 ] [ 10.0.p2 ] [ 10.0.p3 ] [ 10.0.p4 ] [ 10.1 ] [ 10.1.p1 ] [ 10.1.p2 ] [ 10.1.p3 ] [ 10.2 ] [ 10.2.p1 ] [ 10.2.p2 ] [ 10.2.p3 ] [ 10.3 ] [ 10.3.p1 ] [ 10.3.p2 ] [ 10.3.p3 ] [ 10.4 ] [ 10.4.p1 ] [ 10.4.p2 ] [ 10.4.p3 ] [ 10.5 ] [ 10.5.p1 ] [ 10.6 ] [ 10.6.p1 ] [ 10.6.p2 ] [ 10.6.p3 ] [ 10.7 ] [ 10.7.p1 ] [ 10.7.p2 ] [ 10.7.p3 ] [ 10.7.p4 ] [ 11.0 ] [ 11.0.p1 ] [ 11.0.p2 ] [ 11.0.p3, ] [ 11.0.p4 ] [ 11.1 ] [ 11.1.1 ] [ 11.1.2 ] [ 11.1.3 ] [ 11.2 ] [ 11.2.1 ] [ 11.2.2 ] [ 11.3.0 ]

Diff markup

Differences between /examples/extended/parameterisations/Par04/training/convert.py (Version 11.3.0) and /examples/extended/parameterisations/Par04/training/convert.py (Version 11.2.1)


  1 """                                                 1 """
  2 ** convert **                                       2 ** convert **
  3 defines the conversion function to and ONNX fi      3 defines the conversion function to and ONNX file
  4 """                                                 4 """
  5                                                     5 
  6 import argparse                                     6 import argparse
  7 import sys                                          7 import sys
  8                                                     8 
  9 import tf2onnx                                      9 import tf2onnx
 10 import numpy as np                                 10 import numpy as np
 11 from onnxruntime import InferenceSession           11 from onnxruntime import InferenceSession
 12                                                    12 
 13 from core.constants import GLOBAL_CHECKPOINT_D     13 from core.constants import GLOBAL_CHECKPOINT_DIR, CONV_DIR, ORIGINAL_DIM
 14 from core.model import VAEHandler                  14 from core.model import VAEHandler
 15 """                                                15 """
 16     epoch: epoch of the saved checkpoint model     16     epoch: epoch of the saved checkpoint model
 17     study-name: study-name for which the model     17     study-name: study-name for which the model is trained for
 18 """                                                18 """
 19                                                    19 
 20                                                    20 
 21 def parse_args(argv):                              21 def parse_args(argv):
 22     p = argparse.ArgumentParser()                  22     p = argparse.ArgumentParser()
 23     p.add_argument("--epoch", type=int, defaul     23     p.add_argument("--epoch", type=int, default=None)
 24     p.add_argument("--study-name", type=str, d     24     p.add_argument("--study-name", type=str, default="default_study_name")
 25     args = p.parse_args()                          25     args = p.parse_args()
 26     return args                                    26     return args
 27                                                    27 
 28                                                    28 
 29 # main function                                    29 # main function
 30 def main(argv):                                    30 def main(argv):
 31     # 1. Set up the model to convert               31     # 1. Set up the model to convert
 32     # Parse commandline arguments                  32     # Parse commandline arguments
 33     args = parse_args(argv)                        33     args = parse_args(argv)
 34     epoch = args.epoch                             34     epoch = args.epoch
 35     study_name = args.study_name                   35     study_name = args.study_name
 36                                                    36 
 37     # Instantiate and load a saved model           37     # Instantiate and load a saved model
 38     vae = VAEHandler()                             38     vae = VAEHandler()
 39                                                    39 
 40     # Load the saved weights                       40     # Load the saved weights
 41     weights_dir = f"VAE_epoch_{epoch:03}" if e     41     weights_dir = f"VAE_epoch_{epoch:03}" if epoch is not None else "VAE_best"
 42     vae.model.load_weights(                        42     vae.model.load_weights(
 43         f"{GLOBAL_CHECKPOINT_DIR}/{study_name}     43         f"{GLOBAL_CHECKPOINT_DIR}/{study_name}/{weights_dir}/model_weights"
 44     ).expect_partial()                             44     ).expect_partial()
 45                                                    45 
 46     # 2. Convert the model to ONNX format          46     # 2. Convert the model to ONNX format
 47     # Create the Keras model, convert it into      47     # Create the Keras model, convert it into an ONNX model, and save.
 48     keras_model = vae.model.decoder                48     keras_model = vae.model.decoder
 49     output_path = f"{CONV_DIR}/{study_name}/Ge     49     output_path = f"{CONV_DIR}/{study_name}/Generator_{weights_dir}.onnx"
 50     onnx_model = tf2onnx.convert.from_keras(ke     50     onnx_model = tf2onnx.convert.from_keras(keras_model,
 51                                             ou     51                                             output_path=output_path)
 52                                                    52 
 53     # Checking the converted model                 53     # Checking the converted model
 54     input_1 = np.random.randn(10).astype(np.fl     54     input_1 = np.random.randn(10).astype(np.float32).reshape(1, -1)
 55     input_2 = np.random.randn(1).astype(np.flo     55     input_2 = np.random.randn(1).astype(np.float32).reshape(1, -1)
 56     input_3 = np.random.randn(1).astype(np.flo     56     input_3 = np.random.randn(1).astype(np.float32).reshape(1, -1)
 57     input_4 = np.random.randn(2).astype(np.flo     57     input_4 = np.random.randn(2).astype(np.float32).reshape(1, -1)
 58                                                    58 
 59     sess = InferenceSession(output_path)           59     sess = InferenceSession(output_path)
 60     # TODO: @Piyush-555 Find a way to use pred     60     # TODO: @Piyush-555 Find a way to use predefined names
 61     result = sess.run(                             61     result = sess.run(
 62         None, {                                    62         None, {
 63             'input_9': input_1,                    63             'input_9': input_1,
 64             'input_6': input_2,                    64             'input_6': input_2,
 65             'input_7': input_3,                    65             'input_7': input_3,
 66             'input_8': input_4                     66             'input_8': input_4
 67         })                                         67         })
 68     assert result[0].shape[1] == ORIGINAL_DIM      68     assert result[0].shape[1] == ORIGINAL_DIM
 69                                                    69 
 70                                                    70 
 71 if __name__ == "__main__":                         71 if __name__ == "__main__":
 72     exit(main(sys.argv[1:]))                       72     exit(main(sys.argv[1:]))