diff --git a/.gitignore b/.gitignore
index fbca225..a60da8a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1 +1,2 @@
results/
+cache.db
\ No newline at end of file
diff --git a/label/.gitignore b/label/.gitignore
new file mode 100644
index 0000000..9332d63
--- /dev/null
+++ b/label/.gitignore
@@ -0,0 +1,3 @@
+.file-cache/
+__pycache__/
+cache.db
\ No newline at end of file
diff --git a/label/README.md b/label/README.md
new file mode 100644
index 0000000..c00bf77
--- /dev/null
+++ b/label/README.md
@@ -0,0 +1,127 @@
+
+
+# Surya model connection
+
+The [Surya](https://github.com/VikParuchuri/surya) model connection is a powerful tool that integrates the capabilities of Surya with Label Studio. It is designed to assist in machine learning labeling tasks, specifically those involving Optical Character Recognition (OCR).
+
+The primary function of this connection is to recognize and extract text from images, which can be a crucial step in many machine learning workflows. By automating this process, the Surya model connection can significantly increase efficiency, reducing the time and effort required for manual text extraction.
+
+In the context of Label Studio, this connection enhances the platform's labeling capabilities, allowing users to automatically generate labels for text in images. This can be particularly useful in tasks such as data annotation, document digitization, and more.
+
+## Before you begin
+
+Before you begin, you must install the [Label Studio ML backend](https://github.com/HumanSignal/label-studio-ml-backend?tab=readme-ov-file#quickstart).
+
+This tutorial uses the [`surya` example](https://github.com/xiaoyao9184/docker-surya/tree/master/label).
+
+## Labeling configuration
+
+The Surya model connection can be used with the default labeling configuration for OCR in Label Studio. This configuration typically involves defining the types of labels to be used (e.g., text, handwriting, etc.) and the regions of the image where these labels should be applied.
+
+When setting the labeling configuration, select the **Computer Vision > Optical Character Recognition**. This template is pre-configured for OCR tasks and includes the necessary elements for labeling text in images:
+
+```xml
+
+
+
+
+
+
+
+
+
+
+
+
+
+```
+
+
+> Warning! Please note that the current implementation of the Surya model connection does not support images that are directly uploaded to Label Studio. It is designed to work with images that are hosted publicly on the internet. Therefore, to use this connection, you should ensure that your images are publicly accessible via a URL.
+
+
+## Running with Docker (recommended)
+
+1. Start the Machine Learning backend on `http://localhost:9090` with the prebuilt image:
+
+```bash
+cd docker/up.label@gpu-online
+docker-compose up
+```
+
+2. Validate that backend is running
+
+```bash
+$ curl http://localhost:9090/
+{"status":"UP"}
+```
+
+3. Create a project in Label Studio. Then from the **Model** page in the project settings, [connect the model](https://labelstud.io/guide/ml#Connect-the-model-to-Label-Studio). The default URL is `http://localhost:9090`.
+
+
+## Building from source (advanced)
+
+To build the ML backend from source, you have to clone the repository and build the Docker image:
+
+```bash
+docker build -t xiaoyao9184/surya:master -f ./docker/build@source/dockerfile .
+```
+
+## Running without Docker (advanced)
+
+To run the ML backend without Docker, you have to clone the repository and install all dependencies using conda:
+
+```bash
+conda env create -f ./environment.yml
+```
+
+Then you can start the ML backend:
+
+```bash
+conda activate surya
+label-studio-ml start --root-dir . label
+```
+
+The Surya model connection offers several configuration options that can be set in the `docker-compose.yml` file:
+
+- `BASIC_AUTH_USER`: Specifies the basic auth user for the model server.
+- `BASIC_AUTH_PASS`: Specifies the basic auth password for the model server.
+- `LOG_LEVEL`: Sets the log level for the model server.
+- `WORKERS`: Specifies the number of workers for the model server.
+- `THREADS`: Specifies the number of threads for the model server.
+- `MODEL_DIR`: Specifies the model directory.
+- `LANG_LIST`: Specifies the list of languages to be used by the OCR model, separated by commas (default: `mn,en`).
+- `SCORE_THRESHOLD`: Sets the score threshold to filter out noisy results.
+- `LABEL_MAPPINGS_FILE`: Specifies the file with mappings from COCO labels to custom labels.
+- `LABEL_STUDIO_ACCESS_TOKEN`: Specifies the Label Studio access token.
+- `LABEL_STUDIO_HOST`: Specifies the Label Studio host.
+
+These options allow you to customize the behavior of the Surya model connection to suit your specific needs.
+
+# Customization
+
+The ML backend can be customized by adding your own models and logic inside the `./label` directory.
\ No newline at end of file
diff --git a/label/_wsgi.py b/label/_wsgi.py
new file mode 100644
index 0000000..7af3690
--- /dev/null
+++ b/label/_wsgi.py
@@ -0,0 +1,122 @@
+import os
+import argparse
+import json
+import logging
+import logging.config
+
+logging.config.dictConfig({
+ "version": 1,
+ "disable_existing_loggers": False,
+ "formatters": {
+ "standard": {
+ "format": "[%(asctime)s] [%(levelname)s] [%(name)s::%(funcName)s::%(lineno)d] %(message)s"
+ }
+ },
+ "handlers": {
+ "console": {
+ "class": "logging.StreamHandler",
+ "level": os.getenv('LOG_LEVEL'),
+ "stream": "ext://sys.stdout",
+ "formatter": "standard"
+ }
+ },
+ "root": {
+ "level": os.getenv('LOG_LEVEL'),
+ "handlers": [
+ "console"
+ ],
+ "propagate": True
+ }
+})
+
+from label_studio_ml.api import init_app
+from model import SuryaOCR
+
+
+_DEFAULT_CONFIG_PATH = os.path.join(os.path.dirname(__file__), 'config.json')
+
+
+def get_kwargs_from_config(config_path=_DEFAULT_CONFIG_PATH):
+ if not os.path.exists(config_path):
+ return dict()
+ with open(config_path) as f:
+ config = json.load(f)
+ assert isinstance(config, dict)
+ return config
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description='Label studio')
+ parser.add_argument(
+ '-p', '--port', dest='port', type=int, default=9090,
+ help='Server port')
+ parser.add_argument(
+ '--host', dest='host', type=str, default='0.0.0.0',
+ help='Server host')
+ parser.add_argument(
+ '--kwargs', '--with', dest='kwargs', metavar='KEY=VAL', nargs='+', type=lambda kv: kv.split('='),
+ help='Additional LabelStudioMLBase model initialization kwargs')
+ parser.add_argument(
+ '-d', '--debug', dest='debug', action='store_true',
+ help='Switch debug mode')
+ parser.add_argument(
+ '--log-level', dest='log_level', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], default=None,
+ help='Logging level')
+ parser.add_argument(
+ '--model-dir', dest='model_dir', default=os.path.dirname(__file__),
+ help='Directory where models are stored (relative to the project directory)')
+ parser.add_argument(
+ '--check', dest='check', action='store_true',
+ help='Validate model instance before launching server')
+ parser.add_argument('--basic-auth-user',
+ default=os.environ.get('ML_SERVER_BASIC_AUTH_USER', None),
+ help='Basic auth user')
+
+ parser.add_argument('--basic-auth-pass',
+ default=os.environ.get('ML_SERVER_BASIC_AUTH_PASS', None),
+ help='Basic auth pass')
+
+ args = parser.parse_args()
+
+ # setup logging level
+ if args.log_level:
+ logging.root.setLevel(args.log_level)
+
+ def isfloat(value):
+ try:
+ float(value)
+ return True
+ except ValueError:
+ return False
+
+ def parse_kwargs():
+ param = dict()
+ for k, v in args.kwargs:
+ if v.isdigit():
+ param[k] = int(v)
+ elif v == 'True' or v == 'true':
+ param[k] = True
+ elif v == 'False' or v == 'false':
+ param[k] = False
+ elif isfloat(v):
+ param[k] = float(v)
+ else:
+ param[k] = v
+ return param
+
+ kwargs = get_kwargs_from_config()
+
+ if args.kwargs:
+ kwargs.update(parse_kwargs())
+
+ if args.check:
+ print('Check "' + SuryaOCR.__name__ + '" instance creation..')
+ model = SuryaOCR(**kwargs)
+
+ app = init_app(model_class=SuryaOCR, basic_auth_user=args.basic_auth_user, basic_auth_pass=args.basic_auth_pass)
+
+ app.run(host=args.host, port=args.port, debug=args.debug)
+
+else:
+ # for uWSGI use
+ app = init_app(model_class=SuryaOCR)
diff --git a/label/label_mappings.json b/label/label_mappings.json
new file mode 100644
index 0000000..9e26dfe
--- /dev/null
+++ b/label/label_mappings.json
@@ -0,0 +1 @@
+{}
\ No newline at end of file
diff --git a/label/model.py b/label/model.py
new file mode 100644
index 0000000..729f685
--- /dev/null
+++ b/label/model.py
@@ -0,0 +1,203 @@
+import os
+import json
+import boto3
+import logging
+
+from uuid import uuid4
+
+from surya.detection import batch_text_detection
+from surya.input.pdflines import get_page_text_lines, get_table_blocks
+from surya.layout import batch_layout_detection
+from surya.model.detection.model import load_model, load_processor
+from surya.model.layout.model import load_model as load_layout_model
+from surya.model.layout.processor import load_processor as load_layout_processor
+from surya.model.recognition.model import load_model as load_rec_model
+from surya.model.recognition.processor import load_processor as load_rec_processor
+from surya.model.table_rec.model import load_model as load_table_model
+from surya.model.table_rec.processor import load_processor as load_table_processor
+from surya.model.ocr_error.model import load_model as load_ocr_error_model, load_tokenizer as load_ocr_error_processor
+from surya.postprocessing.heatmap import draw_polys_on_image, draw_bboxes_on_image
+from surya.ocr import run_ocr
+from surya.postprocessing.text import draw_text_on_image
+from PIL import Image
+from surya.languages import CODE_TO_LANGUAGE
+from surya.input.langs import replace_lang_with_code
+from surya.schema import OCRResult, TextDetectionResult, LayoutResult, TableResult
+from surya.settings import settings
+from surya.tables import batch_table_recognition
+from surya.postprocessing.util import rescale_bbox
+from pdftext.extraction import plain_text_output
+from surya.ocr_error import batch_ocr_error_detection
+
+from typing import List, Dict, Optional
+from label_studio_ml.model import LabelStudioMLBase
+from label_studio_ml.response import ModelResponse
+from label_studio_ml.utils import get_image_size, DATA_UNDEFINED_NAME
+from label_studio_sdk._extensions.label_studio_tools.core.utils.io import get_local_path
+from botocore.exceptions import ClientError
+from urllib.parse import urlparse
+
+logger = logging.getLogger(__name__)
+
+# load models
+det_model, det_processor = load_model(), load_processor()
+rec_model, rec_processor = load_rec_model(), load_rec_processor()
+
+class SuryaOCR(LabelStudioMLBase):
+ """Custom ML Backend model
+ """
+ LANG_LIST = [lang for lang in os.getenv('LANG_LIST', '').split(',') if lang]
+
+ # score threshold to wipe out noisy results
+ SCORE_THRESHOLD = float(os.getenv('SCORE_THRESHOLD', 0.3))
+ # file with mappings from COCO labels to custom labels {"airplane": "Boeing"}
+ LABEL_MAPPINGS_FILE = os.getenv('LABEL_MAPPINGS_FILE', 'label_mappings.json')
+
+ # Label Studio image upload folder:
+ # should be used only in case you use direct file upload into Label Studio instead of URLs
+ LABEL_STUDIO_ACCESS_TOKEN = (
+ os.environ.get("LABEL_STUDIO_ACCESS_TOKEN") or os.environ.get("LABEL_STUDIO_API_KEY")
+ )
+ LABEL_STUDIO_HOST = (
+ os.environ.get("LABEL_STUDIO_HOST") or os.environ.get("LABEL_STUDIO_URL")
+ )
+
+ MODEL_DIR = os.getenv('MODEL_DIR', '.')
+
+ _label_map = {}
+
+ def setup(self):
+ """Configure any paramaters of your model here
+ """
+ self.set("model_version", f'{self.__class__.__name__}-v0.0.1')
+
+ if self.LABEL_MAPPINGS_FILE and os.path.exists(self.LABEL_MAPPINGS_FILE):
+ with open(self.LABEL_MAPPINGS_FILE, 'r') as f:
+ self._label_map = json.load(f)
+
+ def _get_image_url(self, task, value):
+ # TODO: warning! currently only s3 presigned urls are supported with the default keys
+ # also it seems not be compatible with file directly uploaded to Label Studio
+ # check RND-2 for more details and fix it later
+ image_url = task['data'].get(value) or task['data'].get(DATA_UNDEFINED_NAME)
+
+ if image_url.startswith('s3://'):
+ # presign s3 url
+ r = urlparse(image_url, allow_fragments=False)
+ bucket_name = r.netloc
+ key = r.path.lstrip('/')
+ client = boto3.client('s3')
+ try:
+ image_url = client.generate_presigned_url(
+ ClientMethod='get_object',
+ Params={'Bucket': bucket_name, 'Key': key}
+ )
+ except ClientError as exc:
+ logger.warning(f'Can\'t generate presigned URL for {image_url}. Reason: {exc}')
+ return image_url
+
+ def predict_single(self, task):
+ logger.debug('Task data: %s', task['data'])
+ from_name_poly, to_name, value = self.get_first_tag_occurence('Polygon', 'Image')
+ from_name_labels, _, _ = self.get_first_tag_occurence('Polygon', 'Image')
+ from_name_trans, _, _ = self.get_first_tag_occurence('TextArea', 'Image')
+ labels = self.label_interface.labels
+ labels = sum([list(l) for l in labels], [])
+ if len(labels) > 1:
+ logger.warning('More than one label in the tag. Only the first one will be used: %s', labels[0])
+ label = labels[0]
+
+ image_url = self._get_image_url(task, value)
+ cache_dir = os.path.join(self.MODEL_DIR, '.file-cache')
+ os.makedirs(cache_dir, exist_ok=True)
+ logger.debug(f'Using cache dir: {cache_dir}')
+ image_path = get_local_path(
+ image_url,
+ cache_dir=cache_dir,
+ hostname=self.LABEL_STUDIO_HOST,
+ access_token=self.LABEL_STUDIO_ACCESS_TOKEN,
+ task_id=task.get('id')
+ )
+
+ # run ocr
+ img_pil = Image.open(image_path).convert("RGB")
+ langs = self.LANG_LIST.copy()
+ replace_lang_with_code(langs)
+ model_results = run_ocr([img_pil], [langs],
+ det_model, det_processor,
+ rec_model, rec_processor,
+ highres_images=[img_pil])[0]
+
+ if not model_results:
+ return
+ img_width, img_height = get_image_size(image_path)
+ result = []
+ all_scores = []
+ for line in model_results.text_lines:
+ if not line:
+ logger.warning('Empty result from the model')
+ continue
+ score = line.confidence
+ if score < self.SCORE_THRESHOLD:
+ logger.info(f'Skipping result with low score: {score}')
+ continue
+
+ rel_pnt = []
+ for rp in line.polygon:
+ if rp[0] > img_width or rp[1] > img_height:
+ continue
+ rel_pnt.append([(rp[0] / img_width) * 100, (rp[1] / img_height) * 100])
+
+ # must add one for the polygon
+ id_gen = str(uuid4())[:4]
+ result.append({
+ 'original_width': img_width,
+ 'original_height': img_height,
+ 'image_rotation': 0,
+ 'value': {
+ 'points': rel_pnt,
+ },
+ 'id': id_gen,
+ 'from_name': from_name_poly,
+ 'to_name': to_name,
+ 'type': 'polygon',
+ 'origin': 'manual',
+ 'score': score,
+ })
+ # and one for the transcription
+ result.append({
+ 'original_width': img_width,
+ 'original_height': img_height,
+ 'image_rotation': 0,
+ 'value': {
+ 'points': rel_pnt,
+ 'labels': [label],
+ "text": [line.text]
+ },
+ 'id': id_gen,
+ 'from_name': from_name_trans,
+ 'to_name': to_name,
+ 'type': 'textarea',
+ 'origin': 'manual',
+ 'score': score,
+ })
+ all_scores.append(score)
+
+ return {
+ 'result': result,
+ 'score': sum(all_scores) / max(len(all_scores), 1),
+ 'model_version': self.get('model_version'),
+ }
+
+ def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -> ModelResponse:
+ predictions = []
+ for task in tasks:
+ # TODO: implement is_skipped() function
+ # if is_skipped(task):
+ # continue
+
+ prediction = self.predict_single(task)
+ if prediction:
+ predictions.append(prediction)
+
+ return ModelResponse(predictions=predictions, model_versions=self.get('model_version'))
diff --git a/label/requirements.txt b/label/requirements.txt
new file mode 100644
index 0000000..5cb6ac9
--- /dev/null
+++ b/label/requirements.txt
@@ -0,0 +1,2 @@
+surya-ocr==0.8.1
+boto3==1.35.91
\ No newline at end of file
diff --git a/label/requirements_base.txt b/label/requirements_base.txt
new file mode 100644
index 0000000..68ce357
--- /dev/null
+++ b/label/requirements_base.txt
@@ -0,0 +1,2 @@
+gunicorn==22.0.0
+label-studio-ml @ git+https://github.com/HumanSignal/label-studio-ml-backend.git
\ No newline at end of file
diff --git a/label/requirements_dev.txt b/label/requirements_dev.txt
new file mode 100644
index 0000000..243c482
--- /dev/null
+++ b/label/requirements_dev.txt
@@ -0,0 +1 @@
+boto3==1.35.91
\ No newline at end of file
diff --git a/label/requirements_test.txt b/label/requirements_test.txt
new file mode 100644
index 0000000..fb73904
--- /dev/null
+++ b/label/requirements_test.txt
@@ -0,0 +1,3 @@
+pytest
+pytest-cov
+responses==0.13.0
\ No newline at end of file
diff --git a/label/test_api.py b/label/test_api.py
new file mode 100644
index 0000000..9dd7163
--- /dev/null
+++ b/label/test_api.py
@@ -0,0 +1,94 @@
+"""
+This file contains tests for the API of your model. You can run these tests by installing test requirements:
+
+ ```bash
+ pip install -r requirements-test.txt
+ ```
+Then execute `pytest` in the directory of this file.
+
+- Change `NewModel` to the name of the class in your model.py file.
+- Change the `request` and `expected_response` variables to match the input and output of your model.
+"""
+import os.path
+
+import pytest
+import json
+from model import SuryaOCR
+import responses
+
+
+@pytest.fixture
+def client():
+ from _wsgi import init_app
+ app = init_app(model_class=SuryaOCR)
+ app.config['TESTING'] = True
+ with app.test_client() as client:
+ yield client
+
+
+@pytest.fixture
+def model_dir_env(tmp_path, monkeypatch):
+ model_dir = tmp_path / "model_dir"
+ model_dir.mkdir()
+ monkeypatch.setattr(SuryaOCR, 'MODEL_DIR', str(model_dir))
+ return model_dir
+
+
+@responses.activate
+def test_predict(client, model_dir_env):
+ responses.add(
+ responses.GET,
+ 'http://test_predict.surya.ml-backend.com/image.jpeg',
+ body=open(os.path.join(os.path.dirname(__file__), 'test_images', 'image.jpeg'), 'rb').read(),
+ status=200
+ )
+ request = {
+ 'tasks': [{
+ 'data': {
+ 'image': 'http://test_predict.surya.ml-backend.com/image.jpeg'
+ }
+ }],
+ # Your labeling configuration here
+ 'label_config': '''
+
+
+
+
+
+
+
+
+
+
+
+
+
+'''
+ }
+
+ response = client.post('/predict', data=json.dumps(request), content_type='application/json')
+ assert response.status_code == 200
+ response = json.loads(response.data)
+ expected_texts = {
+ 'IZIN SOLUTIONS',
+ 'Swipe >',
+ 'I',
+ 'PUNY',
+ 'KENAPA',
+ 'Kenapa Harus Punya IMB?',
+ 'HARUS'
+ }
+ texts_response = set()
+ for r in response['results'][0]['result']:
+ if r['from_name'] == 'transcription':
+ assert r['value']['labels'][0] == 'Text'
+ texts_response.add(r['value']['text'][0])
+ assert texts_response == expected_texts
diff --git a/label/test_images/image.jpeg b/label/test_images/image.jpeg
new file mode 100644
index 0000000..e164d01
Binary files /dev/null and b/label/test_images/image.jpeg differ