Geant4 Cross Reference |
1 // 2 // ******************************************************************** 3 // * License and Disclaimer * 4 // * * 5 // * The Geant4 software is copyright of the Copyright Holders of * 6 // * the Geant4 Collaboration. It is provided under the terms and * 7 // * conditions of the Geant4 Software License, included in the file * 8 // * LICENSE and available at http://cern.ch/geant4/license . These * 9 // * include a list of copyright holders. * 10 // * * 11 // * Neither the authors of this software system, nor their employing * 12 // * institutes,nor the agencies providing financial support for this * 13 // * work make any representation or warranty, express or implied, * 14 // * regarding this software system or assume any liability for its * 15 // * use. Please see the license in the file LICENSE and URL above * 16 // * for the full disclaimer and the limitation of liability. * 17 // * * 18 // * This code implementation is the result of the scientific and * 19 // * technical work of the GEANT4 collaboration. * 20 // * By using, copying, modifying or distributing the software (or * 21 // * any work based on the software) you agree to acknowledge its * 22 // * use in resulting scientific publications, and indicate your * 23 // * acceptance of all terms of the Geant4 Software license. * 24 // ******************************************************************** 25 // 26 #ifdef USE_INFERENCE_ONNX 27 # include "Par04OnnxInference.hh" 28 29 # include "Par04InferenceInterface.hh" // for Par04InferenceInterface 30 31 # include <algorithm> // for copy, max 32 # include <cassert> // for assert 33 # include <cstddef> // for size_t 34 # include <cstdint> // for int64_t 35 # include <utility> // for move 36 37 # include <core/session/onnxruntime_cxx_api.h> // for Value, Session, Env 38 # ifdef USE_CUDA 39 # include "cuda_runtime_api.h" 40 # endif 41 42 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo...... 43 44 Par04OnnxInference::Par04OnnxInference(G4String modelPath, G4int profileFlag, G4int optimizeFlag, 45 G4int intraOpNumThreads, G4int cudaFlag, 46 std::vector<const char*>& cuda_keys, 47 std::vector<const char*>& cuda_values, 48 G4String ModelSavePath, G4String profilingOutputSavePath) 49 50 : Par04InferenceInterface() 51 { 52 // initialization of the enviroment and inference session 53 auto envLocal = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "ENV"); 54 fEnv = std::move(envLocal); 55 // Creating a OrtApi Class variable for getting access to C api, necessary for 56 // CUDA 57 const auto& ortApi = Ort::GetApi(); 58 fSessionOptions.SetIntraOpNumThreads(intraOpNumThreads); 59 // graph optimizations of the model 60 // if the flag is not set to true none of the optimizations will be applied 61 // if it is set to true all the optimizations will be applied 62 if (optimizeFlag) { 63 fSessionOptions.SetOptimizedModelFilePath("opt-graph"); 64 fSessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_ALL); 65 // ORT_ENABLE_BASIC #### ORT_ENABLE_EXTENDED 66 } 67 else 68 fSessionOptions.SetGraphOptimizationLevel(ORT_DISABLE_ALL); 69 # ifdef USE_CUDA 70 if (cudaFlag) { 71 OrtCUDAProviderOptionsV2* fCudaOptions = nullptr; 72 // Initialize the CUDA provider options, fCudaOptions should now point to a 73 // valid CUDA configuration. 74 (void)ortApi.CreateCUDAProviderOptions(&fCudaOptions); 75 // Update the CUDA provider options 76 (void)ortApi.UpdateCUDAProviderOptions(fCudaOptions, cuda_keys.data(), cuda_values.data(), 77 cuda_keys.size()); 78 // Append the CUDA execution provider to the session options, indicating to 79 // use CUDA for execution 80 (void)ortApi.SessionOptionsAppendExecutionProvider_CUDA_V2(fSessionOptions, fCudaOptions); 81 } 82 # endif 83 // save json file for model execution profiling 84 if (profileFlag) fSessionOptions.EnableProfiling("opt.json"); 85 86 auto sessionLocal = std::make_unique<Ort::Session>(*fEnv, modelPath, fSessionOptions); 87 fSession = std::move(sessionLocal); 88 fInfo = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemTypeDefault); 89 } 90 91 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo...... 92 93 void Par04OnnxInference::RunInference(std::vector<float> aGenVector, 94 std::vector<G4double>& aEnergies, int aSize) 95 { 96 // input nodes 97 Ort::AllocatorWithDefaultOptions allocator; 98 # if ORT_API_VERSION < 13 99 // Before 1.13 we have to roll our own unique_ptr wrapper here 100 auto allocDeleter = [&allocator](char* p) { 101 allocator.Free(p); 102 }; 103 using AllocatedStringPtr = std::unique_ptr<char, decltype(allocDeleter)>; 104 # endif 105 std::vector<int64_t> input_node_dims; 106 size_t num_input_nodes = fSession->GetInputCount(); 107 std::vector<const char*> input_node_names(num_input_nodes); 108 for (std::size_t i = 0; i < num_input_nodes; i++) { 109 # if ORT_API_VERSION < 13 110 const auto input_name = 111 AllocatedStringPtr(fSession->GetInputName(i, allocator), allocDeleter).release(); 112 # else 113 const auto input_name = fSession->GetInputNameAllocated(i, allocator).release(); 114 # endif 115 fInames = {input_name}; 116 input_node_names[i] = input_name; 117 Ort::TypeInfo type_info = fSession->GetInputTypeInfo(i); 118 auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); 119 input_node_dims = tensor_info.GetShape(); 120 for (std::size_t j = 0; j < input_node_dims.size(); j++) { 121 if (input_node_dims[j] < 0) input_node_dims[j] = 1; 122 } 123 } 124 // output nodes 125 std::vector<int64_t> output_node_dims; 126 size_t num_output_nodes = fSession->GetOutputCount(); 127 std::vector<const char*> output_node_names(num_output_nodes); 128 for (std::size_t i = 0; i < num_output_nodes; i++) { 129 # if ORT_API_VERSION < 13 130 const auto output_name = 131 AllocatedStringPtr(fSession->GetOutputName(i, allocator), allocDeleter).release(); 132 # else 133 const auto output_name = fSession->GetOutputNameAllocated(i, allocator).release(); 134 # endif 135 output_node_names[i] = output_name; 136 Ort::TypeInfo type_info = fSession->GetOutputTypeInfo(i); 137 auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); 138 output_node_dims = tensor_info.GetShape(); 139 for (std::size_t j = 0; j < output_node_dims.size(); j++) { 140 if (output_node_dims[j] < 0) output_node_dims[j] = 1; 141 } 142 } 143 144 // create input tensor object from data values 145 std::vector<int64_t> dims = {1, (unsigned)(aGenVector.size())}; 146 Ort::Value Input_noise_tensor = Ort::Value::CreateTensor<float>( 147 fInfo, aGenVector.data(), aGenVector.size(), dims.data(), dims.size()); 148 assert(Input_noise_tensor.IsTensor()); 149 std::vector<Ort::Value> ort_inputs; 150 ort_inputs.push_back(std::move(Input_noise_tensor)); 151 // run the inference session 152 std::vector<Ort::Value> ort_outputs = 153 fSession->Run(Ort::RunOptions{nullptr}, fInames.data(), ort_inputs.data(), ort_inputs.size(), 154 output_node_names.data(), output_node_names.size()); 155 // get pointer to output tensor float values 156 float* floatarr = ort_outputs.front().GetTensorMutableData<float>(); 157 aEnergies.assign(aSize, 0); 158 for (int i = 0; i < aSize; ++i) 159 aEnergies[i] = floatarr[i]; 160 } 161 162 #endif 163