Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GetOutput for C API #755

Merged
merged 36 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
abfd3e0
add GetOutput for C API
ajindal1 Aug 2, 2024
fe42340
update return var
ajindal1 Aug 2, 2024
1b10033
fix typo
ajindal1 Aug 3, 2024
b042eab
change gen name
ajindal1 Aug 3, 2024
a658c9d
change state
ajindal1 Aug 3, 2024
2f9c734
change var name
ajindal1 Aug 5, 2024
2134e24
make a copy of buffer
ajindal1 Aug 6, 2024
65fd9c0
modify copy of data
ajindal1 Aug 6, 2024
545034f
add use cuda
ajindal1 Aug 6, 2024
ec4f6cf
modify std copy
ajindal1 Aug 6, 2024
842ca97
Merge branch 'main' of github.com:microsoft/onnxruntime-genai into ab…
ajindal1 Aug 7, 2024
471d52b
change allocator device and add for DML
ajindal1 Aug 8, 2024
eaa9211
change allocator cpu
ajindal1 Aug 8, 2024
c2c2816
update dml code
ajindal1 Aug 8, 2024
63ece2b
update dml device type
ajindal1 Aug 8, 2024
cbeaefe
add test for GetOutput
ajindal1 Aug 8, 2024
5eabc4d
update copy type
ajindal1 Aug 8, 2024
b446022
use output data
ajindal1 Aug 8, 2024
3ac992d
cast to float
ajindal1 Aug 9, 2024
3df46f1
fix typo
ajindal1 Aug 9, 2024
20115b9
logging
ajindal1 Aug 9, 2024
b4aaf5b
add logging
ajindal1 Aug 9, 2024
6873693
typo fix
ajindal1 Aug 9, 2024
5fa23a0
add missing code for compute logits
ajindal1 Aug 9, 2024
c022412
add more logs and conditions
ajindal1 Aug 9, 2024
df482ad
restructure
ajindal1 Aug 9, 2024
5f63d00
use float and restructure code
ajindal1 Aug 9, 2024
533911d
typo
ajindal1 Aug 9, 2024
19e9774
fix typo
ajindal1 Aug 9, 2024
175a5c5
try more loggin
ajindal1 Aug 12, 2024
692fa3c
modify type
ajindal1 Aug 12, 2024
ac84157
logits separate ptr
ajindal1 Aug 12, 2024
38d9e68
convert double to float
ajindal1 Aug 12, 2024
5f8d9bf
add info about the fn
ajindal1 Aug 13, 2024
cf2d2e7
updata comments
ajindal1 Aug 15, 2024
2887997
test for DML
ajindal1 Aug 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/ort_genai.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,12 @@ struct OgaGenerator : OgaAbstract {
return OgaGenerator_GetSequenceData(this, index);
}

std::unique_ptr<OgaTensor> GetOutput(const char* name) {
OgaTensor* out;
OgaCheckResult(OgaGenerator_GetOutput(this, name, &out));
return std::unique_ptr<OgaTensor>(out);
}

