Loading...
Searching...
No Matches
Par04OnnxInference.hh
Go to the documentation of this file.
1//
2// ********************************************************************
3// * License and Disclaimer *
4// * *
5// * The Geant4 software is copyright of the Copyright Holders of *
6// * the Geant4 Collaboration. It is provided under the terms and *
7// * conditions of the Geant4 Software License, included in the file *
8// * LICENSE and available at http://cern.ch/geant4/license . These *
9// * include a list of copyright holders. *
10// * *
11// * Neither the authors of this software system, nor their employing *
12// * institutes,nor the agencies providing financial support for this *
13// * work make any representation or warranty, express or implied, *
14// * regarding this software system or assume any liability for its *
15// * use. Please see the license in the file LICENSE and URL above *
16// * for the full disclaimer and the limitation of liability. *
17// * *
18// * This code implementation is the result of the scientific and *
19// * technical work of the GEANT4 collaboration. *
20// * By using, copying, modifying or distributing the software (or *
21// * any work based on the software) you agree to acknowledge its *
22// * use in resulting scientific publications, and indicate your *
23// * acceptance of all terms of the Geant4 Software license. *
24// ********************************************************************
25//
26
27#ifdef USE_INFERENCE_ONNX
28#ifndef PAR04ONNXINFERENCE_HH
29#define PAR04ONNXINFERENCE_HH
30#include <core/session/onnxruntime_c_api.h> // for OrtMemoryInfo
31#include <G4String.hh> // for G4String
32#include <G4Types.hh> // for G4int, G4double
33#include <memory> // for unique_ptr
34#include <vector> // for vector
35#include "Par04InferenceInterface.hh" // for Par04InferenceInterface
36#include "core/session/onnxruntime_cxx_api.h" // for Env, Session, SessionO...
37
38/**
39 * @brief Inference using the ONNX runtime.
40 *
41 * Creates an enviroment whcih manages an internal thread pool and creates an
42 * inference session for the model saved as an ONNX file.
43 * Runs the inference in the session using the input vector from Par04InferenceSetup.
44 *
45 **/
46
48{
49 public:
50 Par04OnnxInference(G4String, G4int, G4int, G4int,
51 G4int, // For Execution Provider Runtime Flags (for now only CUDA)
52 std::vector<const char *> &cuda_keys,
53 std::vector<const char *> &cuda_values,
55
57
58 /// Run inference
59 /// @param[in] aGenVector Input latent space and conditions
60 /// @param[out] aEnergies Model output = generated shower energies
61 /// @param[in] aSize Size of the output
62 void RunInference(std::vector<float> aGenVector, std::vector<G4double>& aEnergies, int aSize);
63
64 private:
65 /// Pointer to the ONNX enviroment
66 std::unique_ptr<Ort::Env> fEnv;
67 /// Pointer to the ONNX inference session
68 std::unique_ptr<Ort::Session> fSession;
69 /// ONNX settings
70 Ort::SessionOptions fSessionOptions;
71 /// ONNX memory info
72 const OrtMemoryInfo* fInfo;
73 struct MemoryInfo;
74 /// the input names represent the names given to the model
75 /// when defining the model's architecture (if applicable)
76 /// they can also be retrieved from model.summary()
77 std::vector<const char*> fInames;
78};
79
80#endif /* PAR04ONNXINFERENCE_HH */
81#endif

Applications | User Support | Publications | Collaboration