Geant4 Cross Reference |
1 // 1 2 // ******************************************* 3 // * License and Disclaimer 4 // * 5 // * The Geant4 software is copyright of th 6 // * the Geant4 Collaboration. It is provided 7 // * conditions of the Geant4 Software License 8 // * LICENSE and available at http://cern.ch/ 9 // * include a list of copyright holders. 10 // * 11 // * Neither the authors of this software syst 12 // * institutes,nor the agencies providing fin 13 // * work make any representation or warran 14 // * regarding this software system or assum 15 // * use. Please see the license in the file 16 // * for the full disclaimer and the limitatio 17 // * 18 // * This code implementation is the result 19 // * technical work of the GEANT4 collaboratio 20 // * By using, copying, modifying or distri 21 // * any work based on the software) you ag 22 // * use in resulting scientific publicati 23 // * acceptance of all terms of the Geant4 Sof 24 // ******************************************* 25 // 26 #ifdef USE_INFERENCE_ONNX 27 # include "Par04OnnxInference.hh" 28 29 # include "Par04InferenceInterface.hh" // fo 30 31 # include <algorithm> // for copy, max 32 # include <cassert> // for assert 33 # include <cstddef> // for size_t 34 # include <cstdint> // for int64_t 35 # include <utility> // for move 36 37 # include <core/session/onnxruntime_cxx_api.h 38 # ifdef USE_CUDA 39 # include "cuda_runtime_api.h" 40 # endif 41 42 //....oooOO0OOooo........oooOO0OOooo........oo 43 44 Par04OnnxInference::Par04OnnxInference(G4Strin 45 G4int i 46 std::ve 47 std::ve 48 G4Strin 49 50 : Par04InferenceInterface() 51 { 52 // initialization of the enviroment and infe 53 auto envLocal = std::make_unique<Ort::Env>(O 54 fEnv = std::move(envLocal); 55 // Creating a OrtApi Class variable for gett 56 // CUDA 57 const auto& ortApi = Ort::GetApi(); 58 fSessionOptions.SetIntraOpNumThreads(intraOp 59 // graph optimizations of the model 60 // if the flag is not set to true none of th 61 // if it is set to true all the optimization 62 if (optimizeFlag) { 63 fSessionOptions.SetOptimizedModelFilePath( 64 fSessionOptions.SetGraphOptimizationLevel( 65 // ORT_ENABLE_BASIC #### ORT_ENABLE_EXTEND 66 } 67 else 68 fSessionOptions.SetGraphOptimizationLevel( 69 # ifdef USE_CUDA 70 if (cudaFlag) { 71 OrtCUDAProviderOptionsV2* fCudaOptions = n 72 // Initialize the CUDA provider options, f 73 // valid CUDA configuration. 74 (void)ortApi.CreateCUDAProviderOptions(&fC 75 // Update the CUDA provider options 76 (void)ortApi.UpdateCUDAProviderOptions(fCu 77 cud 78 // Append the CUDA execution provider to t 79 // use CUDA for execution 80 (void)ortApi.SessionOptionsAppendExecution 81 } 82 # endif 83 // save json file for model execution profil 84 if (profileFlag) fSessionOptions.EnableProfi 85 86 auto sessionLocal = std::make_unique<Ort::Se 87 fSession = std::move(sessionLocal); 88 fInfo = Ort::MemoryInfo::CreateCpu(OrtAlloca 89 } 90 91 //....oooOO0OOooo........oooOO0OOooo........oo 92 93 void Par04OnnxInference::RunInference(std::vec 94 std::vec 95 { 96 // input nodes 97 Ort::AllocatorWithDefaultOptions allocator; 98 # if ORT_API_VERSION < 13 99 // Before 1.13 we have to roll our own uniqu 100 auto allocDeleter = [&allocator](char* p) { 101 allocator.Free(p); 102 }; 103 using AllocatedStringPtr = std::unique_ptr<c 104 # endif 105 std::vector<int64_t> input_node_dims; 106 size_t num_input_nodes = fSession->GetInputC 107 std::vector<const char*> input_node_names(nu 108 for (std::size_t i = 0; i < num_input_nodes; 109 # if ORT_API_VERSION < 13 110 const auto input_name = 111 AllocatedStringPtr(fSession->GetInputNam 112 # else 113 const auto input_name = fSession->GetInput 114 # endif 115 fInames = {input_name}; 116 input_node_names[i] = input_name; 117 Ort::TypeInfo type_info = fSession->GetInp 118 auto tensor_info = type_info.GetTensorType 119 input_node_dims = tensor_info.GetShape(); 120 for (std::size_t j = 0; j < input_node_dim 121 if (input_node_dims[j] < 0) input_node_d 122 } 123 } 124 // output nodes 125 std::vector<int64_t> output_node_dims; 126 size_t num_output_nodes = fSession->GetOutpu 127 std::vector<const char*> output_node_names(n 128 for (std::size_t i = 0; i < num_output_nodes 129 # if ORT_API_VERSION < 13 130 const auto output_name = 131 AllocatedStringPtr(fSession->GetOutputNa 132 # else 133 const auto output_name = fSession->GetOutp 134 # endif 135 output_node_names[i] = output_name; 136 Ort::TypeInfo type_info = fSession->GetOut 137 auto tensor_info = type_info.GetTensorType 138 output_node_dims = tensor_info.GetShape(); 139 for (std::size_t j = 0; j < output_node_di 140 if (output_node_dims[j] < 0) output_node 141 } 142 } 143 144 // create input tensor object from data valu 145 std::vector<int64_t> dims = {1, (unsigned)(a 146 Ort::Value Input_noise_tensor = Ort::Value:: 147 fInfo, aGenVector.data(), aGenVector.size( 148 assert(Input_noise_tensor.IsTensor()); 149 std::vector<Ort::Value> ort_inputs; 150 ort_inputs.push_back(std::move(Input_noise_t 151 // run the inference session 152 std::vector<Ort::Value> ort_outputs = 153 fSession->Run(Ort::RunOptions{nullptr}, fI 154 output_node_names.data(), ou 155 // get pointer to output tensor float values 156 float* floatarr = ort_outputs.front().GetTen 157 aEnergies.assign(aSize, 0); 158 for (int i = 0; i < aSize; ++i) 159 aEnergies[i] = floatarr[i]; 160 } 161 162 #endif 163