Skip to content

Commit

Permalink
feat: Assign strike runs to groups (ENG-981) (#35)
Browse files Browse the repository at this point in the history
* Add --group arg to deploy. Add run-groups subcommand. Structure some formatting. Small bug fix for  regarding links.

* Fix tests
  • Loading branch information
monoxgas authored Feb 11, 2025
1 parent 6753700 commit 5b58112
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 9 deletions.
23 changes: 21 additions & 2 deletions dreadnode_cli/agent/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
format_agent,
format_agent_versions,
format_run,
format_run_groups,
format_runs,
format_strike_models,
format_strikes,
Expand Down Expand Up @@ -68,7 +69,9 @@ def ensure_profile(agent_config: AgentConfig, *, user_config: UserConfig | None
):
print()
raise Exception(
"Agent link does not match the current server profile. Use [bold]dreadnode agent switch[/] or [bold]dreadnode profile switch[/]."
f"Current agent link ([yellow]{agent_config.active_link.profile}[/]) does not match "
f"the current server profile ([magenta]{user_config.active_profile_name}[/]). "
"Use [bold]dreadnode agent switch[/] or [bold]dreadnode profile switch[/]."
)

switch_profile(agent_config.active_link.profile)
Expand Down Expand Up @@ -248,7 +251,14 @@ def push(

if agent_config.links and not agent_config.has_link_to_profile(user_config.active_profile_name):
print(f":link: Linking as a fresh agent to the current profile [magenta]{user_config.active_profile_name}[/]")
print()
new = True
elif agent_config.active and agent_config.active_link.profile != user_config.active_profile_name:
raise Exception(
f"Current agent link ([yellow]{agent_config.active_link.profile}[/]) does not match "
f"the current server profile ([magenta]{user_config.active_profile_name}[/]). "
"Use [bold]dreadnode agent switch[/] or [bold]dreadnode profile switch[/]."
)

server_config = user_config.get_server_config()

Expand Down Expand Up @@ -372,6 +382,7 @@ def deploy(
] = None,
strike: t.Annotated[str | None, typer.Option("--strike", "-s", help="The strike to use for this run")] = None,
watch: t.Annotated[bool, typer.Option("--watch", "-w", help="Watch the run status")] = True,
group: t.Annotated[str | None, typer.Option("--group", "-g", help="Group to associate this run with")] = None,
) -> None:
agent_config = AgentConfig.read(directory)
ensure_profile(agent_config)
Expand Down Expand Up @@ -421,7 +432,7 @@ def deploy(
)

run = client.start_strike_run(
agent.latest_version.id, strike=strike, model=model, user_model=user_model, context=context
agent.latest_version.id, strike=strike, model=model, user_model=user_model, group=group, context=context
)
agent_config.add_run(run.id).write(directory)
formatted = format_run(run, server_url=server_config.url)
Expand Down Expand Up @@ -615,6 +626,14 @@ def switch(
print(f":exclamation: '{agent_or_profile}' not found, use [bold]dreadnode agent links[/]")


@cli.command(help="List strike run groups")
@pretty_cli
def run_groups() -> None:
client = api.create_client()
groups = client.list_strike_run_groups()
print(format_run_groups(groups))


@cli.command(help="Clone a github repository", no_args_is_help=True)
@pretty_cli
def clone(
Expand Down
36 changes: 30 additions & 6 deletions dreadnode_cli/agent/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,9 +267,17 @@ def format_run(
table.add_column("Property", style="dim")
table.add_column("Value")

table.add_row("key", run.key)
table.add_row("status", Text(run.status, style=get_status_style(run.status)))
table.add_row("strike", f"[magenta]{run.strike_name}[/] ([dim]{run.strike_key}[/])")
table.add_row("type", run.strike_type)
table.add_row("group", Text(run.group_key or "-", style="blue" if run.group_key else ""))

if server_url != "":
table.add_row("", "")
table.add_row(
"url", Text(f"{server_url.rstrip('/')}/strikes/agents/{run.agent_key}/runs/{run.id}", style="cyan")
)

if run.agent_name:
agent_name = f"[bold magenta]{run.agent_name}[/] [[dim]{run.agent_key}[/]]"
Expand All @@ -280,10 +288,6 @@ def format_run(
table.add_row("model", run.model.replace(USER_MODEL_PREFIX, "") if run.model else "<default>")
table.add_row("agent", f"{agent_name} ([dim]rev[/] [yellow]{run.agent_revision}[/])")
table.add_row("image", Text(run.agent_version.container.image, style="cyan"))
if server_url != "":
table.add_row(
"run url", Text(f"{server_url.rstrip('/')}/strikes/agents/{run.agent_key}/runs/{run.id}", style="cyan")
)
table.add_row("notes", run.agent_version.notes or "-")

table.add_row("", "")
Expand Down Expand Up @@ -314,21 +318,41 @@ def format_run(

def format_runs(runs: list[api.Client.StrikeRunSummaryResponse]) -> RenderableType:
table = Table(box=box.ROUNDED)
table.add_column("id", style="dim")
table.add_column("key", style="dim")
table.add_column("agent")
table.add_column("status")
table.add_column("model")
table.add_column("group")
table.add_column("started")
table.add_column("duration")

for run in runs:
table.add_row(
str(run.id),
run.key,
f"[bold magenta]{run.agent_key}[/] [dim]:[/] [yellow]{run.agent_revision}[/]",
Text(run.status, style="bold " + get_status_style(run.status)),
Text(run.model.replace(USER_MODEL_PREFIX, "") if run.model else "-"),
Text(run.group_key or "-", style="blue" if run.group_key else "dim"),
format_time(run.start),
Text(format_duration(run.start, run.end), style="bold cyan"),
)

return table


def format_run_groups(groups: list[api.Client.StrikeRunGroupResponse]) -> RenderableType:
table = Table(box=box.ROUNDED)
table.add_column("Name", style="bold cyan")
table.add_column("description")
table.add_column("runs", style="yellow")
table.add_column("created", style="dim")

for group in groups:
table.add_row(
group.key,
group.description or "-",
str(group.run_count),
group.created_at.astimezone().strftime("%c"),
)

return table
2 changes: 1 addition & 1 deletion dreadnode_cli/agent/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def test_ensure_profile() -> None:
agent_config.add_link("test-main", UUID("00000000-0000-0000-0000-000000000000"), "main")
agent_config.active = "test-other"
with patch("rich.prompt.Prompt.ask", return_value="n"):
with pytest.raises(Exception, match="Agent link does not match the current server profile"):
with pytest.raises(Exception, match="Current agent link"):
ensure_profile(agent_config, user_config=user_config)

# We should switch if the user agrees
Expand Down
30 changes: 30 additions & 0 deletions dreadnode_cli/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,16 @@ class Container(BaseModel):
env: dict[str, str]
name: str | None

class StrikeMetricPoint(BaseModel):
timestamp: datetime
value: float
metadata: dict[str, t.Any]

class StrikeMetric(BaseModel):
type: str
description: str | None
points: "list[Client.StrikeMetricPoint]"

class StrikeAgentVersion(BaseModel):
id: UUID
created_at: datetime
Expand Down Expand Up @@ -350,6 +360,7 @@ class StrikeRunZone(_StrikeRunZone):
container_logs: dict[str, str]
outputs: list["Client.StrikeRunOutput"]
inferences: list[dict[str, t.Any]]
metrics: dict[str, "Client.StrikeMetric"]

class StrikeRunContext(BaseModel):
environment: dict[str, str] | None = None
Expand All @@ -358,6 +369,7 @@ class StrikeRunContext(BaseModel):

class _StrikeRun(BaseModel):
id: UUID
key: str
strike_id: UUID
strike_key: str
strike_name: str
Expand All @@ -373,6 +385,9 @@ class _StrikeRun(BaseModel):
status: "Client.StrikeRunStatus"
start: datetime | None
end: datetime | None
group_id: UUID | None
group_key: str | None
group_name: str | None

def is_running(self) -> bool:
return self.status in ["pending", "deploying", "running"]
Expand All @@ -388,6 +403,15 @@ class UserModel(BaseModel):
generator_id: str
api_key: str

class StrikeRunGroupResponse(BaseModel):
id: UUID
key: str
name: str
description: str | None
created_at: datetime
updated_at: datetime
run_count: int

def get_strike(self, strike: str) -> StrikeResponse:
response = self.request("GET", f"/api/strikes/{strike}")
return self.StrikeResponse(**response.json())
Expand Down Expand Up @@ -448,6 +472,7 @@ def start_strike_run(
user_model: UserModel | None = None,
context: StrikeRunContext | None = None,
strike: UUID | str | None = None,
group: UUID | str | None = None,
) -> StrikeRunResponse:
response = self.request(
"POST",
Expand All @@ -457,6 +482,7 @@ def start_strike_run(
"model": model,
"user_model": user_model.model_dump(mode="json") if user_model else None,
"strike": str(strike) if strike else None,
"group": str(group) if group else None,
"context": context.model_dump(mode="json") if context else None,
},
)
Expand All @@ -472,6 +498,10 @@ def list_strike_runs(self, *, strike_id: UUID | str | None = None) -> list[Strik
)
return [self.StrikeRunSummaryResponse(**run) for run in response.json()]

def list_strike_run_groups(self) -> list[StrikeRunGroupResponse]:
response = self.request("GET", "/api/strikes/groups")
return [self.StrikeRunGroupResponse(**group) for group in response.json()]


def create_client(*, profile: str | None = None) -> Client:
"""Create an authenticated API client using stored configuration data."""
Expand Down

0 comments on commit 5b58112

Please sign in to comment.