Skip to content

Commit 5ff52d1

Browse files
authored
RSDK-4977 - add ml training apis (#455)
1 parent 0b4c3ad commit 5ff52d1

File tree

7 files changed

+334
-48
lines changed

7 files changed

+334
-48
lines changed

Diff for: src/viam/app/data_client.py

+18-46
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from datetime import datetime
23
from pathlib import Path
34
from typing import Any, List, Mapping, Optional, Tuple
@@ -17,7 +18,6 @@
1718
BinaryMetadata,
1819
BoundingBoxLabelsByFilterRequest,
1920
BoundingBoxLabelsByFilterResponse,
20-
CaptureInterval,
2121
CaptureMetadata,
2222
DataRequest,
2323
DataServiceStub,
@@ -38,7 +38,6 @@
3838
TabularDataByFilterResponse,
3939
TagsByFilterRequest,
4040
TagsByFilterResponse,
41-
TagsFilter,
4241
)
4342
from viam.proto.app.datasync import (
4443
DataCaptureUploadRequest,
@@ -52,7 +51,7 @@
5251
SensorMetadata,
5352
UploadMetadata,
5453
)
55-
from viam.utils import datetime_to_timestamp, struct_to_dict
54+
from viam.utils import create_filter, datetime_to_timestamp, struct_to_dict
5655

5756
LOGGER = logging.getLogger(__name__)
5857

