Skip to content

Commit

Permalink
Merge branch 'task-9-get-tasks-api' into 'develop'
Browse files Browse the repository at this point in the history
Added /api/cvat-tasks/ list endpoint

See merge request epmc-mlcv/model_garden!16
  • Loading branch information
skudriashev committed May 7, 2020
2 parents a79a053 + 06c4d31 commit c7f7430
Show file tree
Hide file tree
Showing 14 changed files with 508 additions and 69 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@ __pycache__/

# OSX
.DS_Store

.coverage
5 changes: 3 additions & 2 deletions backend/model_garden/serializers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from .cvat_user import CvatUserSerializer
from .dataset import DatasetSerializer
from .media_asset import MediaAssetSerializer
from .task import TaskSerializer
from .task import CvatTaskCreateSerializer, CvatTaskSerializer

__all__ = (
"BucketSerializer",
"CvatTaskCreateSerializer",
"CvatTaskSerializer",
"CvatUserSerializer",
"DatasetSerializer",
"MediaAssetSerializer",
"TaskSerializer",
)
25 changes: 24 additions & 1 deletion backend/model_garden/serializers/task.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,32 @@
from rest_framework import serializers


class TaskSerializer(serializers.Serializer):
class CvatTaskCreateSerializer(serializers.Serializer):
task_name = serializers.CharField()
dataset_id = serializers.CharField()
assignee_id = serializers.IntegerField()
files_in_task = serializers.IntegerField()
count_of_tasks = serializers.IntegerField()


class CvatTaskSerializer(serializers.Serializer):
id = serializers.IntegerField()
url = serializers.URLField()
name = serializers.CharField()
mode = serializers.CharField()
size = serializers.IntegerField(required=False)
owner = serializers.IntegerField(required=False)
assignee = serializers.IntegerField(allow_null=True)
created_date = serializers.DateTimeField(required=False)
updated_date = serializers.DateTimeField(required=False)
overlap = serializers.IntegerField(required=False)
segment_size = serializers.IntegerField(required=False)
z_order = serializers.BooleanField(required=False)
status = serializers.CharField()
labels = serializers.JSONField(required=False)
segments = serializers.JSONField(required=False)
image_quality = serializers.IntegerField(required=False)
start_frame = serializers.IntegerField(required=False)
stop_frame = serializers.IntegerField(required=False)
frame_filter = serializers.CharField(required=False)
project = serializers.IntegerField(required=False, allow_null=True)
51 changes: 50 additions & 1 deletion backend/model_garden/services/cvat.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,35 @@
from typing import Optional, List
import dataclasses
import logging
from typing import Optional, List, Dict, NamedTuple
from urllib.parse import urlencode, urlunsplit

import requests
from django.conf import settings
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

logger = logging.getLogger(__name__)


class CVATServiceException(Exception):
pass


@dataclasses.dataclass
class ListRequest:
page: int = 1
page_size: int = 100
ordering: str = ''
filters: Dict[str, str] = dataclasses.field(default_factory=dict)


class ListResponse(NamedTuple):
count: int
next_url: Optional[str]
prev_url: Optional[str]
results: List[dict]


class CvatService:
API_VERSION = 'v1'

Expand All @@ -31,6 +51,9 @@ def _get_url(self, path: str) -> str:

def _request(self, method: str, path: str, data: dict = None) -> requests.Response:
url = self._get_url(path)

logger.info('Cvat request: %s, data=%s', url, data)

response = getattr(self._session, method)(url=url, json=data)
try:
response.raise_for_status()
Expand Down Expand Up @@ -99,3 +122,29 @@ def create_task(
"project": None
})
return response.json()

def tasks(self, req: ListRequest) -> ListResponse:
"""Fetch tasks from the service that are created by
`settings.CVAT_ROOT_USER_NAME`.
"""
req.filters['owner'] = settings.CVAT_ROOT_USER_NAME

resp = self._get(_join_query('tasks', req))
data = resp.json()

return ListResponse(
count=data.get('count', 0),
next_url=data.get('next'),
prev_url=data.get('previous'),
results=data.get('results', []),
)


def _join_query(path: str, req: ListRequest) -> str:
"""Add query arguments from `req` to the `path`.
"""
req_query = dict(page=req.page, page_size=req.page_size, **req.filters)
if req.ordering:
req_query.update(ordering=req.ordering)

return urlunsplit(('', '', path, urlencode(req_query), ''))
14 changes: 14 additions & 0 deletions backend/model_garden/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,20 @@
},
]

LOGGING = {
'version': 1,
'disable_existing_loggers': False,
'handlers': {
'console': {
'class': 'logging.StreamHandler',
},
},
'root': {
'handlers': ['console'],
'level': 'WARNING' if not DEBUG else 'INFO',
},
}

# Internationalization
# https://docs.djangoproject.com/en/3.0/topics/i18n/

Expand Down
4 changes: 2 additions & 2 deletions backend/model_garden/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@
from model_garden.views import (
BucketViewSet,
CvatUserViewSet,
CvatTaskViewSet,
DatasetViewSet,
MediaAssetViewSet,
TaskViewSet,
)

