Loading...
Searching...
No Matches
Par04LwtnnInference.cc
Go to the documentation of this file.
1//
2// ********************************************************************
3// * License and Disclaimer *
4// * *
5// * The Geant4 software is copyright of the Copyright Holders of *
6// * the Geant4 Collaboration. It is provided under the terms and *
7// * conditions of the Geant4 Software License, included in the file *
8// * LICENSE and available at http://cern.ch/geant4/license . These *
9// * include a list of copyright holders. *
10// * *
11// * Neither the authors of this software system, nor their employing *
12// * institutes,nor the agencies providing financial support for this *
13// * work make any representation or warranty, express or implied, *
14// * regarding this software system or assume any liability for its *
15// * use. Please see the license in the file LICENSE and URL above *
16// * for the full disclaimer and the limitation of liability. *
17// * *
18// * This code implementation is the result of the scientific and *
19// * technical work of the GEANT4 collaboration. *
20// * By using, copying, modifying or distributing the software (or *
21// * any work based on the software) you agree to acknowledge its *
22// * use in resulting scientific publications, and indicate your *
23// * acceptance of all terms of the Geant4 Software license. *
24// ********************************************************************
25//
26#ifdef USE_INFERENCE_LWTNN
28#include <fstream> // for ifstream
29#include <lwtnn/parse_json.hh> // for parse_json_graph
30#include "Par04InferenceInterface.hh" // for Par04InferenceInterface
31
32//....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
33
34Par04LwtnnInference::Par04LwtnnInference(G4String modelPath)
36{
37 // file to read
38 std::ifstream input(modelPath);
39 // build the graph
40 fGraph = std::make_unique<lwt::LightweightGraph>(lwt::parse_json_graph(input));
41 input.close();
42}
43
44//....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
45
46void Par04LwtnnInference::RunInference(std::vector<float> aGenVector,
47 std::vector<G4double>& aEnergies,
48 int aSize)
49{
50 // generation vector
51 fNetworkInputs inputs;
52 for(std::size_t i = 0; i < aGenVector.size(); ++i)
53 {
54 inputs["node_0"]["variable_" + std::to_string(i)] = aGenVector[i];
55 }
56
57 // run the inference
58 fNetworkOutputs outputs = fGraph->compute(inputs);
59 aEnergies.assign(aSize, 0);
60 for(int i = 0; i < aSize; i++)
61 aEnergies[i] = outputs["out_" + std::to_string(i)];
62}
63
64#endif

Applications | User Support | Publications | Collaboration