Geant4 Cross Reference |
1 // 1 2 // ******************************************* 3 // * License and Disclaimer 4 // * 5 // * The Geant4 software is copyright of th 6 // * the Geant4 Collaboration. It is provided 7 // * conditions of the Geant4 Software License 8 // * LICENSE and available at http://cern.ch/ 9 // * include a list of copyright holders. 10 // * 11 // * Neither the authors of this software syst 12 // * institutes,nor the agencies providing fin 13 // * work make any representation or warran 14 // * regarding this software system or assum 15 // * use. Please see the license in the file 16 // * for the full disclaimer and the limitatio 17 // * 18 // * This code implementation is the result 19 // * technical work of the GEANT4 collaboratio 20 // * By using, copying, modifying or distri 21 // * any work based on the software) you ag 22 // * use in resulting scientific publicati 23 // * acceptance of all terms of the Geant4 Sof 24 // ******************************************* 25 // 26 27 #ifdef USE_INFERENCE_ONNX 28 # ifndef PAR04ONNXINFERENCE_HH 29 # define PAR04ONNXINFERENCE_HH 30 # include "Par04InferenceInterface.hh" // 31 # include "core/session/onnxruntime_cxx_api 32 33 # include <G4String.hh> // for G4String 34 # include <G4Types.hh> // for G4int, G4dou 35 # include <memory> // for unique_ptr 36 # include <vector> // for vector 37 38 # include <core/session/onnxruntime_c_api.h 39 40 /** 41 * @brief Inference using the ONNX runtime. 42 * 43 * Creates an enviroment whcih manages an inte 44 * inference session for the model saved as an 45 * Runs the inference in the session using the 46 * 47 **/ 48 49 class Par04OnnxInference : public Par04Inferen 50 { 51 public: 52 Par04OnnxInference(G4String, G4int, G4int, 53 G4int, // For Executio 54 std::vector<const char* 55 G4String, G4String); 56 57 Par04OnnxInference(); 58 59 /// Run inference 60 /// @param[in] aGenVector Input latent spa 61 /// @param[out] aEnergies Model output = g 62 /// @param[in] aSize Size of the output 63 void RunInference(std::vector<float> aGenV 64 65 private: 66 /// Pointer to the ONNX enviroment 67 std::unique_ptr<Ort::Env> fEnv; 68 /// Pointer to the ONNX inference session 69 std::unique_ptr<Ort::Session> fSession; 70 /// ONNX settings 71 Ort::SessionOptions fSessionOptions; 72 /// ONNX memory info 73 const OrtMemoryInfo* fInfo; 74 struct MemoryInfo; 75 /// the input names represent the names gi 76 /// when defining the model's architectur 77 /// they can also be retrieved from model. 78 std::vector<const char*> fInames; 79 }; 80 81 # endif /* PAR04ONNXINFERENCE_HH */ 82 #endif 83