Skip to content

Commit

Permalink
streamline openai param validation & suppression - predicate it on ap…
Browse files Browse the repository at this point in the history
…i signature
  • Loading branch information
leondz committed Sep 23, 2024
1 parent ad501ef commit 578a7d3
Showing 1 changed file with 13 additions and 14 deletions.
27 changes: 13 additions & 14 deletions garak/generators/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
* https://platform.openai.com/docs/model-index-for-researchers
"""

import inspect
import json
import logging
import re
Expand Down Expand Up @@ -206,15 +207,15 @@ def _call_model(
"seed": self.seed,
}

create_args = {
k: v
for k, v in create_args.items()
if v is not None and k not in self.suppressed_params
}

for k, v in self.custom_params.items():
if k not in self.suppressed_params:
create_args[k] = v
create_args = {}
if "n" not in self.suppressed_params:
create_args["n"] = generations_this_call
for arg in inspect.signature(self.generator.create).parameters:
if arg == "model":
create_args[arg] = self.name
continue
if hasattr(self, arg) and arg not in self.suppressed_params:
create_args[arg] = getattr(self, arg)

if self.generator == self.client.completions:
if not isinstance(prompt, str):
Expand Down Expand Up @@ -297,7 +298,7 @@ def _load_client(self):
)

if self.__class__.__name__ == "OpenAIGenerator" and self.name.startswith("o1-"):
msg = "o1 models should use openai.ReasoningGenerator"
msg = "'o1'-class models should use openai.OpenAIReasoningGenerator. Try e.g. `-m openai.OpenAIReasoningGenerator` instead of `-m openai`"
logging.error(msg)
raise garak.exception.BadGeneratorException("🛑 " + msg)

Expand All @@ -314,7 +315,7 @@ def __init__(self, name="", config_root=_config):
super().__init__(self.name, config_root=config_root)


class ReasoningGenerator(OpenAIGenerator):
class OpenAIReasoningGenerator(OpenAIGenerator):
"""Generator wrapper for OpenAI reasoning models, e.g. `o1` family."""

supports_multiple_generations = False
Expand All @@ -327,9 +328,7 @@ class ReasoningGenerator(OpenAIGenerator):
"stop": ["#", ";"],
"suppressed_params": set(["n", "temperature", "max_tokens", "stop"]),
"retry_json": True,
"custom_params": {
"max_completion_tokens": 1500,
},
"max_completion_tokens": 1500,
}


Expand Down

0 comments on commit 578a7d3

Please sign in to comment.