Skip to content

Commit

Permalink
feat: Basic Auth support, Azure openAI, Semver, and using LabelInterf…
Browse files Browse the repository at this point in the history
…ace (#437)

* Adding basic authentication, and updating how model_version is
handled. Also fixing links to the docs. Tests to come

* Adding azure openai, plus response model

* Adding a bit more tests :> fixing typos

* fixing response

* Merge azure into llm_interactive

* git sdk install

* Fix errors

* Updating readme

* Change versions

* Add pytest verbosity

* Change functional test CI location

---------

Co-authored-by: Mikhail Maluyk <[email protected]>
Co-authored-by: Michael Malyuk <[email protected]>
Co-authored-by: nik <[email protected]>
  • Loading branch information
4 people authored Mar 6, 2024
1 parent 3fcf6f6 commit 5ff2d29
Show file tree
Hide file tree
Showing 34 changed files with 708 additions and 251 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,7 @@ jobs:
- name: Run general functional tests
run: |
cd label_studio_ml/
pytest --ignore-glob='**/logs/*' --ignore-glob='**/data/*' --cov=. --cov-report=xml
pytest -vvv --ignore-glob='**/logs/*' --ignore-glob='**/data/*' --cov=. --cov-report=xml
- name: Pull the logs
if: always()
Expand Down
7 changes: 7 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
SHELL := /bin/bash

install:
pip install -r requirements.txt

test:
pytest tests
2 changes: 1 addition & 1 deletion label_studio_ml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
package_name = 'label-studio-ml'

# Package version
__version__ = '2.0.0.dev0'
__version__ = '2.0.1dev0'
74 changes: 57 additions & 17 deletions label_studio_ml/api.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,30 @@
import hmac
import logging

from flask import Flask, request, jsonify
from flask import Flask, request, jsonify, Response

from .response import ModelResponse
from .model import LabelStudioMLBase
from .exceptions import exception_handler


logger = logging.getLogger(__name__)

_server = Flask(__name__)
MODEL_CLASS = LabelStudioMLBase
BASIC_AUTH = None


def init_app(model_class):
def init_app(model_class, basic_auth_user=None, basic_auth_pass=None):
global MODEL_CLASS
global BASIC_AUTH

if not issubclass(model_class, LabelStudioMLBase):
raise ValueError('Inference class should be the subclass of ' + LabelStudioMLBase.__class__.__name__)

MODEL_CLASS = model_class
if basic_auth_user and basic_auth_pass:
BASIC_AUTH = (basic_auth_user, basic_auth_pass)

return _server


Expand Down Expand Up @@ -46,20 +52,38 @@ def _predict():
"""
data = request.json
tasks = data.get('tasks')
params = data.get('params') or {}
project = data.get('project')
if project:
project_id = data.get('project').split('.', 1)[0]
else:
project_id = None
label_config = data.get('label_config')
project = data.get('project')
project_id = project.split('.', 1)[0] if project else None
params = data.get('params', {})
context = params.pop('context', {})

model = MODEL_CLASS(project_id)
model.use_label_config(label_config)
model = MODEL_CLASS(project_id=project_id,
label_config=label_config)

# model.use_label_config(label_config)

response = model.predict(tasks, context=context, **params)

# if there is no model version we will take the default
if isinstance(response, ModelResponse):
if not response.has_model_version:
mv = model.model_version
if mv:
response.set_version(mv)
else:
response.update_predictions_version()

response = response.serialize()

res = response
if res is None:
res = []

predictions = model.predict(tasks, context=context, **params)
return jsonify({'results': predictions})
if isinstance(res, dict):
res = response.get("predictions", response)

return jsonify({'results': res})


@_server.route('/setup', methods=['POST'])
Expand All @@ -68,8 +92,13 @@ def _setup():
data = request.json
project_id = data.get('project').split('.', 1)[0]
label_config = data.get('schema')
model = MODEL_CLASS(project_id)
model.use_label_config(label_config)
extra_params = data.get('extra_params')
model = MODEL_CLASS(project_id=project_id,
label_config=label_config)

if extra_params:
model.set_extra_params(extra_params)

model_version = model.get('model_version')
return jsonify({'model_version': model_version})

Expand All @@ -90,8 +119,7 @@ def webhook():
return jsonify({'status': 'Unknown event'}), 200
project_id = str(data['project']['id'])
label_config = data['project']['label_config']
model = MODEL_CLASS(project_id)
model.use_label_config(label_config)
model = MODEL_CLASS(project_id, label_config=label_config)
model.fit(event, data)
return jsonify({}), 201

Expand Down Expand Up @@ -130,6 +158,18 @@ def index_error(error):
return str(error), 500


def safe_str_cmp(a, b):
return hmac.compare_digest(a, b)


@_server.before_request
def check_auth():
if BASIC_AUTH is not None:
auth = request.authorization
if not auth or not (safe_str_cmp(auth.username, BASIC_AUTH[0]) and safe_str_cmp(auth.password, BASIC_AUTH[1])):
return Response('Unauthorized', 401, {'WWW-Authenticate': 'Basic realm="Login required"'})


@_server.before_request
def log_request_info():
logger.debug('Request headers: %s', request.headers)
Expand Down
9 changes: 8 additions & 1 deletion label_studio_ml/default_configs/_wsgi.py.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,14 @@ if __name__ == "__main__":
parser.add_argument(
'--check', dest='check', action='store_true',
help='Validate model instance before launching server')

parser.add_argument('--basic-auth-user',
default=os.environ.get('ML_SERVER_BASIC_AUTH_USER', None),
help='Basic auth user')

parser.add_argument('--basic-auth-pass',
default=os.environ.get('ML_SERVER_BASIC_AUTH_PASS', None),
help='Basic auth pass')

args = parser.parse_args()

# setup logging level
Expand Down
32 changes: 29 additions & 3 deletions label_studio_ml/default_configs/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,47 @@


class NewModel(LabelStudioMLBase):
"""Custom ML Backend model
"""

def setup(self):
"""Configure any paramaters of your model here
"""
self.set("model_version", "0.0.1")


def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -> List[Dict]:
""" Write your inference logic here
:param tasks: [Label Studio tasks in JSON format](https://labelstud.io/guide/task_format.html)
:param context: [Label Studio context in JSON format](https://labelstud.io/guide/ml.html#Passing-data-to-ML-backend)
:return predictions: [Predictions array in JSON format](https://labelstud.io/guide/export.html#Raw-JSON-format-of-completed-tasks)
:param context: [Label Studio context in JSON format](https://labelstud.io/guide/ml_create#Implement-prediction-logic)
:return predictions: [Predictions array in JSON format](https://labelstud.io/guide/export.html#Label-Studio-JSON-format-of-annotated-tasks)
"""
print(f'''\
Run prediction on {tasks}
Received context: {context}
Project ID: {self.project_id}
Label config: {self.label_config}
Parsed JSON Label config: {self.parsed_label_config}''')
Parsed JSON Label config: {self.parsed_label_config}
Extra params: {self.extra_params}''')

# example for simple classification
# return [{
# "model_version": self.get("model_version"),
# "score": 0.12,
# "result": [{
# "id": "vgzE336-a8",
# "from_name": "sentiment",
# "to_name": "text",
# "type": "choices",
# "value": {
# "choices": [ "Negative" ]
# }
# }]
# }]

return []


def fit(self, event, data, **kwargs):
"""
This method is called each time an annotation is created or updated
Expand Down
3 changes: 1 addition & 2 deletions label_studio_ml/examples/easyocr/easyocr_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
import easyocr

from label_studio_ml.model import LabelStudioMLBase
from label_studio_ml.utils import get_image_size, \
get_single_tag_keys, DATA_UNDEFINED_NAME
from label_studio_ml.utils import get_image_size, DATA_UNDEFINED_NAME
from label_studio_tools.core.utils.io import get_data_dir
from botocore.exceptions import ClientError
from urllib.parse import urlparse
Expand Down
17 changes: 14 additions & 3 deletions label_studio_ml/examples/llm_interactive/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
## Interactive LLM labeling

This example server connects Label Studio to OpenAI's API to interact with GPT chat models (gpt-3.5-turbo, gpt-4, etc.).
This example server connects Label Studio to [OpenAI](https://platform.openai.com/) or [Azure](https://azure.microsoft.com/en-us/products/ai-services/openai-service) API to interact with GPT chat models (gpt-3.5-turbo, gpt-4, etc.).

The interactive flow allows you to perform the following scenarios:

Expand Down Expand Up @@ -231,7 +231,18 @@ When deploying the server, you can specify the following parameters as environme
- `PROMPT_TEMPLATE` (default: `"Source Text: {text}\n\nTask Directive: {prompt}"`): The prompt template to use. If `USE_INTERNAL_PROMPT_TEMPLATE` is set to `1`, the server will use
the default internal prompt template. If `USE_INTERNAL_PROMPT_TEMPLATE` is set to `0`, the server will use the prompt template provided
in the input prompt (i.e. the user input from `<TextArea name="my-prompt" ...>`). In the later case, the user has to provide the placeholders that match input task fields. For example, if the user wants to use the `input_text` and `instruction` field from the input task `{"input_text": "user text", "instruction": "user instruction"}`, the user has to provide the prompt template like this: `"Source Text: {input_text}, Custom instruction : {instruction}"`.
- `OPENAI_MODEL` (default: `gpt-3.5-turbo`) : The OpenAI model to use.
- `OPENAI_MODEL` (default: `gpt-3.5-turbo`) : The OpenAI model to use.
- `OPENAI_PROVIDER` (available options: `openai`, `azure`, default - `openai`) : The OpenAI provider to use.
- `TEMPERATURE` (default: `0.7`): The temperature to use for the model.
- `NUM_RESPONSES` (default: `1`): The number of responses to generate in `<TextArea>` output fields. Useful if you want to generate multiple responses and let the user rank the best one.
- `OPENAI_API_KEY`: The OpenAI API key to use. Must be set before deploying the server.
- `OPENAI_API_KEY`: The OpenAI or Azure API key to use. Must be set before deploying the server.

### Azure Configuration

If you are using Azure as your OpenAI provider (`OPENAI_PROVIDER=azure`), you need to specify the following environment variables:

- `AZURE_RESOURCE_ENDPOINT`: This is the endpoint for your Azure resource. It should be set to the appropriate value based on your Azure setup.

- `AZURE_DEPLOYMENT_NAME`: This is the name of your Azure deployment. It should match the name you've given to your deployment in Azure.

- `AZURE_API_VERSION`: This is the version of the Azure API you are using. The default value is `2023-05-15`.
22 changes: 20 additions & 2 deletions label_studio_ml/examples/llm_interactive/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,30 @@ services:
build: .
environment:
- MODEL_DIR=/data/models
# Specify openai model provider: "openai" or "azure"
- OPENAI_PROVIDER=openai
# Specify API key for openai or azure
- OPENAI_API_KEY=
- OPENAI_MODEL=gpt-4
- PROMPT_PREFIX=
# Specify model name for openai or azure (by default it uses "gpt-3.5-turbo-instruct")
- OPENAI_MODEL=
# Internal prompt template for the model is:
# **Source Text**:\n\n"{text}"\n\n**Task Directive**:\n\n"{prompt}"
# if you want to specify task data keys in the prompt (i.e. input <TextArea name="$PROMPT_PREFIX..."/>, set this to 0
- USE_INTERNAL_PROMPT_TEMPLATE=1
# Prompt prefix for the TextArea component in the frontend to be used for the user input
- PROMPT_PREFIX=prompt
# Log level for the server
- LOG_LEVEL=DEBUG
# Number of responses to generate for each request
- NUM_RESPONSES=1
# Temperature for the model
- TEMPERATURE=0.7
# Azure resourse endpoint (in case OPENAI_PROVIDER=azure)
- AZURE_RESOURCE_ENDPOINT=
# Azure deployment name (in case OPENAI_PROVIDER=azure)
- AZURE_DEPLOYMENT_NAME=
# Azure API version (in case OPENAI_PROVIDER=azure)
- AZURE_API_VERSION=2023-05-15
ports:
- 9090:9090
volumes:
Expand Down
Loading

0 comments on commit 5ff2d29

Please sign in to comment.