From 0e53aa3a58260c14be3fda9cfbdea0096a37f8d0 Mon Sep 17 00:00:00 2001 From: Kye Date: Tue, 19 Dec 2023 15:01:20 -0500 Subject: [PATCH] [FEAT][SkyInterface.create_task] [CODE QUALITY] --- swarms_cloud/sky_api.py | 49 ++++++++++++++++++++++++++++++++----- tests/test_api_generator.py | 2 -- 2 files changed, 43 insertions(+), 8 deletions(-) diff --git a/swarms_cloud/sky_api.py b/swarms_cloud/sky_api.py index 9abd303..4d041c1 100644 --- a/swarms_cloud/sky_api.py +++ b/swarms_cloud/sky_api.py @@ -1,4 +1,5 @@ import sky +from sky import Task class SkyInterface: @@ -38,7 +39,7 @@ class SkyInterface: def __init__(self): self.clusters = {} - def launch(self, task, cluster_name=None, **kwargs): + def launch(self, task: Task = None, cluster_name: str = None, **kwargs): """Launch a task on a cluster Args: @@ -56,7 +57,7 @@ def launch(self, task, cluster_name=None, **kwargs): except Exception as error: print("Error launching cluster:", error) - def execute(self, task, cluster_name, **kwargs): + def execute(self, task: Task = None, cluster_name: str = None, **kwargs): """Execute a task on a cluster Args: @@ -76,7 +77,7 @@ def execute(self, task, cluster_name, **kwargs): except Exception as e: print("Error executing on cluster:", e) - def stop(self, cluster_name, **kwargs): + def stop(self, cluster_name: str = None, **kwargs): """Stop a cluster Args: @@ -87,7 +88,7 @@ def stop(self, cluster_name, **kwargs): except (ValueError, RuntimeError) as e: print("Error stopping cluster:", e) - def start(self, cluster_name, **kwargs): + def start(self, cluster_name: str = None, **kwargs): """start a cluster Args: @@ -98,7 +99,7 @@ def start(self, cluster_name, **kwargs): except Exception as e: print("Error starting cluster:", e) - def down(self, cluster_name, **kwargs): + def down(self, cluster_name: str = None, **kwargs): """Down a cluster Args: @@ -122,7 +123,7 @@ def status(self, **kwargs): except Exception as e: print("Error getting status:", e) - def autostop(self, cluster_name, **kwargs): + def autostop(self, cluster_name: str = None, **kwargs): """Autostop a cluster Args: @@ -132,3 +133,39 @@ def autostop(self, cluster_name, **kwargs): sky.autostop(cluster_name, **kwargs) except Exception as e: print("Error setting autostop:", e) + + def create_task( + self, + name: str = None, + setup: str = None, + run: str = None, + workdir: str = None, + task: str = None, + *args, + **kwargs + ): + """_summary_ + + Args: + name (str, optional): _description_. Defaults to None. + setup (str, optional): _description_. Defaults to None. + run (str, optional): _description_. Defaults to None. + workdir (str, optional): _description_. Defaults to None. + task (str, optional): _description_. Defaults to None. + + Returns: + _type_: _description_ + + # A Task that will sync up local workdir '.', containing + # requirements.txt and train.py. + sky.Task(setup='pip install requirements.txt', + run='python train.py', + workdir='.') + + # An empty Task for provisioning a cluster. + task = sky.Task(num_nodes=n).set_resources(...) + + # Chaining setters. + sky.Task().set_resources(...).set_file_mounts(...) + """ + return Task(name=name, setup=setup, run=run, workdir=workdir, *args, **kwargs) diff --git a/tests/test_api_generator.py b/tests/test_api_generator.py index 00cb94d..7ef76e6 100644 --- a/tests/test_api_generator.py +++ b/tests/test_api_generator.py @@ -1,7 +1,5 @@ -from unittest.mock import MagicMock import pytest -from fastapi import FastAPI, HTTPException from swarms.structs.agent import Agent from swarms_cloud.api_key_generator import generate_api_key