Skip to content

Commit

Permalink
0.4.1 release prep
Browse files Browse the repository at this point in the history
- Add support for multiple workers & CORS headers (`--workers` & `--cors_origin` cmdline option)
  • Loading branch information
uogbuji committed Aug 6, 2024
1 parent b81d122 commit 225afca
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 7 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
Notable changes to Format based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). Project follows [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

<!--
## [Unreleased]
-->

## [Unreleased]
## [0.4.1] - 20240806

### Added

- demo `demo/zipcode.py`
- support for multiple workers & CORS headers (`--workers` & `--cors_origin` cmdline option)

### Fixed

Expand Down
1 change: 1 addition & 0 deletions demo/arithmetic_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ async def arithmetic_calc(num1=None, num2=None, op=None):

toolio_mm = model_manager(MLX_MODEL_PATH, tool_reg=[arithmetic_calc], trace=True)

# Use this to try parallel function calling
# PROMPT = 'Solve the following calculations: 42 * 42, 24 * 24, 5 * 5, 89 * 75, 42 * 46, 69 * 85, 422 * 420, 753 * 321, 72 * 55, 240 * 204, 789 * 654, 123 * 321, 432 * 89, 564 * 321?' # noqa: E501
PROMPT = 'Solve the following calculation: 4242 * 2424.2'
async def async_main(tmm):
Expand Down
37 changes: 32 additions & 5 deletions pylib/cli/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,29 @@
Note: you can also point `--model` at a downloaded or converted MLX model on local storage.
'''

import os
import json
import time
import os
from contextlib import asynccontextmanager
import warnings

from fastapi import FastAPI, Request, status
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
import click
import uvicorn

from llm_structured_output.util.output import info, warning, debug

from toolio.schema_helper import Model
from toolio.llm_helper import enrich_chat_for_tools, DEFAULT_FLAGS, FLAGS_LOOKUP
from toolio.http_schematics import V1ChatCompletionsRequest, V1ChatMessage, V1ResponseFormatType
from toolio.http_schematics import V1ChatCompletionsRequest, V1ResponseFormatType
from toolio.responder import (ToolCallStreamingResponder, ToolCallResponder,
ChatCompletionResponder, ChatCompletionStreamingResponder)


NUM_CPUS = int(os.cpu_count())
app_params = {}

# Context manager for the FastAPI app's lifespan: https://fastapi.tiangolo.com/advanced/events/
Expand All @@ -57,6 +59,9 @@ async def lifespan(app: FastAPI):
# print(app.state.model.model.__class__, app.state.model.model.model_type)
info(f'Model loaded in {(tdone - tstart)/1000000000.0:.3f}s. Type: {app.state.model.model.model_type}')
app.state.model_flags = FLAGS_LOOKUP.get(app.state.model.model.model_type, DEFAULT_FLAGS)
# Look into exposing control over methods & headers as well
app.add_middleware(CORSMiddleware, allow_origins=app_params['cors_origins'], allow_credentials=True,
allow_methods=["*"], allow_headers=["*"])
yield
# Shutdown code here, if any

Expand Down Expand Up @@ -199,7 +204,29 @@ async def post_v1_chat_completions_impl(req_data: V1ChatCompletionsRequest):
help='Path to JSON schema to be used if not provided via API call.'
'Interpolated into {jsonschema} placeholder in prompts')
@click.option('--llmtemp', default='0.1', type=float, help='LLM sampling temperature')
def main(host, port, model, default_schema, default_schema_file, llmtemp):
@click.option('--workers', type=int, default=0,
help='Number of workers processes to spawn (each utilizes one CPU core).'
'Defaults to $WEB_CONCURRENCY environment variable if available, or 1')
@click.option('--cors_origin', multiple=True,
help='Origin to be permitted for CORS https://fastapi.tiangolo.com/tutorial/cors/')
def main(host, port, model, default_schema, default_schema_file, llmtemp, workers, cors_origin):
app_params.update(model=model, default_schema=default_schema, default_schema_fpath=default_schema_file,
llmtemp=llmtemp)
uvicorn.run('toolio.cli.server:app', host=host, port=port, reload=False)
llmtemp=llmtemp, cors_origins=list(cors_origin))
workers = workers or None
# logger.info(f'Host has {NUM_CPUS} CPU cores')
uvicorn.run('toolio.cli.server:app', host=host, port=port, reload=False, workers=workers)


# Implement log config when we
def UNUSED_log_setup(config):
# Set up logging
import logging
global logger # noqa: PLW0603

main_loglevel = config.get('log', {'level': 'INFO'})['level']
logging.config.dictConfig(config['log'])
# Following 2 lines configure the root logger, so all other loggers in this process space will inherit
# logging.basicConfig(level=main_loglevel, format='%(levelname)s:%(name)s: %(message)s')
logging.getLogger().setLevel(main_loglevel) # Seems redundant, but is necessary. Python logging is quirky
logger = logging.getLogger(__name__)
# logger.addFilter(LocalFilter())
3 changes: 2 additions & 1 deletion pylib/tool/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def calculator(expr=None):
'''
Make an arithmetical, mathematical calculation, using operations such as addition (+), subtraction (-),
multiplication (*), and division (/). Don't forget to use parenthesis for grouping.
**Always use this tool for calculations. Never try to do them yourself**.
**Always use this tool for calculations. Never try to do them yourself. Only use numbers and operators.
Do not include units in numbers!**.
'''
# print(repr(expr))
if not ALLOWED_EXPR_PAT.match(expr):
Expand Down

0 comments on commit 225afca

Please sign in to comment.