Geant4 Cross Reference

Cross-Referencing   Geant4
Geant4/examples/extended/parameterisations/Par04/src/Par04OnnxInference.cc

Version: [ ReleaseNotes ] [ 1.0 ] [ 1.1 ] [ 2.0 ] [ 3.0 ] [ 3.1 ] [ 3.2 ] [ 4.0 ] [ 4.0.p1 ] [ 4.0.p2 ] [ 4.1 ] [ 4.1.p1 ] [ 5.0 ] [ 5.0.p1 ] [ 5.1 ] [ 5.1.p1 ] [ 5.2 ] [ 5.2.p1 ] [ 5.2.p2 ] [ 6.0 ] [ 6.0.p1 ] [ 6.1 ] [ 6.2 ] [ 6.2.p1 ] [ 6.2.p2 ] [ 7.0 ] [ 7.0.p1 ] [ 7.1 ] [ 7.1.p1 ] [ 8.0 ] [ 8.0.p1 ] [ 8.1 ] [ 8.1.p1 ] [ 8.1.p2 ] [ 8.2 ] [ 8.2.p1 ] [ 8.3 ] [ 8.3.p1 ] [ 8.3.p2 ] [ 9.0 ] [ 9.0.p1 ] [ 9.0.p2 ] [ 9.1 ] [ 9.1.p1 ] [ 9.1.p2 ] [ 9.1.p3 ] [ 9.2 ] [ 9.2.p1 ] [ 9.2.p2 ] [ 9.2.p3 ] [ 9.2.p4 ] [ 9.3 ] [ 9.3.p1 ] [ 9.3.p2 ] [ 9.4 ] [ 9.4.p1 ] [ 9.4.p2 ] [ 9.4.p3 ] [ 9.4.p4 ] [ 9.5 ] [ 9.5.p1 ] [ 9.5.p2 ] [ 9.6 ] [ 9.6.p1 ] [ 9.6.p2 ] [ 9.6.p3 ] [ 9.6.p4 ] [ 10.0 ] [ 10.0.p1 ] [ 10.0.p2 ] [ 10.0.p3 ] [ 10.0.p4 ] [ 10.1 ] [ 10.1.p1 ] [ 10.1.p2 ] [ 10.1.p3 ] [ 10.2 ] [ 10.2.p1 ] [ 10.2.p2 ] [ 10.2.p3 ] [ 10.3 ] [ 10.3.p1 ] [ 10.3.p2 ] [ 10.3.p3 ] [ 10.4 ] [ 10.4.p1 ] [ 10.4.p2 ] [ 10.4.p3 ] [ 10.5 ] [ 10.5.p1 ] [ 10.6 ] [ 10.6.p1 ] [ 10.6.p2 ] [ 10.6.p3 ] [ 10.7 ] [ 10.7.p1 ] [ 10.7.p2 ] [ 10.7.p3 ] [ 10.7.p4 ] [ 11.0 ] [ 11.0.p1 ] [ 11.0.p2 ] [ 11.0.p3, ] [ 11.0.p4 ] [ 11.1 ] [ 11.1.1 ] [ 11.1.2 ] [ 11.1.3 ] [ 11.2 ] [ 11.2.1 ] [ 11.2.2 ] [ 11.3.0 ]

  1 //
  2 // ********************************************************************
  3 // * License and Disclaimer                                           *
  4 // *                                                                  *
  5 // * The  Geant4 software  is  copyright of the Copyright Holders  of *
  6 // * the Geant4 Collaboration.  It is provided  under  the terms  and *
  7 // * conditions of the Geant4 Software License,  included in the file *
  8 // * LICENSE and available at  http://cern.ch/geant4/license .  These *
  9 // * include a list of copyright holders.                             *
 10 // *                                                                  *
 11 // * Neither the authors of this software system, nor their employing *
 12 // * institutes,nor the agencies providing financial support for this *
 13 // * work  make  any representation or  warranty, express or implied, *
 14 // * regarding  this  software system or assume any liability for its *
 15 // * use.  Please see the license in the file  LICENSE  and URL above *
 16 // * for the full disclaimer and the limitation of liability.         *
 17 // *                                                                  *
 18 // * This  code  implementation is the result of  the  scientific and *
 19 // * technical work of the GEANT4 collaboration.                      *
 20 // * By using,  copying,  modifying or  distributing the software (or *
 21 // * any work based  on the software)  you  agree  to acknowledge its *
 22 // * use  in  resulting  scientific  publications,  and indicate your *
 23 // * acceptance of all terms of the Geant4 Software license.          *
 24 // ********************************************************************
 25 //
 26 #ifdef USE_INFERENCE_ONNX
 27 #  include "Par04OnnxInference.hh"
 28 
 29 #  include "Par04InferenceInterface.hh"  // for Par04InferenceInterface
 30 
 31 #  include <algorithm>  // for copy, max
 32 #  include <cassert>  // for assert
 33 #  include <cstddef>  // for size_t
 34 #  include <cstdint>  // for int64_t
 35 #  include <utility>  // for move
 36 
 37 #  include <core/session/onnxruntime_cxx_api.h>  // for Value, Session, Env
 38 #  ifdef USE_CUDA
 39 #    include "cuda_runtime_api.h"
 40 #  endif
 41 
 42 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
 43 
 44 Par04OnnxInference::Par04OnnxInference(G4String modelPath, G4int profileFlag, G4int optimizeFlag,
 45                                        G4int intraOpNumThreads, G4int cudaFlag,
 46                                        std::vector<const char*>& cuda_keys,
 47                                        std::vector<const char*>& cuda_values,
 48                                        G4String ModelSavePath, G4String profilingOutputSavePath)
 49 
 50   : Par04InferenceInterface()
 51 {
 52   // initialization of the enviroment and inference session
 53   auto envLocal = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "ENV");
 54   fEnv = std::move(envLocal);
 55   // Creating a OrtApi Class variable for getting access to C api, necessary for
 56   // CUDA
 57   const auto& ortApi = Ort::GetApi();
 58   fSessionOptions.SetIntraOpNumThreads(intraOpNumThreads);
 59   // graph optimizations of the model
 60   // if the flag is not set to true none of the optimizations will be applied
 61   // if it is set to true all the optimizations will be applied
 62   if (optimizeFlag) {
 63     fSessionOptions.SetOptimizedModelFilePath("opt-graph");
 64     fSessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
 65     // ORT_ENABLE_BASIC #### ORT_ENABLE_EXTENDED
 66   }
 67   else
 68     fSessionOptions.SetGraphOptimizationLevel(ORT_DISABLE_ALL);
 69 #  ifdef USE_CUDA
 70   if (cudaFlag) {
 71     OrtCUDAProviderOptionsV2* fCudaOptions = nullptr;
 72     // Initialize the CUDA provider options, fCudaOptions should now point to a
 73     // valid CUDA configuration.
 74     (void)ortApi.CreateCUDAProviderOptions(&fCudaOptions);
 75     // Update the CUDA provider options
 76     (void)ortApi.UpdateCUDAProviderOptions(fCudaOptions, cuda_keys.data(), cuda_values.data(),
 77                                            cuda_keys.size());
 78     // Append the CUDA execution provider to the session options, indicating to
 79     // use CUDA for execution
 80     (void)ortApi.SessionOptionsAppendExecutionProvider_CUDA_V2(fSessionOptions, fCudaOptions);
 81   }
 82 #  endif
 83   // save json file for model execution profiling
 84   if (profileFlag) fSessionOptions.EnableProfiling("opt.json");
 85 
 86   auto sessionLocal = std::make_unique<Ort::Session>(*fEnv, modelPath, fSessionOptions);
 87   fSession = std::move(sessionLocal);
 88   fInfo = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemTypeDefault);
 89 }
 90 
 91 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
 92 
 93 void Par04OnnxInference::RunInference(std::vector<float> aGenVector,
 94                                       std::vector<G4double>& aEnergies, int aSize)
 95 {
 96   // input nodes
 97   Ort::AllocatorWithDefaultOptions allocator;
 98 #  if ORT_API_VERSION < 13
 99   // Before 1.13 we have to roll our own unique_ptr wrapper here
100   auto allocDeleter = [&allocator](char* p) {
101     allocator.Free(p);
102   };
103   using AllocatedStringPtr = std::unique_ptr<char, decltype(allocDeleter)>;
104 #  endif
105   std::vector<int64_t> input_node_dims;
106   size_t num_input_nodes = fSession->GetInputCount();
107   std::vector<const char*> input_node_names(num_input_nodes);
108   for (std::size_t i = 0; i < num_input_nodes; i++) {
109 #  if ORT_API_VERSION < 13
110     const auto input_name =
111       AllocatedStringPtr(fSession->GetInputName(i, allocator), allocDeleter).release();
112 #  else
113     const auto input_name = fSession->GetInputNameAllocated(i, allocator).release();
114 #  endif
115     fInames = {input_name};
116     input_node_names[i] = input_name;
117     Ort::TypeInfo type_info = fSession->GetInputTypeInfo(i);
118     auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
119     input_node_dims = tensor_info.GetShape();
120     for (std::size_t j = 0; j < input_node_dims.size(); j++) {
121       if (input_node_dims[j] < 0) input_node_dims[j] = 1;
122     }
123   }
124   // output nodes
125   std::vector<int64_t> output_node_dims;
126   size_t num_output_nodes = fSession->GetOutputCount();
127   std::vector<const char*> output_node_names(num_output_nodes);
128   for (std::size_t i = 0; i < num_output_nodes; i++) {
129 #  if ORT_API_VERSION < 13
130     const auto output_name =
131       AllocatedStringPtr(fSession->GetOutputName(i, allocator), allocDeleter).release();
132 #  else
133     const auto output_name = fSession->GetOutputNameAllocated(i, allocator).release();
134 #  endif
135     output_node_names[i] = output_name;
136     Ort::TypeInfo type_info = fSession->GetOutputTypeInfo(i);
137     auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
138     output_node_dims = tensor_info.GetShape();
139     for (std::size_t j = 0; j < output_node_dims.size(); j++) {
140       if (output_node_dims[j] < 0) output_node_dims[j] = 1;
141     }
142   }
143 
144   // create input tensor object from data values
145   std::vector<int64_t> dims = {1, (unsigned)(aGenVector.size())};
146   Ort::Value Input_noise_tensor = Ort::Value::CreateTensor<float>(
147     fInfo, aGenVector.data(), aGenVector.size(), dims.data(), dims.size());
148   assert(Input_noise_tensor.IsTensor());
149   std::vector<Ort::Value> ort_inputs;
150   ort_inputs.push_back(std::move(Input_noise_tensor));
151   // run the inference session
152   std::vector<Ort::Value> ort_outputs =
153     fSession->Run(Ort::RunOptions{nullptr}, fInames.data(), ort_inputs.data(), ort_inputs.size(),
154                   output_node_names.data(), output_node_names.size());
155   // get pointer to output tensor float values
156   float* floatarr = ort_outputs.front().GetTensorMutableData<float>();
157   aEnergies.assign(aSize, 0);
158   for (int i = 0; i < aSize; ++i)
159     aEnergies[i] = floatarr[i];
160 }
161 
162 #endif
163