Geant4 Cross Reference

Cross-Referencing   Geant4
Geant4/examples/extended/parameterisations/Par04/src/Par04OnnxInference.cc

Version: [ ReleaseNotes ] [ 1.0 ] [ 1.1 ] [ 2.0 ] [ 3.0 ] [ 3.1 ] [ 3.2 ] [ 4.0 ] [ 4.0.p1 ] [ 4.0.p2 ] [ 4.1 ] [ 4.1.p1 ] [ 5.0 ] [ 5.0.p1 ] [ 5.1 ] [ 5.1.p1 ] [ 5.2 ] [ 5.2.p1 ] [ 5.2.p2 ] [ 6.0 ] [ 6.0.p1 ] [ 6.1 ] [ 6.2 ] [ 6.2.p1 ] [ 6.2.p2 ] [ 7.0 ] [ 7.0.p1 ] [ 7.1 ] [ 7.1.p1 ] [ 8.0 ] [ 8.0.p1 ] [ 8.1 ] [ 8.1.p1 ] [ 8.1.p2 ] [ 8.2 ] [ 8.2.p1 ] [ 8.3 ] [ 8.3.p1 ] [ 8.3.p2 ] [ 9.0 ] [ 9.0.p1 ] [ 9.0.p2 ] [ 9.1 ] [ 9.1.p1 ] [ 9.1.p2 ] [ 9.1.p3 ] [ 9.2 ] [ 9.2.p1 ] [ 9.2.p2 ] [ 9.2.p3 ] [ 9.2.p4 ] [ 9.3 ] [ 9.3.p1 ] [ 9.3.p2 ] [ 9.4 ] [ 9.4.p1 ] [ 9.4.p2 ] [ 9.4.p3 ] [ 9.4.p4 ] [ 9.5 ] [ 9.5.p1 ] [ 9.5.p2 ] [ 9.6 ] [ 9.6.p1 ] [ 9.6.p2 ] [ 9.6.p3 ] [ 9.6.p4 ] [ 10.0 ] [ 10.0.p1 ] [ 10.0.p2 ] [ 10.0.p3 ] [ 10.0.p4 ] [ 10.1 ] [ 10.1.p1 ] [ 10.1.p2 ] [ 10.1.p3 ] [ 10.2 ] [ 10.2.p1 ] [ 10.2.p2 ] [ 10.2.p3 ] [ 10.3 ] [ 10.3.p1 ] [ 10.3.p2 ] [ 10.3.p3 ] [ 10.4 ] [ 10.4.p1 ] [ 10.4.p2 ] [ 10.4.p3 ] [ 10.5 ] [ 10.5.p1 ] [ 10.6 ] [ 10.6.p1 ] [ 10.6.p2 ] [ 10.6.p3 ] [ 10.7 ] [ 10.7.p1 ] [ 10.7.p2 ] [ 10.7.p3 ] [ 10.7.p4 ] [ 11.0 ] [ 11.0.p1 ] [ 11.0.p2 ] [ 11.0.p3, ] [ 11.0.p4 ] [ 11.1 ] [ 11.1.1 ] [ 11.1.2 ] [ 11.1.3 ] [ 11.2 ] [ 11.2.1 ] [ 11.2.2 ] [ 11.3.0 ]

Diff markup

