Geant4 Cross Reference

Cross-Referencing   Geant4
Geant4/examples/extended/parameterisations/Par04/src/Par04InferenceSetup.cc

Version: [ ReleaseNotes ] [ 1.0 ] [ 1.1 ] [ 2.0 ] [ 3.0 ] [ 3.1 ] [ 3.2 ] [ 4.0 ] [ 4.0.p1 ] [ 4.0.p2 ] [ 4.1 ] [ 4.1.p1 ] [ 5.0 ] [ 5.0.p1 ] [ 5.1 ] [ 5.1.p1 ] [ 5.2 ] [ 5.2.p1 ] [ 5.2.p2 ] [ 6.0 ] [ 6.0.p1 ] [ 6.1 ] [ 6.2 ] [ 6.2.p1 ] [ 6.2.p2 ] [ 7.0 ] [ 7.0.p1 ] [ 7.1 ] [ 7.1.p1 ] [ 8.0 ] [ 8.0.p1 ] [ 8.1 ] [ 8.1.p1 ] [ 8.1.p2 ] [ 8.2 ] [ 8.2.p1 ] [ 8.3 ] [ 8.3.p1 ] [ 8.3.p2 ] [ 9.0 ] [ 9.0.p1 ] [ 9.0.p2 ] [ 9.1 ] [ 9.1.p1 ] [ 9.1.p2 ] [ 9.1.p3 ] [ 9.2 ] [ 9.2.p1 ] [ 9.2.p2 ] [ 9.2.p3 ] [ 9.2.p4 ] [ 9.3 ] [ 9.3.p1 ] [ 9.3.p2 ] [ 9.4 ] [ 9.4.p1 ] [ 9.4.p2 ] [ 9.4.p3 ] [ 9.4.p4 ] [ 9.5 ] [ 9.5.p1 ] [ 9.5.p2 ] [ 9.6 ] [ 9.6.p1 ] [ 9.6.p2 ] [ 9.6.p3 ] [ 9.6.p4 ] [ 10.0 ] [ 10.0.p1 ] [ 10.0.p2 ] [ 10.0.p3 ] [ 10.0.p4 ] [ 10.1 ] [ 10.1.p1 ] [ 10.1.p2 ] [ 10.1.p3 ] [ 10.2 ] [ 10.2.p1 ] [ 10.2.p2 ] [ 10.2.p3 ] [ 10.3 ] [ 10.3.p1 ] [ 10.3.p2 ] [ 10.3.p3 ] [ 10.4 ] [ 10.4.p1 ] [ 10.4.p2 ] [ 10.4.p3 ] [ 10.5 ] [ 10.5.p1 ] [ 10.6 ] [ 10.6.p1 ] [ 10.6.p2 ] [ 10.6.p3 ] [ 10.7 ] [ 10.7.p1 ] [ 10.7.p2 ] [ 10.7.p3 ] [ 10.7.p4 ] [ 11.0 ] [ 11.0.p1 ] [ 11.0.p2 ] [ 11.0.p3, ] [ 11.0.p4 ] [ 11.1 ] [ 11.1.1 ] [ 11.1.2 ] [ 11.1.3 ] [ 11.2 ] [ 11.2.1 ] [ 11.2.2 ] [ 11.3.0 ]

Diff markup

