Skip to content

Commit

Permalink
Fix tpu vm autoshutdown (#708)
Browse files Browse the repository at this point in the history
* fix autodeletion of TPU nodes?

* actually fix autodeletion of TPU nodes?

* wip

* fork fixes it

* sigh
  • Loading branch information
dlwh authored Aug 27, 2024
1 parent 277e728 commit 0c628d5
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 12 deletions.
9 changes: 6 additions & 3 deletions src/levanter/models/backpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ def init(
use_bias: bool = True,
) -> "BackpackMlp":
k_fc, k_proj = jrandom.split(key, 2)
c_fc = hnn.Linear.init(Out=Mlp, In=Embed, key=k_fc, use_bias=use_bias)
c_proj = hnn.Linear.init(Out=Out, In=Mlp, key=k_proj, use_bias=use_bias)
c_fc = hnn.Linear.init(Out=Mlp, In=Embed, key=k_fc, use_bias=use_bias, out_first=False)
c_proj = hnn.Linear.init(Out=Out, In=Mlp, key=k_proj, use_bias=use_bias, out_first=False)
if isinstance(activation_fn, str):
activation_fn = ACT2FN[activation_fn]
act = activation_fn # type: ignore
Expand Down Expand Up @@ -176,7 +176,10 @@ def init(config: Gpt2Config, *, key) -> "WeightsOnlyAttention":
Embed = config.Embed

k_c, _ = jrandom.split(key, 2)
c_attn = hnn.Linear.init(In=Embed, Out=(Qk, config.Senses, config.SenseHeadDim), key=k_c, use_bias=use_bias)
# NB: out_first=True b/c the torch implementation uses Linear
c_attn = hnn.Linear.init(
In=Embed, Out=(Qk, config.Senses, config.SenseHeadDim), key=k_c, use_bias=use_bias, out_first=True
)
dropout = hnn.Dropout(config.attn_pdrop)

return WeightsOnlyAttention(config, c_attn, dropout)
Expand Down
4 changes: 2 additions & 2 deletions src/levanter/models/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ class WhisperMlp(eqx.Module, StateDictSerializationMixin):
@staticmethod
def init(Embed: Axis, Mlp: Axis, activation_fn, *, key, use_bias: bool = True) -> "WhisperMlp":
k_fc, k_proj = haliax.jax_utils.maybe_rng_split(key, 2)
fc1 = hnn.Linear.init(Out=Mlp, In=Embed, key=k_fc, use_bias=use_bias)
fc2 = hnn.Linear.init(Out=Embed, In=Mlp, key=k_proj, use_bias=use_bias)
fc1 = hnn.Linear.init(Out=Mlp, In=Embed, key=k_fc, use_bias=use_bias, out_first=False)
fc2 = hnn.Linear.init(Out=Embed, In=Mlp, key=k_proj, use_bias=use_bias, out_first=False)
if isinstance(activation_fn, str):
activation_fn = ACT2FN[activation_fn]
act = activation_fn # type: ignore
Expand Down
81 changes: 74 additions & 7 deletions src/levanter/utils/cloud_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import contextlib
import json
import logging
import os
import shutil
Expand Down Expand Up @@ -29,33 +30,99 @@ def _checked_request(url):
raise


def _checked_delete(url):
# first get the token
token = _checked_request(
"http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token"
)
token = json.loads(token)["access_token"]
headers = {"Authorization": f"Bearer {token}", "Metadata-Flavor": "Google"}
try:
response = requests.delete(url, headers=headers)
response.raise_for_status()
return response.text
except requests.exceptions.RequestException:
logger.exception(f"Could not delete {url} from metadata server. Is this a TPU VM?", exc_info=True)
raise


def _shutdown_tpu_with_queued_resource():
queued_resource = _checked_request(
"http://metadata.google.internal/computeMetadata/v1/instance/attributes/queued-resource-name"
)
# queued resource looks like:
# projects/999999/locations/us-central2-b/queuedResources/NAME
# to delete we need to use delete against
# https://tpu.googleapis.com/v2/projects/9999/locations/us-central2-b/queuedResources/NAME?force=true
if queued_resource:
queued_resource_name = queued_resource.split("/")[-1]
# quiet really works like -y
if jax.process_index() == 0:
logger.critical(f"Found queued resource {queued_resource_name}. Attempting to delete it.")
# We need to use curl
# curl -X DELETE -H "Authorization: Bearer $(gcloud auth print-access-token)" \
# -H "Content-Type: application/json" \
# https://tpu.googleapis.com/v2/projects/my-project/locations/us-central2-b/queuedResources/my-queued-resource?force=true
# os.system(f"gcloud compute tpus queued-resources delete {queued_resource} --zone {zone} --force --quiet")
url = f"https://tpu.googleapis.com/v2/{queued_resource}?force=true"
_checked_delete(url)
return True
else:
logger.info("No queued resource found.")
return False


def shutdown_tpu_vm(sleep_seconds=60 * 5):
"""You should probably call this from atexit or something like that."""
# fork a process to do the delete so the main process can exit before the delete is done
logger.info("Forking a process to delete...")
logger.critical(f"Create a file {SENTINEL_FILE} to cancel the shutdown")
logger.critical(f"$ touch {SENTINEL_FILE}")

# fork works better for our use case
pid = os.fork()
if pid == 0:
_do_shutdown_tpu_vm(sleep_seconds)
os._exit(0)
else:
logger.info(f"Forked process {pid} to delete TPU VM")


def _do_shutdown_tpu_vm(sleep_seconds):
# the gcloud command we would run is something like:
# gcloud compute tpus tpu-vm delete tpu-vm-1 --zone us-central1-a --quiet
try:
zone = _checked_request("http://metadata.google.internal/computeMetadata/v1/instance/zone")
zone = zone.split("/")[-1]
name = _checked_request("http://metadata.google.internal/computeMetadata/v1/attributes/instance-id")
name = _checked_request("http://metadata.google.internal/computeMetadata/v1/instance/attributes/instance-id")
project = _checked_request("http://metadata.google.internal/computeMetadata/v1/project/project-id")
except requests.exceptions.RequestException:
logger.warning("Could not get zone or instance-id from metadata server. Is this a TPU VM? Not shutting down.")
return

# the gcloud command we would run is something like:
# gcloud compute tpus tpu-vm delete tpu-vm-1 --zone us-central1-a --quiet
logger.critical(f"Shutting down TPU VM {name} in zone {zone} in {sleep_seconds} seconds")
logger.critical(f"Create a file {SENTINEL_FILE} to cancel the shutdown")
logger.critical(f"$ touch {SENTINEL_FILE}")

time.sleep(sleep_seconds)
if os.path.exists(SENTINEL_FILE):
logger.critical(f"Found sentinel file {SENTINEL_FILE}, not shutting down TPU VM")
return
logger.critical(f"Shutting down TPU VM {name} in zone {zone}")

try:
success = _shutdown_tpu_with_queued_resource()
if success:
return
except requests.exceptions.RequestException:
logger.info("This is not a queued resource, deleting the old fashioned way.")

logger.critical(f"Shutting down TPU VM {name} in zone {zone}")
if jax.process_index() != 0:
logger.info(f"Letting process 0 handle the shutdown. We are process {jax.process_index()}")
return

os.system(f"gcloud compute tpus tpu-vm delete {name} --zone {zone} --quiet")
# os.system(f"gcloud compute tpus tpu-vm delete {name} --zone {zone} --quiet")
# https://tpu.googleapis.com/v2/projects/PROJECT/locations/us-central2-b/nodes/NAME
url = f"http://tpu.googleapis.com/v2/projects/{project}/locations/{zone}/nodes/{name}"
_checked_delete(url)


_sync_count = 0
Expand Down
1 change: 1 addition & 0 deletions tests/test_doremi.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def init_model():
(),
use_bias=True,
key=model_key,
out_first=True,
)

m1, loss1 = fit_to_dataset(ds1)
Expand Down

0 comments on commit 0c628d5

Please sign in to comment.