Geant4 Cross Reference

Cross-Referencing   Geant4
Geant4/examples/extended/parameterisations/Par04/src/Par04InferenceSetup.cc

Version: [ ReleaseNotes ] [ 1.0 ] [ 1.1 ] [ 2.0 ] [ 3.0 ] [ 3.1 ] [ 3.2 ] [ 4.0 ] [ 4.0.p1 ] [ 4.0.p2 ] [ 4.1 ] [ 4.1.p1 ] [ 5.0 ] [ 5.0.p1 ] [ 5.1 ] [ 5.1.p1 ] [ 5.2 ] [ 5.2.p1 ] [ 5.2.p2 ] [ 6.0 ] [ 6.0.p1 ] [ 6.1 ] [ 6.2 ] [ 6.2.p1 ] [ 6.2.p2 ] [ 7.0 ] [ 7.0.p1 ] [ 7.1 ] [ 7.1.p1 ] [ 8.0 ] [ 8.0.p1 ] [ 8.1 ] [ 8.1.p1 ] [ 8.1.p2 ] [ 8.2 ] [ 8.2.p1 ] [ 8.3 ] [ 8.3.p1 ] [ 8.3.p2 ] [ 9.0 ] [ 9.0.p1 ] [ 9.0.p2 ] [ 9.1 ] [ 9.1.p1 ] [ 9.1.p2 ] [ 9.1.p3 ] [ 9.2 ] [ 9.2.p1 ] [ 9.2.p2 ] [ 9.2.p3 ] [ 9.2.p4 ] [ 9.3 ] [ 9.3.p1 ] [ 9.3.p2 ] [ 9.4 ] [ 9.4.p1 ] [ 9.4.p2 ] [ 9.4.p3 ] [ 9.4.p4 ] [ 9.5 ] [ 9.5.p1 ] [ 9.5.p2 ] [ 9.6 ] [ 9.6.p1 ] [ 9.6.p2 ] [ 9.6.p3 ] [ 9.6.p4 ] [ 10.0 ] [ 10.0.p1 ] [ 10.0.p2 ] [ 10.0.p3 ] [ 10.0.p4 ] [ 10.1 ] [ 10.1.p1 ] [ 10.1.p2 ] [ 10.1.p3 ] [ 10.2 ] [ 10.2.p1 ] [ 10.2.p2 ] [ 10.2.p3 ] [ 10.3 ] [ 10.3.p1 ] [ 10.3.p2 ] [ 10.3.p3 ] [ 10.4 ] [ 10.4.p1 ] [ 10.4.p2 ] [ 10.4.p3 ] [ 10.5 ] [ 10.5.p1 ] [ 10.6 ] [ 10.6.p1 ] [ 10.6.p2 ] [ 10.6.p3 ] [ 10.7 ] [ 10.7.p1 ] [ 10.7.p2 ] [ 10.7.p3 ] [ 10.7.p4 ] [ 11.0 ] [ 11.0.p1 ] [ 11.0.p2 ] [ 11.0.p3, ] [ 11.0.p4 ] [ 11.1 ] [ 11.1.1 ] [ 11.1.2 ] [ 11.1.3 ] [ 11.2 ] [ 11.2.1 ] [ 11.2.2 ] [ 11.3.0 ]

  1 //
  2 // ********************************************************************
  3 // * License and Disclaimer                                           *
  4 // *                                                                  *
  5 // * The  Geant4 software  is  copyright of the Copyright Holders  of *
  6 // * the Geant4 Collaboration.  It is provided  under  the terms  and *
  7 // * conditions of the Geant4 Software License,  included in the file *
  8 // * LICENSE and available at  http://cern.ch/geant4/license .  These *
  9 // * include a list of copyright holders.                             *
 10 // *                                                                  *
 11 // * Neither the authors of this software system, nor their employing *
 12 // * institutes,nor the agencies providing financial support for this *
 13 // * work  make  any representation or  warranty, express or implied, *
 14 // * regarding  this  software system or assume any liability for its *
 15 // * use.  Please see the license in the file  LICENSE  and URL above *
 16 // * for the full disclaimer and the limitation of liability.         *
 17 // *                                                                  *
 18 // * This  code  implementation is the result of  the  scientific and *
 19 // * technical work of the GEANT4 collaboration.                      *
 20 // * By using,  copying,  modifying or  distributing the software (or *
 21 // * any work based  on the software)  you  agree  to acknowledge its *
 22 // * use  in  resulting  scientific  publications,  and indicate your *
 23 // * acceptance of all terms of the Geant4 Software license.          *
 24 // ********************************************************************
 25 //
 26 #ifdef USE_INFERENCE
 27 #  include "Par04InferenceSetup.hh"
 28 
 29 #  include "Par04InferenceInterface.hh"  // for Par04InferenceInterface
 30 #  include "Par04InferenceMessenger.hh"  // for Par04InferenceMessenger
 31 #  ifdef USE_INFERENCE_ONNX
 32 #    include "Par04OnnxInference.hh"  // for Par04OnnxInference
 33 #  endif
 34 #  ifdef USE_INFERENCE_LWTNN
 35 #    include "Par04LwtnnInference.hh"  // for Par04LwtnnInference
 36 #  endif
 37 #  ifdef USE_INFERENCE_TORCH
 38 #    include "Par04TorchInference.hh"  // for Par04TorchInference
 39 #  endif
 40 #  include "CLHEP/Random/RandGauss.h"  // for RandGauss
 41 
 42 #  include "G4RotationMatrix.hh"  // for G4RotationMatrix
 43 
 44 #  include <CLHEP/Units/SystemOfUnits.h>  // for pi, GeV, deg
 45 #  include <CLHEP/Vector/Rotation.h>  // for HepRotation
 46 #  include <CLHEP/Vector/ThreeVector.h>  // for Hep3Vector
 47 #  include <G4Exception.hh>  // for G4Exception
 48 #  include <G4ExceptionSeverity.hh>  // for FatalException
 49 #  include <G4ThreeVector.hh>  // for G4ThreeVector
 50 #  include <algorithm>  // for max, copy
 51 #  include <cmath>  // for cos, sin
 52 #  include <string>  // for char_traits, basic_string
 53 
 54 #  include <ext/alloc_traits.h>  // for __alloc_traits<>::value_type
 55 
 56 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
 57 
 58 Par04InferenceSetup::Par04InferenceSetup() : fInferenceMessenger(new Par04InferenceMessenger(this))
 59 {}
 60 
 61 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
 62 
 63 Par04InferenceSetup::~Par04InferenceSetup() {}
 64 
 65 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
 66 
 67 G4bool Par04InferenceSetup::IfTrigger(G4double aEnergy)
 68 {
 69   /// Energy of electrons used in training dataset
 70   if (aEnergy > 1 * CLHEP::GeV || aEnergy < 1024 * CLHEP::GeV) return true;
 71   return false;
 72 }
 73 
 74 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
 75 
 76 void Par04InferenceSetup::SetInferenceLibrary(G4String aName)
 77 {
 78   fInferenceLibrary = aName;
 79 
 80 #  ifdef USE_INFERENCE_ONNX
 81   if (fInferenceLibrary == "ONNX")
 82     fInferenceInterface = std::unique_ptr<Par04InferenceInterface>(new Par04OnnxInference(
 83       fModelPathName, fProfileFlag, fOptimizationFlag, fIntraOpNumThreads, fCudaFlag, cuda_keys,
 84       cuda_values, fModelSavePath, fProfilingOutputSavePath));
 85 #  endif
 86 #  ifdef USE_INFERENCE_LWTNN
 87   if (fInferenceLibrary == "LWTNN")
 88     fInferenceInterface =
 89       std::unique_ptr<Par04InferenceInterface>(new Par04LwtnnInference(fModelPathName));
 90 #  endif
 91 #  ifdef USE_INFERENCE_TORCH
 92   if (fInferenceLibrary == "TORCH")
 93     fInferenceInterface =
 94       std::unique_ptr<Par04InferenceInterface>(new Par04TorchInference(fModelPathName));
 95 #  endif
 96 
 97   CheckInferenceLibrary();
 98 }
 99 
