Geant4 Cross Reference |
1 // 1 // 2 // ******************************************* 2 // ******************************************************************** 3 // * License and Disclaimer 3 // * License and Disclaimer * 4 // * 4 // * * 5 // * The Geant4 software is copyright of th 5 // * The Geant4 software is copyright of the Copyright Holders of * 6 // * the Geant4 Collaboration. It is provided 6 // * the Geant4 Collaboration. It is provided under the terms and * 7 // * conditions of the Geant4 Software License 7 // * conditions of the Geant4 Software License, included in the file * 8 // * LICENSE and available at http://cern.ch/ 8 // * LICENSE and available at http://cern.ch/geant4/license . These * 9 // * include a list of copyright holders. 9 // * include a list of copyright holders. * 10 // * 10 // * * 11 // * Neither the authors of this software syst 11 // * Neither the authors of this software system, nor their employing * 12 // * institutes,nor the agencies providing fin 12 // * institutes,nor the agencies providing financial support for this * 13 // * work make any representation or warran 13 // * work make any representation or warranty, express or implied, * 14 // * regarding this software system or assum 14 // * regarding this software system or assume any liability for its * 15 // * use. Please see the license in the file 15 // * use. Please see the license in the file LICENSE and URL above * 16 // * for the full disclaimer and the limitatio 16 // * for the full disclaimer and the limitation of liability. * 17 // * 17 // * * 18 // * This code implementation is the result 18 // * This code implementation is the result of the scientific and * 19 // * technical work of the GEANT4 collaboratio 19 // * technical work of the GEANT4 collaboration. * 20 // * By using, copying, modifying or distri 20 // * By using, copying, modifying or distributing the software (or * 21 // * any work based on the software) you ag 21 // * any work based on the software) you agree to acknowledge its * 22 // * use in resulting scientific publicati 22 // * use in resulting scientific publications, and indicate your * 23 // * acceptance of all terms of the Geant4 Sof 23 // * acceptance of all terms of the Geant4 Software license. * 24 // ******************************************* 24 // ******************************************************************** 25 // 25 // 26 #ifdef USE_INFERENCE 26 #ifdef USE_INFERENCE 27 # include "Par04InferenceSetup.hh" << 27 #include "Par04InferenceSetup.hh" 28 28 29 # include "Par04InferenceInterface.hh" // fo << 29 #include "Par04InferenceInterface.hh" 30 # include "Par04InferenceMessenger.hh" // fo << 30 #ifdef USE_INFERENCE_ONNX 31 # ifdef USE_INFERENCE_ONNX << 31 #include "Par04OnnxInference.hh" 32 # include "Par04OnnxInference.hh" // for P << 32 #endif 33 # endif << 33 #ifdef USE_INFERENCE_LWTNN 34 # ifdef USE_INFERENCE_LWTNN << 34 #include "Par04LwtnnInference.hh" 35 # include "Par04LwtnnInference.hh" // for << 35 #endif 36 # endif << 36 #include "G4RotationMatrix.hh" 37 # ifdef USE_INFERENCE_TORCH << 37 #include "CLHEP/Random/RandGauss.h" 38 # include "Par04TorchInference.hh" // for << 39 # endif << 40 # include "CLHEP/Random/RandGauss.h" // for << 41 << 42 # include "G4RotationMatrix.hh" // for G4Rot << 43 << 44 # include <CLHEP/Units/SystemOfUnits.h> // f << 45 # include <CLHEP/Vector/Rotation.h> // for H << 46 # include <CLHEP/Vector/ThreeVector.h> // fo << 47 # include <G4Exception.hh> // for G4Exceptio << 48 # include <G4ExceptionSeverity.hh> // for Fa << 49 # include <G4ThreeVector.hh> // for G4ThreeV << 50 # include <algorithm> // for max, copy << 51 # include <cmath> // for cos, sin << 52 # include <string> // for char_traits, basic << 53 << 54 # include <ext/alloc_traits.h> // for __allo << 55 38 56 //....oooOO0OOooo........oooOO0OOooo........oo 39 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo...... 57 40 58 Par04InferenceSetup::Par04InferenceSetup() : f << 41 Par04InferenceSetup::Par04InferenceSetup() >> 42 : fInferenceMessenger(new Par04InferenceMessenger(this)) 59 {} 43 {} 60 44 61 //....oooOO0OOooo........oooOO0OOooo........oo 45 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo...... 62 46 63 Par04InferenceSetup::~Par04InferenceSetup() {} 47 Par04InferenceSetup::~Par04InferenceSetup() {} 64 48 65 //....oooOO0OOooo........oooOO0OOooo........oo 49 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo...... 66 50 67 G4bool Par04InferenceSetup::IfTrigger(G4double 51 G4bool Par04InferenceSetup::IfTrigger(G4double aEnergy) 68 { 52 { 69 /// Energy of electrons used in training dat 53 /// Energy of electrons used in training dataset 70 if (aEnergy > 1 * CLHEP::GeV || aEnergy < 10 << 54 if(aEnergy > 1 * CLHEP::GeV || aEnergy < 1024 * CLHEP::GeV) 71 return false; << 55 return true; 72 } 56 } 73 57 74 //....oooOO0OOooo........oooOO0OOooo........oo 58 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo...... 75 59 76 void Par04InferenceSetup::SetInferenceLibrary( 60 void Par04InferenceSetup::SetInferenceLibrary(G4String aName) 77 { 61 { 78 fInferenceLibrary = aName; 62 fInferenceLibrary = aName; 79 63 80 # ifdef USE_INFERENCE_ONNX << 64 #ifdef USE_INFERENCE_ONNX 81 if (fInferenceLibrary == "ONNX") << 65 if(fInferenceLibrary == "ONNX") 82 fInferenceInterface = std::unique_ptr<Par0 << 66 fInferenceInterface = std::unique_ptr<Par04InferenceInterface>( 83 fModelPathName, fProfileFlag, fOptimizat << 67 new Par04OnnxInference(fModelPathName, fProfileFlag, fOptimizationFlag, fIntraOpNumThreads)); 84 cuda_values, fModelSavePath, fProfilingO << 68 #endif 85 # endif << 69 #ifdef USE_INFERENCE_LWTNN 86 # ifdef USE_INFERENCE_LWTNN << 70 if(fInferenceLibrary == "LWTNN") 87 if (fInferenceLibrary == "LWTNN") << 88 fInferenceInterface = 71 fInferenceInterface = 89 std::unique_ptr<Par04InferenceInterface> 72 std::unique_ptr<Par04InferenceInterface>(new Par04LwtnnInference(fModelPathName)); 90 # endif << 73 #endif 91 # ifdef USE_INFERENCE_TORCH << 92 if (fInferenceLibrary == "TORCH") << 93 fInferenceInterface = << 94 std::unique_ptr<Par04InferenceInterface> << 95 # endif << 96 << 97 CheckInferenceLibrary(); 74 CheckInferenceLibrary(); 98 } 75 } 99 76 100 //....oooOO0OOooo........oooOO0OOooo........oo 77 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo...... 101 78 102 void Par04InferenceSetup::CheckInferenceLibrar 79 void Par04InferenceSetup::CheckInferenceLibrary() 103 { 80 { 104 G4String msg = "Please choose inference libr 81 G4String msg = "Please choose inference library from available libraries ("; 105 # ifdef USE_INFERENCE_ONNX << 82 #ifdef USE_INFERENCE_ONNX 106 msg += "ONNX,"; 83 msg += "ONNX,"; 107 # endif << 84 #endif 108 # ifdef USE_INFERENCE_LWTNN << 85 #ifdef USE_INFERENCE_LWTNN 109 msg += "LWTNN,"; << 86 msg += "LWTNN"; 110 # endif << 87 #endif 111 # ifdef USE_INFERENCE_TORCH << 88 if(fInferenceInterface == nullptr) 112 msg += "TORCH"; << 113 # endif << 114 if (fInferenceInterface == nullptr) << 115 G4Exception("Par04InferenceSetup::CheckInf 89 G4Exception("Par04InferenceSetup::CheckInferenceLibrary()", "InvalidSetup", FatalException, 116 (msg + "). Current name: " + f 90 (msg + "). Current name: " + fInferenceLibrary).c_str()); 117 } 91 } 118 92 119 //....oooOO0OOooo........oooOO0OOooo........oo 93 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo...... 120 94 121 void Par04InferenceSetup::GetEnergies(std::vec 95 void Par04InferenceSetup::GetEnergies(std::vector<G4double>& aEnergies, G4double aInitialEnergy, 122 G4float 96 G4float aInitialAngle) 123 { 97 { 124 // First check if inference library was set 98 // First check if inference library was set correctly 125 CheckInferenceLibrary(); 99 CheckInferenceLibrary(); 126 // size represents the size of the output ve 100 // size represents the size of the output vector 127 int size = fMeshNumber.x() * fMeshNumber.y() 101 int size = fMeshNumber.x() * fMeshNumber.y() * fMeshNumber.z(); 128 102 129 // randomly sample from a gaussian distribut 103 // randomly sample from a gaussian distribution in the latent space 130 std::vector<G4float> genVector(fSizeLatentVe 104 std::vector<G4float> genVector(fSizeLatentVector + fSizeConditionVector, 0); 131 for (int i = 0; i < fSizeLatentVector; ++i) << 105 for(int i = 0; i < fSizeLatentVector; ++i) >> 106 { 132 genVector[i] = CLHEP::RandGauss::shoot(0., 107 genVector[i] = CLHEP::RandGauss::shoot(0., 1.); 133 } 108 } 134 109 135 // Vector of condition 110 // Vector of condition 136 // this is application specific it depdens o 111 // this is application specific it depdens on what the model was condition on 137 // and it depends on how the condition value << 112 // and it depends on how the condition values were encoded at the training time 138 // time in this example the energy of each p << 113 // in this example the energy of each particle is normlaized to the highest 139 // highest energy in the considered range (1 << 114 // energy in the considered range (1GeV-500GeV) 140 // normlaized to the highest angle in the co << 115 // the angle is also is normlaized to the highest angle in the considered range >> 116 // (0-90 in dergrees) 141 // the model in this example was trained on 117 // the model in this example was trained on two detector geometries PBW04 142 // and SiW a one hot encoding vector is use 118 // and SiW a one hot encoding vector is used to represent the geometry with 143 // [0,1] for PBW04 and [1,0] for SiW 119 // [0,1] for PBW04 and [1,0] for SiW 144 // 1. energy << 120 // 1.energy 145 genVector[fSizeLatentVector] = aInitialEnerg 121 genVector[fSizeLatentVector] = aInitialEnergy / fMaxEnergy; 146 // 2. angle 122 // 2. angle 147 genVector[fSizeLatentVector + 1] = (aInitial 123 genVector[fSizeLatentVector + 1] = (aInitialAngle / (CLHEP::deg)) / fMaxAngle; 148 // 3. geometry << 124 // 3.geometry 149 genVector[fSizeLatentVector + 2] = 0; 125 genVector[fSizeLatentVector + 2] = 0; 150 genVector[fSizeLatentVector + 3] = 1; 126 genVector[fSizeLatentVector + 3] = 1; 151 127 152 // Run the inference 128 // Run the inference 153 fInferenceInterface->RunInference(genVector, 129 fInferenceInterface->RunInference(genVector, aEnergies, size); 154 130 155 // After the inference rescale back to the i 131 // After the inference rescale back to the initial energy (in this example the 156 // energies of cells were normalized to the 132 // energies of cells were normalized to the energy of the particle) 157 for (int i = 0; i < size; ++i) { << 133 for(int i = 0; i < size; ++i) >> 134 { 158 aEnergies[i] = aEnergies[i] * aInitialEner 135 aEnergies[i] = aEnergies[i] * aInitialEnergy; 159 } 136 } 160 } 137 } 161 138 162 //....oooOO0OOooo........oooOO0OOooo........oo 139 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo...... 163 140 164 void Par04InferenceSetup::GetPositions(std::ve 141 void Par04InferenceSetup::GetPositions(std::vector<G4ThreeVector>& aPositions, G4ThreeVector pos0, 165 G4Three 142 G4ThreeVector direction) 166 { 143 { 167 aPositions.resize(fMeshNumber.x() * fMeshNum 144 aPositions.resize(fMeshNumber.x() * fMeshNumber.y() * fMeshNumber.z()); 168 145 169 // Calculate rotation matrix along the parti 146 // Calculate rotation matrix along the particle momentum direction 170 // It will rotate the shower axes to match t 147 // It will rotate the shower axes to match the incoming particle direction 171 G4RotationMatrix rotMatrix = G4RotationMatri 148 G4RotationMatrix rotMatrix = G4RotationMatrix(); 172 double particleTheta = direction.theta(); << 149 double particleTheta = direction.theta(); 173 double particlePhi = direction.phi(); << 150 double particlePhi = direction.phi(); 174 rotMatrix.rotateZ(-particlePhi); 151 rotMatrix.rotateZ(-particlePhi); 175 rotMatrix.rotateY(-particleTheta); 152 rotMatrix.rotateY(-particleTheta); 176 G4RotationMatrix rotMatrixInv = CLHEP::inver 153 G4RotationMatrix rotMatrixInv = CLHEP::inverseOf(rotMatrix); 177 154 178 int cpt = 0; 155 int cpt = 0; 179 for (G4int iCellR = 0; iCellR < fMeshNumber. << 156 for(G4int iCellR = 0; iCellR < fMeshNumber.x(); iCellR++) 180 for (G4int iCellPhi = 0; iCellPhi < fMeshN << 157 { 181 for (G4int iCellZ = 0; iCellZ < fMeshNum << 158 for(G4int iCellPhi = 0; iCellPhi < fMeshNumber.y(); iCellPhi++) >> 159 { >> 160 for(G4int iCellZ = 0; iCellZ < fMeshNumber.z(); iCellZ++) >> 161 { 182 aPositions[cpt] = 162 aPositions[cpt] = 183 pos0 << 163 pos0 + 184 + rotMatrixInv << 164 rotMatrixInv * 185 * G4ThreeVector( << 165 G4ThreeVector((iCellR + 0.5) * fMeshSize.x() * 186 (iCellR + 0.5) * fMeshSize.x() << 166 std::cos((iCellPhi + 0.5) * 2 * CLHEP::pi / fMeshNumber.y() - CLHEP::pi), 187 * std::cos((iCellPhi + 0.5) << 167 (iCellR + 0.5) * fMeshSize.x() * 188 (iCellR + 0.5) * fMeshSize.x() << 168 std::sin((iCellPhi + 0.5) * 2 * CLHEP::pi / fMeshNumber.y() - CLHEP::pi), 189 * std::sin((iCellPhi + 0.5) << 169 (iCellZ + 0.5) * fMeshSize.z()); 190 (iCellZ + 0.5) * fMeshSize.z() << 191 cpt++; 170 cpt++; 192 } 171 } 193 } 172 } 194 } 173 } 195 } 174 } 196 175 197 #endif 176 #endif 198 177