30#ifdef USE_INFERENCE_ONNX
33#ifdef USE_INFERENCE_LWTNN
36#ifdef USE_INFERENCE_TORCH
39#include "CLHEP/Random/RandGauss.h"
40#include "G4RotationMatrix.hh"
41#include <CLHEP/Units/SystemOfUnits.h>
42#include <CLHEP/Vector/Rotation.h>
43#include <CLHEP/Vector/ThreeVector.h>
44#include <G4Exception.hh>
45#include <G4ExceptionSeverity.hh>
46#include <G4ThreeVector.hh>
49#include <ext/alloc_traits.h>
54Par04InferenceSetup::Par04InferenceSetup()
59Par04InferenceSetup::~Par04InferenceSetup() {}
63G4bool Par04InferenceSetup::IfTrigger(G4double aEnergy)
66 if (aEnergy > 1 * CLHEP::GeV || aEnergy < 1024 * CLHEP::GeV)
73void Par04InferenceSetup::SetInferenceLibrary(
G4String aName)
75 fInferenceLibrary = aName;
77#ifdef USE_INFERENCE_ONNX
78 if (fInferenceLibrary ==
"ONNX")
81 fModelPathName, fProfileFlag, fOptimizationFlag, fIntraOpNumThreads,
82 fCudaFlag, cuda_keys, cuda_values, fModelSavePath,
83 fProfilingOutputSavePath));
85#ifdef USE_INFERENCE_LWTNN
86 if (fInferenceLibrary ==
"LWTNN")
87 fInferenceInterface = std::unique_ptr<Par04InferenceInterface>(
90#ifdef USE_INFERENCE_TORCH
91 if (fInferenceLibrary ==
"TORCH")
92 fInferenceInterface = std::unique_ptr<Par04InferenceInterface>(
93 new Par04TorchInference(fModelPathName));
96 CheckInferenceLibrary();
101void Par04InferenceSetup::CheckInferenceLibrary()
103 G4String msg =
"Please choose inference library from available libraries (";
104#ifdef USE_INFERENCE_ONNX
107#ifdef USE_INFERENCE_LWTNN
110#ifdef USE_INFERENCE_TORCH
113 if (fInferenceInterface ==
nullptr)
114 G4Exception(
"Par04InferenceSetup::CheckInferenceLibrary()",
"InvalidSetup",
116 (msg +
"). Current name: " + fInferenceLibrary).c_str());
121void Par04InferenceSetup::GetEnergies(std::vector<G4double> &aEnergies,
122 G4double aInitialEnergy,
123 G4float aInitialAngle)
126 CheckInferenceLibrary();
128 int size = fMeshNumber.x() * fMeshNumber.y() * fMeshNumber.z();
131 std::vector<G4float> genVector(fSizeLatentVector + fSizeConditionVector, 0);
132 for (
int i = 0; i < fSizeLatentVector; ++i)
134 genVector[i] = CLHEP::RandGauss::shoot(0., 1.);
147 genVector[fSizeLatentVector] = aInitialEnergy / fMaxEnergy;
149 genVector[fSizeLatentVector + 1] = (aInitialAngle / (CLHEP::deg)) / fMaxAngle;
151 genVector[fSizeLatentVector + 2] = 0;
152 genVector[fSizeLatentVector + 3] = 1;
155 fInferenceInterface->RunInference(genVector, aEnergies, size);
159 for (
int i = 0; i < size; ++i)
161 aEnergies[i] = aEnergies[i] * aInitialEnergy;
167void Par04InferenceSetup::GetPositions(std::vector<G4ThreeVector> &aPositions,
169 G4ThreeVector direction)
171 aPositions.resize(fMeshNumber.x() * fMeshNumber.y() * fMeshNumber.z());
175 G4RotationMatrix rotMatrix = G4RotationMatrix();
176 double particleTheta = direction.theta();
177 double particlePhi = direction.phi();
178 rotMatrix.rotateZ(-particlePhi);
179 rotMatrix.rotateY(-particleTheta);
180 G4RotationMatrix rotMatrixInv = CLHEP::inverseOf(rotMatrix);
183 for (G4int iCellR = 0; iCellR < fMeshNumber.x(); iCellR++)
185 for (G4int iCellPhi = 0; iCellPhi < fMeshNumber.y(); iCellPhi++)
187 for (G4int iCellZ = 0; iCellZ < fMeshNumber.z(); iCellZ++)
192 G4ThreeVector((iCellR + 0.5) * fMeshSize.x() *
193 std::cos((iCellPhi + 0.5) * 2 * CLHEP::pi /
196 (iCellR + 0.5) * fMeshSize.x() *
197 std::sin((iCellPhi + 0.5) * 2 * CLHEP::pi /
200 (iCellZ + 0.5) * fMeshSize.z());