router = routers.DefaultRouter()
router.register(r'buckets', BucketViewSet)
router.register(r'cvat-users', CvatUserViewSet, basename='cvatusers')
router.register(r'cvat-tasks', CvatTaskViewSet, basename='cvattasks')
router.register(r'datasets', DatasetViewSet)
router.register(r'media-assets', MediaAssetViewSet)
router.register(r'tasks', TaskViewSet, basename='tasks')

urlpatterns = [
path('admin/', admin.site.urls),
Expand Down
4 changes: 2 additions & 2 deletions backend/model_garden/views/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from .bucket import BucketViewSet
from .cvat_tasks import CvatTaskViewSet
from .cvat_user import CvatUserViewSet
from .dataset import DatasetViewSet
from .media_asset import MediaAssetViewSet
from .task import TaskViewSet

__all__ = (
"BucketViewSet",
"CvatTaskViewSet",
"CvatUserViewSet",
"DatasetViewSet",
"MediaAssetViewSet",
"TaskViewSet",
)
138 changes: 138 additions & 0 deletions backend/model_garden/views/cvat_tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from typing import Iterable, Iterator

from django import forms
from rest_framework import status
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.viewsets import ViewSet
from rest_framework.pagination import PageNumberPagination
from rest_framework.exceptions import ValidationError

from model_garden.serializers import (
CvatTaskSerializer, CvatTaskCreateSerializer
)
from model_garden.services.cvat import (
CvatService, CVATServiceException, ListRequest
)


class CvatTaskViewSet(ViewSet):
serializer_class = CvatTaskCreateSerializer

def create(self, request):
cvat_service = CvatService()
serializer = self.serializer_class(data=request.data)
serializer.is_valid(raise_exception=True)
data = serializer.data
try:
cvat_service.create_task(
name=data['task_name'],
assignee_id=data['assignee_id'],
owner_id=cvat_service.get_root_user()['id'],
)
except CVATServiceException as e:
return Response(data={'message': str(e)}, status=400)

return Response()

def list(self, request: Request):
queryset = CvatTasksQuerySet(CvatService())
# iterating gets single value as a string instead of a list
queryset.filter(**{k: v for k, v in request.query_params.items()})
queryset.order_by(request.query_params.get('ordering', 'id'))

paginator = CvatTaskPagination()
page = paginator.paginate_queryset(queryset, request)
if not page:
return Response(
data={'message': 'Tasks was not found'},
status=status.HTTP_404_NOT_FOUND,
)

serializer = CvatTaskSerializer(page, many=True)
return paginator.get_paginated_response(serializer.data)


class CvatTaskPagination(PageNumberPagination):
page_size = 100
page_size_query_param = 'page_size'
max_page_size = 1000


class CvatTaskFilter(forms.Form):
id = forms.IntegerField(required=False)
project = forms.CharField(required=False, max_length=256)
name = forms.CharField(required=False, max_length=256)
mode = forms.ChoiceField(
choices=(
('annotation', 'annotation'),
('interpolation', 'interpolation')
),
required=False
)
status = forms.ChoiceField(
choices=(
('annotation', 'annotation'),
('validation', 'validation'),
('completed', 'completed'),
),
required=False
)
assignee = forms.CharField(required=False, max_length=256)


class CvatTasksQuerySet:
ORDERING_FIELDS = ("id", "name", "status", "assignee",)

def __init__(self, cvat_service: CvatService):
self.service = cvat_service
self.service_request = ListRequest()
self._tasks = []

def filter(self, **kwargs) -> 'CvatTasksQuerySet':
filter_form = CvatTaskFilter(kwargs)
if not filter_form.is_valid():
raise ValidationError(
filter_form.errors, code=status.HTTP_400_BAD_REQUEST,
)

self.service_request.filters.update(
{k: v for k, v in filter_form.cleaned_data.items() if v}
)
return self

def order_by(self, field_name: str) -> 'CvatTasksQuerySet':
self.service_request.ordering = field_name
return self

def count(self) -> int:
self.service_request.page = 1
self.service_request.page_size = 1
tasks = self.service.tasks(self.service_request)
return tasks.count

def __len__(self) -> int:
return self.count()

def __getitem__(self, key) -> Iterable[dict]:
if not isinstance(key, (int, slice)):
raise TypeError

start = key.start if isinstance(key, slice) else key
stop = key.stop if isinstance(key, slice) else key + 1
page_size = stop - start

if page_size == 0:
self._tasks = []
return self._tasks

self.service_request.page = start // page_size + 1
self.service_request.page_size = page_size
self._tasks = self.service.tasks(self.service_request).results
return self._tasks

def __iter__(self) -> Iterator[dict]:
"""Django instatiate iterator only after slicing,
so iterator doesn't have to make request
"""
return iter(self._tasks)
25 changes: 0 additions & 25 deletions backend/model_garden/views/task.py

This file was deleted.

Loading

0 comments on commit c7f7430

Please sign in to comment.