Skip to content

Commit

Permalink
Added /api/cvat-tasks/ list endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
artemmustafa committed May 7, 2020
1 parent a79a053 commit 28e1d0b
Show file tree
Hide file tree
Showing 10 changed files with 453 additions and 4 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@ __pycache__/

# OSX
.DS_Store

.coverage
3 changes: 2 additions & 1 deletion backend/model_garden/serializers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from .cvat_user import CvatUserSerializer
from .dataset import DatasetSerializer
from .media_asset import MediaAssetSerializer
from .task import TaskSerializer
from .task import TaskSerializer, CvatTaskSerializer

__all__ = (
"BucketSerializer",
"CvatTaskSerializer"
"CvatUserSerializer",
"DatasetSerializer",
"MediaAssetSerializer",
Expand Down
23 changes: 23 additions & 0 deletions backend/model_garden/serializers/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,26 @@ class TaskSerializer(serializers.Serializer):
assignee_id = serializers.IntegerField()
files_in_task = serializers.IntegerField()
count_of_tasks = serializers.IntegerField()


class CvatTaskSerializer(serializers.Serializer):
id = serializers.IntegerField()
url = serializers.URLField()
name = serializers.CharField()
mode = serializers.CharField()
size = serializers.IntegerField(required=False)
owner = serializers.IntegerField(required=False)
assignee = serializers.IntegerField(allow_null=True)
created_date = serializers.DateTimeField(required=False)
updated_date = serializers.DateTimeField(required=False)
overlap = serializers.IntegerField(required=False)
segment_size = serializers.IntegerField(required=False)
z_order = serializers.BooleanField(required=False)
status = serializers.CharField()
labels = serializers.JSONField(required=False)
segments = serializers.JSONField(required=False)
image_quality = serializers.IntegerField(required=False)
start_frame = serializers.IntegerField(required=False)
stop_frame = serializers.IntegerField(required=False)
frame_filter = serializers.CharField(required=False)
project = serializers.IntegerField(required=False, allow_null=True)
51 changes: 50 additions & 1 deletion backend/model_garden/services/cvat.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,35 @@
from typing import Optional, List
import dataclasses
import logging
from typing import Optional, List, Dict, NamedTuple
from urllib.parse import urlencode, urlunsplit

import requests
from django.conf import settings
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

logger = logging.getLogger(__name__)


class CVATServiceException(Exception):
pass


@dataclasses.dataclass
class ListRequest:
page: int = 1
page_size: int = 100
ordering: str = ''
filters: Dict[str, str] = dataclasses.field(default_factory=dict)


class ListResponse(NamedTuple):
count: int
next_url: Optional[str]
prev_url: Optional[str]
results: List[dict]


class CvatService:
API_VERSION = 'v1'

Expand All @@ -31,6 +51,9 @@ def _get_url(self, path: str) -> str:

def _request(self, method: str, path: str, data: dict = None) -> requests.Response:
url = self._get_url(path)

logger.info('Cvat request: %s, data=%s', url, data)

response = getattr(self._session, method)(url=url, json=data)
try:
response.raise_for_status()
Expand Down Expand Up @@ -99,3 +122,29 @@ def create_task(
"project": None
})
return response.json()

def tasks(self, req: ListRequest) -> ListResponse:
"""Fetche tasks from the service that are created by
`settings.CVAT_ROOT_USER_NAME`.
"""
req.filters['owner'] = settings.CVAT_ROOT_USER_NAME

resp = self._get(_join_query('tasks', req))
data = resp.json()

return ListResponse(
count=data.get('count', 0),
next_url=data.get('next'),
prev_url=data.get('previous'),
results=data.get('results', []),
)


def _join_query(path: str, req: ListRequest) -> str:
"""Add query arguments from `req` to the `path`.
"""
req_query = dict(page=req.page, page_size=req.page_size, **req.filters)
if req.ordering:
req_query.update(ordering=req.ordering)

return urlunsplit(('', '', path, urlencode(req_query), ''))
14 changes: 14 additions & 0 deletions backend/model_garden/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,20 @@
},
]

LOGGING = {
'version': 1,
'disable_existing_loggers': False,
'handlers': {
'console': {
'class': 'logging.StreamHandler',
},
},
'root': {
'handlers': ['console'],
'level': 'WARNING' if not DEBUG else 'INFO',
},
}

# Internationalization
# https://docs.djangoproject.com/en/3.0/topics/i18n/

