Loading...
Searching...
No Matches
Par04InferenceSetup.cc
Go to the documentation of this file.
1//
2// ********************************************************************
3// * License and Disclaimer *
4// * *
5// * The Geant4 software is copyright of the Copyright Holders of *
6// * the Geant4 Collaboration. It is provided under the terms and *
7// * conditions of the Geant4 Software License, included in the file *
8// * LICENSE and available at http://cern.ch/geant4/license . These *
9// * include a list of copyright holders. *
10// * *
11// * Neither the authors of this software system, nor their employing *
12// * institutes,nor the agencies providing financial support for this *
13// * work make any representation or warranty, express or implied, *
14// * regarding this software system or assume any liability for its *
15// * use. Please see the license in the file LICENSE and URL above *
16// * for the full disclaimer and the limitation of liability. *
17// * *
18// * This code implementation is the result of the scientific and *
19// * technical work of the GEANT4 collaboration. *
20// * By using, copying, modifying or distributing the software (or *
21// * any work based on the software) you agree to acknowledge its *
22// * use in resulting scientific publications, and indicate your *
23// * acceptance of all terms of the Geant4 Software license. *
24// ********************************************************************
25//
26#ifdef USE_INFERENCE
27#include "Par04InferenceInterface.hh" // for Par04InferenceInterface
28#include "Par04InferenceMessenger.hh" // for Par04InferenceMessenger
30#ifdef USE_INFERENCE_ONNX
31#include "Par04OnnxInference.hh" // for Par04OnnxInference
32#endif
33#ifdef USE_INFERENCE_LWTNN
34#include "Par04LwtnnInference.hh" // for Par04LwtnnInference
35#endif
36#ifdef USE_INFERENCE_TORCH
37#include "Par04TorchInference.hh" // for Par04TorchInference
38#endif
39#include "CLHEP/Random/RandGauss.h" // for RandGauss
40#include "G4RotationMatrix.hh" // for G4RotationMatrix
41#include <CLHEP/Units/SystemOfUnits.h> // for pi, GeV, deg
42#include <CLHEP/Vector/Rotation.h> // for HepRotation
43#include <CLHEP/Vector/ThreeVector.h> // for Hep3Vector
44#include <G4Exception.hh> // for G4Exception
45#include <G4ExceptionSeverity.hh> // for FatalException
46#include <G4ThreeVector.hh> // for G4ThreeVector
47#include <algorithm> // for max, copy
48#include <cmath> // for cos, sin
49#include <ext/alloc_traits.h> // for __alloc_traits<>::value_type
50#include <string> // for char_traits, basic_string
51
52//....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
53
54Par04InferenceSetup::Par04InferenceSetup()
55 : fInferenceMessenger(new Par04InferenceMessenger(this)) {}
56
57//....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
58
59Par04InferenceSetup::~Par04InferenceSetup() {}
60
61//....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
62
63G4bool Par04InferenceSetup::IfTrigger(G4double aEnergy)
64{
65 /// Energy of electrons used in training dataset
66 if (aEnergy > 1 * CLHEP::GeV || aEnergy < 1024 * CLHEP::GeV)
67 return true;
68 return false;
69}
70
71//....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
72
73void Par04InferenceSetup::SetInferenceLibrary(G4String aName)
74{
75 fInferenceLibrary = aName;
76
77#ifdef USE_INFERENCE_ONNX
78 if (fInferenceLibrary == "ONNX")
79 fInferenceInterface =
80 std::unique_ptr<Par04InferenceInterface>(new Par04OnnxInference(
81 fModelPathName, fProfileFlag, fOptimizationFlag, fIntraOpNumThreads,
82 fCudaFlag, cuda_keys, cuda_values, fModelSavePath,
83 fProfilingOutputSavePath));
84#endif
85#ifdef USE_INFERENCE_LWTNN
86 if (fInferenceLibrary == "LWTNN")
87 fInferenceInterface = std::unique_ptr<Par04InferenceInterface>(
88 new Par04LwtnnInference(fModelPathName));
89#endif
90#ifdef USE_INFERENCE_TORCH
91 if (fInferenceLibrary == "TORCH")
92 fInferenceInterface = std::unique_ptr<Par04InferenceInterface>(
93 new Par04TorchInference(fModelPathName));
94#endif
95
96 CheckInferenceLibrary();
97}
98
99//....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
100
101void Par04InferenceSetup::CheckInferenceLibrary()
102{
103 G4String msg = "Please choose inference library from available libraries (";
104#ifdef USE_INFERENCE_ONNX
105 msg += "ONNX,";
106#endif
107#ifdef USE_INFERENCE_LWTNN
108 msg += "LWTNN,";
109#endif
110#ifdef USE_INFERENCE_TORCH
111 msg += "TORCH";
112#endif
113 if (fInferenceInterface == nullptr)
114 G4Exception("Par04InferenceSetup::CheckInferenceLibrary()", "InvalidSetup",
115 FatalException,
116 (msg + "). Current name: " + fInferenceLibrary).c_str());
117}
118
119//....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
120
121void Par04InferenceSetup::GetEnergies(std::vector<G4double> &aEnergies,
122 G4double aInitialEnergy,
123 G4float aInitialAngle)
124{
125 // First check if inference library was set correctly
126 CheckInferenceLibrary();
127 // size represents the size of the output vector
128 int size = fMeshNumber.x() * fMeshNumber.y() * fMeshNumber.z();
129
130 // randomly sample from a gaussian distribution in the latent space
131 std::vector<G4float> genVector(fSizeLatentVector + fSizeConditionVector, 0);
132 for (int i = 0; i < fSizeLatentVector; ++i)
133 {
134 genVector[i] = CLHEP::RandGauss::shoot(0., 1.);
135 }
136
137 // Vector of condition
138 // this is application specific it depdens on what the model was condition on
139 // and it depends on how the condition values were encoded at the training
140 // time in this example the energy of each particle is normlaized to the
141 // highest energy in the considered range (1GeV-500GeV) the angle is also is
142 // normlaized to the highest angle in the considered range (0-90 in dergrees)
143 // the model in this example was trained on two detector geometries PBW04
144 // and SiW a one hot encoding vector is used to represent the geometry with
145 // [0,1] for PBW04 and [1,0] for SiW
146 // 1. energy
147 genVector[fSizeLatentVector] = aInitialEnergy / fMaxEnergy;
148 // 2. angle
149 genVector[fSizeLatentVector + 1] = (aInitialAngle / (CLHEP::deg)) / fMaxAngle;
150 // 3. geometry
151 genVector[fSizeLatentVector + 2] = 0;
152 genVector[fSizeLatentVector + 3] = 1;
153
154 // Run the inference
155 fInferenceInterface->RunInference(genVector, aEnergies, size);
156
157 // After the inference rescale back to the initial energy (in this example the
158 // energies of cells were normalized to the energy of the particle)
159 for (int i = 0; i < size; ++i)
160 {
161 aEnergies[i] = aEnergies[i] * aInitialEnergy;
162 }
163}
164
165//....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
166
167void Par04InferenceSetup::GetPositions(std::vector<G4ThreeVector> &aPositions,
168 G4ThreeVector pos0,
169 G4ThreeVector direction)
170{
171 aPositions.resize(fMeshNumber.x() * fMeshNumber.y() * fMeshNumber.z());
172
173 // Calculate rotation matrix along the particle momentum direction
174 // It will rotate the shower axes to match the incoming particle direction
175 G4RotationMatrix rotMatrix = G4RotationMatrix();
176 double particleTheta = direction.theta();
177 double particlePhi = direction.phi();
178 rotMatrix.rotateZ(-particlePhi);
179 rotMatrix.rotateY(-particleTheta);
180 G4RotationMatrix rotMatrixInv = CLHEP::inverseOf(rotMatrix);
181
182 int cpt = 0;
183 for (G4int iCellR = 0; iCellR < fMeshNumber.x(); iCellR++)
184 {
185 for (G4int iCellPhi = 0; iCellPhi < fMeshNumber.y(); iCellPhi++)
186 {
187 for (G4int iCellZ = 0; iCellZ < fMeshNumber.z(); iCellZ++)
188 {
189 aPositions[cpt] =
190 pos0 +
191 rotMatrixInv *
192 G4ThreeVector((iCellR + 0.5) * fMeshSize.x() *
193 std::cos((iCellPhi + 0.5) * 2 * CLHEP::pi /
194 fMeshNumber.y() -
195 CLHEP::pi),
196 (iCellR + 0.5) * fMeshSize.x() *
197 std::sin((iCellPhi + 0.5) * 2 * CLHEP::pi /
198 fMeshNumber.y() -
199 CLHEP::pi),
200 (iCellZ + 0.5) * fMeshSize.z());
201 cpt++;
202 }
203 }
204 }
205}
206
207#endif

Applications | User Support | Publications | Collaboration