Skip to content

Commit

Permalink
Merge pull request #8 from geniusrise/feat/concurrency
Browse files Browse the repository at this point in the history
Fix API concurrency lock
  • Loading branch information
ixaxaar authored Feb 24, 2024
2 parents ccffd90 + 1298427 commit 3d4724e
Showing 1 changed file with 22 additions and 11 deletions.
33 changes: 22 additions & 11 deletions geniusrise_vision/base/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,6 @@
sequential_lock = threading.Lock()


def sequential_tool():
with sequential_lock:
# Yield to signal that the request can proceed
yield


# Register the custom tool
cherrypy.tools.sequential = cherrypy.Tool("before_handler", sequential_tool)


class VisionAPI(VisionBulk):
"""
The VisionAPI class inherits from VisionBulk and is designed to facilitate
Expand Down Expand Up @@ -93,6 +83,7 @@ def listen(
compile: bool = False,
flash_attention: bool = False,
better_transformers: bool = False,
concurrent_queries: bool = False,
endpoint: str = "*",
port: int = 3000,
cors_domain: str = "http://localhost:3000",
Expand All @@ -115,6 +106,7 @@ def listen(
compile (bool, optional): Whether to compile the model before fine-tuning. Defaults to False.
flash_attention (bool): Whether to use flash attention 2. Default is False.
better_transformers (bool): Flag to enable Better Transformers optimization for faster processing.
concurrent_queries: (bool): Whether the API supports concurrent API calls (usually false).
endpoint (str, optional): The network endpoint for the server. Defaults to "*".
port (int, optional): The network port for the server. Defaults to 3000.
cors_domain (str, optional): The domain to allow for CORS requests. Defaults to "http://localhost:3000".
Expand All @@ -134,6 +126,7 @@ def listen(
self.compile = compile
self.flash_attention = flash_attention
self.better_transformers = better_transformers
self.concurrent_queries = concurrent_queries
self.model_args = model_args
self.username = username
self.password = password
Expand Down Expand Up @@ -175,6 +168,14 @@ def listen(
# **self.model_args,
)

def sequential_locker():
if self.concurrent_queries:
sequential_lock.acquire()

def sequential_unlocker():
if self.concurrent_queries:
sequential_lock.release()

def CORS():
"""
Configures Cross-Origin Resource Sharing (CORS) for the server.
Expand Down Expand Up @@ -219,6 +220,8 @@ def CORS():
# Configure basic authentication
conf = {
"/": {
"tools.sequential_locker.on": True,
"tools.sequential_unlocker.on": True,
"tools.auth_basic.on": True,
"tools.auth_basic.realm": "geniusrise",
"tools.auth_basic.checkpassword": self.validate_password,
Expand All @@ -227,11 +230,19 @@ def CORS():
}
else:
# Configuration without authentication
conf = {"/": {"tools.CORS.on": True}}
conf = {
"/": {
"tools.sequential_locker.on": True,
"tools.sequential_unlocker.on": True,
"tools.CORS.on": True,
}
}

cherrypy.tools.sequential_locker = cherrypy.Tool("before_handler", sequential_locker)
cherrypy.tools.CORS = cherrypy.Tool("before_handler", CORS)
cherrypy.tree.mount(self, "/api/v1/", conf)
cherrypy.tools.CORS = cherrypy.Tool("before_finalize", CORS)
cherrypy.tools.sequential_unlocker = cherrypy.Tool("before_finalize", sequential_unlocker)
cherrypy.engine.start()
cherrypy.engine.block()

Expand Down

0 comments on commit 3d4724e

Please sign in to comment.