Skip to content

Commit

Permalink
deploy
Browse files Browse the repository at this point in the history
  • Loading branch information
traderpedroso committed Oct 20, 2024
1 parent 7182bff commit e2c0626
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 47 deletions.
3 changes: 1 addition & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ RUN pip install --no-cache-dir \
termcolor \
uvicorn \
griffe==0.48.0 \
python-dotenv \
lmdeploy
python-dotenv

RUN pip install --no-cache-dir git+https://github.com/InternLM/lagent.git

Expand Down
94 changes: 49 additions & 45 deletions mindsearch/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,47 +17,48 @@

def parse_arguments():
import argparse
parser = argparse.ArgumentParser(description='MindSearch API')
parser.add_argument('--lang', default='cn', type=str, help='Language')
parser.add_argument('--model_format',
default='internlm_server',
type=str,
help='Model format')
parser.add_argument('--search_engine',
default='DuckDuckGoSearch',
type=str,
help='Search engine')

parser = argparse.ArgumentParser(description="MindSearch API")
parser.add_argument("--lang", default="en", type=str, help="Language")
parser.add_argument(
"--model_format", default="internlm_server", type=str, help="Model format"
)
parser.add_argument(
"--search_engine", default="DuckDuckGoSearch", type=str, help="Search engine"
)
return parser.parse_args()


args = parse_arguments()
app = FastAPI(docs_url='/')
app = FastAPI(docs_url="/")

app.add_middleware(CORSMiddleware,
allow_origins=['*'],
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'])
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)


class GenerationParams(BaseModel):
inputs: Union[str, List[Dict]]
agent_cfg: Dict = dict()


@app.post('/solve')
@app.post("/solve")
async def run(request: GenerationParams):

def convert_adjacency_to_tree(adjacency_input, root_name):

def build_tree(node_name):
node = {'name': node_name, 'children': []}
node = {"name": node_name, "children": []}
if node_name in adjacency_input:
for child in adjacency_input[node_name]:
child_node = build_tree(child['name'])
child_node['state'] = child['state']
child_node['id'] = child['id']
node['children'].append(child_node)
child_node = build_tree(child["name"])
child_node["state"] = child["state"]
child_node["id"] = child["id"]
node["children"].append(child_node)
return node

return build_tree(root_name)
Expand All @@ -73,8 +74,7 @@ def sync_generator_wrapper():
for response in agent.stream_chat(inputs):
queue.sync_q.put(response)
except Exception as e:
logging.exception(
f'Exception in sync_generator_wrapper: {e}')
logging.exception(f"Exception in sync_generator_wrapper: {e}")
finally:
# Notify async_generator_wrapper that the data generation is complete.
queue.sync_q.put(None)
Expand All @@ -87,9 +87,10 @@ async def async_generator_wrapper():
if response is None: # Ensure that all elements are consumed
break
yield response
if not isinstance(
response,
tuple) and response.state == AgentStatusCode.END:
if (
not isinstance(response, tuple)
and response.state == AgentStatusCode.END
):
break
stop_event.set() # Inform sync_generator_wrapper to stop

Expand All @@ -101,36 +102,39 @@ async def async_generator_wrapper():
node_name = None
origin_adj = deepcopy(agent_return.adjacency_list)
adjacency_list = convert_adjacency_to_tree(
agent_return.adjacency_list, 'root')
assert adjacency_list[
'name'] == 'root' and 'children' in adjacency_list
agent_return.adjacency_list = adjacency_list['children']
agent_return.adjacency_list, "root"
)
assert adjacency_list["name"] == "root" and "children" in adjacency_list
agent_return.adjacency_list = adjacency_list["children"]
agent_return = asdict(agent_return)
agent_return['adj'] = origin_adj
response_json = json.dumps(dict(response=agent_return,
current_node=node_name),
ensure_ascii=False)
yield {'data': response_json}
agent_return["adj"] = origin_adj
response_json = json.dumps(
dict(response=agent_return, current_node=node_name),
ensure_ascii=False,
)
yield {"data": response_json}
# yield f'data: {response_json}\n\n'
except Exception as exc:
msg = 'An error occurred while generating the response.'
msg = "An error occurred while generating the response."
logging.exception(msg)
response_json = json.dumps(
dict(error=dict(msg=msg, details=str(exc))),
ensure_ascii=False)
yield {'data': response_json}
dict(error=dict(msg=msg, details=str(exc))), ensure_ascii=False
)
yield {"data": response_json}
# yield f'data: {response_json}\n\n'
finally:
await stop_event.wait(
) # Waiting for async_generator_wrapper to stop
await stop_event.wait() # Waiting for async_generator_wrapper to stop
queue.close()
await queue.wait_closed()

inputs = request.inputs
agent = init_agent(lang=args.lang, model_format=args.model_format,search_engine=args.search_engine)
agent = init_agent(
lang=args.lang, model_format=args.model_format, search_engine=args.search_engine
)
return EventSourceResponse(generate())


if __name__ == '__main__':
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host='0.0.0.0', port=8002, log_level='info')

uvicorn.run(app, host="0.0.0.0", port=80, log_level="info")

0 comments on commit e2c0626

Please sign in to comment.