Skip to content

Commit

Permalink
processor registry
Browse files Browse the repository at this point in the history
  • Loading branch information
yanxi0830 committed Oct 14, 2024
1 parent 95fd53d commit a22c31b
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 10 deletions.
13 changes: 9 additions & 4 deletions llama_stack/apis/evals/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,21 @@ class EvaluatePostprocessConfig(BaseModel):
kwargs: Optional[Dict[str, Any]] = None


@json_schema_type
class EvaluateProcessorConfig(BaseModel):
processor_identifier: str
preprocess_config: Optional[EvaluatePreprocessConfig] = None
postprocess_config: Optional[EvaluatePostprocessConfig] = None


@json_schema_type
class EvaluateJudgeScoringConfig(BaseModel): ...


@json_schema_type
class LLMJudgeConfig(BaseModel):
judge_preprocess_config: EvaluatePreprocessConfig
judge_processor_config: EvaluateProcessorConfig
judge_model_generation_config: EvaluateModelGenerationConfig
judge_postprocess_config: EvaluatePostprocessConfig
judge_scoring_config: EvaluateJudgeScoringConfig


Expand All @@ -116,9 +122,8 @@ class EvaluateScoringConfig(BaseModel):
@json_schema_type
class EvaluateTaskConfig(BaseModel):
dataset_config: EvaluateDatasetConfig
preprocess_config: Optional[EvaluatePreprocessConfig] = None
processor_config: EvaluateProcessorConfig
generation_config: EvaluateModelGenerationConfig
postprocess_config: Optional[EvaluatePostprocessConfig] = None
scoring_config: EvaluateScoringConfig


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.evals import * # noqa: F403
from llama_stack.providers.impls.meta_reference.evals.processor import * # noqa: F403

from ..registry import Registry

# TODO: decide whether we should group dataset+processor together via Tasks
GeneratorProcessorRegistry = Registry[BaseGeneratorProcessor]()

class GeneratorProcessorRegistry(Registry[BaseGeneratorProcessor]):
_REGISTRY: Dict[str, BaseGeneratorProcessor] = {}
PROCESSOR_REGISTRY = {
"mmlu": MMLUProcessor,
}

for k, v in PROCESSOR_REGISTRY.items():
GeneratorProcessorRegistry.register(k, v)
3 changes: 3 additions & 0 deletions llama_stack/providers/impls/meta_reference/evals/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ async def run_eval_task(
dataset_name=dataset,
row_limit=3,
),
processor_config=EvaluateProcessorConfig(
processor_identifier="mmlu",
),
generation_config=EvaluateModelGenerationConfig(
model=model,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .mmlu_processor import MMLUProcessor # noqa: F401
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .basic_scorers import * # noqa: F401 F403
from .aggregate_scorer import * # noqa: F401 F403
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.distribution.registry.datasets import DatasetRegistry
from llama_stack.distribution.registry.generator_processors import (
GeneratorProcessorRegistry,
)
from llama_stack.distribution.registry.scorers import ScorerRegistry

from llama_stack.providers.impls.meta_reference.evals.scorer.aggregate_scorer import * # noqa: F403
from llama_stack.providers.impls.meta_reference.evals.scorer.basic_scorers import * # noqa: F403
from llama_stack.providers.impls.meta_reference.evals.generator.inference_generator import (
InferenceGenerator,
)
from llama_stack.providers.impls.meta_reference.evals.processor.mmlu_processor import (
MMLUProcessor,
)


from llama_stack.apis.evals import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
Expand Down Expand Up @@ -46,7 +48,10 @@ async def run(
print(f"Running on {len(dataset)} samples")

# F1
processor = MMLUProcessor()
print(GeneratorProcessorRegistry.names())
processor = GeneratorProcessorRegistry.get(
eval_task_config.processor_config.processor_identifier
)()
preprocessed = processor.preprocess(dataset)

# Generation
Expand Down

0 comments on commit a22c31b

Please sign in to comment.