Geant4 Cross Reference

Cross-Referencing   Geant4
Geant4/examples/extended/parameterisations/Par04/src/Par04InferenceSetup.cc

Version: [ ReleaseNotes ] [ 1.0 ] [ 1.1 ] [ 2.0 ] [ 3.0 ] [ 3.1 ] [ 3.2 ] [ 4.0 ] [ 4.0.p1 ] [ 4.0.p2 ] [ 4.1 ] [ 4.1.p1 ] [ 5.0 ] [ 5.0.p1 ] [ 5.1 ] [ 5.1.p1 ] [ 5.2 ] [ 5.2.p1 ] [ 5.2.p2 ] [ 6.0 ] [ 6.0.p1 ] [ 6.1 ] [ 6.2 ] [ 6.2.p1 ] [ 6.2.p2 ] [ 7.0 ] [ 7.0.p1 ] [ 7.1 ] [ 7.1.p1 ] [ 8.0 ] [ 8.0.p1 ] [ 8.1 ] [ 8.1.p1 ] [ 8.1.p2 ] [ 8.2 ] [ 8.2.p1 ] [ 8.3 ] [ 8.3.p1 ] [ 8.3.p2 ] [ 9.0 ] [ 9.0.p1 ] [ 9.0.p2 ] [ 9.1 ] [ 9.1.p1 ] [ 9.1.p2 ] [ 9.1.p3 ] [ 9.2 ] [ 9.2.p1 ] [ 9.2.p2 ] [ 9.2.p3 ] [ 9.2.p4 ] [ 9.3 ] [ 9.3.p1 ] [ 9.3.p2 ] [ 9.4 ] [ 9.4.p1 ] [ 9.4.p2 ] [ 9.4.p3 ] [ 9.4.p4 ] [ 9.5 ] [ 9.5.p1 ] [ 9.5.p2 ] [ 9.6 ] [ 9.6.p1 ] [ 9.6.p2 ] [ 9.6.p3 ] [ 9.6.p4 ] [ 10.0 ] [ 10.0.p1 ] [ 10.0.p2 ] [ 10.0.p3 ] [ 10.0.p4 ] [ 10.1 ] [ 10.1.p1 ] [ 10.1.p2 ] [ 10.1.p3 ] [ 10.2 ] [ 10.2.p1 ] [ 10.2.p2 ] [ 10.2.p3 ] [ 10.3 ] [ 10.3.p1 ] [ 10.3.p2 ] [ 10.3.p3 ] [ 10.4 ] [ 10.4.p1 ] [ 10.4.p2 ] [ 10.4.p3 ] [ 10.5 ] [ 10.5.p1 ] [ 10.6 ] [ 10.6.p1 ] [ 10.6.p2 ] [ 10.6.p3 ] [ 10.7 ] [ 10.7.p1 ] [ 10.7.p2 ] [ 10.7.p3 ] [ 10.7.p4 ] [ 11.0 ] [ 11.0.p1 ] [ 11.0.p2 ] [ 11.0.p3, ] [ 11.0.p4 ] [ 11.1 ] [ 11.1.1 ] [ 11.1.2 ] [ 11.1.3 ] [ 11.2 ] [ 11.2.1 ] [ 11.2.2 ] [ 11.3.0 ]

Diff markup