Expand Down
2 changes: 2 additions & 0 deletions backend/model_garden/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from model_garden.views import (
BucketViewSet,
CvatUserViewSet,
CvatTaskViewSet,
DatasetViewSet,
MediaAssetViewSet,
TaskViewSet,
Expand All @@ -29,6 +30,7 @@
router = routers.DefaultRouter()
router.register(r'buckets', BucketViewSet)
router.register(r'cvat-users', CvatUserViewSet, basename='cvatusers')
router.register(r'cvat-tasks', CvatTaskViewSet, basename='cvattasks')
router.register(r'datasets', DatasetViewSet)
router.register(r'media-assets', MediaAssetViewSet)
router.register(r'tasks', TaskViewSet, basename='tasks')
Expand Down
2 changes: 2 additions & 0 deletions backend/model_garden/views/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from .bucket import BucketViewSet
from .cvat_tasks import CvatTaskViewSet
from .cvat_user import CvatUserViewSet
from .dataset import DatasetViewSet
from .media_asset import MediaAssetViewSet
from .task import TaskViewSet

__all__ = (
"BucketViewSet",
"CvatTaskViewSet"
"CvatUserViewSet",
"DatasetViewSet",
"MediaAssetViewSet",
Expand Down
120 changes: 120 additions & 0 deletions backend/model_garden/views/cvat_tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from typing import Iterable, Iterator

from django import forms
from rest_framework import status
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.viewsets import ViewSet
from rest_framework.pagination import PageNumberPagination
from rest_framework.exceptions import ValidationError

from model_garden.serializers import CvatTaskSerializer
from model_garden.services.cvat import (
CvatService, ListRequest
)


class CvatTaskViewSet(ViewSet):
serializer_class = CvatTaskSerializer

def list(self, request: Request):
queryset = CvatTasksQuerySet(CvatService())
# iterating gets single value as a string instead of a list
queryset.filter(**{k: v for k, v in request.query_params.items()})
queryset.order_by(request.query_params.get('ordering', 'id'))

paginator = CvatTaskPagination()
page = paginator.paginate_queryset(queryset, request)
if not page:
return Response(
data={'message': 'Tasks was not found'},
status=status.HTTP_404_NOT_FOUND,
)

serializer = CvatTaskSerializer(page, many=True)
return paginator.get_paginated_response(serializer.data)


class CvatTaskPagination(PageNumberPagination):
page_size = 100
page_size_query_param = 'page_size'
max_page_size = 1000


class CvatTaskFilter(forms.Form):
id = forms.IntegerField(required=False)
project = forms.CharField(required=False, max_length=256)
name = forms.CharField(required=False, max_length=256)
mode = forms.ChoiceField(
choices=(
('annotation', 'annotation'),
('interpolation', 'interpolation')
),
required=False
)
status = forms.ChoiceField(
choices=(
('annotation', 'annotation'),
('validation', 'validation'),
('completed', 'completed'),
),
required=False
)
assignee = forms.CharField(required=False, max_length=256)


class CvatTasksQuerySet:
ORDERING_FIELDS = ("id", "name", "status", "assignee",)

def __init__(self, cvat_service: CvatService):
self.service = cvat_service
self.service_request = ListRequest()
self._tasks = []

def filter(self, **kwargs) -> 'CvatTasksQuerySet':
filter_form = CvatTaskFilter(kwargs)
if not filter_form.is_valid():
raise ValidationError(
filter_form.errors, code=status.HTTP_400_BAD_REQUEST,
)

self.service_request.filters.update(
{k: v for k, v in filter_form.cleaned_data.items() if v}
)
return self

def order_by(self, field_name: str) -> 'CvatTasksQuerySet':
self.service_request.ordering = field_name
return self

def count(self) -> int:
self.service_request.page = 1
self.service_request.page_size = 1
tasks = self.service.tasks(self.service_request)
return tasks.count

def __len__(self) -> int:
return self.count()

def __getitem__(self, key) -> Iterable[dict]:
if not isinstance(key, (int, slice)):
raise TypeError

start = key.start if isinstance(key, slice) else key
stop = key.stop if isinstance(key, slice) else key + 1
page_size = stop - start

if page_size == 0:
self._tasks = []
return self._tasks

self.service_request.page = start // page_size + 1
self.service_request.page_size = page_size
self._tasks = self.service.tasks(self.service_request).results
return self._tasks

def __iter__(self) -> Iterator[dict]:
"""Django instatiate iterator only after slicing,
so iterator doesn't have to make request
"""
return iter(self._tasks)
Loading

0 comments on commit 28e1d0b

Please sign in to comment.