Skip to content

Commit

Permalink
refactor user stats class
Browse files Browse the repository at this point in the history
  • Loading branch information
amCap1712 committed Jan 8, 2025
1 parent 2d7b294 commit 3a5366e
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 154 deletions.
4 changes: 1 addition & 3 deletions listenbrainz_spark/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@

LISTENBRAINZ_INTERMEDIATE_STATS_DIRECTORY = os.path.join('/', 'data', 'stats-new')

LISTENBRAINZ_USER_STATS_AGG_DIRECTORY = os.path.join('/', 'user_stats_aggregates')
LISTENBRAINZ_USER_STATS_BOOKKEEPING_DIRECTORY = os.path.join('/', 'user_stats_bookkeeping')

LISTENBRAINZ_BASE_STATS_DIRECTORY = os.path.join('/', 'stats')
LISTENBRAINZ_SITEWIDE_STATS_DIRECTORY = os.path.join(LISTENBRAINZ_BASE_STATS_DIRECTORY, 'sitewide')
LISTENBRAINZ_USER_STATS_DIRECTORY = os.path.join(LISTENBRAINZ_SITEWIDE_STATS_DIRECTORY, 'user')

# MLHD+ dump files
MLHD_PLUS_RAW_DATA_DIRECTORY = os.path.join("/", "mlhd-raw")
Expand Down
12 changes: 6 additions & 6 deletions listenbrainz_spark/stats/incremental/user/artist.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,18 @@

class ArtistUserEntity(UserEntity):

def __init__(self):
super().__init__(entity="artists")
def __init__(self, stats_range):
super().__init__(entity="artists", stats_range=stats_range)

def get_cache_tables(self) -> List[str]:
return [ARTIST_COUNTRY_CODE_DATAFRAME]

def get_partial_aggregate_schema(self):
return StructType([
StructField('user_id', IntegerType(), nullable=False),
StructField('artist_name', StringType(), nullable=False),
StructField('artist_mbid', StringType(), nullable=True),
StructField('listen_count', IntegerType(), nullable=False),
StructField("user_id", IntegerType(), nullable=False),
StructField("artist_name", StringType(), nullable=False),
StructField("artist_mbid", StringType(), nullable=True),
StructField("listen_count", IntegerType(), nullable=False),
])

def aggregate(self, table, cache_tables):
Expand Down
147 changes: 37 additions & 110 deletions listenbrainz_spark/stats/incremental/user/entity.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,22 @@
import abc
import logging
from datetime import datetime
from pathlib import Path
from typing import List

from pyspark.errors import AnalysisException
from pyspark.sql import DataFrame
from pyspark.sql.types import StructType, StructField, TimestampType

import listenbrainz_spark
from listenbrainz_spark import hdfs_connection
from listenbrainz_spark.config import HDFS_CLUSTER_URI
from listenbrainz_spark.path import INCREMENTAL_DUMPS_SAVE_PATH, \
LISTENBRAINZ_USER_STATS_AGG_DIRECTORY, LISTENBRAINZ_USER_STATS_BOOKKEEPING_DIRECTORY

from listenbrainz_spark.path import LISTENBRAINZ_USER_STATS_DIRECTORY
from listenbrainz_spark.stats import run_query
from listenbrainz_spark.utils import read_files_from_HDFS, get_listens_from_dump
from listenbrainz_spark.stats.incremental import IncrementalStats
from listenbrainz_spark.utils import read_files_from_HDFS


logger = logging.getLogger(__name__)
BOOKKEEPING_SCHEMA = StructType([
StructField('from_date', TimestampType(), nullable=False),
StructField('to_date', TimestampType(), nullable=False),
StructField('created', TimestampType(), nullable=False),
])


class UserEntity(abc.ABC):

def __init__(self, entity):
self.entity = entity

def get_existing_aggregate_path(self, stats_range) -> str:
return f"{LISTENBRAINZ_USER_STATS_AGG_DIRECTORY}/{self.entity}/{stats_range}"

def get_bookkeeping_path(self, stats_range) -> str:
return f"{LISTENBRAINZ_USER_STATS_BOOKKEEPING_DIRECTORY}/{self.entity}/{stats_range}"
class UserEntity(IncrementalStats, abc.ABC):

def get_partial_aggregate_schema(self) -> StructType:
raise NotImplementedError()
def get_base_path(self) -> str:
return LISTENBRAINZ_USER_STATS_DIRECTORY

def aggregate(self, table, cache_tables) -> DataFrame:
raise NotImplementedError()
def get_table_prefix(self) -> str:
return f"user_{self.entity}_{self.stats_range}"

def filter_existing_aggregate(self, existing_aggregate, incremental_aggregate):
query = f"""
Expand All @@ -53,88 +29,39 @@ def filter_existing_aggregate(self, existing_aggregate, incremental_aggregate):
"""
return run_query(query)

def combine_aggregates(self, existing_aggregate, incremental_aggregate) -> DataFrame:
raise NotImplementedError()