Differences between /examples/extended/parameterisations/Par04/src/Par04InferenceSetup.cc (Version 11.3.0) and /examples/extended/parameterisations/Par04/src/Par04InferenceSetup.cc (Version 11.1.1)


  1 //                                                  1 //
  2 // *******************************************      2 // ********************************************************************
  3 // * License and Disclaimer                         3 // * License and Disclaimer                                           *
  4 // *                                                4 // *                                                                  *
  5 // * The  Geant4 software  is  copyright of th      5 // * The  Geant4 software  is  copyright of the Copyright Holders  of *
  6 // * the Geant4 Collaboration.  It is provided      6 // * the Geant4 Collaboration.  It is provided  under  the terms  and *
  7 // * conditions of the Geant4 Software License      7 // * conditions of the Geant4 Software License,  included in the file *
  8 // * LICENSE and available at  http://cern.ch/      8 // * LICENSE and available at  http://cern.ch/geant4/license .  These *
  9 // * include a list of copyright holders.           9 // * include a list of copyright holders.                             *
 10 // *                                               10 // *                                                                  *
 11 // * Neither the authors of this software syst     11 // * Neither the authors of this software system, nor their employing *
 12 // * institutes,nor the agencies providing fin     12 // * institutes,nor the agencies providing financial support for this *
 13 // * work  make  any representation or  warran     13 // * work  make  any representation or  warranty, express or implied, *
 14 // * regarding  this  software system or assum     14 // * regarding  this  software system or assume any liability for its *
 15 // * use.  Please see the license in the file      15 // * use.  Please see the license in the file  LICENSE  and URL above *
 16 // * for the full disclaimer and the limitatio     16 // * for the full disclaimer and the limitation of liability.         *
 17 // *                                               17 // *                                                                  *
 18 // * This  code  implementation is the result      18 // * This  code  implementation is the result of  the  scientific and *
 19 // * technical work of the GEANT4 collaboratio     19 // * technical work of the GEANT4 collaboration.                      *
 20 // * By using,  copying,  modifying or  distri     20 // * By using,  copying,  modifying or  distributing the software (or *
 21 // * any work based  on the software)  you  ag     21 // * any work based  on the software)  you  agree  to acknowledge its *
 22 // * use  in  resulting  scientific  publicati     22 // * use  in  resulting  scientific  publications,  and indicate your *
 23 // * acceptance of all terms of the Geant4 Sof     23 // * acceptance of all terms of the Geant4 Software license.          *
 24 // *******************************************     24 // ********************************************************************
 25 //                                                 25 //
 26 #ifdef USE_INFERENCE                               26 #ifdef USE_INFERENCE
 27 #  include "Par04InferenceSetup.hh"            <<  27 #include "Par04InferenceSetup.hh"
 28                                                <<  28 #include "Par04InferenceInterface.hh"   // for Par04InferenceInterface
 29 #  include "Par04InferenceInterface.hh"  // fo <<  29 #include "Par04InferenceMessenger.hh"   // for Par04InferenceMessenger
 30 #  include "Par04InferenceMessenger.hh"  // fo <<  30 #ifdef USE_INFERENCE_ONNX
 31 #  ifdef USE_INFERENCE_ONNX                    <<  31 #include "Par04OnnxInference.hh"        // for Par04OnnxInference
 32 #    include "Par04OnnxInference.hh"  // for P <<  32 #endif
 33 #  endif                                       <<  33 #ifdef USE_INFERENCE_LWTNN
 34 #  ifdef USE_INFERENCE_LWTNN                   <<  34 #include "Par04LwtnnInference.hh"       // for Par04LwtnnInference
 35 #    include "Par04LwtnnInference.hh"  // for  <<  35 #endif
 36 #  endif                                       <<  36 #ifdef USE_INFERENCE_TORCH
 37 #  ifdef USE_INFERENCE_TORCH                   <<  37 #include "Par04TorchInference.hh"    // for Par04TorchInference
 38 #    include "Par04TorchInference.hh"  // for  <<  38 #endif
 39 #  endif                                       <<  39 #include <CLHEP/Units/SystemOfUnits.h>  // for pi, GeV, deg
 40 #  include "CLHEP/Random/RandGauss.h"  // for  <<  40 #include <CLHEP/Vector/Rotation.h>      // for HepRotation
 41                                                <<  41 #include <CLHEP/Vector/ThreeVector.h>   // for Hep3Vector
 42 #  include "G4RotationMatrix.hh"  // for G4Rot <<  42 #include <G4Exception.hh>               // for G4Exception
 43                                                <<  43 #include <G4ExceptionSeverity.hh>       // for FatalException
 44 #  include <CLHEP/Units/SystemOfUnits.h>  // f <<  44 #include <G4ThreeVector.hh>             // for G4ThreeVector
 45 #  include <CLHEP/Vector/Rotation.h>  // for H <<  45 #include "CLHEP/Random/RandGauss.h"     // for RandGauss
 46 #  include <CLHEP/Vector/ThreeVector.h>  // fo <<  46 #include "G4RotationMatrix.hh"          // for G4RotationMatrix
 47 #  include <G4Exception.hh>  // for G4Exceptio <<  47 #include <algorithm>                    // for max, copy
 48 #  include <G4ExceptionSeverity.hh>  // for Fa <<  48 #include <string>                       // for char_traits, basic_string
 49 #  include <G4ThreeVector.hh>  // for G4ThreeV <<  49 #include <ext/alloc_traits.h>           // for __alloc_traits<>::value_type
 50 #  include <algorithm>  // for max, copy       <<  50 #include <cmath>                        // for cos, sin
 51 #  include <cmath>  // for cos, sin            << 
 52 #  include <string>  // for char_traits, basic << 
 53                                                << 
 54 #  include <ext/alloc_traits.h>  // for __allo << 
 55                                                    51 
 56 //....oooOO0OOooo........oooOO0OOooo........oo     52 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
 57                                                    53 
 58 Par04InferenceSetup::Par04InferenceSetup() : f <<  54 Par04InferenceSetup::Par04InferenceSetup()
                                                   >>  55   : fInferenceMessenger(new Par04InferenceMessenger(this))
 59 {}                                                 56 {}
 60                                                    57 
 61 //....oooOO0OOooo........oooOO0OOooo........oo     58 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
 62                                                    59 
 63 Par04InferenceSetup::~Par04InferenceSetup() {}     60 Par04InferenceSetup::~Par04InferenceSetup() {}
 64                                                    61 
 65 //....oooOO0OOooo........oooOO0OOooo........oo     62 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
 66                                                    63 
 67 G4bool Par04InferenceSetup::IfTrigger(G4double     64 G4bool Par04InferenceSetup::IfTrigger(G4double aEnergy)
 68 {                                                  65 {
 69   /// Energy of electrons used in training dat     66   /// Energy of electrons used in training dataset
 70   if (aEnergy > 1 * CLHEP::GeV || aEnergy < 10 <<  67   if(aEnergy > 1 * CLHEP::GeV || aEnergy < 1024 * CLHEP::GeV)
                                                   >>  68     return true;
 71   return false;                                    69   return false;
 72 }                                                  70 }
 73                                                    71 
 74 //....oooOO0OOooo........oooOO0OOooo........oo     72 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
 75                                                    73 
 76 void Par04InferenceSetup::SetInferenceLibrary(     74 void Par04InferenceSetup::SetInferenceLibrary(G4String aName)
 77 {                                                  75 {
 78   fInferenceLibrary = aName;                       76   fInferenceLibrary = aName;
 79                                                    77 
 80 #  ifdef USE_INFERENCE_ONNX                    <<  78 #ifdef USE_INFERENCE_ONNX
 81   if (fInferenceLibrary == "ONNX")             <<  79   if(fInferenceLibrary == "ONNX")
 82     fInferenceInterface = std::unique_ptr<Par0 <<  80     fInferenceInterface = std::unique_ptr<Par04InferenceInterface>(
 83       fModelPathName, fProfileFlag, fOptimizat <<  81       new Par04OnnxInference(fModelPathName, fProfileFlag, fOptimizationFlag, fIntraOpNumThreads));
 84       cuda_values, fModelSavePath, fProfilingO <<  82 #endif
 85 #  endif                                       <<  83 #ifdef USE_INFERENCE_LWTNN
 86 #  ifdef USE_INFERENCE_LWTNN                   <<  84   if(fInferenceLibrary == "LWTNN")
 87   if (fInferenceLibrary == "LWTNN")            << 
 88     fInferenceInterface =                          85     fInferenceInterface =
 89       std::unique_ptr<Par04InferenceInterface>     86       std::unique_ptr<Par04InferenceInterface>(new Par04LwtnnInference(fModelPathName));
 90 #  endif                                       <<  87 #endif
 91 #  ifdef USE_INFERENCE_TORCH                   <<  88 #ifdef USE_INFERENCE_TORCH
 92   if (fInferenceLibrary == "TORCH")            <<  89   if(fInferenceLibrary == "TORCH")
 93     fInferenceInterface =                          90     fInferenceInterface =
 94       std::unique_ptr<Par04InferenceInterface>     91       std::unique_ptr<Par04InferenceInterface>(new Par04TorchInference(fModelPathName));
 95 #  endif                                       <<  92 #endif
 96                                                    93 
 97   CheckInferenceLibrary();                         94   CheckInferenceLibrary();
 98 }                                                  95 }
 99                                                    96 
100 //....oooOO0OOooo........oooOO0OOooo........oo     97 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
101                                                    98 
102 void Par04InferenceSetup::CheckInferenceLibrar     99 void Par04InferenceSetup::CheckInferenceLibrary()
103 {                                                 100 {
104   G4String msg = "Please choose inference libr    101   G4String msg = "Please choose inference library from available libraries (";
105 #  ifdef USE_INFERENCE_ONNX                    << 102 #ifdef USE_INFERENCE_ONNX
106   msg += "ONNX,";                                 103   msg += "ONNX,";
107 #  endif                                       << 104 #endif
108 #  ifdef USE_INFERENCE_LWTNN                   << 105 #ifdef USE_INFERENCE_LWTNN
109   msg += "LWTNN,";                                106   msg += "LWTNN,";
110 #  endif                                       << 107 #endif
111 #  ifdef USE_INFERENCE_TORCH                   << 108 #ifdef USE_INFERENCE_TORCH
112   msg += "TORCH";                                 109   msg += "TORCH";
113 #  endif                                       << 110 #endif
114   if (fInferenceInterface == nullptr)          << 111   if(fInferenceInterface == nullptr)
115     G4Exception("Par04InferenceSetup::CheckInf    112     G4Exception("Par04InferenceSetup::CheckInferenceLibrary()", "InvalidSetup", FatalException,
116                 (msg + "). Current name: " + f    113                 (msg + "). Current name: " + fInferenceLibrary).c_str());
117 }                                                 114 }
118                                                   115 
119 //....oooOO0OOooo........oooOO0OOooo........oo    116 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
120                                                   117 
121 void Par04InferenceSetup::GetEnergies(std::vec    118 void Par04InferenceSetup::GetEnergies(std::vector<G4double>& aEnergies, G4double aInitialEnergy,
122                                       G4float     119                                       G4float aInitialAngle)
123 {                                                 120 {
124   // First check if inference library was set     121   // First check if inference library was set correctly
125   CheckInferenceLibrary();                        122   CheckInferenceLibrary();
126   // size represents the size of the output ve    123   // size represents the size of the output vector
127   int size = fMeshNumber.x() * fMeshNumber.y()    124   int size = fMeshNumber.x() * fMeshNumber.y() * fMeshNumber.z();
128                                                   125 
129   // randomly sample from a gaussian distribut    126   // randomly sample from a gaussian distribution in the latent space
130   std::vector<G4float> genVector(fSizeLatentVe    127   std::vector<G4float> genVector(fSizeLatentVector + fSizeConditionVector, 0);
131   for (int i = 0; i < fSizeLatentVector; ++i)  << 128   for(int i = 0; i < fSizeLatentVector; ++i)
                                                   >> 129   {
132     genVector[i] = CLHEP::RandGauss::shoot(0.,    130     genVector[i] = CLHEP::RandGauss::shoot(0., 1.);
133   }                                               131   }
134                                                   132 
135   // Vector of condition                          133   // Vector of condition
136   // this is application specific it depdens o    134   // this is application specific it depdens on what the model was condition on
137   // and it depends on how the condition value << 135   // and it depends on how the condition values were encoded at the training time
138   // time in this example the energy of each p << 136   // in this example the energy of each particle is normlaized to the highest
139   // highest energy in the considered range (1 << 137   // energy in the considered range (1GeV-500GeV)
140   // normlaized to the highest angle in the co << 138   // the angle is also is normlaized to the highest angle in the considered range
                                                   >> 139   // (0-90 in dergrees)
141   // the model in this example was trained on     140   // the model in this example was trained on two detector geometries PBW04
142   // and SiW  a one hot encoding vector is use    141   // and SiW  a one hot encoding vector is used to represent the geometry with
143   // [0,1] for PBW04 and [1,0] for SiW            142   // [0,1] for PBW04 and [1,0] for SiW
144   // 1. energy                                 << 143   // 1.energy
145   genVector[fSizeLatentVector] = aInitialEnerg    144   genVector[fSizeLatentVector] = aInitialEnergy / fMaxEnergy;
146   // 2. angle                                     145   // 2. angle
147   genVector[fSizeLatentVector + 1] = (aInitial    146   genVector[fSizeLatentVector + 1] = (aInitialAngle / (CLHEP::deg)) / fMaxAngle;
148   // 3. geometry                               << 147   // 3.geometry
149   genVector[fSizeLatentVector + 2] = 0;           148   genVector[fSizeLatentVector + 2] = 0;
150   genVector[fSizeLatentVector + 3] = 1;           149   genVector[fSizeLatentVector + 3] = 1;
151                                                   150 
152   // Run the inference                            151   // Run the inference
153   fInferenceInterface->RunInference(genVector,    152   fInferenceInterface->RunInference(genVector, aEnergies, size);
154                                                   153 
155   // After the inference rescale back to the i    154   // After the inference rescale back to the initial energy (in this example the
156   // energies of cells were normalized to the     155   // energies of cells were normalized to the energy of the particle)
157   for (int i = 0; i < size; ++i) {             << 156   for(int i = 0; i < size; ++i)
                                                   >> 157   {
158     aEnergies[i] = aEnergies[i] * aInitialEner    158     aEnergies[i] = aEnergies[i] * aInitialEnergy;
159   }                                               159   }
160 }                                                 160 }
161                                                   161 
162 //....oooOO0OOooo........oooOO0OOooo........oo    162 //....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
163                                                   163 
164 void Par04InferenceSetup::GetPositions(std::ve    164 void Par04InferenceSetup::GetPositions(std::vector<G4ThreeVector>& aPositions, G4ThreeVector pos0,
165                                        G4Three    165                                        G4ThreeVector direction)
166 {                                                 166 {
167   aPositions.resize(fMeshNumber.x() * fMeshNum    167   aPositions.resize(fMeshNumber.x() * fMeshNumber.y() * fMeshNumber.z());
168                                                   168 
169   // Calculate rotation matrix along the parti    169   // Calculate rotation matrix along the particle momentum direction
170   // It will rotate the shower axes to match t    170   // It will rotate the shower axes to match the incoming particle direction
171   G4RotationMatrix rotMatrix = G4RotationMatri    171   G4RotationMatrix rotMatrix = G4RotationMatrix();
172   double particleTheta = direction.theta();    << 172   double particleTheta       = direction.theta();
173   double particlePhi = direction.phi();        << 173   double particlePhi         = direction.phi();
174   rotMatrix.rotateZ(-particlePhi);                174   rotMatrix.rotateZ(-particlePhi);
175   rotMatrix.rotateY(-particleTheta);              175   rotMatrix.rotateY(-particleTheta);
176   G4RotationMatrix rotMatrixInv = CLHEP::inver    176   G4RotationMatrix rotMatrixInv = CLHEP::inverseOf(rotMatrix);
177                                                   177 
178   int cpt = 0;                                    178   int cpt = 0;
179   for (G4int iCellR = 0; iCellR < fMeshNumber. << 179   for(G4int iCellR = 0; iCellR < fMeshNumber.x(); iCellR++)
180     for (G4int iCellPhi = 0; iCellPhi < fMeshN << 180   {
181       for (G4int iCellZ = 0; iCellZ < fMeshNum << 181     for(G4int iCellPhi = 0; iCellPhi < fMeshNumber.y(); iCellPhi++)
                                                   >> 182     {
                                                   >> 183       for(G4int iCellZ = 0; iCellZ < fMeshNumber.z(); iCellZ++)
                                                   >> 184       {
182         aPositions[cpt] =                         185         aPositions[cpt] =
183           pos0                                 << 186           pos0 +
184           + rotMatrixInv                       << 187           rotMatrixInv *
185               * G4ThreeVector(                 << 188             G4ThreeVector((iCellR + 0.5) * fMeshSize.x() *
186                 (iCellR + 0.5) * fMeshSize.x() << 189                             std::cos((iCellPhi + 0.5) * 2 * CLHEP::pi / fMeshNumber.y() - CLHEP::pi),
187                   * std::cos((iCellPhi + 0.5)  << 190                           (iCellR + 0.5) * fMeshSize.x() *
188                 (iCellR + 0.5) * fMeshSize.x() << 191                             std::sin((iCellPhi + 0.5) * 2 * CLHEP::pi / fMeshNumber.y() - CLHEP::pi),
189                   * std::sin((iCellPhi + 0.5)  << 192                           (iCellZ + 0.5) * fMeshSize.z());
190                 (iCellZ + 0.5) * fMeshSize.z() << 
191         cpt++;                                    193         cpt++;
192       }                                           194       }
193     }                                             195     }
194   }                                               196   }
195 }                                                 197 }
196                                                   198 
197 #endif                                            199 #endif
198                                                   200