@@ -689,47 +688,20 @@ def create_filter(
689688
tags: Optional[List[str]] = None,
690689
bbox_labels: Optional[List[str]] = None,
691690
) -> Filter:
692-
"""Create a `Filter`.
693-
694-
Args:
695-
component_name (Optional[str]): Optional name of the component that captured the data being filtered (e.g., "left_motor").
696-
component_type (Optional[str]): Optional type of the componenet that captured the data being filtered (e.g., "motor").
697-
method (Optional[str]): Optional name of the method used to capture the data being filtered (e.g., "IsPowered").
698-
robot_name (Optional[str]): Optional name of the robot associated with the data being filtered (e.g., "viam_rover_1").
699-
robot_id (Optional[str]): Optional ID of the robot associated with the data being filtered.
700-
part_name (Optional[str]): Optional name of the system part associated with the data being filtered (e.g., "viam_rover_1-main").
701-
part_id (Optional[str]): Optional ID of the system part associated with the data being filtered.
702-
location_ids (Optional[List[str]]): Optional list of location IDs associated with the data being filtered.
703-
organization_ids (Optional[List[str]]): Optional list of organization IDs associated with the data being filtered.
704-
mime_type (Optional[List[str]]): Optional mime type of data being filtered (e.g., "image/png").
705-
start_time (Optional[datetime.datetime]): Optional start time of an interval to filter data by.
706-
end_time (Optional[datetime.datetime]): Optional end time of an interval to filter data by.
707-
tags (Optional[List[str]]): Optional list of tags attached to the data being filtered (e.g., ["test"]).
708-
bbox_labels (Optional[List[str]]): Optional list of bounding box labels attached to the data being filtered (e.g., ["square",
709-
"circle"]).
710-
711-
Returns:
712-
viam.proto.app.data.Filter: The `Filter` object.
713-
"""
714-
return Filter(
715-
component_name=component_name if component_name else "",
716-
component_type=component_type if component_type else "",
717-
method=method if method else "",
718-
robot_name=robot_name if robot_name else "",
719-
robot_id=robot_id if robot_id else "",
720-
part_name=part_name if part_name else "",
721-
part_id=part_id if part_id else "",
722-
location_ids=location_ids,
723-
organization_ids=organization_ids,
724-
mime_type=mime_type,
725-
interval=(
726-
CaptureInterval(
727-
start=datetime_to_timestamp(start_time),
728-
end=datetime_to_timestamp(end_time),
729-
)
730-
)
731-
if start_time or end_time
732-
else None,
733-
tags_filter=TagsFilter(tags=tags),
734-
bbox_labels=bbox_labels,
691+
warnings.warn("DataClient.create_filter is deprecated. Use AppClient.create_filter instead.", DeprecationWarning, stacklevel=2)
692+
return create_filter(
693+
component_name,
694+
component_type,
695+
method,
696+
robot_name,
697+
robot_id,
698+
part_name,
699+
part_id,
700+
location_ids,
701+
organization_ids,
702+
mime_type,
703+
start_time,
704+
end_time,
705+
tags,
706+
bbox_labels,
735707
)

Diff for: src/viam/app/ml_training_client.py

+99
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from typing import Mapping, List, Optional
2+
3+
from grpclib.client import Channel
4+
5+
from viam import logging
6+
from viam.proto.app.mltraining import (
7+
CancelTrainingJobRequest,
8+
GetTrainingJobRequest,
9+
GetTrainingJobResponse,
10+
ListTrainingJobsRequest,
11+
ListTrainingJobsResponse,
12+
MLTrainingServiceStub,
13+
ModelType,
14+
TrainingStatus,
15+
TrainingJobMetadata,
16+
)
17+
from viam.proto.app.data import Filter
18+
19+
LOGGER = logging.getLogger(__name__)
20+
21+
22+
class MLTrainingClient:
23+
"""gRPC client for working with ML training jobs.
24+
25+
Constructor is used by `ViamClient` to instantiate relevant service stubs.
26+
Calls to `MLTrainingClient` methods should be made through `ViamClient`.
27+
"""
28+
29+
def __init__(self, channel: Channel, metadata: Mapping[str, str]):
30+
"""Create a `MLTrainingClient` that maintains a connection to app.
31+
32+
Args:
33+
channel (grpclib.client.Channel): Connection to app.
34+
metadata (Mapping[str, str]): Required authorization token to send requests to app.
35+
"""
36+
self._metadata = metadata
37+
self._ml_training_client = MLTrainingServiceStub(channel)
38+
self._channel = channel
39+
40+
async def submit_training_job(
41+
self,
42+
org_id: str,
43+
model_name: str,
44+
model_version: str,
45+
model_type: ModelType,
46+
tags: List[str],
47+
filter: Optional[Filter] = None,
48+
) -> str:
49+
raise NotImplementedError()
50+
51+
async def get_training_job(self, id: str) -> TrainingJobMetadata:
52+
"""Gets training job data.
53+
54+
Args:
55+
id (str): id of the requested training job.
56+
57+
Returns:
58+
viam.proto.app.mltraining.TrainingJobMetadata: training job data.
59+
"""
60+
61+
request = GetTrainingJobRequest(id=id)
62+
response: GetTrainingJobResponse = await self._ml_training_client.GetTrainingJob(request, metadata=self._metadata)
63+
64+
return response.metadata
65+
66+
async def list_training_jobs(
67+
self,
68+
org_id: str,
69+
training_status: Optional[TrainingStatus.ValueType] = None,
70+
) -> List[TrainingJobMetadata]:
71+
"""Returns training job data for all jobs within an org.
72+
73+
Args:
74+
org_id (str): the id of the org to request training job data from.
75+
training_status (Optional[TrainingStatus]): status of training jobs to filter the list by.
76+
If unspecified, all training jobs will be returned.
77+
78+
Returns:
79+
List[viam.proto.app.mltraining.TrainingJobMetadata]: a list of training job data.
80+
"""
81+
82+
training_status = training_status if training_status else TrainingStatus.TRAINING_STATUS_UNSPECIFIED
83+
request = ListTrainingJobsRequest(organization_id=org_id, status=training_status)
84+
response: ListTrainingJobsResponse = await self._ml_training_client.ListTrainingJobs(request, metadata=self._metadata)
85+
86+
return list(response.jobs)
87+
88+
async def cancel_training_job(self, id: str) -> None:
89+
"""Cancels the specified training job.
90+
91+
Args:
92+
id (str): the id of the job to be canceled.
93+
94+
Raises:
95+
GRPCError: if no training job exists with the given id.
96+
"""
97+
98+
request = CancelTrainingJobRequest(id=id)
99+
await self._ml_training_client.CancelTrainingJob(request, metadata=self._metadata)

Diff for: src/viam/app/viam_client.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from viam import logging
77
from viam.app.app_client import AppClient
88
from viam.app.data_client import DataClient
9+
from viam.app.ml_training_client import MLTrainingClient
910
from viam.rpc.dial import DialOptions, _dial_app, _get_access_token
1011

1112
LOGGER = logging.getLogger(__name__)
@@ -68,10 +69,15 @@ def app_client(self) -> AppClient:
6869
"""Insantiate and return an `AppClient` used to make `app` method calls."""
6970
return AppClient(self._channel, self._metadata, self._location_id)
7071

72+
@property
73+
def ml_training_client(self) -> MLTrainingClient:
74+
"""Instantiate and return a `MLTrainingClient` used to make `ml_training` method calls."""
75+
return MLTrainingClient(self._channel, self._metadata)
76+
7177
def close(self):
7278
"""Close opened channels used for the various service stubs initialized."""
7379
if self._closed:
74-
LOGGER.debug("AppClient is already closed.")
80+
LOGGER.debug("ViamClient is already closed.")
7581
return
7682
LOGGER.debug("Closing gRPC channel to app.")
7783
self._channel.close()

Diff for: src/viam/utils.py

+67
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
from google.protobuf.timestamp_pb2 import Timestamp
1313

1414
from viam.proto.common import Geometry, GeoPoint, GetGeometriesRequest, GetGeometriesResponse, Orientation, ResourceName, Vector3
15+
from viam.proto.app.data import (
16+
CaptureInterval,
17+
Filter,
18+
TagsFilter,
19+
)
1520
from viam.resource.base import ResourceBase
1621
from viam.resource.registry import Registry
1722
from viam.resource.types import Subtype, SupportsGetGeometries
@@ -263,3 +268,65 @@ def from_dm_from_extra(extra: Optional[Dict[str, Any]]) -> bool:
263268
return False
264269

265270
return bool(extra.get("fromDataManagement", False))
271+
272+
273+
def create_filter(
274+
component_name: Optional[str] = None,
275+
component_type: Optional[str] = None,
276+
method: Optional[str] = None,
277+
robot_name: Optional[str] = None,
278+
robot_id: Optional[str] = None,
279+
part_name: Optional[str] = None,
280+
part_id: Optional[str] = None,
281+
location_ids: Optional[List[str]] = None,
282+
organization_ids: Optional[List[str]] = None,
283+
mime_type: Optional[List[str]] = None,
284+
start_time: Optional[datetime] = None,
285+
end_time: Optional[datetime] = None,
286+
tags: Optional[List[str]] = None,
287+
bbox_labels: Optional[List[str]] = None,
288+
) -> Filter:
289+
"""Create a `Filter`.
290+
291+
Args:
292+
component_name (Optional[str]): Optional name of the component that captured the data being filtered (e.g., "left_motor").
293+
component_type (Optional[str]): Optional type of the componenet that captured the data being filtered (e.g., "motor").
294+
method (Optional[str]): Optional name of the method used to capture the data being filtered (e.g., "IsPowered").
295+
robot_name (Optional[str]): Optional name of the robot associated with the data being filtered (e.g., "viam_rover_1").
296+
robot_id (Optional[str]): Optional ID of the robot associated with the data being filtered.
297+
part_name (Optional[str]): Optional name of the system part associated with the data being filtered (e.g., "viam_rover_1-main").
298+
part_id (Optional[str]): Optional ID of the system part associated with the data being filtered.
299+
location_ids (Optional[List[str]]): Optional list of location IDs associated with the data being filtered.
300+
organization_ids (Optional[List[str]]): Optional list of organization IDs associated with the data being filtered.
301+
mime_type (Optional[List[str]]): Optional mime type of data being filtered (e.g., "image/png").
302+
start_time (Optional[datetime.datetime]): Optional start time of an interval to filter data by.
303+
end_time (Optional[datetime.datetime]): Optional end time of an interval to filter data by.
304+
tags (Optional[List[str]]): Optional list of tags attached to the data being filtered (e.g., ["test"]).
305+
bbox_labels (Optional[List[str]]): Optional list of bounding box labels attached to the data being filtered (e.g., ["square",
306+
"circle"]).
307+
308+
Returns:
309+
viam.proto.app.data.Filter: The `Filter` object.
310+
"""
311+
return Filter(
312+
component_name=component_name if component_name else "",
313+
component_type=component_type if component_type else "",
314+
method=method if method else "",
315+
robot_name=robot_name if robot_name else "",
316+
robot_id=robot_id if robot_id else "",
317+
part_name=part_name if part_name else "",
318+
part_id=part_id if part_id else "",
319+
location_ids=location_ids,
320+
organization_ids=organization_ids,
321+
mime_type=mime_type,
322+
interval=(
323+
CaptureInterval(
324+
start=datetime_to_timestamp(start_time),
325+
end=datetime_to_timestamp(end_time),
326+
)
327+
)
328+
if start_time or end_time
329+
else None,
330+
tags_filter=TagsFilter(tags=tags),
331+
bbox_labels=bbox_labels,
332+
)

Diff for: tests/mocks/services.py

+48
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,18 @@
199199
StreamingDataCaptureUploadRequest,
200200
StreamingDataCaptureUploadResponse,
201201
)
202+
from viam.proto.app.mltraining import (
203+
CancelTrainingJobRequest,
204+
CancelTrainingJobResponse,
205+
GetTrainingJobRequest,
206+
GetTrainingJobResponse,
207+
ListTrainingJobsRequest,
208+
ListTrainingJobsResponse,
209+
MLTrainingServiceBase,
210+
SubmitTrainingJobRequest,
211+
SubmitTrainingJobResponse,
212+
TrainingJobMetadata,
213+
)
202214
from viam.proto.common import DoCommandRequest, DoCommandResponse, GeoObstacle, GeoPoint, PointCloudObject, Pose, PoseInFrame, ResourceName
203215
from viam.proto.service.mlmodel import (
204216
FlatTensor,
@@ -765,6 +777,42 @@ async def StreamingDataCaptureUpload(
765777
raise NotImplementedError()
766778

767779

780+
class MockMLTraining(MLTrainingServiceBase):
781+
def __init__(self, job_id: str, training_metadata: TrainingJobMetadata):
782+
self.job_id = job_id
783+
self.training_metadata = training_metadata
784+
785+
async def SubmitTrainingJob(self, stream: Stream[SubmitTrainingJobRequest, SubmitTrainingJobResponse]) -> None:
786+
request = await stream.recv_message()
787+
assert request is not None
788+
self.filter = request.filter
789+
self.org_id = request.organization_id
790+
self.model_name = request.model_name
791+
self.model_version = request.model_version
792+
self.model_type = request.model_type
793+
self.tags = request.tags
794+
await stream.send_message(SubmitTrainingJobResponse(id=self.job_id))
795+
796+
async def GetTrainingJob(self, stream: Stream[GetTrainingJobRequest, GetTrainingJobResponse]) -> None:
797+
request = await stream.recv_message()
798+
assert request is not None
799+
self.training_job_id = request.id
800+
await stream.send_message(GetTrainingJobResponse(metadata=self.training_metadata))
801+
802+
async def ListTrainingJobs(self, stream: Stream[ListTrainingJobsRequest, ListTrainingJobsResponse]) -> None:
803+
request = await stream.recv_message()
804+
assert request is not None
805+
self.training_status = request.status
806+
self.org_id = request.organization_id
807+
await stream.send_message(ListTrainingJobsResponse(jobs=[self.training_metadata]))
808+
809+
async def CancelTrainingJob(self, stream: Stream[CancelTrainingJobRequest, CancelTrainingJobResponse]) -> None:
810+
request = await stream.recv_message()
811+
assert request is not None
812+
self.cancel_job_id = request.id
813+
await stream.send_message(CancelTrainingJobResponse())
814+
815+
768816
class MockApp(AppServiceBase):
769817
def __init__(
770818
self,

Diff for: tests/test_data_client.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
)
1616

1717
from .mocks.services import MockData
18+
from viam.utils import create_filter
1819

1920
INCLUDE_BINARY = True
2021
COMPONENT_NAME = "component_name"
@@ -41,7 +42,7 @@
4142
END_DATETIME = END_TS.ToDatetime()
4243
TAGS = ["tag"]
4344
BBOX_LABELS = ["bbox_label"]
44-
FILTER = DataClient.create_filter(
45+
FILTER = create_filter(
4546
component_name=COMPONENT_NAME,
4647
component_type=COMPONENT_TYPE,
4748
method=METHOD,

0 commit comments

Comments
 (0)