Loading...
Searching...
No Matches
Par04TorchInference.cc
Go to the documentation of this file.
1//
2// ********************************************************************
3// * License and Disclaimer *
4// * *
5// * The Geant4 software is copyright of the Copyright Holders of *
6// * the Geant4 Collaboration. It is provided under the terms and *
7// * conditions of the Geant4 Software License, included in the file *
8// * LICENSE and available at http://cern.ch/geant4/license . These *
9// * include a list of copyright holders. *
10// * *
11// * Neither the authors of this software system, nor their employing *
12// * institutes,nor the agencies providing financial support for this *
13// * work make any representation or warranty, express or implied, *
14// * regarding this software system or assume any liability for its *
15// * use. Please see the license in the file LICENSE and URL above *
16// * for the full disclaimer and the limitation of liability. *
17// * *
18// * This code implementation is the result of the scientific and *
19// * technical work of the GEANT4 collaboration. *
20// * By using, copying, modifying or distributing the software (or *
21// * any work based on the software) you agree to acknowledge its *
22// * use in resulting scientific publications, and indicate your *
23// * acceptance of all terms of the Geant4 Software license. *
24// ********************************************************************
25//
26
27#ifdef USE_INFERENCE_TORCH
29#include <algorithm> // for copy, max
30#include <cassert> // for assert
31#include <cstddef> // for size_t
32#include <cstdint> // for int64_t
33#include <utility> // for move
34#include "Par04InferenceInterface.hh" // for Par04InferenceInterface
35#include <torch/torch.h>
36
37//....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
38
39Par04TorchInference::Par04TorchInference(G4String modelPath)
41{
42 fModule = torch::jit::load( modelPath );
43}
44
45//....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
46
47void Par04TorchInference::RunInference(std::vector<float> aGenVector,
48 std::vector<G4double>& aEnergies,
49 int aSize)
50{
51 // latentSize : size of the latent space
52 // 4 is the size of the condition vector
53 int latentSize = aGenVector.size() - 4;
54 // split into latent and condition vectors
55 std::vector<float> latent;
56 for ( int i=0;i<latentSize;i++) {
57 latent.push_back(aGenVector[i]);
58 }
59 std::vector<float> energy;
60 energy.push_back(aGenVector[latentSize+1]);
61 std::vector<float> angle;
62 energy.push_back(aGenVector[latentSize+2]);
63 std::vector<float> geo;
64 for ( int i=latentSize+2;i<latentSize+4;i++) {
65 geo.push_back(aGenVector[i]);
66 }
67
68 // convert vectors to tensors
69 torch::Tensor latentVector = torch::tensor(latent);
70 torch::Tensor eTensor = torch::tensor(energy);
71 torch::Tensor angleTensor = torch::tensor(angle);
72 torch::Tensor geoTensor = torch::tensor(geo);
73
74 std::vector<torch::jit::IValue> genInput;
75
76 genInput.push_back( latentVector );
77 genInput.push_back( eTensor );
78 genInput.push_back( angleTensor );
79 genInput.push_back( geoTensor );
80
81 at::Tensor outTensor = fModule.forward( genInput).toTensor().contiguous();
82
83 std::vector<G4double> output( outTensor.data_ptr<float>(),
84 outTensor.data_ptr<float>() + outTensor.numel() );
85
86 aEnergies.assign(aSize, 0);
87 for(int i = 0; i < aSize; i++) {
88 aEnergies[i] = output[i];
89 }
90}
91
92#endif

Applications | User Support | Publications | Collaboration