Geant4 Cross Reference |
1 // 2 // ******************************************************************** 3 // * License and Disclaimer * 4 // * * 5 // * The Geant4 software is copyright of the Copyright Holders of * 6 // * the Geant4 Collaboration. It is provided under the terms and * 7 // * conditions of the Geant4 Software License, included in the file * 8 // * LICENSE and available at http://cern.ch/geant4/license . These * 9 // * include a list of copyright holders. * 10 // * * 11 // * Neither the authors of this software system, nor their employing * 12 // * institutes,nor the agencies providing financial support for this * 13 // * work make any representation or warranty, express or implied, * 14 // * regarding this software system or assume any liability for its * 15 // * use. Please see the license in the file LICENSE and URL above * 16 // * for the full disclaimer and the limitation of liability. * 17 // * * 18 // * This code implementation is the result of the scientific and * 19 // * technical work of the GEANT4 collaboration. * 20 // * By using, copying, modifying or distributing the software (or * 21 // * any work based on the software) you agree to acknowledge its * 22 // * use in resulting scientific publications, and indicate your * 23 // * acceptance of all terms of the Geant4 Software license. * 24 // ******************************************************************** 25 // 26 27 #ifdef USE_INFERENCE_TORCH 28 # include "Par04TorchInference.hh" 29 30 # include "Par04InferenceInterface.hh" // for Par04InferenceInterface 31 32 # include <algorithm> // for copy, max 33 # include <cassert> // for assert 34 # include <cstddef> // for size_t 35 # include <cstdint> // for int64_t 36 # include <torch/torch.h> 37 # include <utility> // for move 38 39 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo...... 40 41 Par04TorchInference::Par04TorchInference(G4String modelPath) : Par04InferenceInterface() 42 { 43 fModule = torch::jit::load(modelPath); 44 } 45 46 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo...... 47 48 void Par04TorchInference::RunInference(std::vector<float> aGenVector, 49 std::vector<G4double>& aEnergies, int aSize) 50 { 51 // latentSize : size of the latent space 52 // 4 is the size of the condition vector 53 int latentSize = aGenVector.size() - 4; 54 // split into latent and condition vectors 55 std::vector<float> latent; 56 for (int i = 0; i < latentSize; i++) { 57 latent.push_back(aGenVector[i]); 58 } 59 std::vector<float> energy; 60 energy.push_back(aGenVector[latentSize + 1]); 61 std::vector<float> angle; 62 energy.push_back(aGenVector[latentSize + 2]); 63 std::vector<float> geo; 64 for (int i = latentSize + 2; i < latentSize + 4; i++) { 65 geo.push_back(aGenVector[i]); 66 } 67 68 // convert vectors to tensors 69 torch::Tensor latentVector = torch::tensor(latent); 70 torch::Tensor eTensor = torch::tensor(energy); 71 torch::Tensor angleTensor = torch::tensor(angle); 72 torch::Tensor geoTensor = torch::tensor(geo); 73 74 std::vector<torch::jit::IValue> genInput; 75 76 genInput.push_back(latentVector); 77 genInput.push_back(eTensor); 78 genInput.push_back(angleTensor); 79 genInput.push_back(geoTensor); 80 81 at::Tensor outTensor = fModule.forward(genInput).toTensor().contiguous(); 82 83 std::vector<G4double> output(outTensor.data_ptr<float>(), 84 outTensor.data_ptr<float>() + outTensor.numel()); 85 86 aEnergies.assign(aSize, 0); 87 for (int i = 0; i < aSize; i++) { 88 aEnergies[i] = output[i]; 89 } 90 } 91 92 #endif 93