100 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
101 
102 void Par04InferenceSetup::CheckInferenceLibrary()
103 {
104   G4String msg = "Please choose inference library from available libraries (";
105 #  ifdef USE_INFERENCE_ONNX
106   msg += "ONNX,";
107 #  endif
108 #  ifdef USE_INFERENCE_LWTNN
109   msg += "LWTNN,";
110 #  endif
111 #  ifdef USE_INFERENCE_TORCH
112   msg += "TORCH";
113 #  endif
114   if (fInferenceInterface == nullptr)
115     G4Exception("Par04InferenceSetup::CheckInferenceLibrary()", "InvalidSetup", FatalException,
116                 (msg + "). Current name: " + fInferenceLibrary).c_str());
117 }
118 
119 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
120 
121 void Par04InferenceSetup::GetEnergies(std::vector<G4double>& aEnergies, G4double aInitialEnergy,
122                                       G4float aInitialAngle)
123 {
124   // First check if inference library was set correctly
125   CheckInferenceLibrary();
126   // size represents the size of the output vector
127   int size = fMeshNumber.x() * fMeshNumber.y() * fMeshNumber.z();
128 
129   // randomly sample from a gaussian distribution in the latent space
130   std::vector<G4float> genVector(fSizeLatentVector + fSizeConditionVector, 0);
131   for (int i = 0; i < fSizeLatentVector; ++i) {
132     genVector[i] = CLHEP::RandGauss::shoot(0., 1.);
133   }
134 
135   // Vector of condition
136   // this is application specific it depdens on what the model was condition on
137   // and it depends on how the condition values were encoded at the training
138   // time in this example the energy of each particle is normlaized to the
139   // highest energy in the considered range (1GeV-500GeV) the angle is also is
140   // normlaized to the highest angle in the considered range (0-90 in dergrees)
141   // the model in this example was trained on two detector geometries PBW04
142   // and SiW  a one hot encoding vector is used to represent the geometry with
143   // [0,1] for PBW04 and [1,0] for SiW
144   // 1. energy
145   genVector[fSizeLatentVector] = aInitialEnergy / fMaxEnergy;
146   // 2. angle
147   genVector[fSizeLatentVector + 1] = (aInitialAngle / (CLHEP::deg)) / fMaxAngle;
148   // 3. geometry
149   genVector[fSizeLatentVector + 2] = 0;
150   genVector[fSizeLatentVector + 3] = 1;
151 
152   // Run the inference
153   fInferenceInterface->RunInference(genVector, aEnergies, size);
154 
155   // After the inference rescale back to the initial energy (in this example the
156   // energies of cells were normalized to the energy of the particle)
157   for (int i = 0; i < size; ++i) {
158     aEnergies[i] = aEnergies[i] * aInitialEnergy;
159   }
160 }
161 
162 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
163 
164 void Par04InferenceSetup::GetPositions(std::vector<G4ThreeVector>& aPositions, G4ThreeVector pos0,
165                                        G4ThreeVector direction)
166 {
167   aPositions.resize(fMeshNumber.x() * fMeshNumber.y() * fMeshNumber.z());
168 
169   // Calculate rotation matrix along the particle momentum direction
170   // It will rotate the shower axes to match the incoming particle direction
171   G4RotationMatrix rotMatrix = G4RotationMatrix();
172   double particleTheta = direction.theta();
173   double particlePhi = direction.phi();
174   rotMatrix.rotateZ(-particlePhi);
175   rotMatrix.rotateY(-particleTheta);
176   G4RotationMatrix rotMatrixInv = CLHEP::inverseOf(rotMatrix);
177 
178   int cpt = 0;
179   for (G4int iCellR = 0; iCellR < fMeshNumber.x(); iCellR++) {
180     for (G4int iCellPhi = 0; iCellPhi < fMeshNumber.y(); iCellPhi++) {
181       for (G4int iCellZ = 0; iCellZ < fMeshNumber.z(); iCellZ++) {
182         aPositions[cpt] =
183           pos0
184           + rotMatrixInv
185               * G4ThreeVector(
186                 (iCellR + 0.5) * fMeshSize.x()
187                   * std::cos((iCellPhi + 0.5) * 2 * CLHEP::pi / fMeshNumber.y() - CLHEP::pi),
188                 (iCellR + 0.5) * fMeshSize.x()
189                   * std::sin((iCellPhi + 0.5) * 2 * CLHEP::pi / fMeshNumber.y() - CLHEP::pi),
190                 (iCellZ + 0.5) * fMeshSize.z());
191         cpt++;
192       }
193     }
194   }
195 }
196 
197 #endif
198