Loading...
Searching...
No Matches
Par04OnnxInference.cc
Go to the documentation of this file.
1//
2// ********************************************************************
3// * License and Disclaimer *
4// * *
5// * The Geant4 software is copyright of the Copyright Holders of *
6// * the Geant4 Collaboration. It is provided under the terms and *
7// * conditions of the Geant4 Software License, included in the file *
8// * LICENSE and available at http://cern.ch/geant4/license . These *
9// * include a list of copyright holders. *
10// * *
11// * Neither the authors of this software system, nor their employing *
12// * institutes,nor the agencies providing financial support for this *
13// * work make any representation or warranty, express or implied, *
14// * regarding this software system or assume any liability for its *
15// * use. Please see the license in the file LICENSE and URL above *
16// * for the full disclaimer and the limitation of liability. *
17// * *
18// * This code implementation is the result of the scientific and *
19// * technical work of the GEANT4 collaboration. *
20// * By using, copying, modifying or distributing the software (or *
21// * any work based on the software) you agree to acknowledge its *
22// * use in resulting scientific publications, and indicate your *
23// * acceptance of all terms of the Geant4 Software license. *
24// ********************************************************************
25//
26#ifdef USE_INFERENCE_ONNX
27#include "Par04InferenceInterface.hh" // for Par04InferenceInterface
28#include "Par04OnnxInference.hh"
29#include <algorithm> // for copy, max
30#include <cassert> // for assert
31#include <core/session/onnxruntime_cxx_api.h> // for Value, Session, Env
32#include <cstddef> // for size_t
33#include <cstdint> // for int64_t
34#include <utility> // for move
35#ifdef USE_CUDA
36#include "cuda_runtime_api.h"
37#endif
38
39//....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
40
41Par04OnnxInference::Par04OnnxInference(G4String modelPath, G4int profileFlag,
42 G4int optimizeFlag,
43 G4int intraOpNumThreads, G4int cudaFlag,
44 std::vector<const char *> &cuda_keys,
45 std::vector<const char *> &cuda_values,
46 G4String ModelSavePath,
47 G4String profilingOutputSavePath)
48
50 // initialization of the enviroment and inference session
51 auto envLocal = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "ENV");
52 fEnv = std::move(envLocal);
53 // Creating a OrtApi Class variable for getting access to C api, necessary for
54 // CUDA
55 const auto &ortApi = Ort::GetApi();
56 fSessionOptions.SetIntraOpNumThreads(intraOpNumThreads);
57 // graph optimizations of the model
58 // if the flag is not set to true none of the optimizations will be applied
59 // if it is set to true all the optimizations will be applied
60 if (optimizeFlag) {
61 fSessionOptions.SetOptimizedModelFilePath("opt-graph");
62 fSessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
63 // ORT_ENABLE_BASIC #### ORT_ENABLE_EXTENDED
64 } else
65 fSessionOptions.SetGraphOptimizationLevel(ORT_DISABLE_ALL);
66#ifdef USE_CUDA
67 if (cudaFlag) {
68 OrtCUDAProviderOptionsV2 *fCudaOptions = nullptr;
69 // Initialize the CUDA provider options, fCudaOptions should now point to a
70 // valid CUDA configuration.
71 (void)ortApi.CreateCUDAProviderOptions(&fCudaOptions);
72 // Update the CUDA provider options
73 (void)ortApi.UpdateCUDAProviderOptions(
74 fCudaOptions, cuda_keys.data(), cuda_values.data(), cuda_keys.size());
75 // Append the CUDA execution provider to the session options, indicating to
76 // use CUDA for execution
77 (void)ortApi.SessionOptionsAppendExecutionProvider_CUDA_V2(fSessionOptions,
78 fCudaOptions);
79 }
80#endif
81 // save json file for model execution profiling
82 if (profileFlag)
83 fSessionOptions.EnableProfiling("opt.json");
84
85 auto sessionLocal =
86 std::make_unique<Ort::Session>(*fEnv, modelPath, fSessionOptions);
87 fSession = std::move(sessionLocal);
88 fInfo = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator,
89 OrtMemTypeDefault);
90}
91
92//....oooOO0OOooo........oooOO0OOooo........oooOO0OOooo........oooOO0OOooo......
93
94void Par04OnnxInference::RunInference(std::vector<float> aGenVector,
95 std::vector<G4double> &aEnergies,
96 int aSize) {
97 // input nodes
98 Ort::AllocatorWithDefaultOptions allocator;
99#if ORT_API_VERSION < 13
100 // Before 1.13 we have to roll our own unique_ptr wrapper here
101 auto allocDeleter = [&allocator](char *p) { allocator.Free(p); };
102 using AllocatedStringPtr = std::unique_ptr<char, decltype(allocDeleter)>;
103#endif
104 std::vector<int64_t> input_node_dims;
105 size_t num_input_nodes = fSession->GetInputCount();
106 std::vector<const char *> input_node_names(num_input_nodes);
107 for (std::size_t i = 0; i < num_input_nodes; i++) {
108#if ORT_API_VERSION < 13
109 const auto input_name =
110 AllocatedStringPtr(fSession->GetInputName(i, allocator), allocDeleter)
111 .release();
112#else
113 const auto input_name =
114 fSession->GetInputNameAllocated(i, allocator).release();
115#endif
116 fInames = {input_name};
117 input_node_names[i] = input_name;
118 Ort::TypeInfo type_info = fSession->GetInputTypeInfo(i);
119 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
120 input_node_dims = tensor_info.GetShape();
121 for (std::size_t j = 0; j < input_node_dims.size(); j++) {
122 if (input_node_dims[j] < 0)
123 input_node_dims[j] = 1;
124 }
125 }
126 // output nodes
127 std::vector<int64_t> output_node_dims;
128 size_t num_output_nodes = fSession->GetOutputCount();
129 std::vector<const char *> output_node_names(num_output_nodes);
130 for (std::size_t i = 0; i < num_output_nodes; i++) {
131#if ORT_API_VERSION < 13
132 const auto output_name =
133 AllocatedStringPtr(fSession->GetOutputName(i, allocator), allocDeleter)
134 .release();
135#else
136 const auto output_name =
137 fSession->GetOutputNameAllocated(i, allocator).release();
138#endif
139 output_node_names[i] = output_name;
140 Ort::TypeInfo type_info = fSession->GetOutputTypeInfo(i);
141 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
142 output_node_dims = tensor_info.GetShape();
143 for (std::size_t j = 0; j < output_node_dims.size(); j++) {
144 if (output_node_dims[j] < 0)
145 output_node_dims[j] = 1;
146 }
147 }
148
149 // create input tensor object from data values
150 std::vector<int64_t> dims = {1, (unsigned)(aGenVector.size())};
151 Ort::Value Input_noise_tensor = Ort::Value::CreateTensor<float>(
152 fInfo, aGenVector.data(), aGenVector.size(), dims.data(), dims.size());
153 assert(Input_noise_tensor.IsTensor());
154 std::vector<Ort::Value> ort_inputs;
155 ort_inputs.push_back(std::move(Input_noise_tensor));
156 // run the inference session
157 std::vector<Ort::Value> ort_outputs = fSession->Run(
158 Ort::RunOptions{nullptr}, fInames.data(), ort_inputs.data(),
159 ort_inputs.size(), output_node_names.data(), output_node_names.size());
160 // get pointer to output tensor float values
161 float *floatarr = ort_outputs.front().GetTensorMutableData<float>();
162 aEnergies.assign(aSize, 0);
163 for (int i = 0; i < aSize; ++i)
164 aEnergies[i] = floatarr[i];
165}
166
167#endif

Applications | User Support | Publications | Collaboration