Skip to content

Commit

Permalink
fix: no trend analysis possible when no data is found
Browse files Browse the repository at this point in the history
  • Loading branch information
nreinartz committed Nov 27, 2023
1 parent 18d20ff commit 69b858e
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 12 deletions.
1 change: 1 addition & 0 deletions src/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class SearchResults:
raw_per_year: list[float]
adjusted: list[float]
pub_types: dict[str, int]
adjusted_cutoff: float | None = None


@dataclass
Expand Down
38 changes: 29 additions & 9 deletions src/query_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,28 @@ async def process_query(uuid: str, query_repo: QueryRepository,
await query_repo.update_query_progress(entry.uuid, QueryProgress.FINISHED)


async def __fetch_data(entry: QueryEntry, weaviate_accessor: WeaviateAccessor, data_statistics: DataStatistics):
async def __fetch_data(query_repo: QueryRepository, entry: QueryEntry, weaviate_accessor: WeaviateAccessor, data_statistics: DataStatistics):
num_pubs_found = 0
adjusted_cutoff = entry.cutoff

while num_pubs_found < 500:
per_year = await run_in_threadpool(
lambda: weaviate_accessor.get_publications_per_year(
entry.topics, adjusted_cutoff, entry.start_year, entry.end_year)
)
num_pubs_found = sum(per_year.values())

if num_pubs_found < 500:
adjusted_cutoff = adjusted_cutoff - 0.01
print("Adjusting cutoff to ", adjusted_cutoff)

await query_repo.update_query_entry(entry)

pub_objects = await run_in_threadpool(
lambda: weaviate_accessor.get_publications_per_year_adjusted(
entry.topics, data_statistics.publications_per_year, entry.start_year, entry.end_year)
)

per_year = await run_in_threadpool(
lambda: weaviate_accessor.get_publications_per_year(
entry.topics, entry.cutoff, entry.start_year, entry.end_year)
)

year_value_pairs = {year: []
for year in range(entry.start_year, entry.end_year + 1)}
pub_type_count = {}
Expand All @@ -69,7 +80,7 @@ async def __fetch_data(entry: QueryEntry, weaviate_accessor: WeaviateAccessor, d
np.mean(year_value_pairs[year]) for year in range(entry.start_year, entry.end_year + 1)
]

clamped_values = np.maximum(raw_values, entry.cutoff)
clamped_values = np.maximum(raw_values, adjusted_cutoff)

if np.max(clamped_values) > np.min(clamped_values):
adjusted_values = np.round(100 * (np.array(clamped_values) - np.min(clamped_values)) / (
Expand All @@ -81,7 +92,16 @@ async def __fetch_data(entry: QueryEntry, weaviate_accessor: WeaviateAccessor, d
per_year_values = [per_year[year]
for year in range(entry.start_year, entry.end_year + 1)]

return SearchResults(raw=raw_values, raw_per_year=per_year_values, adjusted=adjusted_values, pub_types=pub_type_count)
entry.results.search_results = SearchResults(
raw=raw_values,
raw_per_year=per_year_values,
adjusted=adjusted_values,
pub_types=pub_type_count,
adjusted_cutoff=adjusted_cutoff if adjusted_cutoff != entry.cutoff else None
)

# Return entry here since we updates properties
return entry


async def __analyse_trends(query_repo: QueryRepository, entry: QueryEntry, trend_analyser: TrendAnalyser,
Expand All @@ -92,7 +112,7 @@ async def __analyse_trends(query_repo: QueryRepository, entry: QueryEntry, trend
await query_repo.update_query_entry(entry)

try:
entry.results.search_results = await __fetch_data(entry, weaviate_accessor, data_statistics)
entry = await __fetch_data(query_repo, entry, weaviate_accessor, data_statistics)
except Exception as e:
await query_repo.update_query_progress(entry.uuid, QueryProgress.FAILED)
raise e
Expand Down
5 changes: 4 additions & 1 deletion src/trend/analysis/mlr_time_series_segmenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ class MlrTimeSeriesSegmenter(BaseTimeSeriesSegmenter):
def __init__(self, min_segment_length: int = 4):
super().__init__(min_segment_length)

def segment(self, x, y: list[float] | list[int]) -> (list[tuple], list[dict]):
def segment(self, x, y: list[float] | list[int]) -> list[int]:
if all(val == 0 for val in y):
return []

x_copy, y_copy = x.copy(), y.copy()
while y_copy[0] == 0 and y_copy[1] == 0:
y_copy.pop(0)
Expand Down
11 changes: 9 additions & 2 deletions src/trend/analysis/trend_analyser.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,21 @@ def get_trend_analyser():
class TrendAnalyser:
def analyse(self, x, y) -> (list[int], list[Trend]):
time_series_segmenter = MlrTimeSeriesSegmenter(min_segment_length=4)

if all(val == 0 for val in y):
return [], self.__get_trends_for_segments(x, y, [(0, len(x) - 1)])

y_adjusted = (y / np.max(y)) * 100
cuts = time_series_segmenter.segment(x, y)

if len(cuts) == 0:
return [], self.__get_trends_for_segments(x, y_adjusted, [(0, len(x) - 1)])

cuts_i = [0] + [x.index(cut) for cut in cuts] + [len(x) - 1]

# Sub-trends
segments = [(cuts_i[i], cut) for i, cut in enumerate(cuts_i[1:])]

y_adjusted = (y / np.max(y)) * 100

trends = self.__get_trends_for_segments(x, y_adjusted, segments)
trend_slopes = [trends[i].type.value for i in range(len(trends))]

Expand Down
3 changes: 3 additions & 0 deletions src/trend/descriptor/rule_based_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,9 @@ def generate_description(self, topics: list[str], start_year: int,
end_year: int, values: list[int],
global_trend: Trend, sub_trends: list[Trend]) -> str:

if len(sub_trends) == 0:
return "No trends were detected."

segments = [Segment(t.start, t.end, t.type, t.slope, values[t.start - start_year:t.end - start_year + 1])
for t in sub_trends]
descriptions = [segments[0].describe()]
Expand Down

0 comments on commit 69b858e

Please sign in to comment.