Geant4 Cross Reference |
1 """ 2 ** convert ** 3 defines the conversion function to and ONNX fi 4 """ 5 6 import argparse 7 import sys 8 9 import tf2onnx 10 import numpy as np 11 from onnxruntime import InferenceSession 12 13 from core.constants import GLOBAL_CHECKPOINT_D 14 from core.model import VAEHandler 15 """ 16 epoch: epoch of the saved checkpoint model 17 study-name: study-name for which the model 18 """ 19 20 21 def parse_args(argv): 22 p = argparse.ArgumentParser() 23 p.add_argument("--epoch", type=int, defaul 24 p.add_argument("--study-name", type=str, d 25 args = p.parse_args() 26 return args 27 28 29 # main function 30 def main(argv): 31 # 1. Set up the model to convert 32 # Parse commandline arguments 33 args = parse_args(argv) 34 epoch = args.epoch 35 study_name = args.study_name 36 37 # Instantiate and load a saved model 38 vae = VAEHandler() 39 40 # Load the saved weights 41 weights_dir = f"VAE_epoch_{epoch:03}" if e 42 vae.model.load_weights( 43 f"{GLOBAL_CHECKPOINT_DIR}/{study_name} 44 ).expect_partial() 45 46 # 2. Convert the model to ONNX format 47 # Create the Keras model, convert it into 48 keras_model = vae.model.decoder 49 output_path = f"{CONV_DIR}/{study_name}/Ge 50 onnx_model = tf2onnx.convert.from_keras(ke 51 ou 52 53 # Checking the converted model 54 input_1 = np.random.randn(10).astype(np.fl 55 input_2 = np.random.randn(1).astype(np.flo 56 input_3 = np.random.randn(1).astype(np.flo 57 input_4 = np.random.randn(2).astype(np.flo 58 59 sess = InferenceSession(output_path) 60 # TODO: @Piyush-555 Find a way to use pred 61 result = sess.run( 62 None, { 63 'input_9': input_1, 64 'input_6': input_2, 65 'input_7': input_3, 66 'input_8': input_4 67 }) 68 assert result[0].shape[1] == ORIGINAL_DIM 69 70 71 if __name__ == "__main__": 72 exit(main(sys.argv[1:]))