Skip to content

Commit

Permalink
Merge branch 'main' into patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
dirkgr authored May 30, 2024
2 parents 5ce977a + 6aaa8ff commit cd8f71f
Show file tree
Hide file tree
Showing 14 changed files with 56 additions and 23 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ env:
WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }}
BEAKER_TOKEN: ${{ secrets.BEAKER_TOKEN }}
BEAKER_WORKSPACE: ai2/tango-testing
BEAKER_DEFAULT_CLUSTER: ai2/allennlp-cirrascale
BEAKER_IMAGE: petew/tango-testing
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

Expand Down Expand Up @@ -211,7 +210,7 @@ jobs:
if: steps.virtualenv-cache.outputs.cache-hit != 'true' && (contains(matrix.task.extras, 'flax') || contains(matrix.task.extras, 'all'))
run: |
. .venv/bin/activate
pip install flax==0.6.1 jax==0.4.1 jaxlib==0.4.1 tensorflow-cpu==2.9.1 optax==0.1.3
pip install flax jax jaxlib "tensorflow-cpu>=2.9.1" optax
- name: Install editable (no cache hit)
if: steps.virtualenv-cache.outputs.cache-hit != 'true'
Expand Down Expand Up @@ -282,12 +281,13 @@ jobs:
spec: |
version: v2
description: GPU Tests
budget: ai2/oe-training
tasks:
- name: tests
image:
beaker: ${{ env.BEAKER_IMAGE }}
context:
cluster: ${{ env.BEAKER_DEFAULT_CLUSTER }}
preemptible: true
resources:
gpuCount: 2
envVars:
Expand Down
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Fixed

- Fixed a bunch of dependencies
- Upgraded to new version of wandb

