Geant4 Cross Reference |
1 // 2 // ******************************************************************** 3 // * License and Disclaimer * 4 // * * 5 // * The Geant4 software is copyright of the Copyright Holders of * 6 // * the Geant4 Collaboration. It is provided under the terms and * 7 // * conditions of the Geant4 Software License, included in the file * 8 // * LICENSE and available at http://cern.ch/geant4/license . These * 9 // * include a list of copyright holders. * 10 // * * 11 // * Neither the authors of this software system, nor their employing * 12 // * institutes,nor the agencies providing financial support for this * 13 // * work make any representation or warranty, express or implied, * 14 // * regarding this software system or assume any liability for its * 15 // * use. Please see the license in the file LICENSE and URL above * 16 // * for the full disclaimer and the limitation of liability. * 17 // * * 18 // * This code implementation is the result of the scientific and * 19 // * technical work of the GEANT4 collaboration. * 20 // * By using, copying, modifying or distributing the software (or * 21 // * any work based on the software) you agree to acknowledge its * 22 // * use in resulting scientific publications, and indicate your * 23 // * acceptance of all terms of the Geant4 Software license. * 24 // ******************************************************************** 25 // 26 #ifdef USE_INFERENCE_LWTNN 27 # include "Par04LwtnnInference.hh" 28 29 # include "Par04InferenceInterface.hh" // for Par04InferenceInterface 30 31 # include <fstream> // for ifstream 32 33 # include <lwtnn/parse_json.hh> // for parse_json_graph 34 35 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo...... 36 37 Par04LwtnnInference::Par04LwtnnInference(G4String modelPath) : Par04InferenceInterface() 38 { 39 // file to read 40 std::ifstream input(modelPath); 41 // build the graph 42 fGraph = std::make_unique<lwt::LightweightGraph>(lwt::parse_json_graph(input)); 43 input.close(); 44 } 45 46 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo...... 47 48 void Par04LwtnnInference::RunInference(std::vector<float> aGenVector, 49 std::vector<G4double>& aEnergies, int aSize) 50 { 51 // generation vector 52 fNetworkInputs inputs; 53 for (std::size_t i = 0; i < aGenVector.size(); ++i) { 54 inputs["node_0"]["variable_" + std::to_string(i)] = aGenVector[i]; 55 } 56 57 // run the inference 58 fNetworkOutputs outputs = fGraph->compute(inputs); 59 aEnergies.assign(aSize, 0); 60 for (int i = 0; i < aSize; i++) 61 aEnergies[i] = outputs["out_" + std::to_string(i)]; 62 } 63 64 #endif 65