Skip to content

Commit

Permalink
[SKY]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Dec 22, 2023
1 parent 18ab733 commit 88a21d8
Showing 1 changed file with 43 additions and 12 deletions.
55 changes: 43 additions & 12 deletions swarms_cloud/sky_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List

import sky
from sky import Task

Expand Down Expand Up @@ -36,7 +38,19 @@ class SkyInterface:
"""

def __init__(self):
def __init__(
self,
task_name: str = None,
cluster_name: str = None,
gpus: str = "T4:1",
stream_logs_enabled: bool = False,
*args,
**kwargs,
):
self.task_name = task_name
self.cluster_name = cluster_name
self.gpus = gpus
self.stream_logs_enabled = stream_logs_enabled
self.clusters = {}

def launch(self, task: Task = None, cluster_name: str = None, **kwargs):
Expand All @@ -49,13 +63,23 @@ def launch(self, task: Task = None, cluster_name: str = None, **kwargs):
Returns:
_type_: _description_
"""
cluster = None
try:
job_id, handle = sky.launch(task, cluster_name=cluster_name, **kwargs)
if handle:
self.clusters[cluster_name] = handle
return job_id
cluster = sky.launch(
task=task,
cluster_name=cluster_name,
stream_logs=self.stream_logs_enabled,
**kwargs,
)
print(f"Launched job {cluster} on cluster {cluster_name}")
return cluster
except Exception as error:
print("Error launching cluster:", error)
# Deep error logging
print(
f"Error launching job {cluster} on cluster {cluster_name} with"
f" error {error}"
)
raise error

def execute(self, task: Task = None, cluster_name: str = None, **kwargs):
"""Execute a task on a cluster
Expand All @@ -73,7 +97,12 @@ def execute(self, task: Task = None, cluster_name: str = None, **kwargs):
if cluster_name not in self.clusters:
raise ValueError("Cluster {} does not exist".format(cluster_name))
try:
return sky.exec(task, cluster_name, **kwargs)
return sky.exec(
task=task,
cluster_name=cluster_name,
stream_logs=self.stream_logs_enabled,
**kwargs,
)
except Exception as e:
print("Error executing on cluster:", e)

Expand Down Expand Up @@ -112,14 +141,14 @@ def down(self, cluster_name: str = None, **kwargs):
except (ValueError, RuntimeError) as e:
print("Error tearing down cluster:", e)

def status(self, **kwargs):
def status(self, cluster_names: List[str] = None, **kwargs):
"""Save a cluster
Returns:
r: the status of the cluster
"""
try:
return sky.status(**kwargs)
return sky.status(cluster_names, **kwargs)
except Exception as e:
print("Error getting status:", e)

Expand All @@ -142,7 +171,7 @@ def create_task(
workdir: str = None,
task: str = None,
*args,
**kwargs
**kwargs,
):
"""_summary_
Expand All @@ -155,7 +184,7 @@ def create_task(
Returns:
_type_: _description_
# A Task that will sync up local workdir '.', containing
# requirements.txt and train.py.
sky.Task(setup='pip install requirements.txt',
Expand All @@ -168,4 +197,6 @@ def create_task(
# Chaining setters.
sky.Task().set_resources(...).set_file_mounts(...)
"""
return Task(name=name, setup=setup, run=run, workdir=workdir, *args, **kwargs)
return Task(
name=name, setup=setup, run=run, workdir=workdir, *args, **kwargs
)

0 comments on commit 88a21d8

Please sign in to comment.