diff --git a/council/llm/llm_answer.py b/council/llm/llm_answer.py index 33fc5422..ace088a8 100644 --- a/council/llm/llm_answer.py +++ b/council/llm/llm_answer.py @@ -1,6 +1,6 @@ from __future__ import annotations import inspect -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional import yaml @@ -127,6 +127,17 @@ def parse_yaml(self, bloc: str) -> Dict[str, Any]: raise LLMParsingException(f"Missing {missing_keys} in response.") return properties_dict + def parse_yaml_list(self, bloc: str) -> List[Dict[str, Any]]: + result = [] + d = yaml.safe_load(bloc) + for item in d: + properties_dict = {**item} + missing_keys = [key.name for key in self._properties if key.name not in properties_dict.keys()] + if len(missing_keys) > 0: + raise LLMParsingException(f"Missing {missing_keys} in response.") + result.append(properties_dict) + return result + def parse_yaml_bloc(self, bloc: str) -> Dict[str, Any]: code_bloc = CodeParser.find_first(language="yaml", text=bloc) if code_bloc is not None: diff --git a/council/utils/code_parser.py b/council/utils/code_parser.py index 683bdf7f..ef18bc74 100644 --- a/council/utils/code_parser.py +++ b/council/utils/code_parser.py @@ -44,7 +44,7 @@ def find_last(language: Optional[str] = None, text: str = "") -> Optional[CodeBl @staticmethod def _get_pattern(language: Optional[str]): - return r"```(\w*)\n(.*?)\n```" if language is None else rf"```({language})\n(.*?)\n```" + return r"```(\w*) *\n(.*?)\n```" if language is None else rf"```({language})\n(.*?)\n```" @staticmethod def _build_generator(language: Optional[str], text: str = "") -> Iterable[CodeBlock]: diff --git a/tests/unit/llm/test_llm_answer.py b/tests/unit/llm/test_llm_answer.py index e66102b1..6bc2cf4e 100644 --- a/tests/unit/llm/test_llm_answer.py +++ b/tests/unit/llm/test_llm_answer.py @@ -26,11 +26,25 @@ def test_llm_parse_yaml_answer(self): llma = LLMAnswer(Specialist) bloc = """ - ```yaml - ControllerScore: - name: first - score: 10 - instructions: do this - justification: because - ```""" - print(llma.parse_yaml_bloc(bloc)) +```yaml +name: first +score: 10 +instructions: do this +justification: because +``` +""" + result = llma.parse_yaml_bloc(bloc) + instance = Specialist(**result) + self.assertEqual(instance.name, "first") + + def test_parse_dict(self): + instance = LLMAnswer(Specialist) + bloc = """ +- name: first + score: 10 + instructions: + - do + - this + justification: because +""" + print(instance.parse_yaml_list(bloc))