Skip to content

Commit

Permalink
Refactor create messages and stats validation into class
Browse files Browse the repository at this point in the history
  • Loading branch information
amCap1712 committed Jan 10, 2025
1 parent ba1f8b6 commit f376920
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 246 deletions.
4 changes: 1 addition & 3 deletions listenbrainz_spark/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
LISTENBRAINZ_BASE_STATS_DIRECTORY = os.path.join('/', 'stats')
LISTENBRAINZ_SITEWIDE_STATS_DIRECTORY = os.path.join(LISTENBRAINZ_BASE_STATS_DIRECTORY, 'sitewide')
LISTENBRAINZ_USER_STATS_DIRECTORY = os.path.join(LISTENBRAINZ_BASE_STATS_DIRECTORY, 'user')

LISTENBRAINZ_LISTENER_STATS_AGG_DIRECTORY = os.path.join('/', 'listener_stats_aggregates')
LISTENBRAINZ_LISTENER_STATS_BOOKKEEPING_DIRECTORY = os.path.join('/', 'listener_stats_bookkeeping')
LISTENBRAINZ_LISTENER_STATS_DIRECTORY = os.path.join(LISTENBRAINZ_BASE_STATS_DIRECTORY, 'listener')

# MLHD+ dump files
MLHD_PLUS_RAW_DATA_DIRECTORY = os.path.join("/", "mlhd-raw")
Expand Down
25 changes: 8 additions & 17 deletions listenbrainz_spark/stats/incremental/listener/artist.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,35 +5,26 @@
from listenbrainz_spark.path import ARTIST_COUNTRY_CODE_DATAFRAME
from listenbrainz_spark.stats import run_query
from listenbrainz_spark.stats.incremental.listener.entity import EntityListener
from listenbrainz_spark.stats.incremental.user.entity import UserEntity


class ArtistEntityListener(EntityListener):

def __init__(self):
super().__init__(entity="artists")
def __init__(self, stats_range, database):
super().__init__(entity="artists", stats_range=stats_range, database=database, message_type="entity_listener")

def get_cache_tables(self) -> List[str]:
return [ARTIST_COUNTRY_CODE_DATAFRAME]

def get_partial_aggregate_schema(self):
return StructType([
StructField('artist_name', StringType(), nullable=False),
StructField('artist_mbid', StringType(), nullable=True),
StructField('user_id', IntegerType(), nullable=False),
StructField('listen_count', IntegerType(), nullable=False),
StructField("artist_name", StringType(), nullable=False),
StructField("artist_mbid", StringType(), nullable=True),
StructField("user_id", IntegerType(), nullable=False),
StructField("listen_count", IntegerType(), nullable=False),
])

def filter_existing_aggregate(self, existing_aggregate, incremental_aggregate):
query = f"""
WITH incremental_artists AS (
SELECT DISTINCT artist_mbid FROM {incremental_aggregate}
)
SELECT *
FROM {existing_aggregate} ea
WHERE EXISTS(SELECT 1 FROM incremental_artists iu WHERE iu.artist_mbid = ea.artist_mbid)
"""
return run_query(query)
def get_entity_id(self):
return "artist_mbid"

def aggregate(self, table, cache_tables):
cache_table = cache_tables[0]
Expand Down
136 changes: 18 additions & 118 deletions listenbrainz_spark/stats/incremental/listener/entity.py
Original file line number Diff line number Diff line change
@@ -1,132 +1,32 @@
import abc
import logging
from datetime import datetime
from pathlib import Path
from typing import List

from pyspark.errors import AnalysisException
from pyspark.sql import DataFrame
from pyspark.sql.types import StructType, StructField, TimestampType

import listenbrainz_spark
from listenbrainz_spark import hdfs_connection
from listenbrainz_spark.config import HDFS_CLUSTER_URI
from listenbrainz_spark.path import INCREMENTAL_DUMPS_SAVE_PATH, \
LISTENBRAINZ_LISTENER_STATS_AGG_DIRECTORY, LISTENBRAINZ_LISTENER_STATS_BOOKKEEPING_DIRECTORY
from listenbrainz_spark.stats import run_query
from listenbrainz_spark.utils import read_files_from_HDFS, get_listens_from_dump
from datetime import date
from typing import Optional

from listenbrainz_spark.path import LISTENBRAINZ_LISTENER_STATS_DIRECTORY
from listenbrainz_spark.stats.incremental.user.entity import UserEntity

logger = logging.getLogger(__name__)
BOOKKEEPING_SCHEMA = StructType([
StructField('from_date', TimestampType(), nullable=False),
StructField('to_date', TimestampType(), nullable=False),
StructField('created', TimestampType(), nullable=False),
])


class EntityListener(abc.ABC):

def __init__(self, entity):
self.entity = entity

def get_existing_aggregate_path(self, stats_range) -> str:
return f"{LISTENBRAINZ_LISTENER_STATS_AGG_DIRECTORY}/{self.entity}/{stats_range}"

