1
1
import hmac
2
+ import json
2
3
import logging
3
4
import os
5
+ import dagshub
6
+ import mlflow
7
+ import base64
8
+ import cloudpickle
9
+ from dagshub .data_engine import datasources
4
10
5
11
from flask import Flask , request , jsonify , Response
6
12
11
17
logger = logging .getLogger (__name__ )
12
18
13
19
_server = Flask (__name__ )
14
- MODEL_CLASS = LabelStudioMLBase
15
20
BASIC_AUTH = None
16
21
17
22
18
- def init_app (model_class , basic_auth_user = None , basic_auth_pass = None ):
19
- global MODEL_CLASS
23
+ def init_app (model_instance , basic_auth_user = None , basic_auth_pass = None ):
24
+ global model
20
25
global BASIC_AUTH
21
26
22
- if not issubclass (model_class , LabelStudioMLBase ):
23
- raise ValueError ('Inference class should be the subclass of ' + LabelStudioMLBase .__class__ .__name__ )
24
-
25
- MODEL_CLASS = model_class
27
+ model = model_instance
26
28
basic_auth_user = basic_auth_user or os .environ .get ('BASIC_AUTH_USER' )
27
29
basic_auth_pass = basic_auth_pass or os .environ .get ('BASIC_AUTH_PASS' )
28
30
if basic_auth_user and basic_auth_pass :
29
31
BASIC_AUTH = (basic_auth_user , basic_auth_pass )
30
32
31
33
return _server
32
34
35
+ @_server .post ('/configure' )
36
+ @exception_handler
37
+ def _configure ():
38
+ args = json .loads (request .get_json ())
39
+ dagshub .init (args ['repo' ], args ['username' ]) # user-level privileged auth token
40
+ ls_model = mlflow .pyfunc .load_model (f'models:/{ args ["model" ]} /{ args ["version" ]} ' )
41
+
42
+ model .configure (ls_model , * [cloudpickle .loads (base64 .b64decode (args [hook ])) for hook in ['pre_hook' , 'post_hook' ]])
43
+ # model.api = dagshub.common.api.repo.RepoAPI(f'https://dagshub.com/{args["username"]}/{args["repo"]}', host=args['host'])
44
+
45
+ model .ds = datasources .get_datasource (args ['datasource_repo' ], args ['datasource_name' ])
46
+ model .dp_map = model .ds .all ().dataframe [['path' , 'datapoint_id' ]]
47
+ return []
33
48
34
49
@_server .route ('/predict' , methods = ['POST' ])
35
50
@exception_handler
@@ -61,8 +76,8 @@ def _predict():
61
76
params = data .get ('params' , {})
62
77
context = params .pop ('context' , {})
63
78
64
- model = MODEL_CLASS ( project_id = project_id ,
65
- label_config = label_config )
79
+ model . project_id = project_id
80
+ model . use_label_config ( label_config )
66
81
67
82
# model.use_label_config(label_config)
68
83
@@ -96,8 +111,8 @@ def _setup():
96
111
project_id = data .get ('project' ).split ('.' , 1 )[0 ]
97
112
label_config = data .get ('schema' )
98
113
extra_params = data .get ('extra_params' )
99
- model = MODEL_CLASS ( project_id = project_id ,
100
- label_config = label_config )
114
+ model . project_id = project_id
115
+ model . use_label_config ( label_config )
101
116
102
117
if extra_params :
103
118
model .set_extra_params (extra_params )
@@ -122,7 +137,8 @@ def webhook():
122
137
return jsonify ({'status' : 'Unknown event' }), 200
123
138
project_id = str (data ['project' ]['id' ])
124
139
label_config = data ['project' ]['label_config' ]
125
- model = MODEL_CLASS (project_id , label_config = label_config )
140
+ model .project_id = project_id
141
+ model .use_label_config (label_config )
126
142
model .fit (event , data )
127
143
return jsonify ({}), 201
128
144
@@ -133,7 +149,6 @@ def webhook():
133
149
def health ():
134
150
return jsonify ({
135
151
'status' : 'UP' ,
136
- 'model_class' : MODEL_CLASS .__name__
137
152
})
138
153
139
154
0 commit comments