diff --git a/src/ort_genai.h b/src/ort_genai.h index d0c1d0c75..4a83b69e2 100644 --- a/src/ort_genai.h +++ b/src/ort_genai.h @@ -232,6 +232,12 @@ struct OgaGenerator : OgaAbstract { return OgaGenerator_GetSequenceData(this, index); } + std::unique_ptr GetOutput(const char* name) { + OgaTensor* out; + OgaCheckResult(OgaGenerator_GetOutput(this, name, &out)); + return std::unique_ptr(out); + } + #if __cplusplus >= 202002L std::span GetSequence(size_t index) const { return {GetSequenceData(index), GetSequenceCount(index)}; diff --git a/src/ort_genai_c.cpp b/src/ort_genai_c.cpp index 6f26d2857..a40807845 100644 --- a/src/ort_genai_c.cpp +++ b/src/ort_genai_c.cpp @@ -208,6 +208,50 @@ OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken(OgaGenerator* generator) OGA_CATCH } +OgaResult* OGA_API_CALL OgaGenerator_GetOutput(const OgaGenerator* oga_generator, const char* name, OgaTensor** out) { + OGA_TRY + auto& generator = *reinterpret_cast(oga_generator); + auto* ortvalue_output = generator.state_->GetOutput(name); + auto type_info = ortvalue_output->GetTensorTypeAndShapeInfo(); + std::unique_ptr ortvalue_clone = OrtValue::CreateTensor(generator.model_->allocator_cpu_, + type_info->GetShape(), + type_info->GetElementType()); + // Copy data to ortvalue_clone + auto element_size = Generators::SizeOf(type_info->GetElementType()); + auto data_size = type_info->GetElementCount() * element_size; + if (ortvalue_output->GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_GPU && generator.model_->device_type_ == Generators::DeviceType::CUDA) { +#if USE_CUDA + cudaMemcpy(ortvalue_clone->GetTensorMutableRawData(), ortvalue_output->GetTensorMutableRawData(), data_size, cudaMemcpyDeviceToHost); +#endif + } else if (ortvalue_output->GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_GPU && generator.model_->device_type_ == Generators::DeviceType::DML) { +#if USE_DML + ComPtr gpu_resource; + Ort::ThrowOnError(generator.model_->GetOrtDmlApi()->GetD3D12ResourceFromAllocation( + generator.model_->allocator_device_, + ortvalue_output->GetTensorMutableRawData(), + &gpu_resource)); + auto cpu_tensor = ortvalue_clone->GetTensorMutableRawData(); + generator.model_->GetDmlReadbackHeap()->ReadbackFromGpu( + std::span(reinterpret_cast(cpu_tensor), data_size), + gpu_resource.Get(), + 0, + D3D12_RESOURCE_STATE_UNORDERED_ACCESS); +#endif + } else if (ortvalue_output->GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_CPU) { + std::copy(static_cast(ortvalue_output->GetTensorMutableRawData()), + static_cast(ortvalue_output->GetTensorMutableRawData()) + data_size, + static_cast(ortvalue_clone->GetTensorMutableRawData())); + } else { + throw std::runtime_error("Unsupported Device type: " + ortvalue_output->GetTensorMemoryInfo().GetDeviceType()); + } + + auto tensor = std::make_shared(std::move(ortvalue_clone)); + tensor->external_owner_ = tensor; + *out = reinterpret_cast(tensor.get()); + return nullptr; + OGA_CATCH +} + size_t OGA_API_CALL OgaGenerator_GetSequenceCount(const OgaGenerator* oga_generator, size_t index) { auto& generator = *reinterpret_cast(oga_generator); return generator.GetSequence(static_cast(index)).GetCPU().size(); diff --git a/src/ort_genai_c.h b/src/ort_genai_c.h index ec97ce4e5..7b1f084c2 100644 --- a/src/ort_genai_c.h +++ b/src/ort_genai_c.h @@ -224,6 +224,14 @@ OGA_EXPORT bool OGA_API_CALL OgaGenerator_IsDone(const OgaGenerator* generator); OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_ComputeLogits(OgaGenerator* generator); OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken(OgaGenerator* generator); +/* + * \brief Returns a copy of the model output identified by the given name as an OgaTensor on CPU. The buffer is owned by returned OgaTensor + * and will be released when the OgaTensor is destroyed + * \param[in] generator The generator to run the GetOutput on the name provided and the out pointer to store the output + * \return OgaResult containing the error message if the computation failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GetOutput(const OgaGenerator* oga_generator, const char* name, OgaTensor** out); + /* * \brief Returns the number of tokens in the sequence at the given index. * \param[in] generator The generator to get the count of the tokens for the sequence at the given index. diff --git a/test/c_api_tests.cpp b/test/c_api_tests.cpp index eba5aff15..8e8cc13cb 100644 --- a/test/c_api_tests.cpp +++ b/test/c_api_tests.cpp @@ -173,6 +173,65 @@ TEST(CAPITests, GreedySearchGptFp32CAPI) { } #endif +TEST(CAPITests, GetOutputCAPI) { + std::vector input_ids_shape{2, 4}; + std::vector input_ids{0, 0, 0, 52, 0, 0, 195, 731}; + + auto input_sequence_length = input_ids_shape[1]; + auto batch_size = input_ids_shape[0]; + int max_length = 10; + + // To generate this file: + // python convert_generation.py --model_type gpt2 -m hf-internal-testing/tiny-random-gpt2 --output tiny_gpt2_greedysearch_fp16.onnx --use_gpu --max_length 20 + // And copy the resulting gpt2_init_past_fp32.onnx file into these two files (as it's the same for gpt2) + + auto model = OgaModel::Create(MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32"); + + auto params = OgaGeneratorParams::Create(*model); + params->SetSearchOption("max_length", max_length); + params->SetInputIDs(input_ids.data(), input_ids.size(), input_sequence_length, batch_size); + + auto generator = OgaGenerator::Create(*model, *params); + + // check prompt + // full logits has shape [2, 4, 1000]. Sample 1 for every 200 tokens and the expected sampled logits has shape [2, 4, 5] + std::vector expected_sampled_logits_prompt{0.29694548f, 0.00955007f, 0.0430819f, 0.10063869f, 0.0437237f, + 0.27329233f, 0.00841076f, -0.1060291f, 0.11328877f, 0.13369876f, + 0.30323744f, 0.0545997f, 0.03894716f, 0.11702324f, 0.0410665f, + -0.12675379f, -0.04443946f, 0.14492269f, 0.03021223f, -0.03212897f, + 0.29694548f, 0.00955007f, 0.0430819f, 0.10063869f, 0.0437237f, + 0.27329233f, 0.00841076f, -0.1060291f, 0.11328877f, 0.13369876f, + -0.04699047f, 0.17915794f, 0.20838135f, 0.10888482f, -0.00277808f, + 0.2938929f, -0.10538938f, -0.00226692f, 0.12050669f, -0.10622668f}; + + generator->ComputeLogits(); + auto prompt_logits_ptr = generator->GetOutput("logits"); + auto prompt_logits = static_cast(prompt_logits_ptr->Data()); + int num_prompt_outputs_to_check = 40; + int sample_size = 200; + float tolerance = 0.001f; + // Verify outputs match expected outputs + for (int i = 0; i < num_prompt_outputs_to_check; i++) { + EXPECT_NEAR(expected_sampled_logits_prompt[i], prompt_logits[i*sample_size], tolerance); + } + + generator->GenerateNextToken(); + // check for the 1st token generation + // full logits has shape [2, 1, 1000]. Sample 1 for every 200 tokens and the expected sampled logits has shape [2, 1, 5] + std::vector expected_sampled_logits_token_gen{0.03742531f, -0.05752287f, 0.14159015f, 0.04210977f, -0.1484456f, + 0.3041716f, -0.08701379f, -0.03778192f, 0.07471392f, -0.02049096f}; + + generator->ComputeLogits(); + auto token_gen_logits_ptr = generator->GetOutput("logits"); + auto token_gen_logits = static_cast(token_gen_logits_ptr->Data()); + int num_token_gen_outputs_to_check = 10; + + for (int i = 0; i < num_token_gen_outputs_to_check; i++) { + EXPECT_NEAR(expected_sampled_logits_token_gen[i], token_gen_logits[i*sample_size], tolerance); + } + generator->GenerateNextToken(); +} + #if TEST_PHI2 struct Phi2Test {