Geant4 Cross Reference

Cross-Referencing   Geant4
Geant4/examples/extended/parameterisations/Par04/training/convert.py

Version: [ ReleaseNotes ] [ 1.0 ] [ 1.1 ] [ 2.0 ] [ 3.0 ] [ 3.1 ] [ 3.2 ] [ 4.0 ] [ 4.0.p1 ] [ 4.0.p2 ] [ 4.1 ] [ 4.1.p1 ] [ 5.0 ] [ 5.0.p1 ] [ 5.1 ] [ 5.1.p1 ] [ 5.2 ] [ 5.2.p1 ] [ 5.2.p2 ] [ 6.0 ] [ 6.0.p1 ] [ 6.1 ] [ 6.2 ] [ 6.2.p1 ] [ 6.2.p2 ] [ 7.0 ] [ 7.0.p1 ] [ 7.1 ] [ 7.1.p1 ] [ 8.0 ] [ 8.0.p1 ] [ 8.1 ] [ 8.1.p1 ] [ 8.1.p2 ] [ 8.2 ] [ 8.2.p1 ] [ 8.3 ] [ 8.3.p1 ] [ 8.3.p2 ] [ 9.0 ] [ 9.0.p1 ] [ 9.0.p2 ] [ 9.1 ] [ 9.1.p1 ] [ 9.1.p2 ] [ 9.1.p3 ] [ 9.2 ] [ 9.2.p1 ] [ 9.2.p2 ] [ 9.2.p3 ] [ 9.2.p4 ] [ 9.3 ] [ 9.3.p1 ] [ 9.3.p2 ] [ 9.4 ] [ 9.4.p1 ] [ 9.4.p2 ] [ 9.4.p3 ] [ 9.4.p4 ] [ 9.5 ] [ 9.5.p1 ] [ 9.5.p2 ] [ 9.6 ] [ 9.6.p1 ] [ 9.6.p2 ] [ 9.6.p3 ] [ 9.6.p4 ] [ 10.0 ] [ 10.0.p1 ] [ 10.0.p2 ] [ 10.0.p3 ] [ 10.0.p4 ] [ 10.1 ] [ 10.1.p1 ] [ 10.1.p2 ] [ 10.1.p3 ] [ 10.2 ] [ 10.2.p1 ] [ 10.2.p2 ] [ 10.2.p3 ] [ 10.3 ] [ 10.3.p1 ] [ 10.3.p2 ] [ 10.3.p3 ] [ 10.4 ] [ 10.4.p1 ] [ 10.4.p2 ] [ 10.4.p3 ] [ 10.5 ] [ 10.5.p1 ] [ 10.6 ] [ 10.6.p1 ] [ 10.6.p2 ] [ 10.6.p3 ] [ 10.7 ] [ 10.7.p1 ] [ 10.7.p2 ] [ 10.7.p3 ] [ 10.7.p4 ] [ 11.0 ] [ 11.0.p1 ] [ 11.0.p2 ] [ 11.0.p3, ] [ 11.0.p4 ] [ 11.1 ] [ 11.1.1 ] [ 11.1.2 ] [ 11.1.3 ] [ 11.2 ] [ 11.2.1 ] [ 11.2.2 ] [ 11.3.0 ]

  1 """
  2 ** convert **
  3 defines the conversion function to and ONNX file
  4 """
  5 
  6 import argparse
  7 import sys
  8 
  9 import tf2onnx
 10 import numpy as np
 11 from onnxruntime import InferenceSession
 12 
 13 from core.constants import GLOBAL_CHECKPOINT_DIR, CONV_DIR, ORIGINAL_DIM
 14 from core.model import VAEHandler
 15 """
 16     epoch: epoch of the saved checkpoint model
 17     study-name: study-name for which the model is trained for
 18 """
 19 
 20 
 21 def parse_args(argv):
 22     p = argparse.ArgumentParser()
 23     p.add_argument("--epoch", type=int, default=None)
 24     p.add_argument("--study-name", type=str, default="default_study_name")
 25     args = p.parse_args()
 26     return args
 27 
 28 
 29 # main function
 30 def main(argv):
 31     # 1. Set up the model to convert
 32     # Parse commandline arguments
 33     args = parse_args(argv)
 34     epoch = args.epoch
 35     study_name = args.study_name
 36 
 37     # Instantiate and load a saved model
 38     vae = VAEHandler()
 39 
 40     # Load the saved weights
 41     weights_dir = f"VAE_epoch_{epoch:03}" if epoch is not None else "VAE_best"
 42     vae.model.load_weights(
 43         f"{GLOBAL_CHECKPOINT_DIR}/{study_name}/{weights_dir}/model_weights"
 44     ).expect_partial()
 45 
 46     # 2. Convert the model to ONNX format
 47     # Create the Keras model, convert it into an ONNX model, and save.
 48     keras_model = vae.model.decoder
 49     output_path = f"{CONV_DIR}/{study_name}/Generator_{weights_dir}.onnx"
 50     onnx_model = tf2onnx.convert.from_keras(keras_model,
 51                                             output_path=output_path)
 52 
 53     # Checking the converted model
 54     input_1 = np.random.randn(10).astype(np.float32).reshape(1, -1)
 55     input_2 = np.random.randn(1).astype(np.float32).reshape(1, -1)
 56     input_3 = np.random.randn(1).astype(np.float32).reshape(1, -1)
 57     input_4 = np.random.randn(2).astype(np.float32).reshape(1, -1)
 58 
 59     sess = InferenceSession(output_path)
 60     # TODO: @Piyush-555 Find a way to use predefined names
 61     result = sess.run(
 62         None, {
 63             'input_9': input_1,
 64             'input_6': input_2,
 65             'input_7': input_3,
 66             'input_8': input_4
 67         })
 68     assert result[0].shape[1] == ORIGINAL_DIM
 69 
 70 
 71 if __name__ == "__main__":
 72     exit(main(sys.argv[1:]))