def get_bookkeeping_path(self, stats_range) -> str:
return f"{LISTENBRAINZ_LISTENER_STATS_BOOKKEEPING_DIRECTORY}/{self.entity}/{stats_range}"
class EntityListener(UserEntity, abc.ABC):

def get_partial_aggregate_schema(self) -> StructType:
raise NotImplementedError()
def __init__(self, entity: str, stats_range: str, database: Optional[str], message_type: Optional[str]):
if not database:
database = f"{self.entity}_listeners_{self.stats_range}_{date.today().strftime('%Y%m%d')}"
super().__init__(entity, stats_range, database, message_type)

def aggregate(self, table, cache_tables) -> DataFrame:
raise NotImplementedError()
def get_table_prefix(self) -> str:
return f"{self.entity}_listener_{self.stats_range}"

def filter_existing_aggregate(self, existing_aggregate, incremental_aggregate):
raise NotImplementedError()
def get_base_path(self) -> str:
return LISTENBRAINZ_LISTENER_STATS_DIRECTORY

def combine_aggregates(self, existing_aggregate, incremental_aggregate) -> DataFrame:
def get_entity_id(self):
raise NotImplementedError()

def get_top_n(self, final_aggregate, N) -> DataFrame:
raise NotImplementedError()

def get_cache_tables(self) -> List[str]:
raise NotImplementedError()

def generate_stats(self, stats_range: str, from_date: datetime,
to_date: datetime, top_entity_limit: int):
cache_tables = []
for idx, df_path in enumerate(self.get_cache_tables()):
df_name = f"entity_data_cache_{idx}"
cache_tables.append(df_name)
read_files_from_HDFS(df_path).createOrReplaceTempView(df_name)

metadata_path = self.get_bookkeeping_path(stats_range)
try:
metadata = listenbrainz_spark \
.session \
.read \
.schema(BOOKKEEPING_SCHEMA) \
.json(f"{HDFS_CLUSTER_URI}{metadata_path}") \
.collect()[0]
existing_from_date, existing_to_date = metadata["from_date"], metadata["to_date"]
existing_aggregate_usable = existing_from_date.date() == from_date.date()
except AnalysisException:
existing_aggregate_usable = False
logger.info("Existing partial aggregate not found!")

prefix = f"entity_listener_{self.entity}_{stats_range}"
existing_aggregate_path = self.get_existing_aggregate_path(stats_range)

only_inc_entities = True

if not hdfs_connection.client.status(existing_aggregate_path, strict=False) or not existing_aggregate_usable:
table = f"{prefix}_full_listens"
get_listens_from_dump(from_date, to_date, include_incremental=False).createOrReplaceTempView(table)

logger.info("Creating partial aggregate from full dump listens")
hdfs_connection.client.makedirs(Path(existing_aggregate_path).parent)
full_df = self.aggregate(table, cache_tables)
full_df.write.mode("overwrite").parquet(existing_aggregate_path)

hdfs_connection.client.makedirs(Path(metadata_path).parent)
metadata_df = listenbrainz_spark.session.createDataFrame(
[(from_date, to_date, datetime.now())],
schema=BOOKKEEPING_SCHEMA
)
metadata_df.write.mode("overwrite").json(metadata_path)
only_inc_entities = False

full_df = read_files_from_HDFS(existing_aggregate_path)

if hdfs_connection.client.status(INCREMENTAL_DUMPS_SAVE_PATH, strict=False):
table = f"{prefix}_incremental_listens"
read_files_from_HDFS(INCREMENTAL_DUMPS_SAVE_PATH) \
.createOrReplaceTempView(table)
inc_df = self.aggregate(table, cache_tables)
else:
inc_df = listenbrainz_spark.session.createDataFrame([], schema=self.get_partial_aggregate_schema())
only_inc_entities = False

full_table = f"{prefix}_existing_aggregate"
full_df.createOrReplaceTempView(full_table)

inc_table = f"{prefix}_incremental_aggregate"
inc_df.createOrReplaceTempView(inc_table)

if only_inc_entities:
existing_table = f"{prefix}_filtered_aggregate"
filtered_aggregate_df = self.filter_existing_aggregate(full_table, inc_table)
filtered_aggregate_df.createOrReplaceTempView(existing_table)
else:
existing_table = full_table

combined_df = self.combine_aggregates(existing_table, inc_table)

combined_table = f"{prefix}_combined_aggregate"
combined_df.createOrReplaceTempView(combined_table)
results_df = self.get_top_n(combined_table, top_entity_limit)
def items_per_message(self):
return 10000

return only_inc_entities, results_df.toLocalIterator()

def parse_one_user_stats(self, entry: dict):
raise entry
38 changes: 16 additions & 22 deletions listenbrainz_spark/stats/incremental/listener/release_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,43 +2,37 @@

from pyspark.sql.types import StructType, StructField, StringType, IntegerType, ArrayType

