From 3a5366e793adaf499f585665134e649146ebfe34 Mon Sep 17 00:00:00 2001 From: Kartik Ohri Date: Wed, 8 Jan 2025 21:54:50 +0530 Subject: [PATCH] refactor user stats class --- listenbrainz_spark/path.py | 4 +- .../stats/incremental/user/artist.py | 12 +- .../stats/incremental/user/entity.py | 147 +++++------------- .../stats/incremental/user/recording.py | 26 ++-- .../stats/incremental/user/release.py | 22 +-- .../stats/incremental/user/release_group.py | 22 +-- 6 files changed, 79 insertions(+), 154 deletions(-) diff --git a/listenbrainz_spark/path.py b/listenbrainz_spark/path.py index 7f9c2d48bb..e83d30b574 100644 --- a/listenbrainz_spark/path.py +++ b/listenbrainz_spark/path.py @@ -5,11 +5,9 @@ LISTENBRAINZ_INTERMEDIATE_STATS_DIRECTORY = os.path.join('/', 'data', 'stats-new') -LISTENBRAINZ_USER_STATS_AGG_DIRECTORY = os.path.join('/', 'user_stats_aggregates') -LISTENBRAINZ_USER_STATS_BOOKKEEPING_DIRECTORY = os.path.join('/', 'user_stats_bookkeeping') - LISTENBRAINZ_BASE_STATS_DIRECTORY = os.path.join('/', 'stats') LISTENBRAINZ_SITEWIDE_STATS_DIRECTORY = os.path.join(LISTENBRAINZ_BASE_STATS_DIRECTORY, 'sitewide') +LISTENBRAINZ_USER_STATS_DIRECTORY = os.path.join(LISTENBRAINZ_SITEWIDE_STATS_DIRECTORY, 'user') # MLHD+ dump files MLHD_PLUS_RAW_DATA_DIRECTORY = os.path.join("/", "mlhd-raw") diff --git a/listenbrainz_spark/stats/incremental/user/artist.py b/listenbrainz_spark/stats/incremental/user/artist.py index c892854680..1e334db22d 100644 --- a/listenbrainz_spark/stats/incremental/user/artist.py +++ b/listenbrainz_spark/stats/incremental/user/artist.py @@ -9,18 +9,18 @@ class ArtistUserEntity(UserEntity): - def __init__(self): - super().__init__(entity="artists") + def __init__(self, stats_range): + super().__init__(entity="artists", stats_range=stats_range) def get_cache_tables(self) -> List[str]: return [ARTIST_COUNTRY_CODE_DATAFRAME] def get_partial_aggregate_schema(self): return StructType([ - StructField('user_id', IntegerType(), nullable=False), - StructField('artist_name', StringType(), nullable=False), - StructField('artist_mbid', StringType(), nullable=True), - StructField('listen_count', IntegerType(), nullable=False), + StructField("user_id", IntegerType(), nullable=False), + StructField("artist_name", StringType(), nullable=False), + StructField("artist_mbid", StringType(), nullable=True), + StructField("listen_count", IntegerType(), nullable=False), ]) def aggregate(self, table, cache_tables): diff --git a/listenbrainz_spark/stats/incremental/user/entity.py b/listenbrainz_spark/stats/incremental/user/entity.py index 18fa39b4b0..ab4758811f 100644 --- a/listenbrainz_spark/stats/incremental/user/entity.py +++ b/listenbrainz_spark/stats/incremental/user/entity.py @@ -1,46 +1,22 @@ import abc import logging -from datetime import datetime -from pathlib import Path -from typing import List - -from pyspark.errors import AnalysisException -from pyspark.sql import DataFrame -from pyspark.sql.types import StructType, StructField, TimestampType - -import listenbrainz_spark -from listenbrainz_spark import hdfs_connection -from listenbrainz_spark.config import HDFS_CLUSTER_URI -from listenbrainz_spark.path import INCREMENTAL_DUMPS_SAVE_PATH, \ - LISTENBRAINZ_USER_STATS_AGG_DIRECTORY, LISTENBRAINZ_USER_STATS_BOOKKEEPING_DIRECTORY + +from listenbrainz_spark.path import LISTENBRAINZ_USER_STATS_DIRECTORY from listenbrainz_spark.stats import run_query -from listenbrainz_spark.utils import read_files_from_HDFS, get_listens_from_dump +from listenbrainz_spark.stats.incremental import IncrementalStats +from listenbrainz_spark.utils import read_files_from_HDFS logger = logging.getLogger(__name__) -BOOKKEEPING_SCHEMA = StructType([ - StructField('from_date', TimestampType(), nullable=False), - StructField('to_date', TimestampType(), nullable=False), - StructField('created', TimestampType(), nullable=False), -]) - -class UserEntity(abc.ABC): - - def __init__(self, entity): - self.entity = entity - - def get_existing_aggregate_path(self, stats_range) -> str: - return f"{LISTENBRAINZ_USER_STATS_AGG_DIRECTORY}/{self.entity}/{stats_range}" - def get_bookkeeping_path(self, stats_range) -> str: - return f"{LISTENBRAINZ_USER_STATS_BOOKKEEPING_DIRECTORY}/{self.entity}/{stats_range}" +class UserEntity(IncrementalStats, abc.ABC): - def get_partial_aggregate_schema(self) -> StructType: - raise NotImplementedError() + def get_base_path(self) -> str: + return LISTENBRAINZ_USER_STATS_DIRECTORY - def aggregate(self, table, cache_tables) -> DataFrame: - raise NotImplementedError() + def get_table_prefix(self) -> str: + return f"user_{self.entity}_{self.stats_range}" def filter_existing_aggregate(self, existing_aggregate, incremental_aggregate): query = f""" @@ -53,88 +29,39 @@ def filter_existing_aggregate(self, existing_aggregate, incremental_aggregate): """ return run_query(query) - def combine_aggregates(self, existing_aggregate, incremental_aggregate) -> DataFrame: - raise NotImplementedError() - - def get_top_n(self, final_aggregate, N) -> DataFrame: - raise NotImplementedError() - - def get_cache_tables(self) -> List[str]: - raise NotImplementedError() - - def generate_stats(self, stats_range: str, from_date: datetime, - to_date: datetime, top_entity_limit: int): - cache_tables = [] - for idx, df_path in enumerate(self.get_cache_tables()): - df_name = f"entity_data_cache_{idx}" - cache_tables.append(df_name) - read_files_from_HDFS(df_path).createOrReplaceTempView(df_name) - - metadata_path = self.get_bookkeeping_path(stats_range) - try: - metadata = listenbrainz_spark \ - .session \ - .read \ - .schema(BOOKKEEPING_SCHEMA) \ - .json(f"{HDFS_CLUSTER_URI}{metadata_path}") \ - .collect()[0] - existing_from_date, existing_to_date = metadata["from_date"], metadata["to_date"] - existing_aggregate_usable = existing_from_date.date() == from_date.date() - except AnalysisException: - existing_aggregate_usable = False - logger.info("Existing partial aggregate not found!") - - prefix = f"user_{self.entity}_{stats_range}" - existing_aggregate_path = self.get_existing_aggregate_path(stats_range) - - only_inc_users = True - - if not hdfs_connection.client.status(existing_aggregate_path, strict=False) or not existing_aggregate_usable: - table = f"{prefix}_full_listens" - get_listens_from_dump(from_date, to_date, include_incremental=False).createOrReplaceTempView(table) - - logger.info("Creating partial aggregate from full dump listens") - hdfs_connection.client.makedirs(Path(existing_aggregate_path).parent) - full_df = self.aggregate(table, cache_tables) - full_df.write.mode("overwrite").parquet(existing_aggregate_path) - - hdfs_connection.client.makedirs(Path(metadata_path).parent) - metadata_df = listenbrainz_spark.session.createDataFrame( - [(from_date, to_date, datetime.now())], - schema=BOOKKEEPING_SCHEMA - ) - metadata_df.write.mode("overwrite").json(metadata_path) - only_inc_users = False - - full_df = read_files_from_HDFS(existing_aggregate_path) + def generate_stats(self, top_entity_limit: int): + self.setup_cache_tables() + prefix = self.get_table_prefix() - if hdfs_connection.client.status(INCREMENTAL_DUMPS_SAVE_PATH, strict=False): - table = f"{prefix}_incremental_listens" - read_files_from_HDFS(INCREMENTAL_DUMPS_SAVE_PATH) \ - .createOrReplaceTempView(table) - inc_df = self.aggregate(table, cache_tables) - else: - inc_df = listenbrainz_spark.session.createDataFrame([], schema=self.get_partial_aggregate_schema()) + if not self.partial_aggregate_usable(): + self.create_partial_aggregate() only_inc_users = False + else: + only_inc_users = True - full_table = f"{prefix}_existing_aggregate" - full_df.createOrReplaceTempView(full_table) + partial_df = read_files_from_HDFS(self.get_existing_aggregate_path()) + partial_table = f"{prefix}_existing_aggregate" + partial_df.createOrReplaceTempView(partial_table) - inc_table = f"{prefix}_incremental_aggregate" - inc_df.createOrReplaceTempView(inc_table) + if self.incremental_dump_exists(): + inc_df = self.create_incremental_aggregate() + inc_table = f"{prefix}_incremental_aggregate" + inc_df.createOrReplaceTempView(inc_table) - if only_inc_users: - existing_table = f"{prefix}_filtered_aggregate" - filtered_aggregate_df = self.filter_existing_aggregate(full_table, inc_table) - filtered_aggregate_df.createOrReplaceTempView(existing_table) + if only_inc_users: + filtered_aggregate_df = self.filter_existing_aggregate(partial_table, inc_table) + filtered_table = f"{prefix}_filtered_aggregate" + filtered_aggregate_df.createOrReplaceTempView(filtered_table) + else: + filtered_table = partial_table + + final_df = self.combine_aggregates(filtered_table, inc_table) else: - existing_table = full_table + final_df = partial_df + only_inc_users = False - combined_df = self.combine_aggregates(existing_table, inc_table) - - combined_table = f"{prefix}_combined_aggregate" - combined_df.createOrReplaceTempView(combined_table) - results_df = self.get_top_n(combined_table, top_entity_limit) + final_table = f"{prefix}_final_aggregate" + final_df.createOrReplaceTempView(final_table) - return only_inc_users, results_df.toLocalIterator() - \ No newline at end of file + results_df = self.get_top_n(final_table, top_entity_limit) + return self.from_date, self.to_date, only_inc_users, results_df.toLocalIterator() diff --git a/listenbrainz_spark/stats/incremental/user/recording.py b/listenbrainz_spark/stats/incremental/user/recording.py index a04840ac6b..08bac15b68 100644 --- a/listenbrainz_spark/stats/incremental/user/recording.py +++ b/listenbrainz_spark/stats/incremental/user/recording.py @@ -11,25 +11,25 @@ class RecordingUserEntity(UserEntity): - def __init__(self): - super().__init__(entity="recordings") + def __init__(self, stats_range): + super().__init__(entity="recordings", stats_range=stats_range) def get_cache_tables(self) -> List[str]: return [RECORDING_ARTIST_DATAFRAME, RELEASE_METADATA_CACHE_DATAFRAME] def get_partial_aggregate_schema(self): return StructType([ - StructField('user_id', IntegerType(), nullable=False), - StructField('recording_name', StringType(), nullable=False), - StructField('recording_mbid', StringType(), nullable=True), - StructField('artist_name', StringType(), nullable=False), - StructField('artist_credit_mbids', ArrayType(StringType()), nullable=True), - StructField('release_name', StringType(), nullable=True), - StructField('release_mbid', StringType(), nullable=True), - StructField('artists', artists_column_schema, nullable=True), - StructField('caa_id', IntegerType(), nullable=True), - StructField('caa_release_mbid', StringType(), nullable=True), - StructField('listen_count', IntegerType(), nullable=False), + StructField("user_id", IntegerType(), nullable=False), + StructField("recording_name", StringType(), nullable=False), + StructField("recording_mbid", StringType(), nullable=True), + StructField("artist_name", StringType(), nullable=False), + StructField("artist_credit_mbids", ArrayType(StringType()), nullable=True), + StructField("release_name", StringType(), nullable=True), + StructField("release_mbid", StringType(), nullable=True), + StructField("artists", artists_column_schema, nullable=True), + StructField("caa_id", IntegerType(), nullable=True), + StructField("caa_release_mbid", StringType(), nullable=True), + StructField("listen_count", IntegerType(), nullable=False), ]) def aggregate(self, table, cache_tables): diff --git a/listenbrainz_spark/stats/incremental/user/release.py b/listenbrainz_spark/stats/incremental/user/release.py index fd0886b662..f5e290f4fe 100644 --- a/listenbrainz_spark/stats/incremental/user/release.py +++ b/listenbrainz_spark/stats/incremental/user/release.py @@ -10,23 +10,23 @@ class ReleaseUserEntity(UserEntity): - def __init__(self): - super().__init__(entity="releases") + def __init__(self, stats_range): + super().__init__(entity="releases", stats_range=stats_range) def get_cache_tables(self) -> List[str]: return [RELEASE_METADATA_CACHE_DATAFRAME] def get_partial_aggregate_schema(self): return StructType([ - StructField('user_id', IntegerType(), nullable=False), - StructField('release_name', StringType(), nullable=False), - StructField('release_mbid', StringType(), nullable=False), - StructField('artist_name', StringType(), nullable=False), - StructField('artist_credit_mbids', ArrayType(StringType()), nullable=False), - StructField('artists', artists_column_schema, nullable=True), - StructField('caa_id', IntegerType(), nullable=True), - StructField('caa_release_mbid', StringType(), nullable=True), - StructField('listen_count', IntegerType(), nullable=False), + StructField("user_id", IntegerType(), nullable=False), + StructField("release_name", StringType(), nullable=False), + StructField("release_mbid", StringType(), nullable=False), + StructField("artist_name", StringType(), nullable=False), + StructField("artist_credit_mbids", ArrayType(StringType()), nullable=False), + StructField("artists", artists_column_schema, nullable=True), + StructField("caa_id", IntegerType(), nullable=True), + StructField("caa_release_mbid", StringType(), nullable=True), + StructField("listen_count", IntegerType(), nullable=False), ]) def aggregate(self, table, cache_tables): diff --git a/listenbrainz_spark/stats/incremental/user/release_group.py b/listenbrainz_spark/stats/incremental/user/release_group.py index 115834913f..aa448d8f0d 100644 --- a/listenbrainz_spark/stats/incremental/user/release_group.py +++ b/listenbrainz_spark/stats/incremental/user/release_group.py @@ -11,23 +11,23 @@ class ReleaseGroupUserEntity(UserEntity): - def __init__(self): - super().__init__(entity="release_groups") + def __init__(self, stats_range): + super().__init__(entity="release_groups", stats_range=stats_range) def get_cache_tables(self) -> List[str]: return [RELEASE_METADATA_CACHE_DATAFRAME, RELEASE_GROUP_METADATA_CACHE_DATAFRAME] def get_partial_aggregate_schema(self): return StructType([ - StructField('user_id', IntegerType(), nullable=False), - StructField('release_group_name', StringType(), nullable=False), - StructField('release_group_mbid', StringType(), nullable=False), - StructField('artist_name', StringType(), nullable=False), - StructField('artist_credit_mbids', ArrayType(StringType()), nullable=False), - StructField('artists', artists_column_schema, nullable=True), - StructField('caa_id', IntegerType(), nullable=True), - StructField('caa_release_mbid', StringType(), nullable=True), - StructField('listen_count', IntegerType(), nullable=False), + StructField("user_id", IntegerType(), nullable=False), + StructField("release_group_name", StringType(), nullable=False), + StructField("release_group_mbid", StringType(), nullable=False), + StructField("artist_name", StringType(), nullable=False), + StructField("artist_credit_mbids", ArrayType(StringType()), nullable=False), + StructField("artists", artists_column_schema, nullable=True), + StructField("caa_id", IntegerType(), nullable=True), + StructField("caa_release_mbid", StringType(), nullable=True), + StructField("listen_count", IntegerType(), nullable=False), ]) def aggregate(self, table, cache_tables):