Skip to content

Commit

Permalink
Correct the types of registry and cli.oaievalset (openai#1027)
Browse files Browse the repository at this point in the history
Currently, the `registry` and `cli.oaievalset` won't be checked as
expected due to the misform wildcard in Additional Sections.

This PR:

- re-enable the checks for these modules
- correct and strengthen the types of these modules
  • Loading branch information
pan93412 authored Jun 5, 2023
1 parent 723f669 commit c2c8abe
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 36 deletions.
37 changes: 28 additions & 9 deletions evals/cli/oaievalset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
"""
import argparse
import json
import logging
import subprocess
from pathlib import Path
from typing import Optional
from typing import Optional, cast

from evals.registry import Registry

Task = list[str]
logger = logging.getLogger(__name__)


class Progress:
Expand Down Expand Up @@ -61,17 +63,34 @@ def get_parser() -> argparse.ArgumentParser:
return parser


class OaiEvalSetArguments(argparse.Namespace):
model: str
eval_set: str
resume: bool
exit_on_error: bool


def run(
args, unknown_args, registry: Optional[Registry] = None, run_command: str = "oaieval"
args: OaiEvalSetArguments,
unknown_args: list[str],
registry: Optional[Registry] = None,
run_command: str = "oaieval",
) -> None:
registry = registry or Registry()
commands: list[Task] = []
eval_set = registry.get_eval_set(args.eval_set)
for eval in registry.get_evals(eval_set.evals):
command = [run_command, args.model, eval.key] + unknown_args
if command in commands:
continue
commands.append(command)
eval_set = registry.get_eval_set(args.eval_set) if args.eval_set else None
if eval_set:
for index, eval in enumerate(registry.get_evals(eval_set.evals)):
if not eval or not eval.key:
logger.debug("The eval #%d in eval_set is not valid", index)

command = [run_command, args.model, eval.key] + unknown_args
if command in commands:
continue
commands.append(command)
else:
logger.warning("No eval set found for %s", args.eval_set)

num_evals = len(commands)

progress = Progress(f"/tmp/oaievalset/{args.model}.{args.eval_set}.progress.txt")
Expand Down Expand Up @@ -100,7 +119,7 @@ def run(
def main() -> None:
parser = get_parser()
args, unknown_args = parser.parse_known_args()
run(args, unknown_args)
run(cast(OaiEvalSetArguments, args), unknown_args)


if __name__ == "__main__":
Expand Down
54 changes: 30 additions & 24 deletions evals/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import re
from functools import cached_property
from pathlib import Path
from typing import Any, Iterator, Optional, Sequence, Type, Union
from typing import Any, Iterator, Optional, Sequence, Type, TypeVar, Union

import openai
import yaml
Expand Down Expand Up @@ -73,15 +73,19 @@ def n_ctx_from_model_name(model_name: str) -> Optional[int]:
}


T = TypeVar("T")
RawRegistry = dict[str, Any]


class Registry:
def __init__(self, registry_paths: Sequence[Union[str, Path]] = DEFAULT_PATHS):
self._registry_paths = [Path(p) if isinstance(p, str) else p for p in registry_paths]

def add_registry_paths(self, paths: list[Union[str, Path]]):
def add_registry_paths(self, paths: list[Union[str, Path]]) -> None:
self._registry_paths.extend([Path(p) if isinstance(p, str) else p for p in paths])

@cached_property
def api_model_ids(self):
def api_model_ids(self) -> list[str]:
try:
return [m["id"] for m in openai.Model.list()["data"]]
except openai.error.OpenAIError as err:
Expand Down Expand Up @@ -119,10 +123,12 @@ def make_completion_fn(self, name: str) -> CompletionFn:
assert isinstance(instance, CompletionFn), f"{name} must be a CompletionFn"
return instance

def get_class(self, spec: dict) -> Any:
def get_class(self, spec: EvalSpec) -> Any:
return make_object(spec.cls, **(spec.args if spec.args else {}))

def _dereference(self, name: str, d: dict, object: str, type: Type, **kwargs: dict) -> dict:
def _dereference(
self, name: str, d: RawRegistry, object: str, type: Type[T], **kwargs: dict
) -> Optional[T]:
if not name in d:
logger.warning(
(
Expand All @@ -132,7 +138,7 @@ def _dereference(self, name: str, d: dict, object: str, type: Type, **kwargs: di
)
return None

def get_alias():
def get_alias() -> Optional[str]:
if isinstance(d[name], str):
return d[name]
if isinstance(d[name], dict) and "id" in d[name]:
Expand All @@ -157,7 +163,7 @@ def get_alias():
except TypeError as e:
raise TypeError(f"Error while processing {object} '{name}': {e}")

def get_modelgraded_spec(self, name: str, **kwargs: dict) -> dict[str, Any]:
def get_modelgraded_spec(self, name: str, **kwargs: dict) -> Optional[ModelGradedSpec]:
assert name in self._modelgraded_specs, (
f"Modelgraded spec {name} not found. "
f"Closest matches: {difflib.get_close_matches(name, self._modelgraded_specs.keys(), n=5)}"
Expand All @@ -166,18 +172,18 @@ def get_modelgraded_spec(self, name: str, **kwargs: dict) -> dict[str, Any]:
name, self._modelgraded_specs, "modelgraded spec", ModelGradedSpec, **kwargs
)

def get_completion_fn(self, name: str) -> CompletionFnSpec:
def get_completion_fn(self, name: str) -> Optional[CompletionFnSpec]:
return self._dereference(name, self._completion_fns, "completion_fn", CompletionFnSpec)

def get_eval(self, name: str) -> EvalSpec:
def get_eval(self, name: str) -> Optional[EvalSpec]:
return self._dereference(name, self._evals, "eval", EvalSpec)

def get_eval_set(self, name: str) -> EvalSetSpec:
def get_eval_set(self, name: str) -> Optional[EvalSetSpec]:
return self._dereference(name, self._eval_sets, "eval set", EvalSetSpec)

def get_evals(self, patterns: Sequence[str]) -> Iterator[EvalSpec]:
def get_evals(self, patterns: Sequence[str]) -> Iterator[Optional[EvalSpec]]:
# valid patterns: hello, hello.dev*, hello.dev.*-v1
def get_regexp(pattern):
def get_regexp(pattern: str) -> re.Pattern[str]:
pattern = pattern.replace(".", "\\.")
pattern = pattern.replace("*", ".*")
return re.compile(f"^{pattern}$")
Expand All @@ -188,14 +194,14 @@ def get_regexp(pattern):
if any(map(lambda regexp: regexp.match(name), regexps)):
yield self.get_eval(name)

def get_base_evals(self) -> list[BaseEvalSpec]:
base_evals = []
def get_base_evals(self) -> list[Optional[BaseEvalSpec]]:
base_evals: list[Optional[BaseEvalSpec]] = []
for name, spec in self._evals.items():
if name.count(".") == 0:
base_evals.append(self.get_base_eval(name))
return base_evals

def get_base_eval(self, name: str) -> BaseEvalSpec:
def get_base_eval(self, name: str) -> Optional[BaseEvalSpec]:
if not name in self._evals:
return None

Expand All @@ -210,11 +216,11 @@ def get_base_eval(self, name: str) -> BaseEvalSpec:
alias = spec_or_alias
return BaseEvalSpec(id=alias)

def _process_file(self, registry, path):
def _process_file(self, registry: RawRegistry, path: Path) -> None:
with open(path, "r", encoding="utf-8") as f:
d = yaml.safe_load(f)

if d is None:
if d is None or not isinstance(d, dict):
# no entries in the file
return

Expand All @@ -241,17 +247,17 @@ def _process_file(self, registry, path):
del spec["class"]
registry[name] = spec

def _process_directory(self, registry, path):
def _process_directory(self, registry: RawRegistry, path: Path) -> None:
files = Path(path).glob("*.yaml")
for file in files:
self._process_file(registry, file)

def _load_registry(self, paths):
def _load_registry(self, paths: Sequence[Path]) -> RawRegistry:
"""Load registry from a list of paths.
Each path or yaml specifies a dictionary of name -> spec.
"""
registry = {}
registry: RawRegistry = {}
for path in paths:
logging.info(f"Loading registry from {path}")
if os.path.exists(path):
Expand All @@ -262,19 +268,19 @@ def _load_registry(self, paths):
return registry

@functools.cached_property
def _completion_fns(self):
def _completion_fns(self) -> RawRegistry:
return self._load_registry([p / "completion_fns" for p in self._registry_paths])

@functools.cached_property
def _eval_sets(self):
def _eval_sets(self) -> RawRegistry:
return self._load_registry([p / "eval_sets" for p in self._registry_paths])

@functools.cached_property
def _evals(self):
def _evals(self) -> RawRegistry:
return self._load_registry([p / "evals" for p in self._registry_paths])

@functools.cached_property
def _modelgraded_specs(self):
def _modelgraded_specs(self) -> RawRegistry:
return self._load_registry([p / "modelgraded" for p in self._registry_paths])


Expand Down
10 changes: 8 additions & 2 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
[mypy]
python_version=3.9

mypy_path=$MYPY_CONFIG_FILE_DIR/typings

; Not all dependencies have type annotations; ignore this.
ignore_missing_imports=True
namespace_packages=True
Expand All @@ -22,11 +24,11 @@ disallow_untyped_defs=False
; However, some directories that are fully type-annotated and don't have type errors have opted in
; to type checking.

[mypy-registry.*]
[mypy-evals.registry]
ignore_errors=False
disallow_untyped_defs=True

[mypy-oaievalset.*]
[mypy-evals.cli.oaievalset]
ignore_errors=False
disallow_untyped_defs=True

Expand All @@ -38,4 +40,8 @@ disallow_untyped_defs=True
ignore_errors=False
disallow_untyped_defs=True

[mypy-openai.*]
ignore_errors=False
disallow_untyped_defs=True

; TODO: Add the other modules here
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,13 @@ dependencies = [
"matplotlib",
"pytest",
"setuptools_scm",
"langchain"
"langchain",
"types-PyYAML",
]

[project.scripts]
oaieval = "evals.cli.oaieval:main"
oaievalset = "evals.cli.oaievalset:main"

[tool.setuptools]
packages = ["evals"]
1 change: 1 addition & 0 deletions typings/openai/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .model import Model as Model
15 changes: 15 additions & 0 deletions typings/openai/model.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import Optional

from .response import ListResponse

class Model:
@classmethod
def list(
cls,
api_key: Optional[str] = ...,
request_id: Optional[str] = ...,
api_version: Optional[str] = ...,
organization: Optional[str] = ...,
api_base: Optional[str] = ...,
api_type: Optional[str] = ...,
) -> ListResponse: ...
15 changes: 15 additions & 0 deletions typings/openai/response.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import Any, Literal, TypedDict

class ListResponse(TypedDict):
"""Response from Model.list
Reference: https://platform.openai.com/docs/api-reference/models"""

object: Literal["list"]
data: list[Model]

class Model(TypedDict):
id: str
object: Literal["model"]
owned_by: str
permission: list[Any] # TODO

0 comments on commit c2c8abe

Please sign in to comment.