diff --git a/swarms_cloud/sky_api.py b/swarms_cloud/sky_api.py index 4d041c1..6fd1f77 100644 --- a/swarms_cloud/sky_api.py +++ b/swarms_cloud/sky_api.py @@ -1,3 +1,5 @@ +from typing import List + import sky from sky import Task @@ -36,7 +38,19 @@ class SkyInterface: """ - def __init__(self): + def __init__( + self, + task_name: str = None, + cluster_name: str = None, + gpus: str = "T4:1", + stream_logs_enabled: bool = False, + *args, + **kwargs, + ): + self.task_name = task_name + self.cluster_name = cluster_name + self.gpus = gpus + self.stream_logs_enabled = stream_logs_enabled self.clusters = {} def launch(self, task: Task = None, cluster_name: str = None, **kwargs): @@ -49,13 +63,23 @@ def launch(self, task: Task = None, cluster_name: str = None, **kwargs): Returns: _type_: _description_ """ + cluster = None try: - job_id, handle = sky.launch(task, cluster_name=cluster_name, **kwargs) - if handle: - self.clusters[cluster_name] = handle - return job_id + cluster = sky.launch( + task=task, + cluster_name=cluster_name, + stream_logs=self.stream_logs_enabled, + **kwargs, + ) + print(f"Launched job {cluster} on cluster {cluster_name}") + return cluster except Exception as error: - print("Error launching cluster:", error) + # Deep error logging + print( + f"Error launching job {cluster} on cluster {cluster_name} with" + f" error {error}" + ) + raise error def execute(self, task: Task = None, cluster_name: str = None, **kwargs): """Execute a task on a cluster @@ -73,7 +97,12 @@ def execute(self, task: Task = None, cluster_name: str = None, **kwargs): if cluster_name not in self.clusters: raise ValueError("Cluster {} does not exist".format(cluster_name)) try: - return sky.exec(task, cluster_name, **kwargs) + return sky.exec( + task=task, + cluster_name=cluster_name, + stream_logs=self.stream_logs_enabled, + **kwargs, + ) except Exception as e: print("Error executing on cluster:", e) @@ -112,14 +141,14 @@ def down(self, cluster_name: str = None, **kwargs): except (ValueError, RuntimeError) as e: print("Error tearing down cluster:", e) - def status(self, **kwargs): + def status(self, cluster_names: List[str] = None, **kwargs): """Save a cluster Returns: r: the status of the cluster """ try: - return sky.status(**kwargs) + return sky.status(cluster_names, **kwargs) except Exception as e: print("Error getting status:", e) @@ -142,7 +171,7 @@ def create_task( workdir: str = None, task: str = None, *args, - **kwargs + **kwargs, ): """_summary_ @@ -155,7 +184,7 @@ def create_task( Returns: _type_: _description_ - + # A Task that will sync up local workdir '.', containing # requirements.txt and train.py. sky.Task(setup='pip install requirements.txt', @@ -168,4 +197,6 @@ def create_task( # Chaining setters. sky.Task().set_resources(...).set_file_mounts(...) """ - return Task(name=name, setup=setup, run=run, workdir=workdir, *args, **kwargs) + return Task( + name=name, setup=setup, run=run, workdir=workdir, *args, **kwargs + )