Differences between /examples/extended/parameterisations/Par04/src/Par04OnnxInference.cc (Version 11.3.0) and /examples/extended/parameterisations/Par04/src/Par04OnnxInference.cc (Version 10.3.p1)


  1 //                                                  1 
  2 // *******************************************    
  3 // * License and Disclaimer                       
  4 // *                                              
  5 // * The  Geant4 software  is  copyright of th    
  6 // * the Geant4 Collaboration.  It is provided    
  7 // * conditions of the Geant4 Software License    
  8 // * LICENSE and available at  http://cern.ch/    
  9 // * include a list of copyright holders.         
 10 // *                                              
 11 // * Neither the authors of this software syst    
 12 // * institutes,nor the agencies providing fin    
 13 // * work  make  any representation or  warran    
 14 // * regarding  this  software system or assum    
 15 // * use.  Please see the license in the file     
 16 // * for the full disclaimer and the limitatio    
 17 // *                                              
 18 // * This  code  implementation is the result     
 19 // * technical work of the GEANT4 collaboratio    
 20 // * By using,  copying,  modifying or  distri    
 21 // * any work based  on the software)  you  ag    
 22 // * use  in  resulting  scientific  publicati    
 23 // * acceptance of all terms of the Geant4 Sof    
 24 // *******************************************    
 25 //                                                
 26 #ifdef USE_INFERENCE_ONNX                         
 27 #  include "Par04OnnxInference.hh"                
 28                                                   
 29 #  include "Par04InferenceInterface.hh"  // fo    
 30                                                   
 31 #  include <algorithm>  // for copy, max          
 32 #  include <cassert>  // for assert               
 33 #  include <cstddef>  // for size_t               
 34 #  include <cstdint>  // for int64_t              
 35 #  include <utility>  // for move                 
 36                                                   
 37 #  include <core/session/onnxruntime_cxx_api.h    
 38 #  ifdef USE_CUDA                                 
 39 #    include "cuda_runtime_api.h"                 
 40 #  endif                                          
 41                                                   
 42 //....oooOO0OOooo........oooOO0OOooo........oo    
 43                                                   
 44 Par04OnnxInference::Par04OnnxInference(G4Strin    
 45                                        G4int i    
 46                                        std::ve    
 47                                        std::ve    
 48                                        G4Strin    
 49                                                   
 50   : Par04InferenceInterface()                     
 51 {                                                 
 52   // initialization of the enviroment and infe    
 53   auto envLocal = std::make_unique<Ort::Env>(O    
 54   fEnv = std::move(envLocal);                     
 55   // Creating a OrtApi Class variable for gett    
 56   // CUDA                                         
 57   const auto& ortApi = Ort::GetApi();             
 58   fSessionOptions.SetIntraOpNumThreads(intraOp    
 59   // graph optimizations of the model             
 60   // if the flag is not set to true none of th    
 61   // if it is set to true all the optimization    
 62   if (optimizeFlag) {                             
 63     fSessionOptions.SetOptimizedModelFilePath(    
 64     fSessionOptions.SetGraphOptimizationLevel(    
 65     // ORT_ENABLE_BASIC #### ORT_ENABLE_EXTEND    
 66   }                                               
 67   else                                            
 68     fSessionOptions.SetGraphOptimizationLevel(    
 69 #  ifdef USE_CUDA                                 
 70   if (cudaFlag) {                                 
 71     OrtCUDAProviderOptionsV2* fCudaOptions = n    
 72     // Initialize the CUDA provider options, f    
 73     // valid CUDA configuration.                  
 74     (void)ortApi.CreateCUDAProviderOptions(&fC    
 75     // Update the CUDA provider options           
 76     (void)ortApi.UpdateCUDAProviderOptions(fCu    
 77                                            cud    
 78     // Append the CUDA execution provider to t    
 79     // use CUDA for execution                     
 80     (void)ortApi.SessionOptionsAppendExecution    
 81   }                                               
 82 #  endif                                          
 83   // save json file for model execution profil    
 84   if (profileFlag) fSessionOptions.EnableProfi    
 85                                                   
 86   auto sessionLocal = std::make_unique<Ort::Se    
 87   fSession = std::move(sessionLocal);             
 88   fInfo = Ort::MemoryInfo::CreateCpu(OrtAlloca    
 89 }                                                 
 90                                                   
 91 //....oooOO0OOooo........oooOO0OOooo........oo    
 92                                                   
 93 void Par04OnnxInference::RunInference(std::vec    
 94                                       std::vec    
 95 {                                                 
 96   // input nodes                                  
 97   Ort::AllocatorWithDefaultOptions allocator;     
 98 #  if ORT_API_VERSION < 13                        
 99   // Before 1.13 we have to roll our own uniqu    
100   auto allocDeleter = [&allocator](char* p) {     
101     allocator.Free(p);                            
102   };                                              
103   using AllocatedStringPtr = std::unique_ptr<c    
104 #  endif                                          
105   std::vector<int64_t> input_node_dims;           
106   size_t num_input_nodes = fSession->GetInputC    
107   std::vector<const char*> input_node_names(nu    
108   for (std::size_t i = 0; i < num_input_nodes;    
109 #  if ORT_API_VERSION < 13                        
110     const auto input_name =                       
111       AllocatedStringPtr(fSession->GetInputNam    
112 #  else                                           
113     const auto input_name = fSession->GetInput    
114 #  endif                                          
115     fInames = {input_name};                       
116     input_node_names[i] = input_name;             
117     Ort::TypeInfo type_info = fSession->GetInp    
118     auto tensor_info = type_info.GetTensorType    
119     input_node_dims = tensor_info.GetShape();     
120     for (std::size_t j = 0; j < input_node_dim    
121       if (input_node_dims[j] < 0) input_node_d    
122     }                                             
123   }                                               
124   // output nodes                                 
125   std::vector<int64_t> output_node_dims;          
126   size_t num_output_nodes = fSession->GetOutpu    
127   std::vector<const char*> output_node_names(n    
128   for (std::size_t i = 0; i < num_output_nodes    
129 #  if ORT_API_VERSION < 13                        
130     const auto output_name =                      
131       AllocatedStringPtr(fSession->GetOutputNa    
132 #  else                                           
133     const auto output_name = fSession->GetOutp    
134 #  endif                                          
135     output_node_names[i] = output_name;           
136     Ort::TypeInfo type_info = fSession->GetOut    
137     auto tensor_info = type_info.GetTensorType    
138     output_node_dims = tensor_info.GetShape();    
139     for (std::size_t j = 0; j < output_node_di    
140       if (output_node_dims[j] < 0) output_node    
141     }                                             
142   }                                               
143                                                   
144   // create input tensor object from data valu    
145   std::vector<int64_t> dims = {1, (unsigned)(a    
146   Ort::Value Input_noise_tensor = Ort::Value::    
147     fInfo, aGenVector.data(), aGenVector.size(    
148   assert(Input_noise_tensor.IsTensor());          
149   std::vector<Ort::Value> ort_inputs;             
150   ort_inputs.push_back(std::move(Input_noise_t    
151   // run the inference session                    
152   std::vector<Ort::Value> ort_outputs =           
153     fSession->Run(Ort::RunOptions{nullptr}, fI    
154                   output_node_names.data(), ou    
155   // get pointer to output tensor float values    
156   float* floatarr = ort_outputs.front().GetTen    
157   aEnergies.assign(aSize, 0);                     
158   for (int i = 0; i < aSize; ++i)                 
159     aEnergies[i] = floatarr[i];                   
160 }                                                 
161                                                   
162 #endif                                            
163