-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodel.py
50 lines (40 loc) · 1.48 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import pprint
from typing import List, Dict, Optional
from label_studio_ml.model import LabelStudioMLBase
from label_studio_ml.response import ModelResponse
from label_studio_sdk.converter import brush
from uuid import uuid4
import json
class DagsHubLSModel(LabelStudioMLBase):
"""Custom ML Backend model
"""
def __init__(self):
pass
def configure(self, model, pre_hook, post_hook, ds, dp_map):
self.model = model
self.pre_hook = pre_hook
self.post_hook = post_hook
self.ds = ds
self.dp_map = dp_map
print(f'''\
Configured with model: {model}
Dataset: {ds}
''')
def setup(self):
self.set("model_version", "0.0.1")
def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -> ModelResponse:
print(f'''\
Run prediction on {tasks}
Received context: {context}
Project ID: {self.project_id}
Label config: {self.label_config}
Parsed JSON Label config: {self.parsed_label_config}
Extra params: {self.extra_params}''')
tasks = [(self.ds['path'] == self.dp_map[self.dp_map['datapoint_id'] == task['meta']['datapoint_id']].iloc[0].path).head()[0].download_file().as_posix() for task in tasks] # get local path
res = self.post_hook(self.model.predict(self.pre_hook(tasks)))
print(f'''\
Returning: {res}
''')
return res
def fit(self, event, data, **kwargs):
pass