from listenbrainz_spark.path import ARTIST_COUNTRY_CODE_DATAFRAME, RELEASE_METADATA_CACHE_DATAFRAME, \
from listenbrainz_spark.path import RELEASE_METADATA_CACHE_DATAFRAME, \
RELEASE_GROUP_METADATA_CACHE_DATAFRAME
from listenbrainz_spark.stats import run_query
from listenbrainz_spark.stats.incremental.listener.entity import EntityListener
from listenbrainz_spark.stats.incremental.user.entity import UserEntity


class ReleaseGroupEntityListener(EntityListener):

def __init__(self):
super().__init__(entity="release_groups")
def __init__(self, stats_range, database):
super().__init__(
entity="release_groups", stats_range=stats_range,
database=database, message_type="entity_listener"
)

def get_cache_tables(self) -> List[str]:
return [RELEASE_METADATA_CACHE_DATAFRAME, RELEASE_GROUP_METADATA_CACHE_DATAFRAME]

def get_partial_aggregate_schema(self):
return StructType([
StructField('release_group_mbid', StringType(), nullable=False),
StructField('release_group_name', StringType(), nullable=False),
StructField('release_group_artist_name', StringType(), nullable=False),
StructField('artist_credit_mbids', ArrayType(StringType()), nullable=False),
StructField('caa_id', IntegerType(), nullable=True),
StructField('caa_release_mbid', StringType(), nullable=True),
StructField('user_id', IntegerType(), nullable=False),
StructField('listen_count', IntegerType(), nullable=False),
StructField("release_group_mbid", StringType(), nullable=False),
StructField("release_group_name", StringType(), nullable=False),
StructField("release_group_artist_name", StringType(), nullable=False),
StructField("artist_credit_mbids", ArrayType(StringType()), nullable=False),
StructField("caa_id", IntegerType(), nullable=True),
StructField("caa_release_mbid", StringType(), nullable=True),
StructField("user_id", IntegerType(), nullable=False),
StructField("listen_count", IntegerType(), nullable=False),
])

def filter_existing_aggregate(self, existing_aggregate, incremental_aggregate):
query = f"""
WITH incremental_release_groups AS (
SELECT DISTINCT release_group_mbid FROM {incremental_aggregate}
)
SELECT *
FROM {existing_aggregate} ea
WHERE EXISTS(SELECT 1 FROM incremental_release_groups iu WHERE iu.release_group_mbid = ea.release_group_mbid)
"""
return run_query(query)
def get_entity_id(self):
return "release_group_mbid"

def aggregate(self, table, cache_tables):
rel_cache_table = cache_tables[0]
Expand Down
19 changes: 12 additions & 7 deletions listenbrainz_spark/stats/incremental/user/entity.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import abc
import json
import logging
from datetime import date, datetime
from typing import Optional, Iterator, Dict, Tuple
Expand All @@ -15,10 +14,8 @@
from listenbrainz_spark.path import LISTENBRAINZ_USER_STATS_DIRECTORY
from listenbrainz_spark.stats import run_query
from listenbrainz_spark.stats.incremental import IncrementalStats
from listenbrainz_spark.stats.user import USERS_PER_MESSAGE
from listenbrainz_spark.utils import read_files_from_HDFS


logger = logging.getLogger(__name__)

entity_model_map = {
Expand All @@ -31,7 +28,7 @@

class UserEntity(IncrementalStats, abc.ABC):

def __init__(self, entity: str, stats_range: str = None, database: str = None, message_type: str = None,
def __init__(self, entity: str, stats_range: str = None, database: str = None, message_type: str = None,
from_date: datetime = None, to_date: datetime = None):
super().__init__(entity, stats_range, from_date, to_date)
if database:
Expand All @@ -46,14 +43,22 @@ def get_base_path(self) -> str:
def get_table_prefix(self) -> str:
return f"user_{self.entity}_{self.stats_range}"

def get_entity_id(self):
return "user_id"

def items_per_message(self):
""" Get the number of items to chunk per message """
return 25

def filter_existing_aggregate(self, existing_aggregate, incremental_aggregate):
entity_id = self.get_entity_id()
query = f"""
WITH incremental_users AS (
SELECT DISTINCT user_id FROM {incremental_aggregate}
SELECT DISTINCT {entity_id} FROM {incremental_aggregate}
)
SELECT *
FROM {existing_aggregate} ea
WHERE EXISTS(SELECT 1 FROM incremental_users iu WHERE iu.user_id = ea.user_id)
WHERE EXISTS(SELECT 1 FROM incremental_users iu WHERE iu.{entity_id} = ea.{entity_id})
"""
return run_query(query)

Expand Down Expand Up @@ -131,7 +136,7 @@ def create_messages(self, only_inc_users, results: DataFrame) -> Iterator[Dict]:
to_ts = int(self.to_date.timestamp())

data = results.toLocalIterator()
for entries in chunked(data, USERS_PER_MESSAGE):
for entries in chunked(data, self.items_per_message()):
multiple_user_stats = []
for entry in entries:
row = entry.asDict(recursive=True)
Expand Down
Loading

0 comments on commit f376920

Please sign in to comment.