Skip to content

Commit

Permalink
rename output, change default mix size (#6)
Browse files Browse the repository at this point in the history
* rename output, change default mix size

* fix tests
  • Loading branch information
l-moamen authored Sep 27, 2023
1 parent c86ad79 commit b44d1a7
Show file tree
Hide file tree
Showing 7 changed files with 10 additions and 9 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ classifiers = [

[tool.poetry]
name = "tidal_algorithmic_mixes"
version = "0.0.4"
version = "0.0.5"
description = "common transformers used by the tidal personalization team."
authors = [
"Loay <[email protected]>",
Expand Down
2 changes: 1 addition & 1 deletion test/discovery_mix/test_observed_tracks_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_transform(self):
runner._data = ObservedDiscoveryMixTracksAggregatorTransformationData(observed_mixes=observed_mixes,
mixes=mixes)
runner.transform()
res = runner.output.output
res = runner.output.df

self.assertEqual(res.columns, [c.USER, c.TRACK_GROUP])
self.assertEqual(res.count(), len(tracks_user_1) + len(tracks_user_2))
Expand Down
2 changes: 1 addition & 1 deletion test/discovery_mix/test_post_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,6 @@ def test_transform(self):
min_mix_size=0)
post_processor._data = self.data
post_processor.transform()
res = post_processor.output.output.collect()[0]
res = post_processor.output.df.collect()[0]
self.assertEqual(Row(user=1, tracks=['xxx'], mixId='1f1451b3b417516e9e4b4423958', atDate=res.atDate),
res)
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ class DiscoveryMixDailyUpdateTransformationData:

@dataclass
class DiscoveryMixDailyUpdateTransformationOutput:
output: DataFrame
df: DataFrame


class DiscoveryMixDailyUpdateTransformationConfig(Config):
def __init__(self, **kwargs):
self.current_date = kwargs.get('current_date')
self.mix_size = int(kwargs.get('mix_size', 70))
self.mix_size = int(kwargs.get('mix_size', 10))
Config.__init__(self, **kwargs)


Expand Down Expand Up @@ -60,7 +60,7 @@ def transform(self, *args, **kwargs):
self.config.mix_size)
.withColumn(c.UPDATED, F.lit(mix_utils.updated(time.time())))
.where(F.size(c.TRACKS) >= self.config.mix_size - 2))
self._output = DiscoveryMixDailyUpdateTransformationOutput(output=discovery_mix)
self._output = DiscoveryMixDailyUpdateTransformationOutput(discovery_mix)

def slicer(self, mixes: DataFrame, current_date: date, mix_size: int) -> DataFrame:
""" Extract the tracks of the day from the weekly computed list """
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class ObservedDiscoveryMixTracksAggregatorTransformationData:

@dataclass
class ObservedDiscoveryMixTracksAggregatorTransformationOutput:
output: DataFrame
df: DataFrame


class ObservedDiscoveryMixTracksAggregatorTransformation(ETLModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class DiscoveryMixPostProcessorTransformationData:

@dataclass
class DiscoveryMixPostProcessorTransformationOutput:
output: DataFrame
df: DataFrame


class DiscoveryMixPostProcessorTransformationConfig(Config):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import dataclass

from mlflow.pyfunc.spark_model_cache import SparkModelCache
# noinspection PyProtectedMember
from mlflow.tracking.artifact_utils import _download_artifact_from_uri
from mlflow.utils.file_utils import TempDir
import numpy as np
Expand All @@ -25,7 +26,7 @@ class DiscoveryMixSasRecModelTransformationData:

@dataclass
class DiscoveryMixSasRecModelTransformationOutput:
output: DataFrame
df: DataFrame


class DiscoveryMixSasRecModelTransformationConfig(Config):
Expand Down

0 comments on commit b44d1a7

Please sign in to comment.