-
Notifications
You must be signed in to change notification settings - Fork 360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
estimate token use before sending openai completions #1112
base: main
Are you sure you want to change the base?
estimate token use before sending openai completions #1112
Conversation
Signed-off-by: Jeffrey Martin <[email protected]>
The issue was identified when attempting to validate this linked comment. |
Many good questions, will respond. We would love this for |
This is implemented in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The idea is good. Some possible issues around max_tokens, context_len and deprefix.
if ( | ||
hasattr(self, "context_len") | ||
and self.context_len is not None | ||
and generation_max_tokens > self.context_len |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some models return the prompt in their output. In these cases, deprefix
should be asserted. Thus, the status of deprefix
may have implications for output token budget.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
deprefix
is not passed to the model in OpenAI client create()
calls hence not included in this evaluation.
# basic token boundary validation to ensure requests are not rejected for exceeding target context length | ||
generation_max_tokens = create_args.get("max_tokens", None) | ||
if generation_max_tokens is not None: | ||
# count tokens in prompt and ensure max_tokens requested is <= context_len allowed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_tokens
and context_len
are only related if deprefix
is asserted
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OpenAI client create()
does not accept deprefix
as a named param and will not be passed by the generator call. If future support for passing deprefix
in some way is added to the generator in the future we can rethink this calculation.
def create(
self,
*,
messages: Iterable[ChatCompletionMessageParam],
model: Union[str, ChatModel],
audio: Optional[ChatCompletionAudioParam] | NotGiven = NOT_GIVEN,
frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
function_call: completion_create_params.FunctionCall | NotGiven = NOT_GIVEN,
functions: Iterable[completion_create_params.Function] | NotGiven = NOT_GIVEN,
logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
logprobs: Optional[bool] | NotGiven = NOT_GIVEN,
max_completion_tokens: Optional[int] | NotGiven = NOT_GIVEN,
max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
metadata: Optional[Dict[str, str]] | NotGiven = NOT_GIVEN,
modalities: Optional[List[ChatCompletionModality]] | NotGiven = NOT_GIVEN,
n: Optional[int] | NotGiven = NOT_GIVEN,
parallel_tool_calls: bool | NotGiven = NOT_GIVEN,
prediction: Optional[ChatCompletionPredictionContentParam] | NotGiven = NOT_GIVEN,
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
reasoning_effort: ChatCompletionReasoningEffort | NotGiven = NOT_GIVEN,
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
seed: Optional[int] | NotGiven = NOT_GIVEN,
service_tier: Optional[Literal["auto", "default"]] | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN,
store: Optional[bool] | NotGiven = NOT_GIVEN,
stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
stream_options: Optional[ChatCompletionStreamOptionsParam] | NotGiven = NOT_GIVEN,
temperature: Optional[float] | NotGiven = NOT_GIVEN,
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
top_p: Optional[float] | NotGiven = NOT_GIVEN,
user: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ChatCompletion:
logging.warning( | ||
f"Requested max_tokens {generation_max_tokens} exceeds context length {self.context_len}, reducing requested maximum" | ||
) | ||
generation_max_tokens = self.context_len |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like this disregards max_tokens
if context_len
is not None
, is that right? What's the intuition behind this? The intent is that max_tokens
constrains generation (which is unbounded for most models, timeouts notwithstanding), and that context_len
describes the fixed length of the input that's hard-predicated on model architecture
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR is based on observed behavior of the OpenAI endpoints, this attempts to ensure a valid request can be made. OpenAI services are setting an upper bound on max_tokens
when it is passed as part of the create
request and return a 400 that states if the param is passed that prompt
+ max_tokens
must be less than context_length
defined by the service for the model.
Hence if we know enough about the target model in this runtime we can make a best effort estimate to avoid bashing against a brick wall making requests we can predict will return no valid inference response. If the runtime does not know the context_len
value ahead of time or max_tokens
has been suppressed there is not enough information to make any prediction and execution will make the request.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, that's crucial, good to know. We should document this here with reference to an OpenAI uri. Is variable name usage consistent with elsewhere in garak?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is variable name usage consistent with elsewhere in garak?
self.uri
is the endpoint targeted for all classes that extend OpenAICompatible
if that is the question.
As to documenting I could see adding some context about the assumptions made here to being based on OpenAI API spec.
As a future iteration it may be of value to evaluate if shifting max_tokens
to max_completion_tokens
is appropriate. The deprecation of the option by OpenAI
may end up causing some fragmentation in the meaning of max_tokens
for generators in general in garak
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we're suffering from an overloading of max_tokens
which has different semantics with garak and for OpenAI.
With:
if max_tokens allowed is above the model supported context the context_len is held as the max_tokens for the request
Is this saying
if the max_tokens value passed in the API call allowed is above the model-supported context length
context_len
, thecontext_len
is used as themax_tokens
value for the call
?
If so - can you run through the logic behind this in simple, verbose, explicit terms?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When making a request to OpenAI, the create()
passed parameter max_tokens
+ prompt_tokens
must be less than the model defined context length the services supports or the create()
will cancel before attempting to process the prompt.
If the user configures garak
's max_tokens
to say 20000
and the and the target model is gpt-3.5-turbo-instruct
the model context length supported by OpenAI is 4096
. This new check will do the following to allow requests to be made to gpt-3.5-turbo-instruct
when the prompt is short enough to get at least 1 token back in the response:
- Set the initial value to be passed to the
create()
for max_tokens to thegarak
configured value:20000
- Check the value compared to the model's context length of
4096
, since20000
is more than4096
we constrain the request to callcreate()
with at most the context length the model can support setting it to4096
- Next we estimate the prompt tokens for this example use
1000
as the estimate. - Subtract the estimate
1000
from available token length4096
and set the maxgenerated
additional tokens to3096
- Call create() with the
1000
token prompt andmax_tokens
as3096
Now a scenario where the model has plenty of context length such as gpt-4-turbo
with context length support at 128000
:
- Set the initial value to be passed to the
create()
for max_tokens to thegarak
configured value:20000
- Check the value compared to the model's context length of
128000
, since20000
is less than128000
we determine the model can support setting it to the user requested20000
- Next we estimate the prompt tokens for this example use
1000
as the estimate. - Subtract the estimate
1000
from available requested max token length20000
and set the maxgenerated
additional tokens to19000
- Call create() with the
1000
token prompt andmax_tokens
as19000
This constrains max_tokens
for garak
as a maximum budget for the number of tokens in each request in total not the total number of tokens to be generated as output. Also this allows any request that would not exceed that threshold to be processed against models that have a maximum context length smaller than the user requested upper bound.
Another possible approach could be to simply not pass max_tokens
to a model that has a known context length smaller than the user provided max_tokens
. However this may still result in an error response if the prompt itself were to exceed the model context length. Open to reducing complexity in this way if the team thinks the value trade off is acceptable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_tokens + prompt_tokens must be less than the model defined context length
Alright, I am going to have to take a moment with the API guide to get on top of this. OpenAI model input capacity must be greater than prompt length and output length?
Examples make a ton of sense, thanks. This looks like a really helpful PR/feature.
-- I think the results might just be a few variable renaming suggestions. Will get back to this within a day or two.
Signed-off-by: Jeffrey Martin <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Signed-off-by: Jeffrey Martin <[email protected]>
When setting
max_tokens
for services compliant with OpenAI python client the value passed to the client needs to be reduce to a maximum of the model's supported context length inclusive of the tokens in the prompt request.This revision validates the available context space before attempting to request inference with the following behaviors:
Please review with a eye to desired runtime behavior, should the run be terminated if a prompt from a probe exceeds the context length of the target model or should the run continue and simply log the skipped
Attempt
?Error reported as 400 response when context length of the model is exceeded:
Test example:
high_tokens_config.yaml:
Logged error: