Geant4 Cross Reference |
1 // 2 // ******************************************************************** 3 // * License and Disclaimer * 4 // * * 5 // * The Geant4 software is copyright of the Copyright Holders of * 6 // * the Geant4 Collaboration. It is provided under the terms and * 7 // * conditions of the Geant4 Software License, included in the file * 8 // * LICENSE and available at http://cern.ch/geant4/license . These * 9 // * include a list of copyright holders. * 10 // * * 11 // * Neither the authors of this software system, nor their employing * 12 // * institutes,nor the agencies providing financial support for this * 13 // * work make any representation or warranty, express or implied, * 14 // * regarding this software system or assume any liability for its * 15 // * use. Please see the license in the file LICENSE and URL above * 16 // * for the full disclaimer and the limitation of liability. * 17 // * * 18 // * This code implementation is the result of the scientific and * 19 // * technical work of the GEANT4 collaboration. * 20 // * By using, copying, modifying or distributing the software (or * 21 // * any work based on the software) you agree to acknowledge its * 22 // * use in resulting scientific publications, and indicate your * 23 // * acceptance of all terms of the Geant4 Software license. * 24 // ******************************************************************** 25 // 26 #ifdef USE_INFERENCE 27 # include "Par04InferenceSetup.hh" 28 29 # include "Par04InferenceInterface.hh" // for Par04InferenceInterface 30 # include "Par04InferenceMessenger.hh" // for Par04InferenceMessenger 31 # ifdef USE_INFERENCE_ONNX 32 # include "Par04OnnxInference.hh" // for Par04OnnxInference 33 # endif 34 # ifdef USE_INFERENCE_LWTNN 35 # include "Par04LwtnnInference.hh" // for Par04LwtnnInference 36 # endif 37 # ifdef USE_INFERENCE_TORCH 38 # include "Par04TorchInference.hh" // for Par04TorchInference 39 # endif 40 # include "CLHEP/Random/RandGauss.h" // for RandGauss 41 42 # include "G4RotationMatrix.hh" // for G4RotationMatrix 43 44 # include <CLHEP/Units/SystemOfUnits.h> // for pi, GeV, deg 45 # include <CLHEP/Vector/Rotation.h> // for HepRotation 46 # include <CLHEP/Vector/ThreeVector.h> // for Hep3Vector 47 # include <G4Exception.hh> // for G4Exception 48 # include <G4ExceptionSeverity.hh> // for FatalException 49 # include <G4ThreeVector.hh> // for G4ThreeVector 50 # include <algorithm> // for max, copy 51 # include <cmath> // for cos, sin 52 # include <string> // for char_traits, basic_string 53 54 # include <ext/alloc_traits.h> // for __alloc_traits<>::value_type 55 56 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo...... 57 58 Par04InferenceSetup::Par04InferenceSetup() : fInferenceMessenger(new Par04InferenceMessenger(this)) 59 {} 60 61 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo...... 62 63 Par04InferenceSetup::~Par04InferenceSetup() {} 64 65 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo...... 66 67 G4bool Par04InferenceSetup::IfTrigger(G4double aEnergy) 68 { 69 /// Energy of electrons used in training dataset 70 if (aEnergy > 1 * CLHEP::GeV || aEnergy < 1024 * CLHEP::GeV) return true; 71 return false; 72 } 73 74 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo...... 75 76 void Par04InferenceSetup::SetInferenceLibrary(G4String aName) 77 { 78 fInferenceLibrary = aName; 79 80 # ifdef USE_INFERENCE_ONNX 81 if (fInferenceLibrary == "ONNX") 82 fInferenceInterface = std::unique_ptr<Par04InferenceInterface>(new Par04OnnxInference( 83 fModelPathName, fProfileFlag, fOptimizationFlag, fIntraOpNumThreads, fCudaFlag, cuda_keys, 84 cuda_values, fModelSavePath, fProfilingOutputSavePath)); 85 # endif 86 # ifdef USE_INFERENCE_LWTNN 87 if (fInferenceLibrary == "LWTNN") 88 fInferenceInterface = 89 std::unique_ptr<Par04InferenceInterface>(new Par04LwtnnInference(fModelPathName)); 90 # endif 91 # ifdef USE_INFERENCE_TORCH 92 if (fInferenceLibrary == "TORCH") 93 fInferenceInterface = 94 std::unique_ptr<Par04InferenceInterface>(new Par04TorchInference(fModelPathName)); 95 # endif 96 97 CheckInferenceLibrary(); 98 } 99 100 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo...... 101 102 void Par04InferenceSetup::CheckInferenceLibrary() 103 { 104 G4String msg = "Please choose inference library from available libraries ("; 105 # ifdef USE_INFERENCE_ONNX 106 msg += "ONNX,"; 107 # endif 108 # ifdef USE_INFERENCE_LWTNN 109 msg += "LWTNN,"; 110 # endif 111 # ifdef USE_INFERENCE_TORCH 112 msg += "TORCH"; 113 # endif 114 if (fInferenceInterface == nullptr) 115 G4Exception("Par04InferenceSetup::CheckInferenceLibrary()", "InvalidSetup", FatalException, 116 (msg + "). Current name: " + fInferenceLibrary).c_str()); 117 } 118 119 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo...... 120 121 void Par04InferenceSetup::GetEnergies(std::vector<G4double>& aEnergies, G4double aInitialEnergy, 122 G4float aInitialAngle) 123 { 124 // First check if inference library was set correctly 125 CheckInferenceLibrary(); 126 // size represents the size of the output vector 127 int size = fMeshNumber.x() * fMeshNumber.y() * fMeshNumber.z(); 128 129 // randomly sample from a gaussian distribution in the latent space 130 std::vector<G4float> genVector(fSizeLatentVector + fSizeConditionVector, 0); 131 for (int i = 0; i < fSizeLatentVector; ++i) { 132 genVector[i] = CLHEP::RandGauss::shoot(0., 1.); 133 } 134 135 // Vector of condition 136 // this is application specific it depdens on what the model was condition on 137 // and it depends on how the condition values were encoded at the training 138 // time in this example the energy of each particle is normlaized to the 139 // highest energy in the considered range (1GeV-500GeV) the angle is also is 140 // normlaized to the highest angle in the considered range (0-90 in dergrees) 141 // the model in this example was trained on two detector geometries PBW04 142 // and SiW a one hot encoding vector is used to represent the geometry with 143 // [0,1] for PBW04 and [1,0] for SiW 144 // 1. energy 145 genVector[fSizeLatentVector] = aInitialEnergy / fMaxEnergy; 146 // 2. angle 147 genVector[fSizeLatentVector + 1] = (aInitialAngle / (CLHEP::deg)) / fMaxAngle; 148 // 3. geometry 149 genVector[fSizeLatentVector + 2] = 0; 150 genVector[fSizeLatentVector + 3] = 1; 151 152 // Run the inference 153 fInferenceInterface->RunInference(genVector, aEnergies, size); 154 155 // After the inference rescale back to the initial energy (in this example the 156 // energies of cells were normalized to the energy of the particle) 157 for (int i = 0; i < size; ++i) { 158 aEnergies[i] = aEnergies[i] * aInitialEnergy; 159 } 160 } 161 162 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo...... 163 164 void Par04InferenceSetup::GetPositions(std::vector<G4ThreeVector>& aPositions, G4ThreeVector pos0, 165 G4ThreeVector direction) 166 { 167 aPositions.resize(fMeshNumber.x() * fMeshNumber.y() * fMeshNumber.z()); 168 169 // Calculate rotation matrix along the particle momentum direction 170 // It will rotate the shower axes to match the incoming particle direction 171 G4RotationMatrix rotMatrix = G4RotationMatrix(); 172 double particleTheta = direction.theta(); 173 double particlePhi = direction.phi(); 174 rotMatrix.rotateZ(-particlePhi); 175 rotMatrix.rotateY(-particleTheta); 176 G4RotationMatrix rotMatrixInv = CLHEP::inverseOf(rotMatrix); 177 178 int cpt = 0; 179 for (G4int iCellR = 0; iCellR < fMeshNumber.x(); iCellR++) { 180 for (G4int iCellPhi = 0; iCellPhi < fMeshNumber.y(); iCellPhi++) { 181 for (G4int iCellZ = 0; iCellZ < fMeshNumber.z(); iCellZ++) { 182 aPositions[cpt] = 183 pos0 184 + rotMatrixInv 185 * G4ThreeVector( 186 (iCellR + 0.5) * fMeshSize.x() 187 * std::cos((iCellPhi + 0.5) * 2 * CLHEP::pi / fMeshNumber.y() - CLHEP::pi), 188 (iCellR + 0.5) * fMeshSize.x() 189 * std::sin((iCellPhi + 0.5) * 2 * CLHEP::pi / fMeshNumber.y() - CLHEP::pi), 190 (iCellZ + 0.5) * fMeshSize.z()); 191 cpt++; 192 } 193 } 194 } 195 } 196 197 #endif 198