#if __cplusplus >= 202002L
std::span<const int32_t> GetSequence(size_t index) const {
return {GetSequenceData(index), GetSequenceCount(index)};
Expand Down
44 changes: 44 additions & 0 deletions src/ort_genai_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,50 @@ OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken(OgaGenerator* generator)
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaGenerator_GetOutput(const OgaGenerator* oga_generator, const char* name, OgaTensor** out) {
yufenglee marked this conversation as resolved.
Show resolved Hide resolved
OGA_TRY
auto& generator = *reinterpret_cast<const Generators::Generator*>(oga_generator);
auto* ortvalue_output = generator.state_->GetOutput(name);
auto type_info = ortvalue_output->GetTensorTypeAndShapeInfo();
std::unique_ptr<OrtValue> ortvalue_clone = OrtValue::CreateTensor(generator.model_->allocator_cpu_,
type_info->GetShape(),
type_info->GetElementType());
// Copy data to ortvalue_clone
auto element_size = Generators::SizeOf(type_info->GetElementType());
auto data_size = type_info->GetElementCount() * element_size;
if (ortvalue_output->GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_GPU && generator.model_->device_type_ == Generators::DeviceType::CUDA) {
#if USE_CUDA
cudaMemcpy(ortvalue_clone->GetTensorMutableRawData(), ortvalue_output->GetTensorMutableRawData(), data_size, cudaMemcpyDeviceToHost);
#endif
} else if (ortvalue_output->GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_GPU && generator.model_->device_type_ == Generators::DeviceType::DML) {
#if USE_DML
ComPtr<ID3D12Resource> gpu_resource;
Ort::ThrowOnError(generator.model_->GetOrtDmlApi()->GetD3D12ResourceFromAllocation(
generator.model_->allocator_device_,
ortvalue_output->GetTensorMutableRawData(),
&gpu_resource));
auto cpu_tensor = ortvalue_clone->GetTensorMutableRawData();
generator.model_->GetDmlReadbackHeap()->ReadbackFromGpu(
std::span(reinterpret_cast<uint8_t*>(cpu_tensor), data_size),
gpu_resource.Get(),
0,
D3D12_RESOURCE_STATE_UNORDERED_ACCESS);
#endif
} else if (ortvalue_output->GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_CPU) {
std::copy(static_cast<uint8_t*>(ortvalue_output->GetTensorMutableRawData()),
static_cast<uint8_t*>(ortvalue_output->GetTensorMutableRawData()) + data_size,
static_cast<uint8_t*>(ortvalue_clone->GetTensorMutableRawData()));
}

// Add else statement for no recognized device found above

auto tensor = std::make_shared<Generators::Tensor>(std::move(ortvalue_clone));
tensor->external_owner_ = tensor;
*out = reinterpret_cast<OgaTensor*>(tensor.get());
return nullptr;
OGA_CATCH
}

size_t OGA_API_CALL OgaGenerator_GetSequenceCount(const OgaGenerator* oga_generator, size_t index) {
auto& generator = *reinterpret_cast<const Generators::Generator*>(oga_generator);
return generator.GetSequence(static_cast<int>(index)).GetCPU().size();
Expand Down
7 changes: 7 additions & 0 deletions src/ort_genai_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,13 @@ OGA_EXPORT bool OGA_API_CALL OgaGenerator_IsDone(const OgaGenerator* generator);
OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_ComputeLogits(OgaGenerator* generator);
OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken(OgaGenerator* generator);

/*
* \brief Runs GetOutput on the name provided and copies the data in another tensor and creates OgaTensor for it
* \param[in] generator The generator to run the GetOutput on the name provided and the out pointer to store the output
* \return OgaResult containing the error message if the computation failed.
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GetOutput(const OgaGenerator* oga_generator, const char* name, OgaTensor** out);
ajindal1 marked this conversation as resolved.
Show resolved Hide resolved

/*
* \brief Returns the number of tokens in the sequence at the given index.
* \param[in] generator The generator to get the count of the tokens for the sequence at the given index.
Expand Down
61 changes: 61 additions & 0 deletions test/c_api_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,67 @@ TEST(CAPITests, GreedySearchGptFp32CAPI) {
}
#endif

#if !USE_DML
TEST(CAPITests, GetOutputCAPI) {
std::vector<int64_t> input_ids_shape{2, 4};
std::vector<int32_t> input_ids{0, 0, 0, 52, 0, 0, 195, 731};

auto input_sequence_length = input_ids_shape[1];
auto batch_size = input_ids_shape[0];
int max_length = 10;

// To generate this file:
// python convert_generation.py --model_type gpt2 -m hf-internal-testing/tiny-random-gpt2 --output tiny_gpt2_greedysearch_fp16.onnx --use_gpu --max_length 20
// And copy the resulting gpt2_init_past_fp32.onnx file into these two files (as it's the same for gpt2)

auto model = OgaModel::Create(MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32");

auto params = OgaGeneratorParams::Create(*model);
params->SetSearchOption("max_length", max_length);
params->SetInputIDs(input_ids.data(), input_ids.size(), input_sequence_length, batch_size);

auto generator = OgaGenerator::Create(*model, *params);

// check prompt
// full logits has shape [2, 4, 1000]. Sample 1 for every 200 tokens and the expected sampled logits has shape [2, 4, 5]
std::vector<float> expected_sampled_logits_prompt{0.29694548f, 0.00955007f, 0.0430819f, 0.10063869f, 0.0437237f,
0.27329233f, 0.00841076f, -0.1060291f, 0.11328877f, 0.13369876f,
0.30323744f, 0.0545997f, 0.03894716f, 0.11702324f, 0.0410665f,
-0.12675379f, -0.04443946f, 0.14492269f, 0.03021223f, -0.03212897f,
0.29694548f, 0.00955007f, 0.0430819f, 0.10063869f, 0.0437237f,
0.27329233f, 0.00841076f, -0.1060291f, 0.11328877f, 0.13369876f,
-0.04699047f, 0.17915794f, 0.20838135f, 0.10888482f, -0.00277808f,
0.2938929f, -0.10538938f, -0.00226692f, 0.12050669f, -0.10622668f};

generator->ComputeLogits();
auto prompt_logits_ptr = generator->GetOutput("logits");
auto prompt_logits = static_cast<float*>(prompt_logits_ptr->Data());
int num_prompt_outputs_to_check = 40;
int sample_size = 200;
float tolerance = 0.001f;
// Verify outputs match expected outputs
for (int i = 0; i < num_prompt_outputs_to_check; i++) {
EXPECT_NEAR(expected_sampled_logits_prompt[i], prompt_logits[i*sample_size], tolerance);
}

generator->GenerateNextToken();
// check for the 1st token generation
// full logits has shape [2, 1, 1000]. Sample 1 for every 200 tokens and the expected sampled logits has shape [2, 1, 5]
std::vector<float> expected_sampled_logits_token_gen{0.03742531f, -0.05752287f, 0.14159015f, 0.04210977f, -0.1484456f,
0.3041716f, -0.08701379f, -0.03778192f, 0.07471392f, -0.02049096f};

generator->ComputeLogits();
auto token_gen_logits_ptr = generator->GetOutput("logits");
auto token_gen_logits = static_cast<float*>(token_gen_logits_ptr->Data());
int num_token_gen_outputs_to_check = 10;

for (int i = 0; i < num_token_gen_outputs_to_check; i++) {
EXPECT_NEAR(expected_sampled_logits_token_gen[i], token_gen_logits[i*sample_size], tolerance);
}
generator->GenerateNextToken();
}
#endif

#if TEST_PHI2

struct Phi2Test {
Expand Down
Loading