forked from NVIDIA/NeMo-Guardrails
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request NVIDIA#370 from botitai/feature/gotitai-truthchecker
Add Got It AI's Truthchecking service for RAG applications
- Loading branch information
Showing
10 changed files
with
287 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import json | ||
import logging | ||
import os | ||
from typing import Optional | ||
|
||
import aiohttp | ||
|
||
from nemoguardrails.actions import action | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
@action(name="call gotitai truthchecker api", is_system_action=True) | ||
async def call_gotitai_truthchecker_api(context: Optional[dict] = None): | ||
api_key = os.environ.get("GOTITAI_API_KEY") | ||
|
||
if api_key is None: | ||
raise ValueError("GOTITAI_API_KEY environment variable not set.") | ||
|
||
if context is None: | ||
raise ValueError( | ||
"Context is empty. `user_message`, `bot_response` and `relevant_chunks` keys are required to call the GotIt AI Truthchecker api." | ||
) | ||
|
||
user_message = context.get("user_message", "") | ||
response = context.get("bot_message", "") | ||
knowledge = context.get("relevant_chunks_sep", []) | ||
|
||
retval = {"hallucination": None} # in case the api call is skipped | ||
|
||
if not isinstance(knowledge, list): | ||
log.error( | ||
"Could not run Got It AI Truthchecker. `relevant_chunks_sep` must be a list of knowledge." | ||
) | ||
return retval | ||
|
||
if not knowledge: | ||
log.error( | ||
"Could not run Got It AI Truthchecker. At least 1 relevant chunk is required." | ||
) | ||
return retval | ||
|
||
url = "https://api.got-it.ai/api/v1/hallucination-manager/truthchecker" | ||
headers = { | ||
"Content-Type": "application/json", | ||
"Authorization": "Bearer " + api_key, | ||
} | ||
data = { | ||
"knowledge": [ | ||
{ | ||
"text": chunk, | ||
} | ||
for chunk in knowledge | ||
], | ||
"prompt": user_message, | ||
"generated_text": response, | ||
# Messages is empty for now since there is no standard way to get them. | ||
# This should be updated once 0.8.0 is released. | ||
# Reference: https://github.com/NVIDIA/NeMo-Guardrails/issues/246 | ||
"messages": [], | ||
} | ||
|
||
async with aiohttp.ClientSession() as session: | ||
async with session.post( | ||
url=url, | ||
headers=headers, | ||
json=data, | ||
) as response: | ||
if response.status != 200: | ||
log.error( | ||
f"GotItAI TruthChecking call failed with status code {response.status}.\n" | ||
f"Details: {await response.json()}" | ||
) | ||
response_json = await response.json() | ||
log.info(json.dumps(response_json, indent=True)) | ||
hallucination = response_json["hallucination"] | ||
retval = {"hallucination": hallucination} | ||
|
||
return retval |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
define subflow gotitai rag truthcheck | ||
"""Guardrail based on the maximum risk score.""" | ||
if $check_facts == True | ||
$check_facts = False | ||
|
||
$result = execute call gotitai truthchecker api | ||
|
||
if $result.hallucination == "yes" | ||
bot inform answer unknown | ||
stop |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
models: | ||
- type: main | ||
engine: openai | ||
model: gpt-3.5-turbo-instruct | ||
rails: | ||
output: | ||
flows: | ||
- gotitai rag truthcheck |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
define user ask general question | ||
"Do you ship within 2 days?" | ||
|
||
define flow | ||
user ask general question | ||
$check_facts = True | ||
bot provide answer | ||
|
||
define bot inform answer unknown | ||
"I don't know the answer to that." |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
|
||
import pytest | ||
from aioresponses import aioresponses | ||
|
||
from nemoguardrails import RailsConfig | ||
from nemoguardrails.actions.actions import ActionResult, action | ||
from tests.utils import TestChat | ||
|
||
CONFIGS_FOLDER = os.path.join(os.path.dirname(__file__), ".", "test_configs") | ||
|
||
GOTITAI_API_URL = "https://api.got-it.ai/api/v1/hallucination-manager/truthchecker" | ||
|
||
|
||
@action(is_system_action=True) | ||
async def retrieve_relevant_chunks(): | ||
"""Retrieve relevant chunks from the knowledge base and add them to the context.""" | ||
context_updates = {} | ||
context_updates["relevant_chunks_sep"] = ["Shipping takes at least 3 days."] | ||
|
||
return ActionResult( | ||
return_value=context_updates["relevant_chunks_sep"], | ||
context_updates=context_updates, | ||
) | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_hallucination(monkeypatch): | ||
monkeypatch.setenv("GOTITAI_API_KEY", "xxx") | ||
config = RailsConfig.from_path(os.path.join(CONFIGS_FOLDER, "gotitai_truthchecker")) | ||
chat = TestChat( | ||
config, | ||
llm_completions=[ | ||
"user ask general question", # user intent | ||
"Yes, shipping can be done in 2 days.", # bot response that will be intercepted | ||
], | ||
) | ||
|
||
with aioresponses() as m: | ||
chat.app.register_action(retrieve_relevant_chunks, "retrieve_relevant_chunks") | ||
m.post( | ||
GOTITAI_API_URL, | ||
payload={ | ||
"hallucination": "yes", | ||
}, | ||
) | ||
|
||
chat >> "Do you ship within 2 days?" | ||
await chat.bot_async("I don't know the answer to that.") | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_not_hallucination(monkeypatch): | ||
monkeypatch.setenv("GOTITAI_API_KEY", "xxx") | ||
config = RailsConfig.from_path(os.path.join(CONFIGS_FOLDER, "gotitai_truthchecker")) | ||
chat = TestChat( | ||
config, | ||
llm_completions=[ | ||
# " express greeting", | ||
"user ask general question", # user intent | ||
"No, shipping takes at least 3 days.", # bot response that will not be intercepted | ||
], | ||
) | ||
|
||
with aioresponses() as m: | ||
chat.app.register_action(retrieve_relevant_chunks, "retrieve_relevant_chunks") | ||
m.post( | ||
GOTITAI_API_URL, | ||
payload={ | ||
"hallucination": "no", | ||
}, | ||
) | ||
|
||
chat >> "Do you ship within 2 days?" | ||
await chat.bot_async("No, shipping takes at least 3 days.") | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_no_context(monkeypatch): | ||
monkeypatch.setenv("GOTITAI_API_KEY", "xxx") | ||
config = RailsConfig.from_path(os.path.join(CONFIGS_FOLDER, "gotitai_truthchecker")) | ||
chat = TestChat( | ||
config, | ||
llm_completions=[ | ||
# " express greeting", | ||
"user ask general question", # user intent | ||
"Yes, shipping can be done in 2 days.", # bot response that will not be intercepted | ||
], | ||
) | ||
|
||
with aioresponses() as m: | ||
m.post( | ||
GOTITAI_API_URL, | ||
payload={ | ||
"hallucination": None, | ||
}, | ||
) | ||
|
||
chat >> "Do you ship within 2 days?" | ||
await chat.bot_async("Yes, shipping can be done in 2 days.") |