From 027bcc331326c8b699392a40269ff63fbb3a11cd Mon Sep 17 00:00:00 2001 From: kyleclo Date: Fri, 2 Aug 2024 02:32:10 -0700 Subject: [PATCH 1/8] add test for flatten dict; extend flatten dict to handle lists; augment compare wandb config script to also flatten list dicts --- olmo/util.py | 28 ++++++++++++- scripts/compare_wandb_configs.py | 67 +++++++++++++++++++++----------- tests/util_test.py | 27 +++++++++++++ 3 files changed, 97 insertions(+), 25 deletions(-) diff --git a/olmo/util.py b/olmo/util.py index 3c67bb51f..6cea8457e 100644 --- a/olmo/util.py +++ b/olmo/util.py @@ -804,12 +804,36 @@ def get_bytes_range(self, index: int, length: int) -> bytes: return response["Body"].read() -def flatten_dict(dictionary, parent_key="", separator="."): +def flatten_dict(dictionary, parent_key="", separator=".", include_lists=False, root=""): d: Dict[str, Any] = {} for key, value in dictionary.items(): new_key = parent_key + separator + key if parent_key else key if isinstance(value, MutableMapping): - d.update(**flatten_dict(value, new_key, separator=separator)) + d.update( + **flatten_dict( + value, + new_key, + separator=separator, + include_lists=include_lists, + root=root, + ) + ) + elif isinstance(value, list) and include_lists: + new_list = [] + for v in value: + if isinstance(v, MutableMapping): + new_list.append( + flatten_dict( + v, + parent_key=root, + separator=separator, + include_lists=include_lists, + root=root, + ) + ) + else: + new_list.append(v) + d[new_key] = new_list else: d[new_key] = value return d diff --git a/scripts/compare_wandb_configs.py b/scripts/compare_wandb_configs.py index 5c3dd6ead..30e2aa1ce 100644 --- a/scripts/compare_wandb_configs.py +++ b/scripts/compare_wandb_configs.py @@ -24,6 +24,47 @@ def parse_run_path(run_path: str) -> str: raise ValueError(f"Could not parse '{run_path}'") +def print_keys_with_differences(left_config, right_config, level=0): + prefix = "\t\t" * level + + s_left = "" + left_only_keys = left_config.keys() - right_config.keys() + if len(left_only_keys) > 0: + s_left += prefix + "Settings only in left:\n" + s_left += (prefix + "\n").join(f"\t{k}: {left_config[k]}" for k in sorted(left_only_keys)) + "\n" + + s_right = "" + right_only_keys = right_config.keys() - left_config.keys() + if len(right_only_keys) > 0: + s_right += prefix + "Settings only in right:\n" + s_right += (prefix + "\n").join(f"\t{k}: {right_config[k]}" for k in sorted(right_only_keys)) + "\n" + + s_shared = "" + keys_with_differences = { + k for k in left_config.keys() & right_config.keys() if left_config[k] != right_config[k] + } + if len(keys_with_differences) > 0: + for k in sorted(keys_with_differences): + if isinstance(left_config[k], list) and isinstance(right_config[k], list): + s_list = prefix + f"{k}:\n" + for left, right in zip(left_config[k], right_config[k]): # assumes lists are same order + if isinstance(left, dict) and isinstance(right, dict): + print_keys_with_differences(left_config=left, right_config=right, level=level + 1) + else: + s_list += prefix + f"\t{left}\n" + prefix + f"\t{right}\n\n" + if s_list != prefix + f"{k}:\n": + s_shared += s_list + else: + s_shared += prefix + f"{k}\n\t{left_config[k]}\n" + prefix + f"\t{right_config[k]}\n\n" + + if (s_left or s_right) and not s_shared: + s = s_left + s_right + prefix + "No differences in shared settings.\n" + else: + s = s_left + s_right + s_shared + print(s.strip()) + return + + @click.command() @click.argument( "left_run_path", @@ -43,30 +84,10 @@ def main( left_run = api.run(parse_run_path(left_run_path)) right_run = api.run(parse_run_path(right_run_path)) - left_config = flatten_dict(left_run._attrs["rawconfig"]) - right_config = flatten_dict(right_run._attrs["rawconfig"]) - - left_only_keys = left_config.keys() - right_config.keys() - if len(left_only_keys) > 0: - print("Settings only in left:") - print("\n".join(f"\t{k}: {left_config[k]}" for k in sorted(left_only_keys))) - print() - - right_only_keys = right_config.keys() - left_config.keys() - if len(right_only_keys) > 0: - print("Settings only in right:") - print("\n".join(f"\t{k}: {right_config[k]}" for k in sorted(right_only_keys))) - print() + left_config = flatten_dict(left_run._attrs["rawconfig"], include_lists=True) + right_config = flatten_dict(right_run._attrs["rawconfig"], include_lists=True) - keys_with_differences = { - k for k in left_config.keys() & right_config.keys() if left_config[k] != right_config[k] - } - if len(keys_with_differences) > 0: - if len(left_only_keys) > 0 or len(right_only_keys) > 0: - print("Settings with differences:") - print("\n".join(f"{k}\n\t{left_config[k]}\n\t{right_config[k]}\n" for k in sorted(keys_with_differences))) - else: - print("No differences in shared settings.") + print_keys_with_differences(left_config=left_config, right_config=right_config) if __name__ == "__main__": diff --git a/tests/util_test.py b/tests/util_test.py index 7aa031215..960ef383b 100644 --- a/tests/util_test.py +++ b/tests/util_test.py @@ -1,3 +1,5 @@ +from unittest import TestCase + from olmo import util @@ -12,3 +14,28 @@ def test_dir_is_empty(tmp_path): # Should return false if dir contains anything, even hidden files. (dir / ".foo").touch() assert not util.dir_is_empty(dir) + + +def test_flatten_dict(): + # basic flattening + test_dict = {"a": 0, "b": {"e": 5, "f": 1}, "c": 2} + assert util.flatten_dict(test_dict) == {"a": 0, "b.e": 5, "b.f": 1, "c": 2} + + # Should flatten nested dicts into a single dict with dotted keys. + test_dict_with_list_of_dicts = { + "a": 0, + "b": {"e": [{"x": {"z": [222, 333]}}, {"y": {"g": [99, 100]}}], "f": 1}, + "c": 2, + } + assert util.flatten_dict(test_dict_with_list_of_dicts) == { + "a": 0, + "b.e": [{"x": {"z": [222, 333]}}, {"y": {"g": [99, 100]}}], # doesnt get flattened + "b.f": 1, + "c": 2, + } + assert util.flatten_dict(test_dict_with_list_of_dicts, include_lists=True) == { + "a": 0, + "b.e": [{"x.z": [222, 333]}, {"y.g": [99, 100]}], # gets flattened + "b.f": 1, + "c": 2, + } From 792de79374f2f11a62084dcb15ffc09c5faae9eb Mon Sep 17 00:00:00 2001 From: kyleclo Date: Wed, 7 Aug 2024 11:14:34 -0700 Subject: [PATCH 2/8] new flatten --- olmo/util.py | 31 +++++-------------------------- tests/util_test.py | 5 ++++- 2 files changed, 9 insertions(+), 27 deletions(-) diff --git a/olmo/util.py b/olmo/util.py index 6cea8457e..cc248a611 100644 --- a/olmo/util.py +++ b/olmo/util.py @@ -804,36 +804,15 @@ def get_bytes_range(self, index: int, length: int) -> bytes: return response["Body"].read() -def flatten_dict(dictionary, parent_key="", separator=".", include_lists=False, root=""): +def flatten_dict(dictionary, parent_key="", separator=".", include_lists=False): d: Dict[str, Any] = {} for key, value in dictionary.items(): new_key = parent_key + separator + key if parent_key else key + # convert lists to dict with key + if isinstance(value, list) and include_lists: + value = {f"{i}": v for i, v in enumerate(value)} if isinstance(value, MutableMapping): - d.update( - **flatten_dict( - value, - new_key, - separator=separator, - include_lists=include_lists, - root=root, - ) - ) - elif isinstance(value, list) and include_lists: - new_list = [] - for v in value: - if isinstance(v, MutableMapping): - new_list.append( - flatten_dict( - v, - parent_key=root, - separator=separator, - include_lists=include_lists, - root=root, - ) - ) - else: - new_list.append(v) - d[new_key] = new_list + d.update(**flatten_dict(value, new_key, separator=separator, include_lists=include_lists)) else: d[new_key] = value return d diff --git a/tests/util_test.py b/tests/util_test.py index 960ef383b..b81ba6d8a 100644 --- a/tests/util_test.py +++ b/tests/util_test.py @@ -35,7 +35,10 @@ def test_flatten_dict(): } assert util.flatten_dict(test_dict_with_list_of_dicts, include_lists=True) == { "a": 0, - "b.e": [{"x.z": [222, 333]}, {"y.g": [99, 100]}], # gets flattened + "b.e.0.x.z.0": 222, + "b.e.0.x.z.1": 333, + "b.e.1.y.g.0": 99, + "b.e.1.y.g.1": 100, "b.f": 1, "c": 2, } From 147013f3187b25fa1ac1596826e2b1418a5c7e50 Mon Sep 17 00:00:00 2001 From: kyleclo Date: Thu, 10 Oct 2024 15:49:16 -0700 Subject: [PATCH 3/8] add richer diff functionality between configs --- olmo/util.py | 16 ++++- scripts/compare_wandb_configs.py | 113 +++++++++++++++++++++++++++++-- 2 files changed, 122 insertions(+), 7 deletions(-) diff --git a/olmo/util.py b/olmo/util.py index 6db6b7083..55fc17336 100644 --- a/olmo/util.py +++ b/olmo/util.py @@ -8,13 +8,14 @@ import sys import time import warnings +from collections import defaultdict from datetime import datetime from enum import Enum from itertools import cycle, islice from pathlib import Path from queue import Queue from threading import Thread -from typing import Any, Callable, Dict, MutableMapping, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple, Union import boto3 import botocore.exceptions as boto_exceptions @@ -842,6 +843,19 @@ def get_bytes_range(self, index: int, length: int) -> bytes: def flatten_dict(dictionary, parent_key="", separator=".", include_lists=False): + """ + Flatten a nested dictionary into a single-level dictionary. + + Args: + dictionary (dict): The nested dictionary to be flattened. + parent_key (str, optional): The parent key to be prepended to the keys of the flattened dictionary. Defaults to "". + separator (str, optional): The separator to be used between the parent key and the keys of the flattened dictionary. Defaults to ".". + include_lists (bool, optional): Whether to convert lists to dictionaries with integer keys. Defaults to False. + + Returns: + dict: The flattened dictionary. + + """ d: Dict[str, Any] = {} for key, value in dictionary.items(): new_key = parent_key + separator + key if parent_key else key diff --git a/scripts/compare_wandb_configs.py b/scripts/compare_wandb_configs.py index 30e2aa1ce..ab02e4558 100644 --- a/scripts/compare_wandb_configs.py +++ b/scripts/compare_wandb_configs.py @@ -1,9 +1,28 @@ +""" + +Examples: + Comparing Peteish7 to OLMoE + - python scripts/compare_wandb_configs.py https://wandb.ai/ai2-llm/olmo-medium/runs/cej4ya39 https://wandb.ai/ai2-llm/olmoe/runs/rzsn9tlc + + Comparing Peteish7 to Amberish7 + - python scripts/compare_wandb_configs.py https://wandb.ai/ai2-llm/olmo-medium/runs/cej4ya39 https://wandb.ai/ai2-llm/olmo-medium/runs/ij4ls6v2 + + +""" + import logging +import os import re +from collections import Counter import click -from olmo.util import flatten_dict, prepare_cli_environment +from olmo.util import ( + build_file_tree, + file_tree_to_strings, + flatten_dict, + prepare_cli_environment, +) log = logging.getLogger(__name__) run_path_re = re.compile(r"^[^/]+/[^/]+/[^/]+$") @@ -58,13 +77,44 @@ def print_keys_with_differences(left_config, right_config, level=0): s_shared += prefix + f"{k}\n\t{left_config[k]}\n" + prefix + f"\t{right_config[k]}\n\n" if (s_left or s_right) and not s_shared: - s = s_left + s_right + prefix + "No differences in shared settings.\n" + s = s_left + "=" * 50 + "\n" + s_right + "=" * 50 + "\n" + prefix + "No differences in shared settings.\n" else: - s = s_left + s_right + s_shared + s = s_left + "=" * 50 + "\n" + s_right + "=" * 50 + "\n" + s_shared print(s.strip()) return +def simplify_path(path): + # Remove the variable prefix up to and including 'preprocessed/' + path = re.sub(r"^.*?preprocessed/", "", path) + + # Split the path into components + components = path.split("/") + + # Remove common suffixes like 'allenai' + components = [c for c in components if c != "allenai"] + + if components: + return "__".join(components) + else: + return "unknown_dataset" + + +def print_data_differences(left_data_paths: Counter, right_data_paths: Counter): + print(f"===== Data Paths for left config:\n") + left_data_paths = {simplify_path(path): count for path, count in left_data_paths.items()} + for path, num_files in left_data_paths.items(): + print(f"\t{path}: {num_files}") + print("\n\n") + + print(f"===== Data Paths for right config:\n") + right_data_paths = {simplify_path(path): count for path, count in right_data_paths.items()} + for path, num_files in right_data_paths.items(): + print(f"\t{path}: {num_files}") + + return + + @click.command() @click.argument( "left_run_path", @@ -84,10 +134,61 @@ def main( left_run = api.run(parse_run_path(left_run_path)) right_run = api.run(parse_run_path(right_run_path)) - left_config = flatten_dict(left_run._attrs["rawconfig"], include_lists=True) - right_config = flatten_dict(right_run._attrs["rawconfig"], include_lists=True) - + left_config_raw = left_run._attrs["rawconfig"] + right_config_raw = right_run._attrs["rawconfig"] + + # flattening the dict will make diffs easier + left_config = flatten_dict(left_config_raw) + right_config = flatten_dict(right_config_raw) + + # there are 2 specific fields in config that are difficult to diff: + # "evaluators" is List[Dict] + # "data.paths" is List[str] + # let's handle each of these directly. + + # first, data.paths can be grouped and counted. + left_data_paths = Counter([os.path.dirname(path) for path in left_config["data.paths"]]) + # for path, num_files in left_data_paths.items(): + # new_key = "data.paths" + "." + path + # left_config[new_key] = f"Num Files: {num_files}" + right_data_paths = Counter([os.path.dirname(path) for path in right_config["data.paths"]]) + # for path, num_files in right_data_paths.items(): + # new_key = "data.paths" + "." + path + # right_config[new_key] = f"Num Files: {num_files}" + del left_config["data.paths"] + del right_config["data.paths"] + + # next, evaluators can be added to the flat dict with unique key per evaluator + # also, each evaluator can also have a 'data.paths' field which needs collapsing + left_evaluators = {} + for evaluator in left_config["evaluators"]: + new_key = ".".join(["evaluators" + "." + evaluator["type"] + "." + evaluator["label"]]) + if evaluator["data"]["paths"]: + evaluator["data"]["paths"] = Counter([os.path.dirname(path) for path in evaluator["data"]["paths"]]) + left_evaluators[new_key] = evaluator + right_evaluators = {} + for evaluator in right_config["evaluators"]: + new_key = ".".join(["evaluators" + "." + evaluator["type"] + "." + evaluator["label"]]) + if evaluator["data"]["paths"]: + evaluator["data"]["paths"] = Counter([os.path.dirname(path) for path in evaluator["data"]["paths"]]) + right_evaluators[new_key] = evaluator + del left_config["evaluators"] + del right_config["evaluators"] + + # print config differences + print(f"===== Config differences between {left_run_path} and {right_run_path}:\n") print_keys_with_differences(left_config=left_config, right_config=right_config) + print("\n\n") + + # print data differences + print(f"===== Data differences between {left_run_path} and {right_run_path}:\n") + print_data_differences(left_data_paths, right_data_paths) + print("\n\n") + + # print eval differences + print(f"===== Evaluator differences between {left_run_path} and {right_run_path}:\n") + print_keys_with_differences(left_config=left_evaluators, right_config=right_evaluators) + print("\n\n") if __name__ == "__main__": From 85b8422ca18bba5aa73567e463129b2f9a22ac11 Mon Sep 17 00:00:00 2001 From: kyleclo Date: Thu, 10 Oct 2024 15:50:43 -0700 Subject: [PATCH 4/8] broken imports --- scripts/compare_wandb_configs.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/scripts/compare_wandb_configs.py b/scripts/compare_wandb_configs.py index ab02e4558..f216a4f67 100644 --- a/scripts/compare_wandb_configs.py +++ b/scripts/compare_wandb_configs.py @@ -17,12 +17,7 @@ import click -from olmo.util import ( - build_file_tree, - file_tree_to_strings, - flatten_dict, - prepare_cli_environment, -) +from olmo.util import flatten_dict, prepare_cli_environment log = logging.getLogger(__name__) run_path_re = re.compile(r"^[^/]+/[^/]+/[^/]+$") From 9c89418a2d9df2b3737ff07cd59c02e45841a175 Mon Sep 17 00:00:00 2001 From: kyleclo Date: Thu, 10 Oct 2024 17:09:20 -0700 Subject: [PATCH 5/8] linting; mypy --- olmo/util.py | 3 +-- scripts/compare_wandb_configs.py | 12 ++++++------ tests/util_test.py | 2 -- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/olmo/util.py b/olmo/util.py index 55fc17336..3697e86ce 100644 --- a/olmo/util.py +++ b/olmo/util.py @@ -8,14 +8,13 @@ import sys import time import warnings -from collections import defaultdict from datetime import datetime from enum import Enum from itertools import cycle, islice from pathlib import Path from queue import Queue from threading import Thread -from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple, Union +from typing import Any, Callable, Dict, MutableMapping, Optional, Tuple, Union import boto3 import botocore.exceptions as boto_exceptions diff --git a/scripts/compare_wandb_configs.py b/scripts/compare_wandb_configs.py index f216a4f67..11e2907b9 100644 --- a/scripts/compare_wandb_configs.py +++ b/scripts/compare_wandb_configs.py @@ -96,15 +96,15 @@ def simplify_path(path): def print_data_differences(left_data_paths: Counter, right_data_paths: Counter): - print(f"===== Data Paths for left config:\n") - left_data_paths = {simplify_path(path): count for path, count in left_data_paths.items()} - for path, num_files in left_data_paths.items(): + print("===== Data Paths for left config:\n") + simplified_left_data_paths = {simplify_path(path): count for path, count in left_data_paths.items()} + for path, num_files in simplified_left_data_paths.items(): print(f"\t{path}: {num_files}") print("\n\n") - print(f"===== Data Paths for right config:\n") - right_data_paths = {simplify_path(path): count for path, count in right_data_paths.items()} - for path, num_files in right_data_paths.items(): + print("===== Data Paths for right config:\n") + simplified_right_data_paths = {simplify_path(path): count for path, count in right_data_paths.items()} + for path, num_files in simplified_right_data_paths.items(): print(f"\t{path}: {num_files}") return diff --git a/tests/util_test.py b/tests/util_test.py index b81ba6d8a..c3493458c 100644 --- a/tests/util_test.py +++ b/tests/util_test.py @@ -1,5 +1,3 @@ -from unittest import TestCase - from olmo import util From 500e943dc6cb4ce01c63180d49ae01f521720516 Mon Sep 17 00:00:00 2001 From: kyleclo Date: Thu, 10 Oct 2024 17:13:09 -0700 Subject: [PATCH 6/8] changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index e9752a733..4dd57de1d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added ability to try loading latest checkpoint from save folder using `--try_load_latest_save`. - Added support for flash attention and gradient checkpointing to `hf_olmo`. +- Added to `scripts.compare_wandb_configs.py` the ability to more easily compare differences in data mixes and evaluation tasks. ## [v0.5.0](https://github.com/allenai/OLMo/releases/tag/v0.5.0) - 2024-08-26 From 3b037a492995a850927a0177608ca53ddfe1fdc4 Mon Sep 17 00:00:00 2001 From: kyleclo Date: Mon, 21 Oct 2024 16:55:29 -0700 Subject: [PATCH 7/8] get rid of simplify path --- scripts/compare_wandb_configs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/compare_wandb_configs.py b/scripts/compare_wandb_configs.py index 11e2907b9..1597ccbec 100644 --- a/scripts/compare_wandb_configs.py +++ b/scripts/compare_wandb_configs.py @@ -97,13 +97,13 @@ def simplify_path(path): def print_data_differences(left_data_paths: Counter, right_data_paths: Counter): print("===== Data Paths for left config:\n") - simplified_left_data_paths = {simplify_path(path): count for path, count in left_data_paths.items()} + simplified_left_data_paths = {path: count for path, count in left_data_paths.items()} for path, num_files in simplified_left_data_paths.items(): print(f"\t{path}: {num_files}") print("\n\n") print("===== Data Paths for right config:\n") - simplified_right_data_paths = {simplify_path(path): count for path, count in right_data_paths.items()} + simplified_right_data_paths = {path: count for path, count in right_data_paths.items()} for path, num_files in simplified_right_data_paths.items(): print(f"\t{path}: {num_files}") From f2c2a1534f401f3b030e478fea6ae083bea1f3a6 Mon Sep 17 00:00:00 2001 From: kyleclo Date: Tue, 29 Oct 2024 13:33:07 -0700 Subject: [PATCH 8/8] pylint --- scripts/compare_wandb_configs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/compare_wandb_configs.py b/scripts/compare_wandb_configs.py index 312e8bafd..aed025e54 100644 --- a/scripts/compare_wandb_configs.py +++ b/scripts/compare_wandb_configs.py @@ -141,17 +141,17 @@ def _simplify_evaluators(evaluators): ) # print config differences - print(f"==================== Param differences ====================\n\n") + print("==================== Param differences ====================\n\n") print_keys_with_differences(left_config=left_config, right_config=right_config) print("============================================================= \n\n") # print data differences - print(f"==================== Data Differences ====================\n\n") + print("==================== Data Differences ====================\n\n") print_data_differences(left_data_paths, right_data_paths) print("============================================================= \n\n") # print eval differences - print(f"==================== Eval Differences ====================\n\n") + print("==================== Eval Differences ====================\n\n") print_keys_with_differences(left_config=left_evaluators, right_config=right_evaluators) print("============================================================= \n\n")