Skip to content

Commit 03002a0

Browse files
committed
configured for use with customizable mlflow model
1 parent d182572 commit 03002a0

File tree

1 file changed

+28
-13
lines changed

1 file changed

+28
-13
lines changed

label_studio_ml/api.py

+28-13
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
import hmac
2+
import json
23
import logging
34
import os
5+
import dagshub
6+
import mlflow
7+
import base64
8+
import cloudpickle
9+
from dagshub.data_engine import datasources
410

511
from flask import Flask, request, jsonify, Response
612

@@ -11,25 +17,34 @@
1117
logger = logging.getLogger(__name__)
1218

1319
_server = Flask(__name__)
14-
MODEL_CLASS = LabelStudioMLBase
1520
BASIC_AUTH = None
1621

1722

18-
def init_app(model_class, basic_auth_user=None, basic_auth_pass=None):
19-
global MODEL_CLASS
23+
def init_app(model_instance, basic_auth_user=None, basic_auth_pass=None):
24+
global model
2025
global BASIC_AUTH
2126

22-
if not issubclass(model_class, LabelStudioMLBase):
23-
raise ValueError('Inference class should be the subclass of ' + LabelStudioMLBase.__class__.__name__)
24-
25-
MODEL_CLASS = model_class
27+
model = model_instance
2628
basic_auth_user = basic_auth_user or os.environ.get('BASIC_AUTH_USER')
2729
basic_auth_pass = basic_auth_pass or os.environ.get('BASIC_AUTH_PASS')
2830
if basic_auth_user and basic_auth_pass:
2931
BASIC_AUTH = (basic_auth_user, basic_auth_pass)
3032

3133
return _server
3234

35+
@_server.post('/configure')
36+
@exception_handler
37+
def _configure():
38+
args = json.loads(request.get_json())
39+
dagshub.init(args['repo'], args['username']) # user-level privileged auth token
40+
ls_model = mlflow.pyfunc.load_model(f'models:/{args["model"]}/{args["version"]}')
41+
42+
model.configure(ls_model, *[cloudpickle.loads(base64.b64decode(args[hook])) for hook in ['pre_hook', 'post_hook']])
43+
# model.api = dagshub.common.api.repo.RepoAPI(f'https://dagshub.com/{args["username"]}/{args["repo"]}', host=args['host'])
44+
45+
model.ds = datasources.get_datasource(args['datasource_repo'], args['datasource_name'])
46+
model.dp_map = model.ds.all().dataframe[['path', 'datapoint_id']]
47+
return []
3348

3449
@_server.route('/predict', methods=['POST'])
3550
@exception_handler
@@ -61,8 +76,8 @@ def _predict():
6176
params = data.get('params', {})
6277
context = params.pop('context', {})
6378

64-
model = MODEL_CLASS(project_id=project_id,
65-
label_config=label_config)
79+
model.project_id = project_id
80+
model.use_label_config(label_config)
6681

6782
# model.use_label_config(label_config)
6883

@@ -96,8 +111,8 @@ def _setup():
96111
project_id = data.get('project').split('.', 1)[0]
97112
label_config = data.get('schema')
98113
extra_params = data.get('extra_params')
99-
model = MODEL_CLASS(project_id=project_id,
100-
label_config=label_config)
114+
model.project_id = project_id
115+
model.use_label_config(label_config)
101116

102117
if extra_params:
103118
model.set_extra_params(extra_params)
@@ -122,7 +137,8 @@ def webhook():
122137
return jsonify({'status': 'Unknown event'}), 200
123138
project_id = str(data['project']['id'])
124139
label_config = data['project']['label_config']
125-
model = MODEL_CLASS(project_id, label_config=label_config)
140+
model.project_id = project_id
141+
model.use_label_config(label_config)
126142
model.fit(event, data)
127143
return jsonify({}), 201
128144

@@ -133,7 +149,6 @@ def webhook():
133149
def health():
134150
return jsonify({
135151
'status': 'UP',
136-
'model_class': MODEL_CLASS.__name__
137152
})
138153

139154

0 commit comments

Comments
 (0)