Skip to content

Commit

Permalink
fix review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
varunshankar committed Oct 10, 2023
1 parent 642aea1 commit e278ed8
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 21 deletions.
2 changes: 1 addition & 1 deletion data/test_prompt_templates/test_prompt.jinja2
Original file line number Diff line number Diff line change
@@ -1 +1 @@
This is a test prompt.
This is a test prompt.
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def __init__(
) -> None:
self.config = {**self.get_default_config(), **config}
self.prompt_template = get_prompt_template(
DEFAULT_COMMAND_PROMPT_TEMPLATE,
config.get("prompt"),
DEFAULT_COMMAND_PROMPT_TEMPLATE,
)
self._model_storage = model_storage
self._resource = resource
Expand Down
2 changes: 1 addition & 1 deletion rasa/shared/utils/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def embedder_factory(


def get_prompt_template(
default_prompt_template: Text, jinja_file_path: Optional[Text]
jinja_file_path: Optional[Text], default_prompt_template: Text
) -> Text:
"""Returns the prompt template.
Expand Down
36 changes: 18 additions & 18 deletions tests/cdu/generator/test_llm_command_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,21 @@ def model_storage(tmp_path_factory: TempPathFactory) -> ModelStorage:
return LocalModelStorage(tmp_path_factory.mktemp(uuid.uuid4().hex))


@pytest.mark.parametrize(
"config, expected",
[
(
{"prompt": "data/test_prompt_templates/test_prompt.jinja2"},
"This is a test prompt.",
),
(
{},
"Your task is to analyze the current conversation",
),
],
)
async def test_llm_command_generator_prompt_initialisation(
model_storage: ModelStorage, resource: Resource, config: Dict, expected: str
):
generator = LLMCommandGenerator(config, model_storage, resource)
assert generator.prompt_template.startswith(expected)
async def test_llm_command_generator_prompt_init_custom(
model_storage: ModelStorage, resource: Resource
) -> None:
generator = LLMCommandGenerator(
{"prompt": "data/test_prompt_templates/test_prompt.jinja2"},
model_storage,
resource,
)
assert generator.prompt_template.startswith("This is a test prompt.")


async def test_llm_command_generator_prompt_init_default(
model_storage: ModelStorage, resource: Resource
) -> None:
generator = LLMCommandGenerator({}, model_storage, resource)
assert generator.prompt_template.startswith(
"Your task is to analyze the current conversation"
)

0 comments on commit e278ed8

Please sign in to comment.