Differences between /examples/extended/parameterisations/Par04/src/Par04InferenceSetup.cc (Version 11.3.0) and /examples/extended/parameterisations/Par04/src/Par04InferenceSetup.cc (Version 11.0.p2)


  1 //                                                  1 //
  2 // *******************************************      2 // ********************************************************************
  3 // * License and Disclaimer                         3 // * License and Disclaimer                                           *
  4 // *                                                4 // *                                                                  *
  5 // * The  Geant4 software  is  copyright of th      5 // * The  Geant4 software  is  copyright of the Copyright Holders  of *
  6 // * the Geant4 Collaboration.  It is provided      6 // * the Geant4 Collaboration.  It is provided  under  the terms  and *
  7 // * conditions of the Geant4 Software License      7 // * conditions of the Geant4 Software License,  included in the file *
  8 // * LICENSE and available at  http://cern.ch/      8 // * LICENSE and available at  http://cern.ch/geant4/license .  These *
  9 // * include a list of copyright holders.           9 // * include a list of copyright holders.                             *
 10 // *                                               10 // *                                                                  *
 11 // * Neither the authors of this software syst     11 // * Neither the authors of this software system, nor their employing *
 12 // * institutes,nor the agencies providing fin     12 // * institutes,nor the agencies providing financial support for this *
 13 // * work  make  any representation or  warran     13 // * work  make  any representation or  warranty, express or implied, *
 14 // * regarding  this  software system or assum     14 // * regarding  this  software system or assume any liability for its *
 15 // * use.  Please see the license in the file      15 // * use.  Please see the license in the file  LICENSE  and URL above *
 16 // * for the full disclaimer and the limitatio     16 // * for the full disclaimer and the limitation of liability.         *
 17 // *                                               17 // *                                                                  *
 18 // * This  code  implementation is the result      18 // * This  code  implementation is the result of  the  scientific and *
 19 // * technical work of the GEANT4 collaboratio     19 // * technical work of the GEANT4 collaboration.                      *
 20 // * By using,  copying,  modifying or  distri     20 // * By using,  copying,  modifying or  distributing the software (or *
 21 // * any work based  on the software)  you  ag     21 // * any work based  on the software)  you  agree  to acknowledge its *
 22 // * use  in  resulting  scientific  publicati     22 // * use  in  resulting  scientific  publications,  and indicate your *
 23 // * acceptance of all terms of the Geant4 Sof     23 // * acceptance of all terms of the Geant4 Software license.          *
 24 // *******************************************     24 // ********************************************************************
 25 //                                                 25 //
 26 #ifdef USE_INFERENCE                               26 #ifdef USE_INFERENCE
 27 #  include "Par04InferenceSetup.hh"            <<  27 #include "Par04InferenceSetup.hh"
 28                                                    28 
 29 #  include "Par04InferenceInterface.hh"  // fo <<  29 #include "Par04InferenceInterface.hh"
 30 #  include "Par04InferenceMessenger.hh"  // fo <<  30 #ifdef USE_INFERENCE_ONNX
 31 #  ifdef USE_INFERENCE_ONNX                    <<  31 #include "Par04OnnxInference.hh"
 32 #    include "Par04OnnxInference.hh"  // for P <<  32 #endif
 33 #  endif                                       <<  33 #ifdef USE_INFERENCE_LWTNN
 34 #  ifdef USE_INFERENCE_LWTNN                   <<  34 #include "Par04LwtnnInference.hh"
 35 #    include "Par04LwtnnInference.hh"  // for  <<  35 #endif
 36 #  endif                                       <<  36 #include "G4RotationMatrix.hh"
 37 #  ifdef USE_INFERENCE_TORCH                   <<  37 #include "CLHEP/Random/RandGauss.h"
 38 #    include "Par04TorchInference.hh"  // for  << 
 39 #  endif                                       << 
 40 #  include "CLHEP/Random/RandGauss.h"  // for  << 
 41                                                << 
 42 #  include "G4RotationMatrix.hh"  // for G4Rot << 
 43                                                << 
 44 #  include <CLHEP/Units/SystemOfUnits.h>  // f << 
 45 #  include <CLHEP/Vector/Rotation.h>  // for H << 
 46 #  include <CLHEP/Vector/ThreeVector.h>  // fo << 
 47 #  include <G4Exception.hh>  // for G4Exceptio << 
 48 #  include <G4ExceptionSeverity.hh>  // for Fa << 
 49 #  include <G4ThreeVector.hh>  // for G4ThreeV << 
 50 #  include <algorithm>  // for max, copy       << 
 51 #  include <cmath>  // for cos, sin            << 
 52 #  include <string>  // for char_traits, basic << 
 53                                                << 
 54 #  include <ext/alloc_traits.h>  // for __allo << 
 55                                                    38 
 56 //....oooOO0OOooo........oooOO0OOooo........oo     39 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
 57                                                    40 
 58 Par04InferenceSetup::Par04InferenceSetup() : f <<  41 Par04InferenceSetup::Par04InferenceSetup()
                                                   >>  42   : fInferenceMessenger(new Par04InferenceMessenger(this))
 59 {}                                                 43 {}
 60                                                    44 
 61 //....oooOO0OOooo........oooOO0OOooo........oo     45 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
 62                                                    46 
 63 Par04InferenceSetup::~Par04InferenceSetup() {}     47 Par04InferenceSetup::~Par04InferenceSetup() {}
 64                                                    48 
 65 //....oooOO0OOooo........oooOO0OOooo........oo     49 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
 66                                                    50 
 67 G4bool Par04InferenceSetup::IfTrigger(G4double     51 G4bool Par04InferenceSetup::IfTrigger(G4double aEnergy)
 68 {                                                  52 {
 69   /// Energy of electrons used in training dat     53   /// Energy of electrons used in training dataset
 70   if (aEnergy > 1 * CLHEP::GeV || aEnergy < 10 <<  54   if(aEnergy > 1 * CLHEP::GeV || aEnergy < 1024 * CLHEP::GeV)
 71   return false;                                <<  55     return true;
 72 }                                                  56 }
 73                                                    57 
 74 //....oooOO0OOooo........oooOO0OOooo........oo     58 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
 75                                                    59 
 76 void Par04InferenceSetup::SetInferenceLibrary(     60 void Par04InferenceSetup::SetInferenceLibrary(G4String aName)
 77 {                                                  61 {
 78   fInferenceLibrary = aName;                       62   fInferenceLibrary = aName;
 79                                                    63 
 80 #  ifdef USE_INFERENCE_ONNX                    <<  64 #ifdef USE_INFERENCE_ONNX
 81   if (fInferenceLibrary == "ONNX")             <<  65   if(fInferenceLibrary == "ONNX")
 82     fInferenceInterface = std::unique_ptr<Par0 <<  66     fInferenceInterface = std::unique_ptr<Par04InferenceInterface>(
 83       fModelPathName, fProfileFlag, fOptimizat <<  67       new Par04OnnxInference(fModelPathName, fProfileFlag, fOptimizationFlag, fIntraOpNumThreads));
 84       cuda_values, fModelSavePath, fProfilingO <<  68 #endif
 85 #  endif                                       <<  69 #ifdef USE_INFERENCE_LWTNN
 86 #  ifdef USE_INFERENCE_LWTNN                   <<  70   if(fInferenceLibrary == "LWTNN")
 87   if (fInferenceLibrary == "LWTNN")            << 
 88     fInferenceInterface =                          71     fInferenceInterface =
 89       std::unique_ptr<Par04InferenceInterface>     72       std::unique_ptr<Par04InferenceInterface>(new Par04LwtnnInference(fModelPathName));
 90 #  endif                                       <<  73 #endif
 91 #  ifdef USE_INFERENCE_TORCH                   << 
 92   if (fInferenceLibrary == "TORCH")            << 
 93     fInferenceInterface =                      << 
 94       std::unique_ptr<Par04InferenceInterface> << 
 95 #  endif                                       << 
 96                                                << 
 97   CheckInferenceLibrary();                         74   CheckInferenceLibrary();
 98 }                                                  75 }
 99                                                    76 
100 //....oooOO0OOooo........oooOO0OOooo........oo     77 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
101                                                    78 
102 void Par04InferenceSetup::CheckInferenceLibrar     79 void Par04InferenceSetup::CheckInferenceLibrary()
103 {                                                  80 {
104   G4String msg = "Please choose inference libr     81   G4String msg = "Please choose inference library from available libraries (";
105 #  ifdef USE_INFERENCE_ONNX                    <<  82 #ifdef USE_INFERENCE_ONNX
106   msg += "ONNX,";                                  83   msg += "ONNX,";
107 #  endif                                       <<  84 #endif
108 #  ifdef USE_INFERENCE_LWTNN                   <<  85 #ifdef USE_INFERENCE_LWTNN
109   msg += "LWTNN,";                             <<  86   msg += "LWTNN";
110 #  endif                                       <<  87 #endif
111 #  ifdef USE_INFERENCE_TORCH                   <<  88   if(fInferenceInterface == nullptr)
112   msg += "TORCH";                              << 
113 #  endif                                       << 
114   if (fInferenceInterface == nullptr)          << 
115     G4Exception("Par04InferenceSetup::CheckInf     89     G4Exception("Par04InferenceSetup::CheckInferenceLibrary()", "InvalidSetup", FatalException,
116                 (msg + "). Current name: " + f     90                 (msg + "). Current name: " + fInferenceLibrary).c_str());
117 }                                                  91 }
118                                                    92 
119 //....oooOO0OOooo........oooOO0OOooo........oo     93 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
120                                                    94 
121 void Par04InferenceSetup::GetEnergies(std::vec     95 void Par04InferenceSetup::GetEnergies(std::vector<G4double>& aEnergies, G4double aInitialEnergy,
122                                       G4float      96                                       G4float aInitialAngle)
123 {                                                  97 {
124   // First check if inference library was set      98   // First check if inference library was set correctly
125   CheckInferenceLibrary();                         99   CheckInferenceLibrary();
126   // size represents the size of the output ve    100   // size represents the size of the output vector
127   int size = fMeshNumber.x() * fMeshNumber.y()    101   int size = fMeshNumber.x() * fMeshNumber.y() * fMeshNumber.z();
128                                                   102 
129   // randomly sample from a gaussian distribut    103   // randomly sample from a gaussian distribution in the latent space
130   std::vector<G4float> genVector(fSizeLatentVe    104   std::vector<G4float> genVector(fSizeLatentVector + fSizeConditionVector, 0);
131   for (int i = 0; i < fSizeLatentVector; ++i)  << 105   for(int i = 0; i < fSizeLatentVector; ++i)
                                                   >> 106   {
132     genVector[i] = CLHEP::RandGauss::shoot(0.,    107     genVector[i] = CLHEP::RandGauss::shoot(0., 1.);
133   }                                               108   }
134                                                   109 
135   // Vector of condition                          110   // Vector of condition
136   // this is application specific it depdens o    111   // this is application specific it depdens on what the model was condition on
137   // and it depends on how the condition value << 112   // and it depends on how the condition values were encoded at the training time
138   // time in this example the energy of each p << 113   // in this example the energy of each particle is normlaized to the highest
139   // highest energy in the considered range (1 << 114   // energy in the considered range (1GeV-500GeV)
140   // normlaized to the highest angle in the co << 115   // the angle is also is normlaized to the highest angle in the considered range
                                                   >> 116   // (0-90 in dergrees)
141   // the model in this example was trained on     117   // the model in this example was trained on two detector geometries PBW04
142   // and SiW  a one hot encoding vector is use    118   // and SiW  a one hot encoding vector is used to represent the geometry with
143   // [0,1] for PBW04 and [1,0] for SiW            119   // [0,1] for PBW04 and [1,0] for SiW
144   // 1. energy                                 << 120   // 1.energy
145   genVector[fSizeLatentVector] = aInitialEnerg    121   genVector[fSizeLatentVector] = aInitialEnergy / fMaxEnergy;
146   // 2. angle                                     122   // 2. angle
147   genVector[fSizeLatentVector + 1] = (aInitial    123   genVector[fSizeLatentVector + 1] = (aInitialAngle / (CLHEP::deg)) / fMaxAngle;
148   // 3. geometry                               << 124   // 3.geometry
149   genVector[fSizeLatentVector + 2] = 0;           125   genVector[fSizeLatentVector + 2] = 0;
150   genVector[fSizeLatentVector + 3] = 1;           126   genVector[fSizeLatentVector + 3] = 1;
151                                                   127 
152   // Run the inference                            128   // Run the inference
153   fInferenceInterface->RunInference(genVector,    129   fInferenceInterface->RunInference(genVector, aEnergies, size);
154                                                   130 
155   // After the inference rescale back to the i    131   // After the inference rescale back to the initial energy (in this example the
156   // energies of cells were normalized to the     132   // energies of cells were normalized to the energy of the particle)
157   for (int i = 0; i < size; ++i) {             << 133   for(int i = 0; i < size; ++i)
                                                   >> 134   {
158     aEnergies[i] = aEnergies[i] * aInitialEner    135     aEnergies[i] = aEnergies[i] * aInitialEnergy;
159   }                                               136   }
160 }                                                 137 }
161                                                   138 
162 //....oooOO0OOooo........oooOO0OOooo........oo    139 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
163                                                   140 
164 void Par04InferenceSetup::GetPositions(std::ve    141 void Par04InferenceSetup::GetPositions(std::vector<G4ThreeVector>& aPositions, G4ThreeVector pos0,
165                                        G4Three    142                                        G4ThreeVector direction)
166 {                                                 143 {
167   aPositions.resize(fMeshNumber.x() * fMeshNum    144   aPositions.resize(fMeshNumber.x() * fMeshNumber.y() * fMeshNumber.z());
168                                                   145 
169   // Calculate rotation matrix along the parti    146   // Calculate rotation matrix along the particle momentum direction
170   // It will rotate the shower axes to match t    147   // It will rotate the shower axes to match the incoming particle direction
171   G4RotationMatrix rotMatrix = G4RotationMatri    148   G4RotationMatrix rotMatrix = G4RotationMatrix();
172   double particleTheta = direction.theta();    << 149   double particleTheta       = direction.theta();
173   double particlePhi = direction.phi();        << 150   double particlePhi         = direction.phi();
174   rotMatrix.rotateZ(-particlePhi);                151   rotMatrix.rotateZ(-particlePhi);
175   rotMatrix.rotateY(-particleTheta);              152   rotMatrix.rotateY(-particleTheta);
176   G4RotationMatrix rotMatrixInv = CLHEP::inver    153   G4RotationMatrix rotMatrixInv = CLHEP::inverseOf(rotMatrix);
177                                                   154 
178   int cpt = 0;                                    155   int cpt = 0;
179   for (G4int iCellR = 0; iCellR < fMeshNumber. << 156   for(G4int iCellR = 0; iCellR < fMeshNumber.x(); iCellR++)
180     for (G4int iCellPhi = 0; iCellPhi < fMeshN << 157   {
181       for (G4int iCellZ = 0; iCellZ < fMeshNum << 158     for(G4int iCellPhi = 0; iCellPhi < fMeshNumber.y(); iCellPhi++)
                                                   >> 159     {
                                                   >> 160       for(G4int iCellZ = 0; iCellZ < fMeshNumber.z(); iCellZ++)
                                                   >> 161       {
182         aPositions[cpt] =                         162         aPositions[cpt] =
183           pos0                                 << 163           pos0 +
184           + rotMatrixInv                       << 164           rotMatrixInv *
185               * G4ThreeVector(                 << 165             G4ThreeVector((iCellR + 0.5) * fMeshSize.x() *
186                 (iCellR + 0.5) * fMeshSize.x() << 166                             std::cos((iCellPhi + 0.5) * 2 * CLHEP::pi / fMeshNumber.y() - CLHEP::pi),
187                   * std::cos((iCellPhi + 0.5)  << 167                           (iCellR + 0.5) * fMeshSize.x() *
188                 (iCellR + 0.5) * fMeshSize.x() << 168                             std::sin((iCellPhi + 0.5) * 2 * CLHEP::pi / fMeshNumber.y() - CLHEP::pi),
189                   * std::sin((iCellPhi + 0.5)  << 169                           (iCellZ + 0.5) * fMeshSize.z());
190                 (iCellZ + 0.5) * fMeshSize.z() << 
191         cpt++;                                    170         cpt++;
192       }                                           171       }
193     }                                             172     }
194   }                                               173   }
195 }                                                 174 }
196                                                   175 
197 #endif                                            176 #endif
198                                                   177