## [v1.3.2](https://github.com/allenai/tango/releases/tag/v1.3.2) - 2023-10-27

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion docs/source/first_steps.md
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ Computing...: 100%|##########| 100/100 [00:05<00:00, 18.99it/s]
✓ The output for "add_numbers" is in workspace/runs/live-tarpon/add_numbers
```

The last line in the output tells us where we can find the result of our "add_numbers" step. `live-parpon` is
The last line in the output tells us where we can find the result of our "add_numbers" step. `live-tarpon` is
the name of the run. Run names are randomly generated and may be different on your machine. `add_numbers` is the
name of the step in your config. The whole path is a symlink to a directory, which contains (among other things)
a file `data.json`:
Expand Down
12 changes: 6 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ dependencies = [
"click-help-colors>=0.9.1,<0.10",
"rich>=12.3,<14.0",
"tqdm>=4.62,<5.0",
"more-itertools>=8.0,<10.0",
"more-itertools>=8.0,<11.0",
"sqlitedict",
"glob2>=0.7",
"petname>=2.6,<3.0",
Expand Down Expand Up @@ -89,14 +89,14 @@ fairscale = [
]
flax = [
"datasets>=1.12,<3.0",
"jax>=0.4.1,<=0.4.13",
"jaxlib>=0.4.1,<=0.4.13",
"flax>=0.6.1,<=0.7.0",
"optax>=0.1.2",
"jax",
"jaxlib",
"flax",
"optax",
"tensorflow-cpu>=2.9.1"
]
wandb = [
"wandb>=0.12,<0.14.3",
"wandb>=0.16",
"retry"
]
beaker = [
Expand Down
2 changes: 1 addition & 1 deletion tango/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class SettingsObject(NamedTuple):
called_by_executor: bool


@click.group(**_CLICK_GROUP_DEFAULTS)
@click.group(name=None, **_CLICK_GROUP_DEFAULTS)
@click.version_option(version=VERSION)
@click.option(
"--settings",
Expand Down
10 changes: 9 additions & 1 deletion tango/integrations/beaker/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ def __init__(
priority: Optional[Union[str, Priority]] = None,
allow_dirty: bool = False,
scheduler: Optional[BeakerScheduler] = None,
budget: Optional[str] = None,
**kwargs,
):
# Pre-validate arguments.
Expand All @@ -365,6 +366,11 @@ def __init__(
"Either 'beaker_image' or 'docker_image' must be specified for BeakerExecutor, but not both."
)

if budget is None:
raise ConfigurationError("You must specify a budget to use the beaker executor.")
else:
self._budget = budget

from tango.workspaces import LocalWorkspace, MemoryWorkspace

if isinstance(workspace, MemoryWorkspace):
Expand Down Expand Up @@ -1029,7 +1035,9 @@ def _build_experiment_spec(
return (
experiment_name,
ExperimentSpec(
tasks=[task_spec], description=f'Tango step "{step_name}" ({step.unique_id})'
tasks=[task_spec],
description=f'Tango step "{step_name}" ({step.unique_id})',
budget=self._budget,
),
[step_graph_dataset],
)
2 changes: 1 addition & 1 deletion tango/integrations/flax/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(

self.logger = logging.getLogger(FlaxDataLoader.__name__)

def __call__(self, rng: jax.random.PRNGKeyArray, do_distributed: bool):
def __call__(self, rng: jax._src.random.KeyArrayLike, do_distributed: bool):
steps_per_epoch = self.dataset_size // self.batch_size

if self.shuffle:
Expand Down
3 changes: 2 additions & 1 deletion tango/integrations/flax/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class Optimizer(Registrable):
:options: +ELLIPSIS
optax::adabelief
optax::adadelta
optax::adafactor
optax::adagrad
optax::adam
Expand Down Expand Up @@ -100,7 +101,7 @@ def factory_func():
Optimizer.register("optax::" + name)(factory_func)

# Register all learning rate schedulers.
for name, cls in optax._src.schedule.__dict__.items():
for name, cls in optax.schedules.__dict__.items():
if isfunction(cls) and not name.startswith("_") and cls.__annotations__:
factory_func = scheduler_factory(cls)
LRScheduler.register("optax::" + name)(factory_func)
Expand Down
4 changes: 2 additions & 2 deletions tango/integrations/flax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
import jax


def get_PRNGkey(seed: int = 42) -> Union[Any, jax.random.PRNGKeyArray]:
def get_PRNGkey(seed: int = 42) -> Union[Any, jax._src.random.KeyArray]:
"""
Utility function to create a pseudo-random number generator key
given a seed.
"""
return jax.random.PRNGKey(seed)


def get_multiple_keys(key, multiple: int = 1) -> Union[Any, jax.random.PRNGKeyArray]:
def get_multiple_keys(key, multiple: int = 1) -> Union[Any, jax._src.random.KeyArray]:
"""
Utility function to split a PRNG key into multiple new keys.
Used in distributed training.
Expand Down
2 changes: 2 additions & 0 deletions tango/integrations/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
transformers::Adafactor
transformers::AdamW
transformers::LayerWiseDummyOptimizer
- :class:`~tango.integrations.torch.LRScheduler`: All learning rate scheduler function from transformers
are registered according to their type name (e.g. "transformers::linear").
Expand All @@ -92,6 +93,7 @@
transformers::constant
transformers::constant_with_warmup
transformers::cosine
transformers::cosine_with_min_lr
transformers::cosine_with_restarts
transformers::inverse_sqrt
transformers::linear
Expand Down
8 changes: 3 additions & 5 deletions tango/integrations/wandb/step_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _step_artifact_name(self, step: Union[Step, StepInfo]) -> str:

def _step_result_remote( # type: ignore
self, step: Union[Step, StepInfo]
) -> Optional[wandb.apis.public.Artifact]:
) -> Optional[wandb.Artifact]:
artifact_kind = (step.metadata or {}).get("artifact_kind", ArtifactKind.STEP_RESULT.value)
try:
return self.wandb_client.artifact(
Expand All @@ -88,9 +88,7 @@ def _step_result_remote( # type: ignore
def create_step_result_artifact(self, step: Step, objects_dir: Optional[PathOrStr] = None):
self._upload_step_remote(step, objects_dir)

def get_step_result_artifact(
self, step: Union[Step, StepInfo]
) -> Optional[wandb.apis.public.Artifact]:
def get_step_result_artifact(self, step: Union[Step, StepInfo]) -> Optional[wandb.Artifact]:
artifact_kind = (step.metadata or {}).get("artifact_kind", ArtifactKind.STEP_RESULT.value)
try:
return self.wandb_client.artifact(
Expand Down Expand Up @@ -144,7 +142,7 @@ def use_step_result_artifact(self, step: Union[Step, StepInfo]) -> None:

def _download_step_remote(self, step_result, target_dir: PathOrStr):
try:
step_result.download(root=target_dir, recursive=True)
step_result.download(root=target_dir)
except (WandbError, ValueError):
raise RemoteNotFoundError()

Expand Down
13 changes: 12 additions & 1 deletion tango/integrations/wandb/util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import re
import warnings
from enum import Enum

Expand All @@ -13,7 +14,17 @@ def is_missing_artifact_error(err: WandbError):
Check if a specific W&B error is caused by a 404 on the artifact we're looking for.
"""
# This is brittle, but at least we have a test for it.
return "does not contain artifact" in err.message

# This is a workaround for a bug in the wandb API
if err.message == "'NoneType' object has no attribute 'get'":
return True

if re.search(r"^artifact '.*' not found in '.*'$", err.message):
return True

return ("does not contain artifact" in err.message) or (
"Unable to fetch artifact with name" in err.message
)


def check_environment():
Expand Down
3 changes: 3 additions & 0 deletions tests/integrations/beaker/executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def test_from_params(beaker_workspace_name: str):
beaker_image="ai2/conda",
github_token="FAKE_TOKEN",
datasets=[{"source": {"beaker": "some-dataset"}, "mount_path": "/input"}],
budget="ai2/allennlp",
),
workspace=BeakerWorkspace(workspace=beaker_workspace_name),
clusters=["fake-cluster"],
Expand All @@ -38,6 +39,7 @@ def test_init_with_mem_workspace(beaker_workspace_name: str):
beaker_image="ai2/conda",
github_token="FAKE_TOKEN",
clusters=["fake-cluster"],
budget="ai2/allennlp",
)


Expand All @@ -50,6 +52,7 @@ def settings(beaker_workspace_name: str) -> TangoGlobalSettings:
"beaker_workspace": beaker_workspace_name,
"install_cmd": "pip install .[beaker]",
"clusters": ["ai2/allennlp-cirrascale", "ai2/general-cirrascale"],
"budget": "ai2/allennlp",
},
)

Expand Down
7 changes: 6 additions & 1 deletion tests/integrations/flax/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,10 @@ def test_trainer(self):
],
)
assert (
result_dir / "train" / "work" / "checkpoint_state_latest" / "checkpoint_0"
result_dir
/ "train"
/ "work"
/ "checkpoint_state_latest"
/ "checkpoint_0"
/ "checkpoint"
).is_file()

0 comments on commit cd8f71f

Please sign in to comment.