From e8c9e3b40024886e564579d147d84819c3b973f4 Mon Sep 17 00:00:00 2001 From: Joseph Date: Sat, 4 May 2024 10:22:39 +0100 Subject: [PATCH] add arg parsing, refactor main --- dita/discogs/collection.py | 277 +++++++++++++++++++++---------------- 1 file changed, 154 insertions(+), 123 deletions(-) diff --git a/dita/discogs/collection.py b/dita/discogs/collection.py index e39ec7f..bca5267 100755 --- a/dita/discogs/collection.py +++ b/dita/discogs/collection.py @@ -3,9 +3,8 @@ fetch a csv/df, or read one. """ -# from pprint import pprint -import os -import sys +import argparse +import subprocess from datetime import datetime from datetime import timedelta from datetime import timezone @@ -32,7 +31,6 @@ from dita.tag.core import open_url from dita.tag.core import select_from_list - VAL_DELIM = ":" # prop:val FILT_DELIM = "," # prop1:val,prop2:val @@ -63,9 +61,9 @@ def __init__( # feasible since 'genre' can be specified more than once. i am not # willing to allow 'genre:a+b' or whatever. # TODO: dict[str, set[str]] - self.filter_list = () # tuple[tuple[str,str]] + self.filter_list: tuple[tuple[str, str]] = () # tuple[tuple[str,str]] - self.filtered = self.df.copy() + self.filtered: pd.DataFrame = self.df.copy() def __len__(self) -> int: return len(self.df) @@ -136,7 +134,7 @@ def filter_text(df): # print(df[matches]) if not matches.any(): - print(val, "not found in field", prop) + eprint(val, "not found in field", prop) raise ValueError df = df[matches] @@ -203,23 +201,22 @@ def filter_text(df): def filter( self, filters: str = "", - # group_artist: bool = False, - # unique_albums: bool = False, sort: bool = True, ): """Parses filters to be sequentially applied to an offline collection. - Filters are to be passed as strings in the form ':', which - will be stored. If is left blank, user input will be required. + Filters are to be passed as strings in the form `:`, which + will be stored. If `` is left blank, user input will be required. Special prefixes are allowed. - The actual applying of filters is done by apply_filter(). + The actual applying of filters is done by `apply_filter`. If a filter clears the selection, the initial state can be restored - with reset_filters(). + with `reset_filters`. - Sorting is done by default. + Sorting is done by default. `title` is parsed as regex. Examples: + ``` artist:[blank] genre:black metal (spaces allowed) genre:black metal,r:3 (r => 3) @@ -228,17 +225,7 @@ def filter( genre:thrash! (groups by artist and calculates mean rating -- warning: discards release information!) title:Goldberg Variations (groups releases by id) - - Args: - filters: [TODO:description] - group_artist: [TODO:description] - - Returns: - [TODO:description] - - Raises: - [TODO:name]: [TODO:description] - [TODO:name]: [TODO:description] + ``` """ for filt in filters.split(FILT_DELIM): @@ -275,7 +262,8 @@ def filter( eprint("Property must be one of:", "/".join(self.filtered.columns)) raise ValueError - print(key + VAL_DELIM + val) + # messes up stdout when called externally + # eprint(key + VAL_DELIM + val) self.filtered = self.apply_filter( # self.filtered, @@ -301,11 +289,14 @@ def filter( # print(self.filtered.to_dict()) - if sort or any(filt[1] == "@" for filt in self.filter_list): + if sort or any(filt[1][-1] == "@" for filt in self.filter_list): self.sort() - if any(filt[1] == "!" for filt in self.filter_list): - self.filtered = group_collection_by_artist(self.filtered) + if any(filt[1][-1] == "!" for filt in self.filter_list): + self.filtered = group_collection_by_artist( + self.filtered, + metric=mean_plus, + ) # return self.df @@ -314,7 +305,7 @@ def sort(self): (e.g. 'r:3@'), it will take precedence over everything else. """ sortkey = { - # "date_added": True, + "date_added": True, "r": False, "artist": True, "year": True, @@ -355,14 +346,12 @@ def sort(self): def dump_collection_to_csv(): """Fetch all pages of a user's collection and write to csv""" - df = ( - # note: tqdm is kinda goofy with generators + # note: tqdm is kinda goofy with generators + ( pd.DataFrame(tqdm(get_collection_releases())) - # pd.DataFrame(get_collection_releases()) - # .drop_duplicates(["id"]) # allow composers and performers to be listed .sort_values("date_added") + .to_csv(DISCOGS_CSV) ) - df.to_csv(DISCOGS_CSV) def get_wantlist_releases() -> pd.DataFrame: @@ -474,6 +463,14 @@ def get_collection_releases( # }}} +def mean_plus(ints: list[int]) -> float: + mean = np.mean(ints) + for i in ints: + mult = 1 + ((i - 1) / 100) + mean *= mult + return round(mean, 2) + + def top_n_sum( artist_ratings: pd.Series, num: float = 3, @@ -516,30 +513,36 @@ def top_n_sum( # return int(np.mean(artist_ratings.sort_values(ascending=False)[:3]) * 3) +PERCENTILE = 2 # float also allowed, e.g. 2.5 +# TODO: +# lambda x: np.mean(x) * 3, +# lambda x: np.median(x) * 3, +METRICS = { + "top_n_sum": lambda x: top_n_sum(x, 10 // PERCENTILE), + "mean": np.mean, + "mean_plus": mean_plus, + "len": len, +} + + def group_collection_by_artist( df: pd.DataFrame, groupby: str = "artist", min_releases: int = 2, - metric: Callable = top_n_sum, + # https://docs.python.org/3/library/typing.html#annotating-callable-objects + metric: Callable[[pd.Series], int] = top_n_sum, ) -> pd.DataFrame: - """Wrapper for df.groupby, with metric to be applied - - Args: - df: df to be grouped - groupby: primarily supports 'artist', may add support for 'label' in future - metric: default is top_three_sum. Alternative metrics include: - mean: lambda x: np.mean(x) * 3 - count: len - median: lambda x: np.median(x) * 3 + """Wrapper for `df.groupby`, with `metric` to be applied Returns: - df, with columns [, 'r'] + df, with columns `[groupby, 'r']` added """ # clean first, otherwise groupby will be performed incorrectly - df["artist"] = df.artist.apply(clean_artist) + df.artist = df.artist.apply(clean_artist) if groupby == "label": - df = df[df.label != "Not On Label"] + df.label = df.label.replace("Not On Label", None) + df.dropna(subset=["label"], inplace=True) # Series[bool], unique items in column that fulfill the condition cond: pd.Series = df[groupby].value_counts() >= min_releases @@ -553,17 +556,15 @@ def group_collection_by_artist( .apply(metric) # keep cols (better to keep than drop) [[groupby, "r"]] - .sort_values( - ["r", groupby], - ascending=[False, True], - ) + .sort_values(["r", groupby], ascending=[False, True]) + .set_index("artist") ) def cprint_df(df): - """Print df with color-coded (ANSI-escaped) 'r' column. ANSI escapes lead + """Print `df` with color-coded (ANSI-escaped) `r` column. ANSI escapes lead to column misalignment with the default printer, possibly because the - length of the escaped string exceeds 'r', but .to_markdown() can be used to + length of the escaped string exceeds `r`, but `df.to_markdown` can be used to circumvent this. """ @@ -573,21 +574,24 @@ def cprint_df(df): df_str = df.reset_index(drop=True).to_markdown() - # truncate title - excess = len(df_str.split("\n")[0]) - os.get_terminal_size().columns + _, width = subprocess.check_output(["stty", "size"]).split() + + excess = len(df_str.split("\n")[0]) - int(width) if excess > 0: - df.title = df.title.apply(lambda x: x[: max(df.title.str.len()) - excess]) + limit = int(max(df.title.str.len()) - excess) + df.title = df.title.apply(lambda x: x[:limit]) df_str = df.reset_index(drop=True).to_markdown() print(df_str) cprint(mean_rating) -def get_percentiles( +def filter_by_percentile( df: pd.DataFrame, col: str = "r", + thresh: float = 5, ) -> pd.DataFrame: - """Adds 'perc' column to df.""" + """Adds `perc` column to df.""" master = df[col].to_list() percentiles = {} perc = 1 @@ -602,70 +606,97 @@ def get_percentiles( percentiles[0] = 100 df["perc"] = df.r.apply(lambda x: percentiles[x]) - return df + return df[df.perc <= thresh] -if __name__ == "__main__": - if ":" in sys.argv[1]: # and "http" not in sys.argv[1]: - DISCOGS_DF: pd.DataFrame = pd.read_csv( - DISCOGS_CSV, - index_col=0, - parse_dates=["date_added"], # allow calculation of date differences - na_filter=False, # corner case: None (Meshuggah) is not nan +def parse_args(): + parser = argparse.ArgumentParser() + + subparsers = parser.add_subparsers(dest="subcommand") + + filt = subparsers.add_parser("filter") + + filt.add_argument("filters") + filt.add_argument( + "--format", + choices=["json", "csv", "pretty"], + default="pretty", + required=False, + ) + + top = subparsers.add_parser("top") + top.add_argument( + "--metric", + choices=METRICS, + default="top_n_sum", + required=False, + ) + + return parser.parse_args() + + +def filter_collection(args: argparse.Namespace): + df: pd.DataFrame = pd.read_csv( + DISCOGS_CSV, + index_col=0, + parse_dates=["date_added"], # allow calculation of date differences + na_filter=False, # corner case: None (Meshuggah) is not nan + ) + + coll = Collection(df) + coll.filter(args.filters) + + if args.format in ["csv", "json"]: + print(getattr(coll.filtered, f"to_{args.format}")()) + return + + if "id" in coll.filtered.columns: + coll.filtered.drop_duplicates("id", inplace=True) + + cprint_df(coll.filtered) + + sel = input() + if sel: + print(d_get(coll.filtered.iloc[int(sel)].id)["uri"]) + else: + # assumes filter was artist:XXX + if coll.filter_list[0][0] == "artist": + Artist(get_artist_id(coll.filter_list[0][1])).rate_all() + + +def main(): + args = parse_args() + + # print(args) + # print(args.subcommand) + + if args.subcommand == "filter": + filter_collection(args) + + if args.subcommand == "top": + top_df = group_collection_by_artist( + pd.read_csv(DISCOGS_CSV), + metric=METRICS.get(args.metric), ) - coll = Collection(DISCOGS_DF) - coll.filter(" ".join(sys.argv[1:])) - # print(col) - cprint_df(coll.filtered.drop_duplicates("id")) - sel = input() - if sel: - print(d_get(coll.filtered.iloc[int(sel)].id)["uri"]) - else: - # assumes filter was artist:XXX - if coll.filter_list[0][0] == "artist": - Artist(get_artist_id(coll.filter_list[0][1])).rate_all() - - elif len(sys.argv) == 2: - if sys.argv[1] == "--dump": - get_collection_releases_verbose() - dump_collection_to_csv() - - elif sys.argv[1] == "--want": - print(get_wantlist_releases()) - - elif sys.argv[1] == "--top": - PERC = 2 # float also allowed, e.g. 2.5 - top_df = group_collection_by_artist( - pd.read_csv(DISCOGS_CSV), - metric=lambda x: top_n_sum(x, 10 // PERC), - # groupby="label", - # metric=lambda x: np.mean(x) * 3, - # metric=len, - # metric=lambda x: np.median(x) * 3, - ) - # print(top_df, len(top_df)) - # raise ValueError - top_df = get_percentiles(top_df) - # print(top_df[top_df.perc == PERC + 1]) - top_df = top_df[top_df.perc <= PERC] - print(top_df, len(top_df)) - - # from etc.rym_artists import print_rym_artists - - # print_rym_artists(top_df) - - # print_rym_artists(top_df[top_df.r > 22]) - # print() - # print_rym_artists(top_df[top_df.r <= 22]) - - elif sys.argv[1] == "--random": - coll = Collection(pd.read_csv(DISCOGS_CSV)) - coll.filter("r:4") - ran = coll.filtered.sample(n=1) - open_url( - "https://open.spotify.com/search/", - " ".join( - [ran.artist.iloc[0], ran.title.iloc[0]], - ).split(), - suffix="albums", - ) + top_df = filter_by_percentile(top_df, thresh=PERCENTILE) + print(top_df, len(top_df)) + + # elif len(sys.argv) == 2: + # if sys.argv[1] == "--want": + # print(get_wantlist_releases()) + # + # elif sys.argv[1] == "--random": + # coll = Collection(pd.read_csv(DISCOGS_CSV)) + # coll.filter("r:4") + # ran = coll.filtered.sample(n=1) + # open_url( + # "https://open.spotify.com/search/", + # " ".join( + # [ran.artist.iloc[0], ran.title.iloc[0]], + # ).split(), + # suffix="albums", + # ) + + +if __name__ == "__main__": + main()