Skip to content

Commit

Permalink
[FEAT][SkyInterface.create_task] [CODE QUALITY]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Dec 19, 2023
1 parent 298a536 commit 0e53aa3
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 8 deletions.
49 changes: 43 additions & 6 deletions swarms_cloud/sky_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sky
from sky import Task


class SkyInterface:
Expand Down Expand Up @@ -38,7 +39,7 @@ class SkyInterface:
def __init__(self):
self.clusters = {}

def launch(self, task, cluster_name=None, **kwargs):
def launch(self, task: Task = None, cluster_name: str = None, **kwargs):
"""Launch a task on a cluster
Args:
Expand All @@ -56,7 +57,7 @@ def launch(self, task, cluster_name=None, **kwargs):
except Exception as error:
print("Error launching cluster:", error)

def execute(self, task, cluster_name, **kwargs):
def execute(self, task: Task = None, cluster_name: str = None, **kwargs):
"""Execute a task on a cluster
Args:
Expand All @@ -76,7 +77,7 @@ def execute(self, task, cluster_name, **kwargs):
except Exception as e:
print("Error executing on cluster:", e)

def stop(self, cluster_name, **kwargs):
def stop(self, cluster_name: str = None, **kwargs):
"""Stop a cluster
Args:
Expand All @@ -87,7 +88,7 @@ def stop(self, cluster_name, **kwargs):
except (ValueError, RuntimeError) as e:
print("Error stopping cluster:", e)

def start(self, cluster_name, **kwargs):
def start(self, cluster_name: str = None, **kwargs):
"""start a cluster
Args:
Expand All @@ -98,7 +99,7 @@ def start(self, cluster_name, **kwargs):
except Exception as e:
print("Error starting cluster:", e)

def down(self, cluster_name, **kwargs):
def down(self, cluster_name: str = None, **kwargs):
"""Down a cluster
Args:
Expand All @@ -122,7 +123,7 @@ def status(self, **kwargs):
except Exception as e:
print("Error getting status:", e)

def autostop(self, cluster_name, **kwargs):
def autostop(self, cluster_name: str = None, **kwargs):
"""Autostop a cluster
Args:
Expand All @@ -132,3 +133,39 @@ def autostop(self, cluster_name, **kwargs):
sky.autostop(cluster_name, **kwargs)
except Exception as e:
print("Error setting autostop:", e)

def create_task(
self,
name: str = None,
setup: str = None,
run: str = None,
workdir: str = None,
task: str = None,
*args,
**kwargs
):
"""_summary_
Args:
name (str, optional): _description_. Defaults to None.
setup (str, optional): _description_. Defaults to None.
run (str, optional): _description_. Defaults to None.
workdir (str, optional): _description_. Defaults to None.
task (str, optional): _description_. Defaults to None.
Returns:
_type_: _description_
# A Task that will sync up local workdir '.', containing
# requirements.txt and train.py.
sky.Task(setup='pip install requirements.txt',
run='python train.py',
workdir='.')
# An empty Task for provisioning a cluster.
task = sky.Task(num_nodes=n).set_resources(...)
# Chaining setters.
sky.Task().set_resources(...).set_file_mounts(...)
"""
return Task(name=name, setup=setup, run=run, workdir=workdir, *args, **kwargs)
2 changes: 0 additions & 2 deletions tests/test_api_generator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from unittest.mock import MagicMock

import pytest
from fastapi import FastAPI, HTTPException
from swarms.structs.agent import Agent

from swarms_cloud.api_key_generator import generate_api_key
Expand Down

0 comments on commit 0e53aa3

Please sign in to comment.