Geant4 Cross Reference |
1 // 1 // 2 // ******************************************* 2 // ******************************************************************** 3 // * License and Disclaimer 3 // * License and Disclaimer * 4 // * 4 // * * 5 // * The Geant4 software is copyright of th 5 // * The Geant4 software is copyright of the Copyright Holders of * 6 // * the Geant4 Collaboration. It is provided 6 // * the Geant4 Collaboration. It is provided under the terms and * 7 // * conditions of the Geant4 Software License 7 // * conditions of the Geant4 Software License, included in the file * 8 // * LICENSE and available at http://cern.ch/ 8 // * LICENSE and available at http://cern.ch/geant4/license . These * 9 // * include a list of copyright holders. 9 // * include a list of copyright holders. * 10 // * 10 // * * 11 // * Neither the authors of this software syst 11 // * Neither the authors of this software system, nor their employing * 12 // * institutes,nor the agencies providing fin 12 // * institutes,nor the agencies providing financial support for this * 13 // * work make any representation or warran 13 // * work make any representation or warranty, express or implied, * 14 // * regarding this software system or assum 14 // * regarding this software system or assume any liability for its * 15 // * use. Please see the license in the file 15 // * use. Please see the license in the file LICENSE and URL above * 16 // * for the full disclaimer and the limitatio 16 // * for the full disclaimer and the limitation of liability. * 17 // * 17 // * * 18 // * This code implementation is the result 18 // * This code implementation is the result of the scientific and * 19 // * technical work of the GEANT4 collaboratio 19 // * technical work of the GEANT4 collaboration. * 20 // * By using, copying, modifying or distri 20 // * By using, copying, modifying or distributing the software (or * 21 // * any work based on the software) you ag 21 // * any work based on the software) you agree to acknowledge its * 22 // * use in resulting scientific publicati 22 // * use in resulting scientific publications, and indicate your * 23 // * acceptance of all terms of the Geant4 Sof 23 // * acceptance of all terms of the Geant4 Software license. * 24 // ******************************************* 24 // ******************************************************************** 25 // 25 // 26 #ifdef USE_INFERENCE 26 #ifdef USE_INFERENCE 27 # include "Par04InferenceSetup.hh" << 27 #include "Par04InferenceSetup.hh" 28 << 28 #include "Par04InferenceInterface.hh" // for Par04InferenceInterface 29 # include "Par04InferenceInterface.hh" // fo << 29 #include "Par04InferenceMessenger.hh" // for Par04InferenceMessenger 30 # include "Par04InferenceMessenger.hh" // fo << 30 #ifdef USE_INFERENCE_ONNX 31 # ifdef USE_INFERENCE_ONNX << 31 #include "Par04OnnxInference.hh" // for Par04OnnxInference 32 # include "Par04OnnxInference.hh" // for P << 32 #endif 33 # endif << 33 #ifdef USE_INFERENCE_LWTNN 34 # ifdef USE_INFERENCE_LWTNN << 34 #include "Par04LwtnnInference.hh" // for Par04LwtnnInference 35 # include "Par04LwtnnInference.hh" // for << 35 #endif 36 # endif << 36 #ifdef USE_INFERENCE_TORCH 37 # ifdef USE_INFERENCE_TORCH << 37 #include "Par04TorchInference.hh" // for Par04TorchInference 38 # include "Par04TorchInference.hh" // for << 38 #endif 39 # endif << 39 #include <CLHEP/Units/SystemOfUnits.h> // for pi, GeV, deg 40 # include "CLHEP/Random/RandGauss.h" // for << 40 #include <CLHEP/Vector/Rotation.h> // for HepRotation 41 << 41 #include <CLHEP/Vector/ThreeVector.h> // for Hep3Vector 42 # include "G4RotationMatrix.hh" // for G4Rot << 42 #include <G4Exception.hh> // for G4Exception 43 << 43 #include <G4ExceptionSeverity.hh> // for FatalException 44 # include <CLHEP/Units/SystemOfUnits.h> // f << 44 #include <G4ThreeVector.hh> // for G4ThreeVector 45 # include <CLHEP/Vector/Rotation.h> // for H << 45 #include "CLHEP/Random/RandGauss.h" // for RandGauss 46 # include <CLHEP/Vector/ThreeVector.h> // fo << 46 #include "G4RotationMatrix.hh" // for G4RotationMatrix 47 # include <G4Exception.hh> // for G4Exceptio << 47 #include <algorithm> // for max, copy 48 # include <G4ExceptionSeverity.hh> // for Fa << 48 #include <string> // for char_traits, basic_string 49 # include <G4ThreeVector.hh> // for G4ThreeV << 49 #include <ext/alloc_traits.h> // for __alloc_traits<>::value_type 50 # include <algorithm> // for max, copy << 50 #include <cmath> // for cos, sin 51 # include <cmath> // for cos, sin << 52 # include <string> // for char_traits, basic << 53 << 54 # include <ext/alloc_traits.h> // for __allo << 55 51 56 //....oooOO0OOooo........oooOO0OOooo........oo 52 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo...... 57 53 58 Par04InferenceSetup::Par04InferenceSetup() : f << 54 Par04InferenceSetup::Par04InferenceSetup() >> 55 : fInferenceMessenger(new Par04InferenceMessenger(this)) 59 {} 56 {} 60 57 61 //....oooOO0OOooo........oooOO0OOooo........oo 58 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo...... 62 59 63 Par04InferenceSetup::~Par04InferenceSetup() {} 60 Par04InferenceSetup::~Par04InferenceSetup() {} 64 61 65 //....oooOO0OOooo........oooOO0OOooo........oo 62 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo...... 66 63 67 G4bool Par04InferenceSetup::IfTrigger(G4double 64 G4bool Par04InferenceSetup::IfTrigger(G4double aEnergy) 68 { 65 { 69 /// Energy of electrons used in training dat 66 /// Energy of electrons used in training dataset 70 if (aEnergy > 1 * CLHEP::GeV || aEnergy < 10 << 67 if(aEnergy > 1 * CLHEP::GeV || aEnergy < 1024 * CLHEP::GeV) >> 68 return true; 71 return false; 69 return false; 72 } 70 } 73 71 74 //....oooOO0OOooo........oooOO0OOooo........oo 72 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo...... 75 73 76 void Par04InferenceSetup::SetInferenceLibrary( 74 void Par04InferenceSetup::SetInferenceLibrary(G4String aName) 77 { 75 { 78 fInferenceLibrary = aName; 76 fInferenceLibrary = aName; 79 77 80 # ifdef USE_INFERENCE_ONNX << 78 #ifdef USE_INFERENCE_ONNX 81 if (fInferenceLibrary == "ONNX") << 79 if(fInferenceLibrary == "ONNX") 82 fInferenceInterface = std::unique_ptr<Par0 << 80 fInferenceInterface = std::unique_ptr<Par04InferenceInterface>( 83 fModelPathName, fProfileFlag, fOptimizat << 81 new Par04OnnxInference(fModelPathName, fProfileFlag, fOptimizationFlag, fIntraOpNumThreads)); 84 cuda_values, fModelSavePath, fProfilingO << 82 #endif 85 # endif << 83 #ifdef USE_INFERENCE_LWTNN 86 # ifdef USE_INFERENCE_LWTNN << 84 if(fInferenceLibrary == "LWTNN") 87 if (fInferenceLibrary == "LWTNN") << 88 fInferenceInterface = 85 fInferenceInterface = 89 std::unique_ptr<Par04InferenceInterface> 86 std::unique_ptr<Par04InferenceInterface>(new Par04LwtnnInference(fModelPathName)); 90 # endif << 87 #endif 91 # ifdef USE_INFERENCE_TORCH << 88 #ifdef USE_INFERENCE_TORCH 92 if (fInferenceLibrary == "TORCH") << 89 if(fInferenceLibrary == "TORCH") 93 fInferenceInterface = 90 fInferenceInterface = 94 std::unique_ptr<Par04InferenceInterface> 91 std::unique_ptr<Par04InferenceInterface>(new Par04TorchInference(fModelPathName)); 95 # endif << 92 #endif 96 93 97 CheckInferenceLibrary(); 94 CheckInferenceLibrary(); 98 } 95 } 99 96 100 //....oooOO0OOooo........oooOO0OOooo........oo 97 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo...... 101 98 102 void Par04InferenceSetup::CheckInferenceLibrar 99 void Par04InferenceSetup::CheckInferenceLibrary() 103 { 100 { 104 G4String msg = "Please choose inference libr 101 G4String msg = "Please choose inference library from available libraries ("; 105 # ifdef USE_INFERENCE_ONNX << 102 #ifdef USE_INFERENCE_ONNX 106 msg += "ONNX,"; 103 msg += "ONNX,"; 107 # endif << 104 #endif 108 # ifdef USE_INFERENCE_LWTNN << 105 #ifdef USE_INFERENCE_LWTNN 109 msg += "LWTNN,"; 106 msg += "LWTNN,"; 110 # endif << 107 #endif 111 # ifdef USE_INFERENCE_TORCH << 108 #ifdef USE_INFERENCE_TORCH 112 msg += "TORCH"; 109 msg += "TORCH"; 113 # endif << 110 #endif 114 if (fInferenceInterface == nullptr) << 111 if(fInferenceInterface == nullptr) 115 G4Exception("Par04InferenceSetup::CheckInf 112 G4Exception("Par04InferenceSetup::CheckInferenceLibrary()", "InvalidSetup", FatalException, 116 (msg + "). Current name: " + f 113 (msg + "). Current name: " + fInferenceLibrary).c_str()); 117 } 114 } 118 115 119 //....oooOO0OOooo........oooOO0OOooo........oo 116 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo...... 120 117 121 void Par04InferenceSetup::GetEnergies(std::vec 118 void Par04InferenceSetup::GetEnergies(std::vector<G4double>& aEnergies, G4double aInitialEnergy, 122 G4float 119 G4float aInitialAngle) 123 { 120 { 124 // First check if inference library was set 121 // First check if inference library was set correctly 125 CheckInferenceLibrary(); 122 CheckInferenceLibrary(); 126 // size represents the size of the output ve 123 // size represents the size of the output vector 127 int size = fMeshNumber.x() * fMeshNumber.y() 124 int size = fMeshNumber.x() * fMeshNumber.y() * fMeshNumber.z(); 128 125 129 // randomly sample from a gaussian distribut 126 // randomly sample from a gaussian distribution in the latent space 130 std::vector<G4float> genVector(fSizeLatentVe 127 std::vector<G4float> genVector(fSizeLatentVector + fSizeConditionVector, 0); 131 for (int i = 0; i < fSizeLatentVector; ++i) << 128 for(int i = 0; i < fSizeLatentVector; ++i) >> 129 { 132 genVector[i] = CLHEP::RandGauss::shoot(0., 130 genVector[i] = CLHEP::RandGauss::shoot(0., 1.); 133 } 131 } 134 132 135 // Vector of condition 133 // Vector of condition 136 // this is application specific it depdens o 134 // this is application specific it depdens on what the model was condition on 137 // and it depends on how the condition value << 135 // and it depends on how the condition values were encoded at the training time 138 // time in this example the energy of each p << 136 // in this example the energy of each particle is normlaized to the highest 139 // highest energy in the considered range (1 << 137 // energy in the considered range (1GeV-500GeV) 140 // normlaized to the highest angle in the co << 138 // the angle is also is normlaized to the highest angle in the considered range >> 139 // (0-90 in dergrees) 141 // the model in this example was trained on 140 // the model in this example was trained on two detector geometries PBW04 142 // and SiW a one hot encoding vector is use 141 // and SiW a one hot encoding vector is used to represent the geometry with 143 // [0,1] for PBW04 and [1,0] for SiW 142 // [0,1] for PBW04 and [1,0] for SiW 144 // 1. energy << 143 // 1.energy 145 genVector[fSizeLatentVector] = aInitialEnerg 144 genVector[fSizeLatentVector] = aInitialEnergy / fMaxEnergy; 146 // 2. angle 145 // 2. angle 147 genVector[fSizeLatentVector + 1] = (aInitial 146 genVector[fSizeLatentVector + 1] = (aInitialAngle / (CLHEP::deg)) / fMaxAngle; 148 // 3. geometry << 147 // 3.geometry 149 genVector[fSizeLatentVector + 2] = 0; 148 genVector[fSizeLatentVector + 2] = 0; 150 genVector[fSizeLatentVector + 3] = 1; 149 genVector[fSizeLatentVector + 3] = 1; 151 150 152 // Run the inference 151 // Run the inference 153 fInferenceInterface->RunInference(genVector, 152 fInferenceInterface->RunInference(genVector, aEnergies, size); 154 153 155 // After the inference rescale back to the i 154 // After the inference rescale back to the initial energy (in this example the 156 // energies of cells were normalized to the 155 // energies of cells were normalized to the energy of the particle) 157 for (int i = 0; i < size; ++i) { << 156 for(int i = 0; i < size; ++i) >> 157 { 158 aEnergies[i] = aEnergies[i] * aInitialEner 158 aEnergies[i] = aEnergies[i] * aInitialEnergy; 159 } 159 } 160 } 160 } 161 161 162 //....oooOO0OOooo........oooOO0OOooo........oo 162 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo...... 163 163 164 void Par04InferenceSetup::GetPositions(std::ve 164 void Par04InferenceSetup::GetPositions(std::vector<G4ThreeVector>& aPositions, G4ThreeVector pos0, 165 G4Three 165 G4ThreeVector direction) 166 { 166 { 167 aPositions.resize(fMeshNumber.x() * fMeshNum 167 aPositions.resize(fMeshNumber.x() * fMeshNumber.y() * fMeshNumber.z()); 168 168 169 // Calculate rotation matrix along the parti 169 // Calculate rotation matrix along the particle momentum direction 170 // It will rotate the shower axes to match t 170 // It will rotate the shower axes to match the incoming particle direction 171 G4RotationMatrix rotMatrix = G4RotationMatri 171 G4RotationMatrix rotMatrix = G4RotationMatrix(); 172 double particleTheta = direction.theta(); << 172 double particleTheta = direction.theta(); 173 double particlePhi = direction.phi(); << 173 double particlePhi = direction.phi(); 174 rotMatrix.rotateZ(-particlePhi); 174 rotMatrix.rotateZ(-particlePhi); 175 rotMatrix.rotateY(-particleTheta); 175 rotMatrix.rotateY(-particleTheta); 176 G4RotationMatrix rotMatrixInv = CLHEP::inver 176 G4RotationMatrix rotMatrixInv = CLHEP::inverseOf(rotMatrix); 177 177 178 int cpt = 0; 178 int cpt = 0; 179 for (G4int iCellR = 0; iCellR < fMeshNumber. << 179 for(G4int iCellR = 0; iCellR < fMeshNumber.x(); iCellR++) 180 for (G4int iCellPhi = 0; iCellPhi < fMeshN << 180 { 181 for (G4int iCellZ = 0; iCellZ < fMeshNum << 181 for(G4int iCellPhi = 0; iCellPhi < fMeshNumber.y(); iCellPhi++) >> 182 { >> 183 for(G4int iCellZ = 0; iCellZ < fMeshNumber.z(); iCellZ++) >> 184 { 182 aPositions[cpt] = 185 aPositions[cpt] = 183 pos0 << 186 pos0 + 184 + rotMatrixInv << 187 rotMatrixInv * 185 * G4ThreeVector( << 188 G4ThreeVector((iCellR + 0.5) * fMeshSize.x() * 186 (iCellR + 0.5) * fMeshSize.x() << 189 std::cos((iCellPhi + 0.5) * 2 * CLHEP::pi / fMeshNumber.y() - CLHEP::pi), 187 * std::cos((iCellPhi + 0.5) << 190 (iCellR + 0.5) * fMeshSize.x() * 188 (iCellR + 0.5) * fMeshSize.x() << 191 std::sin((iCellPhi + 0.5) * 2 * CLHEP::pi / fMeshNumber.y() - CLHEP::pi), 189 * std::sin((iCellPhi + 0.5) << 192 (iCellZ + 0.5) * fMeshSize.z()); 190 (iCellZ + 0.5) * fMeshSize.z() << 191 cpt++; 193 cpt++; 192 } 194 } 193 } 195 } 194 } 196 } 195 } 197 } 196 198 197 #endif 199 #endif 198 200