Skip to content

Commit

Permalink
test for new parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
aflament committed Jan 29, 2024
1 parent 63f02a3 commit 04287f0
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 20 deletions.
38 changes: 27 additions & 11 deletions council/skills/google/google_context/google_news.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import datetime
import logging

from typing import Optional, List, Any
Expand All @@ -20,19 +21,29 @@ class GoogleNewsSearchEngine(ContextProvider):

suffix: str = ""

def __init__(self, period: str, suffix: str):
def __init__(
self, period: Optional[str], suffix: str, start: Optional[datetime] = None, end: Optional[datetime] = None
):
super().__init__("google name")
self.google_news = GoogleNews(period=period)
self.google_news = GoogleNews()
if period is not None:
self.google_news.set_period(period)
elif start is not None:
if end is None:
end = datetime.now()
self.google_news.set_time_range(start.strftime("%Y-%m-%d"), end.strftime("%Y-%m-%d"))

self.suffix = suffix
self.google_news.enableException(enable=True)
self._max_page_num = 5

def execute_impl(self, query: str, nb_results: int) -> list[ResponseReference]:
self.google_news.clear()
results: List[Any] = []
try:
self.google_news.search(f"{query} {self.suffix}".replace(" ", "+"))
page_num = 1
while len(results) < nb_results:
while len(results) < nb_results and page_num <= self._max_page_num:
self.google_news.get_page(page_num)
google_news_result = self.google_news.results()
if len(google_news_result) == 0:
Expand All @@ -57,11 +68,16 @@ def execute_impl(self, query: str, nb_results: int) -> list[ResponseReference]:

@staticmethod
def from_result(result: dict) -> Optional[ResponseReference]:
title = result.get("title", None)
url = result.get("link", None)
if title is not None and url is not None:
snippet = result.get("desc", None)
date = result.get("date", None)
return ResponseReference(title=title, url=url, snippet=snippet, date=date)

return None
title: Optional[str] = result.get("title", None)
if title is None or title == "":
return None
url: Optional[str] = result.get("link", None)
if url is None or url == "":
return None

date = result.get("date", None)
if date is None or date == "":
return None

snippet = result.get("desc", None)
return ResponseReference(title=title, url=url, snippet=snippet, date=date)
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ anthropic>=0.5.0
# Skills
## Google
google-api-python-client==2.106.0
GoogleNews==1.6.10
GoogleNews>=1.6.10
google-api-python-client-stubs==1.18.0
pymediawiki~=0.7.3
beautifulsoup4~=4.12.2
Expand Down
52 changes: 44 additions & 8 deletions tests/integration/skills/test_google_skills.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
from datetime import datetime, timedelta

import dotenv

Expand All @@ -14,13 +15,13 @@ class TestBase(unittest.TestCase):
def setUp(self) -> None:
dotenv.load_dotenv()

def test_gnews(self):
def test_gnews_engine(self):
expected = 8
gn = GoogleNewsSearchEngine(period="90d", suffix="Finance")
resp = gn.execute(query="USD", nb_results=expected)
self.assertEqual(len(resp), expected)

def test_gsearch(self):
def test_gsearch_engine(self):
expected = 8
gn = GoogleSearchEngine.from_env()
resp = gn.execute(query="USD", nb_results=expected)
Expand All @@ -30,24 +31,32 @@ def test_gnews_skill(self):
context = ChainContext.from_user_message("USD", Budget(duration=10))
context.chat_history.add_user_message("EUR")

skill = GoogleNewsSkill(suffix="Finance")
expected_result_count = 4
skill = GoogleNewsSkill(suffix="Finance", nb_results=expected_result_count, period="15d")
result = skill.execute(SkillContext.from_chain_context(context, Option.none()))
self.assertTrue(result.is_ok)
self.assertIn("EUR", result.message)
for d in json.loads(result.data):

json_loads = json.loads(result.data)
self.assertLessEqual(expected_result_count, len(json_loads))
for d in json_loads:
self.assertGreater(len(d["title"]), 0)
self.assertGreater(len(d["url"]), 0)
self.assertEqual(len(d["snippet"]), 0)
self.assertIsNotNone(d["date"])
self.assertTrue(is_within_period(d["date"], 15))

def test_gsearch_skill(self):
context = ChainContext.from_user_message("USD", budget=Budget(duration=10))

skill = GoogleSearchSkill()
expected_result_count = 7
skill = GoogleSearchSkill(nb_results=expected_result_count)
result = skill.execute(SkillContext.from_chain_context(context, Option.none()))

self.assertTrue(result.is_ok)
self.assertIn("USD", result.message)
for d in json.loads(result.data):

json_loads = json.loads(result.data)
self.assertEqual(expected_result_count, len(json_loads))
for d in json_loads:
self.assertGreater(len(d["title"]), 0)
self.assertGreater(len(d["url"]), 0)
self.assertGreater(len(d["snippet"]), 0)
Expand All @@ -61,3 +70,30 @@ def test_skill_no_message(self):
self.assertTrue(False)
except OptionException as oe:
self.assertTrue(oe)


def parse_relative_timestamp(timestamp: str) -> datetime:
if "minute" in timestamp:
minutes_ago = int(timestamp.split()[0])
return datetime.now() - timedelta(minutes=minutes_ago)
if "hour" in timestamp:
hours_ago = int(timestamp.split()[0])
return datetime.now() - timedelta(hours=hours_ago)
if "day" in timestamp:
days_ago = int(timestamp.split()[0])
return datetime.now() - timedelta(days=days_ago)
if "week" in timestamp:
weeks_ago = int(timestamp.split()[0])
return datetime.now() - timedelta(weeks=weeks_ago)
else:
raise ValueError("Unsupported relative timestamp format")


def is_within_period(result: str, period: int) -> bool:
if "ago" in result:
publication_date = parse_relative_timestamp(result)
else:
publication_date = datetime.strptime(result, "%Y-%m-%d")

delta = datetime.now() - publication_date
return delta <= timedelta(days=period)

0 comments on commit 04287f0

Please sign in to comment.