27#ifdef USE_INFERENCE_TORCH
35#include <torch/torch.h>
39Par04TorchInference::Par04TorchInference(
G4String modelPath)
42 fModule = torch::jit::load( modelPath );
47void Par04TorchInference::RunInference(std::vector<float> aGenVector,
48 std::vector<G4double>& aEnergies,
53 int latentSize = aGenVector.size() - 4;
55 std::vector<float> latent;
56 for (
int i=0;i<latentSize;i++) {
57 latent.push_back(aGenVector[i]);
59 std::vector<float> energy;
60 energy.push_back(aGenVector[latentSize+1]);
61 std::vector<float> angle;
62 energy.push_back(aGenVector[latentSize+2]);
63 std::vector<float> geo;
64 for (
int i=latentSize+2;i<latentSize+4;i++) {
65 geo.push_back(aGenVector[i]);
69 torch::Tensor latentVector = torch::tensor(latent);
70 torch::Tensor eTensor = torch::tensor(energy);
71 torch::Tensor angleTensor = torch::tensor(angle);
72 torch::Tensor geoTensor = torch::tensor(geo);
74 std::vector<torch::jit::IValue> genInput;
76 genInput.push_back( latentVector );
77 genInput.push_back( eTensor );
78 genInput.push_back( angleTensor );
79 genInput.push_back( geoTensor );
81 at::Tensor outTensor = fModule.forward( genInput).toTensor().contiguous();
83 std::vector<G4double> output( outTensor.data_ptr<
float>(),
84 outTensor.data_ptr<
float>() + outTensor.numel() );
86 aEnergies.assign(aSize, 0);
87 for(
int i = 0; i < aSize; i++) {
88 aEnergies[i] = output[i];