Geant4 Cross Reference |
1 """ 1 """ 2 ** convert ** 2 ** convert ** 3 defines the conversion function to and ONNX fi 3 defines the conversion function to and ONNX file 4 """ 4 """ 5 5 6 import argparse 6 import argparse 7 import sys 7 import sys 8 8 9 import tf2onnx 9 import tf2onnx 10 import numpy as np 10 import numpy as np 11 from onnxruntime import InferenceSession 11 from onnxruntime import InferenceSession 12 12 13 from core.constants import GLOBAL_CHECKPOINT_D 13 from core.constants import GLOBAL_CHECKPOINT_DIR, CONV_DIR, ORIGINAL_DIM 14 from core.model import VAEHandler 14 from core.model import VAEHandler 15 """ 15 """ 16 epoch: epoch of the saved checkpoint model 16 epoch: epoch of the saved checkpoint model 17 study-name: study-name for which the model 17 study-name: study-name for which the model is trained for 18 """ 18 """ 19 19 20 20 21 def parse_args(argv): 21 def parse_args(argv): 22 p = argparse.ArgumentParser() 22 p = argparse.ArgumentParser() 23 p.add_argument("--epoch", type=int, defaul 23 p.add_argument("--epoch", type=int, default=None) 24 p.add_argument("--study-name", type=str, d 24 p.add_argument("--study-name", type=str, default="default_study_name") 25 args = p.parse_args() 25 args = p.parse_args() 26 return args 26 return args 27 27 28 28 29 # main function 29 # main function 30 def main(argv): 30 def main(argv): 31 # 1. Set up the model to convert 31 # 1. Set up the model to convert 32 # Parse commandline arguments 32 # Parse commandline arguments 33 args = parse_args(argv) 33 args = parse_args(argv) 34 epoch = args.epoch 34 epoch = args.epoch 35 study_name = args.study_name 35 study_name = args.study_name 36 36 37 # Instantiate and load a saved model 37 # Instantiate and load a saved model 38 vae = VAEHandler() 38 vae = VAEHandler() 39 39 40 # Load the saved weights 40 # Load the saved weights 41 weights_dir = f"VAE_epoch_{epoch:03}" if e 41 weights_dir = f"VAE_epoch_{epoch:03}" if epoch is not None else "VAE_best" 42 vae.model.load_weights( 42 vae.model.load_weights( 43 f"{GLOBAL_CHECKPOINT_DIR}/{study_name} 43 f"{GLOBAL_CHECKPOINT_DIR}/{study_name}/{weights_dir}/model_weights" 44 ).expect_partial() 44 ).expect_partial() 45 45 46 # 2. Convert the model to ONNX format 46 # 2. Convert the model to ONNX format 47 # Create the Keras model, convert it into 47 # Create the Keras model, convert it into an ONNX model, and save. 48 keras_model = vae.model.decoder 48 keras_model = vae.model.decoder 49 output_path = f"{CONV_DIR}/{study_name}/Ge 49 output_path = f"{CONV_DIR}/{study_name}/Generator_{weights_dir}.onnx" 50 onnx_model = tf2onnx.convert.from_keras(ke 50 onnx_model = tf2onnx.convert.from_keras(keras_model, 51 ou 51 output_path=output_path) 52 52 53 # Checking the converted model 53 # Checking the converted model 54 input_1 = np.random.randn(10).astype(np.fl 54 input_1 = np.random.randn(10).astype(np.float32).reshape(1, -1) 55 input_2 = np.random.randn(1).astype(np.flo 55 input_2 = np.random.randn(1).astype(np.float32).reshape(1, -1) 56 input_3 = np.random.randn(1).astype(np.flo 56 input_3 = np.random.randn(1).astype(np.float32).reshape(1, -1) 57 input_4 = np.random.randn(2).astype(np.flo 57 input_4 = np.random.randn(2).astype(np.float32).reshape(1, -1) 58 58 59 sess = InferenceSession(output_path) 59 sess = InferenceSession(output_path) 60 # TODO: @Piyush-555 Find a way to use pred 60 # TODO: @Piyush-555 Find a way to use predefined names 61 result = sess.run( 61 result = sess.run( 62 None, { 62 None, { 63 'input_9': input_1, 63 'input_9': input_1, 64 'input_6': input_2, 64 'input_6': input_2, 65 'input_7': input_3, 65 'input_7': input_3, 66 'input_8': input_4 66 'input_8': input_4 67 }) 67 }) 68 assert result[0].shape[1] == ORIGINAL_DIM 68 assert result[0].shape[1] == ORIGINAL_DIM 69 69 70 70 71 if __name__ == "__main__": 71 if __name__ == "__main__": 72 exit(main(sys.argv[1:])) 72 exit(main(sys.argv[1:]))