def get_top_n(self, final_aggregate, N) -> DataFrame:
raise NotImplementedError()

def get_cache_tables(self) -> List[str]:
raise NotImplementedError()

def generate_stats(self, stats_range: str, from_date: datetime,
to_date: datetime, top_entity_limit: int):
cache_tables = []
for idx, df_path in enumerate(self.get_cache_tables()):
df_name = f"entity_data_cache_{idx}"
cache_tables.append(df_name)
read_files_from_HDFS(df_path).createOrReplaceTempView(df_name)

metadata_path = self.get_bookkeeping_path(stats_range)
try:
metadata = listenbrainz_spark \
.session \
.read \
.schema(BOOKKEEPING_SCHEMA) \
.json(f"{HDFS_CLUSTER_URI}{metadata_path}") \
.collect()[0]
existing_from_date, existing_to_date = metadata["from_date"], metadata["to_date"]
existing_aggregate_usable = existing_from_date.date() == from_date.date()
except AnalysisException:
existing_aggregate_usable = False
logger.info("Existing partial aggregate not found!")

prefix = f"user_{self.entity}_{stats_range}"
existing_aggregate_path = self.get_existing_aggregate_path(stats_range)

only_inc_users = True

if not hdfs_connection.client.status(existing_aggregate_path, strict=False) or not existing_aggregate_usable:
table = f"{prefix}_full_listens"
get_listens_from_dump(from_date, to_date, include_incremental=False).createOrReplaceTempView(table)

logger.info("Creating partial aggregate from full dump listens")
hdfs_connection.client.makedirs(Path(existing_aggregate_path).parent)
full_df = self.aggregate(table, cache_tables)
full_df.write.mode("overwrite").parquet(existing_aggregate_path)

hdfs_connection.client.makedirs(Path(metadata_path).parent)
metadata_df = listenbrainz_spark.session.createDataFrame(
[(from_date, to_date, datetime.now())],
schema=BOOKKEEPING_SCHEMA
)
metadata_df.write.mode("overwrite").json(metadata_path)
only_inc_users = False

full_df = read_files_from_HDFS(existing_aggregate_path)
def generate_stats(self, top_entity_limit: int):
self.setup_cache_tables()
prefix = self.get_table_prefix()

if hdfs_connection.client.status(INCREMENTAL_DUMPS_SAVE_PATH, strict=False):
table = f"{prefix}_incremental_listens"
read_files_from_HDFS(INCREMENTAL_DUMPS_SAVE_PATH) \
.createOrReplaceTempView(table)
inc_df = self.aggregate(table, cache_tables)
else:
inc_df = listenbrainz_spark.session.createDataFrame([], schema=self.get_partial_aggregate_schema())
if not self.partial_aggregate_usable():
self.create_partial_aggregate()
only_inc_users = False
else:
only_inc_users = True

full_table = f"{prefix}_existing_aggregate"
full_df.createOrReplaceTempView(full_table)
partial_df = read_files_from_HDFS(self.get_existing_aggregate_path())
partial_table = f"{prefix}_existing_aggregate"
partial_df.createOrReplaceTempView(partial_table)

inc_table = f"{prefix}_incremental_aggregate"
inc_df.createOrReplaceTempView(inc_table)
if self.incremental_dump_exists():
inc_df = self.create_incremental_aggregate()
inc_table = f"{prefix}_incremental_aggregate"
inc_df.createOrReplaceTempView(inc_table)

if only_inc_users:
existing_table = f"{prefix}_filtered_aggregate"
filtered_aggregate_df = self.filter_existing_aggregate(full_table, inc_table)
filtered_aggregate_df.createOrReplaceTempView(existing_table)
if only_inc_users:
filtered_aggregate_df = self.filter_existing_aggregate(partial_table, inc_table)
filtered_table = f"{prefix}_filtered_aggregate"
filtered_aggregate_df.createOrReplaceTempView(filtered_table)
else:
filtered_table = partial_table

final_df = self.combine_aggregates(filtered_table, inc_table)
else:
existing_table = full_table
final_df = partial_df
only_inc_users = False

combined_df = self.combine_aggregates(existing_table, inc_table)

combined_table = f"{prefix}_combined_aggregate"
combined_df.createOrReplaceTempView(combined_table)
results_df = self.get_top_n(combined_table, top_entity_limit)
final_table = f"{prefix}_final_aggregate"
final_df.createOrReplaceTempView(final_table)

return only_inc_users, results_df.toLocalIterator()

results_df = self.get_top_n(final_table, top_entity_limit)
return self.from_date, self.to_date, only_inc_users, results_df.toLocalIterator()
26 changes: 13 additions & 13 deletions listenbrainz_spark/stats/incremental/user/recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,25 @@

class RecordingUserEntity(UserEntity):

def __init__(self):
super().__init__(entity="recordings")
def __init__(self, stats_range):
super().__init__(entity="recordings", stats_range=stats_range)

