Geant4 Cross Reference

Cross-Referencing   Geant4
Geant4/examples/extended/parameterisations/Par04/src/Par04TorchInference.cc

Version: [ ReleaseNotes ] [ 1.0 ] [ 1.1 ] [ 2.0 ] [ 3.0 ] [ 3.1 ] [ 3.2 ] [ 4.0 ] [ 4.0.p1 ] [ 4.0.p2 ] [ 4.1 ] [ 4.1.p1 ] [ 5.0 ] [ 5.0.p1 ] [ 5.1 ] [ 5.1.p1 ] [ 5.2 ] [ 5.2.p1 ] [ 5.2.p2 ] [ 6.0 ] [ 6.0.p1 ] [ 6.1 ] [ 6.2 ] [ 6.2.p1 ] [ 6.2.p2 ] [ 7.0 ] [ 7.0.p1 ] [ 7.1 ] [ 7.1.p1 ] [ 8.0 ] [ 8.0.p1 ] [ 8.1 ] [ 8.1.p1 ] [ 8.1.p2 ] [ 8.2 ] [ 8.2.p1 ] [ 8.3 ] [ 8.3.p1 ] [ 8.3.p2 ] [ 9.0 ] [ 9.0.p1 ] [ 9.0.p2 ] [ 9.1 ] [ 9.1.p1 ] [ 9.1.p2 ] [ 9.1.p3 ] [ 9.2 ] [ 9.2.p1 ] [ 9.2.p2 ] [ 9.2.p3 ] [ 9.2.p4 ] [ 9.3 ] [ 9.3.p1 ] [ 9.3.p2 ] [ 9.4 ] [ 9.4.p1 ] [ 9.4.p2 ] [ 9.4.p3 ] [ 9.4.p4 ] [ 9.5 ] [ 9.5.p1 ] [ 9.5.p2 ] [ 9.6 ] [ 9.6.p1 ] [ 9.6.p2 ] [ 9.6.p3 ] [ 9.6.p4 ] [ 10.0 ] [ 10.0.p1 ] [ 10.0.p2 ] [ 10.0.p3 ] [ 10.0.p4 ] [ 10.1 ] [ 10.1.p1 ] [ 10.1.p2 ] [ 10.1.p3 ] [ 10.2 ] [ 10.2.p1 ] [ 10.2.p2 ] [ 10.2.p3 ] [ 10.3 ] [ 10.3.p1 ] [ 10.3.p2 ] [ 10.3.p3 ] [ 10.4 ] [ 10.4.p1 ] [ 10.4.p2 ] [ 10.4.p3 ] [ 10.5 ] [ 10.5.p1 ] [ 10.6 ] [ 10.6.p1 ] [ 10.6.p2 ] [ 10.6.p3 ] [ 10.7 ] [ 10.7.p1 ] [ 10.7.p2 ] [ 10.7.p3 ] [ 10.7.p4 ] [ 11.0 ] [ 11.0.p1 ] [ 11.0.p2 ] [ 11.0.p3, ] [ 11.0.p4 ] [ 11.1 ] [ 11.1.1 ] [ 11.1.2 ] [ 11.1.3 ] [ 11.2 ] [ 11.2.1 ] [ 11.2.2 ] [ 11.3.0 ]

  1 //
  2 // ********************************************************************
  3 // * License and Disclaimer                                           *
  4 // *                                                                  *
  5 // * The  Geant4 software  is  copyright of the Copyright Holders  of *
  6 // * the Geant4 Collaboration.  It is provided  under  the terms  and *
  7 // * conditions of the Geant4 Software License,  included in the file *
  8 // * LICENSE and available at  http://cern.ch/geant4/license .  These *
  9 // * include a list of copyright holders.                             *
 10 // *                                                                  *
 11 // * Neither the authors of this software system, nor their employing *
 12 // * institutes,nor the agencies providing financial support for this *
 13 // * work  make  any representation or  warranty, express or implied, *
 14 // * regarding  this  software system or assume any liability for its *
 15 // * use.  Please see the license in the file  LICENSE  and URL above *
 16 // * for the full disclaimer and the limitation of liability.         *
 17 // *                                                                  *
 18 // * This  code  implementation is the result of  the  scientific and *
 19 // * technical work of the GEANT4 collaboration.                      *
 20 // * By using,  copying,  modifying or  distributing the software (or *
 21 // * any work based  on the software)  you  agree  to acknowledge its *
 22 // * use  in  resulting  scientific  publications,  and indicate your *
 23 // * acceptance of all terms of the Geant4 Software license.          *
 24 // ********************************************************************
 25 //
 26 
 27 #ifdef USE_INFERENCE_TORCH
 28 #  include "Par04TorchInference.hh"
 29 
 30 #  include "Par04InferenceInterface.hh"  // for Par04InferenceInterface
 31 
 32 #  include <algorithm>  // for copy, max
 33 #  include <cassert>  // for assert
 34 #  include <cstddef>  // for size_t
 35 #  include <cstdint>  // for int64_t
 36 #  include <torch/torch.h>
 37 #  include <utility>  // for move
 38 
 39 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
 40 
 41 Par04TorchInference::Par04TorchInference(G4String modelPath) : Par04InferenceInterface()
 42 {
 43   fModule = torch::jit::load(modelPath);
 44 }
 45 
 46 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
 47 
 48 void Par04TorchInference::RunInference(std::vector<float> aGenVector,
 49                                        std::vector<G4double>& aEnergies, int aSize)
 50 {
 51   // latentSize : size of the latent space
 52   // 4 is the size of the condition vector
 53   int latentSize = aGenVector.size() - 4;
 54   // split into latent and condition vectors
 55   std::vector<float> latent;
 56   for (int i = 0; i < latentSize; i++) {
 57     latent.push_back(aGenVector[i]);
 58   }
 59   std::vector<float> energy;
 60   energy.push_back(aGenVector[latentSize + 1]);
 61   std::vector<float> angle;
 62   energy.push_back(aGenVector[latentSize + 2]);
 63   std::vector<float> geo;
 64   for (int i = latentSize + 2; i < latentSize + 4; i++) {
 65     geo.push_back(aGenVector[i]);
 66   }
 67 
 68   // convert vectors to tensors
 69   torch::Tensor latentVector = torch::tensor(latent);
 70   torch::Tensor eTensor = torch::tensor(energy);
 71   torch::Tensor angleTensor = torch::tensor(angle);
 72   torch::Tensor geoTensor = torch::tensor(geo);
 73 
 74   std::vector<torch::jit::IValue> genInput;
 75 
 76   genInput.push_back(latentVector);
 77   genInput.push_back(eTensor);
 78   genInput.push_back(angleTensor);
 79   genInput.push_back(geoTensor);
 80 
 81   at::Tensor outTensor = fModule.forward(genInput).toTensor().contiguous();
 82 
 83   std::vector<G4double> output(outTensor.data_ptr<float>(),
 84                                outTensor.data_ptr<float>() + outTensor.numel());
 85 
 86   aEnergies.assign(aSize, 0);
 87   for (int i = 0; i < aSize; i++) {
 88     aEnergies[i] = output[i];
 89   }
 90 }
 91 
 92 #endif
 93