Geant4 Cross Reference

Cross-Referencing   Geant4
Geant4/examples/extended/parameterisations/Par04/training/convert.py

Version: [ ReleaseNotes ] [ 1.0 ] [ 1.1 ] [ 2.0 ] [ 3.0 ] [ 3.1 ] [ 3.2 ] [ 4.0 ] [ 4.0.p1 ] [ 4.0.p2 ] [ 4.1 ] [ 4.1.p1 ] [ 5.0 ] [ 5.0.p1 ] [ 5.1 ] [ 5.1.p1 ] [ 5.2 ] [ 5.2.p1 ] [ 5.2.p2 ] [ 6.0 ] [ 6.0.p1 ] [ 6.1 ] [ 6.2 ] [ 6.2.p1 ] [ 6.2.p2 ] [ 7.0 ] [ 7.0.p1 ] [ 7.1 ] [ 7.1.p1 ] [ 8.0 ] [ 8.0.p1 ] [ 8.1 ] [ 8.1.p1 ] [ 8.1.p2 ] [ 8.2 ] [ 8.2.p1 ] [ 8.3 ] [ 8.3.p1 ] [ 8.3.p2 ] [ 9.0 ] [ 9.0.p1 ] [ 9.0.p2 ] [ 9.1 ] [ 9.1.p1 ] [ 9.1.p2 ] [ 9.1.p3 ] [ 9.2 ] [ 9.2.p1 ] [ 9.2.p2 ] [ 9.2.p3 ] [ 9.2.p4 ] [ 9.3 ] [ 9.3.p1 ] [ 9.3.p2 ] [ 9.4 ] [ 9.4.p1 ] [ 9.4.p2 ] [ 9.4.p3 ] [ 9.4.p4 ] [ 9.5 ] [ 9.5.p1 ] [ 9.5.p2 ] [ 9.6 ] [ 9.6.p1 ] [ 9.6.p2 ] [ 9.6.p3 ] [ 9.6.p4 ] [ 10.0 ] [ 10.0.p1 ] [ 10.0.p2 ] [ 10.0.p3 ] [ 10.0.p4 ] [ 10.1 ] [ 10.1.p1 ] [ 10.1.p2 ] [ 10.1.p3 ] [ 10.2 ] [ 10.2.p1 ] [ 10.2.p2 ] [ 10.2.p3 ] [ 10.3 ] [ 10.3.p1 ] [ 10.3.p2 ] [ 10.3.p3 ] [ 10.4 ] [ 10.4.p1 ] [ 10.4.p2 ] [ 10.4.p3 ] [ 10.5 ] [ 10.5.p1 ] [ 10.6 ] [ 10.6.p1 ] [ 10.6.p2 ] [ 10.6.p3 ] [ 10.7 ] [ 10.7.p1 ] [ 10.7.p2 ] [ 10.7.p3 ] [ 10.7.p4 ] [ 11.0 ] [ 11.0.p1 ] [ 11.0.p2 ] [ 11.0.p3, ] [ 11.0.p4 ] [ 11.1 ] [ 11.1.1 ] [ 11.1.2 ] [ 11.1.3 ] [ 11.2 ] [ 11.2.1 ] [ 11.2.2 ] [ 11.3.0 ]

Diff markup

Differences between /examples/extended/parameterisations/Par04/training/convert.py (Version 11.3.0) and /examples/extended/parameterisations/Par04/training/convert.py (Version 10.0)


  1 """                                               
  2 ** convert **                                     
  3 defines the conversion function to and ONNX fi    
  4 """                                               
  5                                                   
  6 import argparse                                   
  7 import sys                                        
  8                                                   
  9 import tf2onnx                                    
 10 import numpy as np                                
 11 from onnxruntime import InferenceSession          
 12                                                   
 13 from core.constants import GLOBAL_CHECKPOINT_D    
 14 from core.model import VAEHandler                 
 15 """                                               
 16     epoch: epoch of the saved checkpoint model    
 17     study-name: study-name for which the model    
 18 """                                               
 19                                                   
 20                                                   
 21 def parse_args(argv):                             
 22     p = argparse.ArgumentParser()                 
 23     p.add_argument("--epoch", type=int, defaul    
 24     p.add_argument("--study-name", type=str, d    
 25     args = p.parse_args()                         
 26     return args                                   
 27                                                   
 28                                                   
 29 # main function                                   
 30 def main(argv):                                   
 31     # 1. Set up the model to convert              
 32     # Parse commandline arguments                 
 33     args = parse_args(argv)                       
 34     epoch = args.epoch                            
 35     study_name = args.study_name                  
 36                                                   
 37     # Instantiate and load a saved model          
 38     vae = VAEHandler()                            
 39                                                   
 40     # Load the saved weights                      
 41     weights_dir = f"VAE_epoch_{epoch:03}" if e    
 42     vae.model.load_weights(                       
 43         f"{GLOBAL_CHECKPOINT_DIR}/{study_name}    
 44     ).expect_partial()                            
 45                                                   
 46     # 2. Convert the model to ONNX format         
 47     # Create the Keras model, convert it into     
 48     keras_model = vae.model.decoder               
 49     output_path = f"{CONV_DIR}/{study_name}/Ge    
 50     onnx_model = tf2onnx.convert.from_keras(ke    
 51                                             ou    
 52                                                   
 53     # Checking the converted model                
 54     input_1 = np.random.randn(10).astype(np.fl    
 55     input_2 = np.random.randn(1).astype(np.flo    
 56     input_3 = np.random.randn(1).astype(np.flo    
 57     input_4 = np.random.randn(2).astype(np.flo    
 58                                                   
 59     sess = InferenceSession(output_path)          
 60     # TODO: @Piyush-555 Find a way to use pred    
 61     result = sess.run(                            
 62         None, {                                   
 63             'input_9': input_1,                   
 64             'input_6': input_2,                   
 65             'input_7': input_3,                   
 66             'input_8': input_4                    
 67         })                                        
 68     assert result[0].shape[1] == ORIGINAL_DIM     
 69                                                   
 70                                                   
 71 if __name__ == "__main__":                        
 72     exit(main(sys.argv[1:]))