def get_cache_tables(self) -> List[str]:
return [RECORDING_ARTIST_DATAFRAME, RELEASE_METADATA_CACHE_DATAFRAME]

def get_partial_aggregate_schema(self):
return StructType([
StructField('user_id', IntegerType(), nullable=False),
StructField('recording_name', StringType(), nullable=False),
StructField('recording_mbid', StringType(), nullable=True),
StructField('artist_name', StringType(), nullable=False),
StructField('artist_credit_mbids', ArrayType(StringType()), nullable=True),
StructField('release_name', StringType(), nullable=True),
StructField('release_mbid', StringType(), nullable=True),
StructField('artists', artists_column_schema, nullable=True),
StructField('caa_id', IntegerType(), nullable=True),
StructField('caa_release_mbid', StringType(), nullable=True),
StructField('listen_count', IntegerType(), nullable=False),
StructField("user_id", IntegerType(), nullable=False),
StructField("recording_name", StringType(), nullable=False),
StructField("recording_mbid", StringType(), nullable=True),
StructField("artist_name", StringType(), nullable=False),
StructField("artist_credit_mbids", ArrayType(StringType()), nullable=True),
StructField("release_name", StringType(), nullable=True),
StructField("release_mbid", StringType(), nullable=True),
StructField("artists", artists_column_schema, nullable=True),
StructField("caa_id", IntegerType(), nullable=True),
StructField("caa_release_mbid", StringType(), nullable=True),
StructField("listen_count", IntegerType(), nullable=False),
])

def aggregate(self, table, cache_tables):
Expand Down
22 changes: 11 additions & 11 deletions listenbrainz_spark/stats/incremental/user/release.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,23 @@

class ReleaseUserEntity(UserEntity):

def __init__(self):
super().__init__(entity="releases")
def __init__(self, stats_range):
super().__init__(entity="releases", stats_range=stats_range)

def get_cache_tables(self) -> List[str]:
return [RELEASE_METADATA_CACHE_DATAFRAME]

def get_partial_aggregate_schema(self):
return StructType([
StructField('user_id', IntegerType(), nullable=False),
StructField('release_name', StringType(), nullable=False),
StructField('release_mbid', StringType(), nullable=False),
StructField('artist_name', StringType(), nullable=False),
StructField('artist_credit_mbids', ArrayType(StringType()), nullable=False),
StructField('artists', artists_column_schema, nullable=True),
StructField('caa_id', IntegerType(), nullable=True),
StructField('caa_release_mbid', StringType(), nullable=True),
StructField('listen_count', IntegerType(), nullable=False),
StructField("user_id", IntegerType(), nullable=False),
StructField("release_name", StringType(), nullable=False),
StructField("release_mbid", StringType(), nullable=False),
StructField("artist_name", StringType(), nullable=False),
StructField("artist_credit_mbids", ArrayType(StringType()), nullable=False),
StructField("artists", artists_column_schema, nullable=True),
StructField("caa_id", IntegerType(), nullable=True),
StructField("caa_release_mbid", StringType(), nullable=True),
StructField("listen_count", IntegerType(), nullable=False),
])

def aggregate(self, table, cache_tables):
Expand Down
22 changes: 11 additions & 11 deletions listenbrainz_spark/stats/incremental/user/release_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,23 @@

class ReleaseGroupUserEntity(UserEntity):

def __init__(self):
super().__init__(entity="release_groups")
def __init__(self, stats_range):
super().__init__(entity="release_groups", stats_range=stats_range)

def get_cache_tables(self) -> List[str]:
return [RELEASE_METADATA_CACHE_DATAFRAME, RELEASE_GROUP_METADATA_CACHE_DATAFRAME]

def get_partial_aggregate_schema(self):
return StructType([
StructField('user_id', IntegerType(), nullable=False),
StructField('release_group_name', StringType(), nullable=False),
StructField('release_group_mbid', StringType(), nullable=False),
StructField('artist_name', StringType(), nullable=False),
StructField('artist_credit_mbids', ArrayType(StringType()), nullable=False),
StructField('artists', artists_column_schema, nullable=True),
StructField('caa_id', IntegerType(), nullable=True),
StructField('caa_release_mbid', StringType(), nullable=True),
StructField('listen_count', IntegerType(), nullable=False),
StructField("user_id", IntegerType(), nullable=False),
StructField("release_group_name", StringType(), nullable=False),
StructField("release_group_mbid", StringType(), nullable=False),
StructField("artist_name", StringType(), nullable=False),
StructField("artist_credit_mbids", ArrayType(StringType()), nullable=False),
StructField("artists", artists_column_schema, nullable=True),
StructField("caa_id", IntegerType(), nullable=True),
StructField("caa_release_mbid", StringType(), nullable=True),
StructField("listen_count", IntegerType(), nullable=False),
])

def aggregate(self, table, cache_tables):
Expand Down

0 comments on commit 3a5366e

Please sign in to comment.