26#ifdef USE_INFERENCE_ONNX
31#include <core/session/onnxruntime_cxx_api.h>
36#include "cuda_runtime_api.h"
41Par04OnnxInference::Par04OnnxInference(
G4String modelPath, G4int profileFlag,
43 G4int intraOpNumThreads, G4int cudaFlag,
44 std::vector<const char *> &cuda_keys,
45 std::vector<const char *> &cuda_values,
51 auto envLocal = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING,
"ENV");
52 fEnv = std::move(envLocal);
55 const auto &ortApi = Ort::GetApi();
56 fSessionOptions.SetIntraOpNumThreads(intraOpNumThreads);
61 fSessionOptions.SetOptimizedModelFilePath(
"opt-graph");
62 fSessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
65 fSessionOptions.SetGraphOptimizationLevel(ORT_DISABLE_ALL);
68 OrtCUDAProviderOptionsV2 *fCudaOptions =
nullptr;
71 (void)ortApi.CreateCUDAProviderOptions(&fCudaOptions);
73 (void)ortApi.UpdateCUDAProviderOptions(
74 fCudaOptions, cuda_keys.data(), cuda_values.data(), cuda_keys.size());
77 (void)ortApi.SessionOptionsAppendExecutionProvider_CUDA_V2(fSessionOptions,
83 fSessionOptions.EnableProfiling(
"opt.json");
86 std::make_unique<Ort::Session>(*fEnv, modelPath, fSessionOptions);
87 fSession = std::move(sessionLocal);
88 fInfo = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator,
94void Par04OnnxInference::RunInference(std::vector<float> aGenVector,
95 std::vector<G4double> &aEnergies,
98 Ort::AllocatorWithDefaultOptions allocator;
99#if ORT_API_VERSION < 13
101 auto allocDeleter = [&allocator](
char *p) { allocator.Free(p); };
102 using AllocatedStringPtr = std::unique_ptr<char,
decltype(allocDeleter)>;
104 std::vector<int64_t> input_node_dims;
105 size_t num_input_nodes = fSession->GetInputCount();
106 std::vector<const char *> input_node_names(num_input_nodes);
107 for (std::size_t i = 0; i < num_input_nodes; i++) {
108#if ORT_API_VERSION < 13
109 const auto input_name =
110 AllocatedStringPtr(fSession->GetInputName(i, allocator), allocDeleter)
113 const auto input_name =
114 fSession->GetInputNameAllocated(i, allocator).release();
116 fInames = {input_name};
117 input_node_names[i] = input_name;
118 Ort::TypeInfo type_info = fSession->GetInputTypeInfo(i);
119 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
120 input_node_dims = tensor_info.GetShape();
121 for (std::size_t j = 0; j < input_node_dims.size(); j++) {
122 if (input_node_dims[j] < 0)
123 input_node_dims[j] = 1;
127 std::vector<int64_t> output_node_dims;
128 size_t num_output_nodes = fSession->GetOutputCount();
129 std::vector<const char *> output_node_names(num_output_nodes);
130 for (std::size_t i = 0; i < num_output_nodes; i++) {
131#if ORT_API_VERSION < 13
132 const auto output_name =
133 AllocatedStringPtr(fSession->GetOutputName(i, allocator), allocDeleter)
136 const auto output_name =
137 fSession->GetOutputNameAllocated(i, allocator).release();
139 output_node_names[i] = output_name;
140 Ort::TypeInfo type_info = fSession->GetOutputTypeInfo(i);
141 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
142 output_node_dims = tensor_info.GetShape();
143 for (std::size_t j = 0; j < output_node_dims.size(); j++) {
144 if (output_node_dims[j] < 0)
145 output_node_dims[j] = 1;
150 std::vector<int64_t> dims = {1, (unsigned)(aGenVector.size())};
151 Ort::Value Input_noise_tensor = Ort::Value::CreateTensor<float>(
152 fInfo, aGenVector.data(), aGenVector.size(), dims.data(), dims.size());
153 assert(Input_noise_tensor.IsTensor());
154 std::vector<Ort::Value> ort_inputs;
155 ort_inputs.push_back(std::move(Input_noise_tensor));
157 std::vector<Ort::Value> ort_outputs = fSession->Run(
158 Ort::RunOptions{
nullptr}, fInames.data(), ort_inputs.data(),
159 ort_inputs.size(), output_node_names.data(), output_node_names.size());
161 float *floatarr = ort_outputs.front().GetTensorMutableData<
float>();
162 aEnergies.assign(aSize, 0);
163 for (
int i = 0; i < aSize; ++i)
164 aEnergies[i] = floatarr[i];