Geant4 Cross Reference |
1 // 1 2 // ******************************************* 3 // * License and Disclaimer 4 // * 5 // * The Geant4 software is copyright of th 6 // * the Geant4 Collaboration. It is provided 7 // * conditions of the Geant4 Software License 8 // * LICENSE and available at http://cern.ch/ 9 // * include a list of copyright holders. 10 // * 11 // * Neither the authors of this software syst 12 // * institutes,nor the agencies providing fin 13 // * work make any representation or warran 14 // * regarding this software system or assum 15 // * use. Please see the license in the file 16 // * for the full disclaimer and the limitatio 17 // * 18 // * This code implementation is the result 19 // * technical work of the GEANT4 collaboratio 20 // * By using, copying, modifying or distri 21 // * any work based on the software) you ag 22 // * use in resulting scientific publicati 23 // * acceptance of all terms of the Geant4 Sof 24 // ******************************************* 25 // 26 #ifdef USE_INFERENCE 27 # include "Par04InferenceSetup.hh" 28 29 # include "Par04InferenceInterface.hh" // fo 30 # include "Par04InferenceMessenger.hh" // fo 31 # ifdef USE_INFERENCE_ONNX 32 # include "Par04OnnxInference.hh" // for P 33 # endif 34 # ifdef USE_INFERENCE_LWTNN 35 # include "Par04LwtnnInference.hh" // for 36 # endif 37 # ifdef USE_INFERENCE_TORCH 38 # include "Par04TorchInference.hh" // for 39 # endif 40 # include "CLHEP/Random/RandGauss.h" // for 41 42 # include "G4RotationMatrix.hh" // for G4Rot 43 44 # include <CLHEP/Units/SystemOfUnits.h> // f 45 # include <CLHEP/Vector/Rotation.h> // for H 46 # include <CLHEP/Vector/ThreeVector.h> // fo 47 # include <G4Exception.hh> // for G4Exceptio 48 # include <G4ExceptionSeverity.hh> // for Fa 49 # include <G4ThreeVector.hh> // for G4ThreeV 50 # include <algorithm> // for max, copy 51 # include <cmath> // for cos, sin 52 # include <string> // for char_traits, basic 53 54 # include <ext/alloc_traits.h> // for __allo 55 56 //....oooOO0OOooo........oooOO0OOooo........oo 57 58 Par04InferenceSetup::Par04InferenceSetup() : f 59 {} 60 61 //....oooOO0OOooo........oooOO0OOooo........oo 62 63 Par04InferenceSetup::~Par04InferenceSetup() {} 64 65 //....oooOO0OOooo........oooOO0OOooo........oo 66 67 G4bool Par04InferenceSetup::IfTrigger(G4double 68 { 69 /// Energy of electrons used in training dat 70 if (aEnergy > 1 * CLHEP::GeV || aEnergy < 10 71 return false; 72 } 73 74 //....oooOO0OOooo........oooOO0OOooo........oo 75 76 void Par04InferenceSetup::SetInferenceLibrary( 77 { 78 fInferenceLibrary = aName; 79 80 # ifdef USE_INFERENCE_ONNX 81 if (fInferenceLibrary == "ONNX") 82 fInferenceInterface = std::unique_ptr<Par0 83 fModelPathName, fProfileFlag, fOptimizat 84 cuda_values, fModelSavePath, fProfilingO 85 # endif 86 # ifdef USE_INFERENCE_LWTNN 87 if (fInferenceLibrary == "LWTNN") 88 fInferenceInterface = 89 std::unique_ptr<Par04InferenceInterface> 90 # endif 91 # ifdef USE_INFERENCE_TORCH 92 if (fInferenceLibrary == "TORCH") 93 fInferenceInterface = 94 std::unique_ptr<Par04InferenceInterface> 95 # endif 96 97 CheckInferenceLibrary(); 98 } 99 100 //....oooOO0OOooo........oooOO0OOooo........oo 101 102 void Par04InferenceSetup::CheckInferenceLibrar 103 { 104 G4String msg = "Please choose inference libr 105 # ifdef USE_INFERENCE_ONNX 106 msg += "ONNX,"; 107 # endif 108 # ifdef USE_INFERENCE_LWTNN 109 msg += "LWTNN,"; 110 # endif 111 # ifdef USE_INFERENCE_TORCH 112 msg += "TORCH"; 113 # endif 114 if (fInferenceInterface == nullptr) 115 G4Exception("Par04InferenceSetup::CheckInf 116 (msg + "). Current name: " + f 117 } 118 119 //....oooOO0OOooo........oooOO0OOooo........oo 120 121 void Par04InferenceSetup::GetEnergies(std::vec 122 G4float 123 { 124 // First check if inference library was set 125 CheckInferenceLibrary(); 126 // size represents the size of the output ve 127 int size = fMeshNumber.x() * fMeshNumber.y() 128 129 // randomly sample from a gaussian distribut 130 std::vector<G4float> genVector(fSizeLatentVe 131 for (int i = 0; i < fSizeLatentVector; ++i) 132 genVector[i] = CLHEP::RandGauss::shoot(0., 133 } 134 135 // Vector of condition 136 // this is application specific it depdens o 137 // and it depends on how the condition value 138 // time in this example the energy of each p 139 // highest energy in the considered range (1 140 // normlaized to the highest angle in the co 141 // the model in this example was trained on 142 // and SiW a one hot encoding vector is use 143 // [0,1] for PBW04 and [1,0] for SiW 144 // 1. energy 145 genVector[fSizeLatentVector] = aInitialEnerg 146 // 2. angle 147 genVector[fSizeLatentVector + 1] = (aInitial 148 // 3. geometry 149 genVector[fSizeLatentVector + 2] = 0; 150 genVector[fSizeLatentVector + 3] = 1; 151 152 // Run the inference 153 fInferenceInterface->RunInference(genVector, 154 155 // After the inference rescale back to the i 156 // energies of cells were normalized to the 157 for (int i = 0; i < size; ++i) { 158 aEnergies[i] = aEnergies[i] * aInitialEner 159 } 160 } 161 162 //....oooOO0OOooo........oooOO0OOooo........oo 163 164 void Par04InferenceSetup::GetPositions(std::ve 165 G4Three 166 { 167 aPositions.resize(fMeshNumber.x() * fMeshNum 168 169 // Calculate rotation matrix along the parti 170 // It will rotate the shower axes to match t 171 G4RotationMatrix rotMatrix = G4RotationMatri 172 double particleTheta = direction.theta(); 173 double particlePhi = direction.phi(); 174 rotMatrix.rotateZ(-particlePhi); 175 rotMatrix.rotateY(-particleTheta); 176 G4RotationMatrix rotMatrixInv = CLHEP::inver 177 178 int cpt = 0; 179 for (G4int iCellR = 0; iCellR < fMeshNumber. 180 for (G4int iCellPhi = 0; iCellPhi < fMeshN 181 for (G4int iCellZ = 0; iCellZ < fMeshNum 182 aPositions[cpt] = 183 pos0 184 + rotMatrixInv 185 * G4ThreeVector( 186 (iCellR + 0.5) * fMeshSize.x() 187 * std::cos((iCellPhi + 0.5) 188 (iCellR + 0.5) * fMeshSize.x() 189 * std::sin((iCellPhi + 0.5) 190 (iCellZ + 0.5) * fMeshSize.z() 191 cpt++; 192 } 193 } 194 } 195 } 196 197 #endif 198