Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GetOutput for C API #755

Merged
merged 36 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
abfd3e0
add GetOutput for C API
ajindal1 Aug 2, 2024
fe42340
update return var
ajindal1 Aug 2, 2024
1b10033
fix typo
ajindal1 Aug 3, 2024
b042eab
change gen name
ajindal1 Aug 3, 2024
a658c9d
change state
ajindal1 Aug 3, 2024
2f9c734
change var name
ajindal1 Aug 5, 2024
2134e24
make a copy of buffer
ajindal1 Aug 6, 2024
65fd9c0
modify copy of data
ajindal1 Aug 6, 2024
545034f
add use cuda
ajindal1 Aug 6, 2024
ec4f6cf
modify std copy
ajindal1 Aug 6, 2024
842ca97
Merge branch 'main' of github.com:microsoft/onnxruntime-genai into ab…
ajindal1 Aug 7, 2024
471d52b
change allocator device and add for DML
ajindal1 Aug 8, 2024
eaa9211
change allocator cpu
ajindal1 Aug 8, 2024
c2c2816
update dml code
ajindal1 Aug 8, 2024
63ece2b
update dml device type
ajindal1 Aug 8, 2024
cbeaefe
add test for GetOutput
ajindal1 Aug 8, 2024
5eabc4d
update copy type
ajindal1 Aug 8, 2024
b446022
use output data
ajindal1 Aug 8, 2024
3ac992d
cast to float
ajindal1 Aug 9, 2024
3df46f1
fix typo
ajindal1 Aug 9, 2024
20115b9
logging
ajindal1 Aug 9, 2024
b4aaf5b
add logging
ajindal1 Aug 9, 2024
6873693
typo fix
ajindal1 Aug 9, 2024
5fa23a0
add missing code for compute logits
ajindal1 Aug 9, 2024
c022412
add more logs and conditions
ajindal1 Aug 9, 2024
df482ad
restructure
ajindal1 Aug 9, 2024
5f63d00
use float and restructure code
ajindal1 Aug 9, 2024
533911d
typo
ajindal1 Aug 9, 2024
19e9774
fix typo
ajindal1 Aug 9, 2024
175a5c5
try more loggin
ajindal1 Aug 12, 2024
692fa3c
modify type
ajindal1 Aug 12, 2024
ac84157
logits separate ptr
ajindal1 Aug 12, 2024
38d9e68
convert double to float
ajindal1 Aug 12, 2024
5f8d9bf
add info about the fn
ajindal1 Aug 13, 2024
cf2d2e7
updata comments
ajindal1 Aug 15, 2024
2887997
test for DML
ajindal1 Aug 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/ort_genai.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,12 @@ struct OgaGenerator : OgaAbstract {
return OgaGenerator_GetSequenceData(this, index);
}

OgaTensor* GetOutput(const char* name) {
OgaTensor* out;
OgaCheckResult(OgaGenerator_GetOutput(this, name, &out));
return out;
ajindal1 marked this conversation as resolved.
Show resolved Hide resolved
}

#if __cplusplus >= 202002L
std::span<const int32_t> GetSequence(size_t index) const {
return {GetSequenceData(index), GetSequenceCount(index)};
Expand Down
9 changes: 9 additions & 0 deletions src/ort_genai_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,15 @@ OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken(OgaGenerator* generator)
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaGenerator_GetOutput(const OgaGenerator* oga_generator, const char* name, OgaTensor** out) {
yufenglee marked this conversation as resolved.
Show resolved Hide resolved
OGA_TRY
auto& generator = *reinterpret_cast<const Generators::Generator*>(oga_generator);
auto* ortValueOutput = generator.state_->GetOutput(name);
*out = reinterpret_cast<OgaTensor*>(ortValueOutput);
ajindal1 marked this conversation as resolved.
Show resolved Hide resolved
ajindal1 marked this conversation as resolved.
Show resolved Hide resolved
return nullptr;
OGA_CATCH
}

size_t OGA_API_CALL OgaGenerator_GetSequenceCount(const OgaGenerator* oga_generator, size_t index) {
auto& generator = *reinterpret_cast<const Generators::Generator*>(oga_generator);
return generator.GetSequence(static_cast<int>(index)).GetCPU().size();
Expand Down
2 changes: 2 additions & 0 deletions src/ort_genai_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ OGA_EXPORT bool OGA_API_CALL OgaGenerator_IsDone(const OgaGenerator* generator);
OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_ComputeLogits(OgaGenerator* generator);
OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken(OgaGenerator* generator);

OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GetOutput(const OgaGenerator* oga_generator, const char* name, OgaTensor** out);
ajindal1 marked this conversation as resolved.
Show resolved Hide resolved

/*
* \brief Returns the number of tokens in the sequence at the given index.
* \param[in] generator The generator to get the count of the tokens for the sequence at the given index.
Expand Down
Loading