Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: dia-1287: single kafka producer for app #160

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 44 additions & 17 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@

ResponseData = TypeVar("ResponseData")

kafka_producer = {}
single_kafka_producer = None


class Response(BaseModel, Generic[ResponseData]):
success: bool = True
Expand Down Expand Up @@ -166,24 +169,48 @@ async def submit_batch(batch: BatchData):
Returns:
Response: Generic response indicating status of request
"""

topic = get_input_topic_name(batch.job_id)
producer = AIOKafkaProducer(
bootstrap_servers=settings.kafka_bootstrap_servers,
value_serializer=lambda v: json.dumps(v).encode("utf-8"),
)
await producer.start()

try:
for record in batch.data:
await producer.send_and_wait(topic, value=record)
except UnknownTopicOrPartitionError:
await producer.stop()
raise HTTPException(
status_code=500, detail=f"{topic=} for job {batch.job_id} not found"
)
finally:
await producer.stop()
global kafka_producer
if os.environ['SINGLE_PRODUCER']=='true':
logger.warning(f"SINGLE PRODUCER RUN")
print(f"SINGLE PRODUCER RUN",flush=True)
global single_kafka_producer
if not single_kafka_producer:
logger.warning(f"creating new kafka_producer_for_topic")
single_kafka_producer = AIOKafkaProducer(
bootstrap_servers=settings.kafka_bootstrap_servers,
value_serializer=lambda v: json.dumps(v).encode("utf-8"),
)
await single_kafka_producer.start()
import time
try:
for record in batch.data:
await single_kafka_producer.send_and_wait(topic, value=record)
time.sleep(.1)
except UnknownTopicOrPartitionError:
raise HTTPException(
status_code=500, detail=f"{topic=} for job {batch.job_id} not found"
)
else:
logger.warning(f"MULTIPLE PRODUCER RUN")
kafka_producer_for_topic=kafka_producer.get(topic,None)
if not kafka_producer_for_topic:
logger.warning(f"creating new kafka_producer_for_topic")
kafka_producer_for_topic = AIOKafkaProducer(
bootstrap_servers=settings.kafka_bootstrap_servers,
value_serializer=lambda v: json.dumps(v).encode("utf-8"),
)
kafka_producer[topic]=kafka_producer_for_topic
await kafka_producer_for_topic.start()
import time
try:
for record in batch.data:
await kafka_producer_for_topic.send_and_wait(topic, value=record)
time.sleep(.1)
except UnknownTopicOrPartitionError:
raise HTTPException(
status_code=500, detail=f"{topic=} for job {batch.job_id} not found"
)

return Response[BatchSubmitted](data=BatchSubmitted(job_id=batch.job_id))

Expand Down
3 changes: 3 additions & 0 deletions server/handlers/result_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,12 @@ def write_header(self):

def __call__(self, result_batch: list[LSEBatchItem]):
logger.debug(f"\n\nHandler received batch: {result_batch}\n\n")
for rec in result_batch:
print(f"results handler, {rec} ",flush=True)

# coerce dicts to LSEBatchItems for validation
result_batch = [LSEBatchItem(**record) for record in result_batch]


# open and write to file
with open(self.output_path, "a") as f:
Expand Down
1 change: 1 addition & 0 deletions server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class Settings(BaseSettings):
kafka_input_consumer_timeout_ms: int = 1500 # 1.5 seconds
kafka_output_consumer_timeout_ms: int = 1500 # 1.5 seconds
task_time_limit_sec: int = 60 * 60 * 6 # 6 hours
single_kafka_producer: bool = True

model_config = SettingsConfigDict(
# have to use an absolute path here so celery workers can find it
Expand Down
25 changes: 25 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from fakeredis import FakeStrictRedis
from fastapi.testclient import TestClient
from server.app import _get_redis_conn
from server.utils import Settings
import os



@pytest.fixture(scope="module")
Expand Down Expand Up @@ -50,8 +53,21 @@ def pytest_collection_modifyitems(config, items):
item.add_marker(skip_server)




@pytest.fixture
def client():
os.environ['SINGLE_PRODUCER'] = 'true'
from server.app import app

with TestClient(app) as client:
yield client


@pytest.fixture
def multiclient():
os.environ['SINGLE_PRODUCER'] = 'false'

from server.app import app

with TestClient(app) as client:
Expand All @@ -61,7 +77,16 @@ def client():
@pytest_asyncio.fixture
async def async_client():
from server.app import app
os.environ['SINGLE_PRODUCER'] = 'true'
async with httpx.AsyncClient(
timeout=10, app=app, base_url="http://localhost:30001"
) as client:
yield client

@pytest_asyncio.fixture
async def multi_async_client():
from server.app import app
os.environ['SINGLE_PRODUCER'] = 'true'
async with httpx.AsyncClient(
timeout=10, app=app, base_url="http://localhost:30001"
) as client:
Expand Down
1 change: 1 addition & 0 deletions tests/jsonrecords.json

Large diffs are not rendered by default.

188 changes: 187 additions & 1 deletion tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,47 @@
"result_handler": {"type": "DummyHandler"},
}

SUBMIT_PAYLOAD_HUMOR = {
"agent": {
"environment": {
"type": "AsyncKafkaEnvironment",
"kafka_bootstrap_servers": "",
"kafka_input_topic": "",
"kafka_output_topic": "",
"timeout_ms": 1,
},
"skills": [
{
"type": "ClassificationSkill",
"name": "text_classifier",
"instructions": "",
"input_template": "{text}",
"output_template": "{output}",
"labels": {
"output": [
"humor",
"not humor",
]
},
}
],
"runtimes": {
"default": {
"type": "AsyncOpenAIChatRuntime",
"model": "gpt-3.5-turbo-0125",
"api_key": OPENAI_API_KEY,
"max_tokens": 10,
"temperature": 0,
"concurrent_clients": 100,
"batch_size": 100,
"timeout": 10,
"verbose": False,
}
},
},
"result_handler": {"type": "DummyHandler"},
}


async def arun_job(
client: httpx.AsyncClient,
Expand Down Expand Up @@ -98,7 +139,7 @@ async def arun_job_and_get_output(
client: httpx.AsyncClient,
streaming_payload_agent: dict,
batch_payload_datas: list[list[dict]],
timeout_sec=10,
timeout_sec=60*60*2,
poll_interval_sec=1,
) -> pd.DataFrame:

Expand Down Expand Up @@ -202,6 +243,75 @@ def test_streaming(client):
output["output"] == expected_output
).all(), "adala did not return expected output"

@pytest.mark.use_openai
@pytest.mark.use_server
def test_streaming_10000(multiclient):
client = multiclient
f = open('jsonrecords.json')
import json
# returns JSON object as
# a dictionary
datarecord = json.load(f)

data = pd.DataFrame.from_records(
datarecord
)
batch_data = data.drop("output", axis=1).to_dict(orient="records")
expected_output = data.set_index("task_id")["output"]

with NamedTemporaryFile(mode="r") as f:

print("filename", f.name, flush=True)

SUBMIT_PAYLOAD_HUMOR["result_handler"] = {
"type": "CSVHandler",
"output_path": f.name,
}

resp = client.post("/jobs/submit-streaming", json=SUBMIT_PAYLOAD_HUMOR)
resp.raise_for_status()
job_id = resp.json()["data"]["job_id"]
batchstartidx=0
batchendidx = 1000
for i in range(0,10):

batch_payload = {
"job_id": job_id,
"data": batch_data[batchstartidx:batchendidx],
}
resp = client.post("/jobs/submit-batch", json=batch_payload)
resp.raise_for_status()
batchendidx+=1000
batchstartidx+=1000
# time.sleep(1)
# batch_payload = {
# "job_id": job_id,
# "data": batch_data[2:],
# }
# resp = client.post("/jobs/submit-batch", json=batch_payload)
# resp.raise_for_status()

timeout_sec = 60*60*3
poll_interval_sec = 1
terminal_statuses = ["Completed", "Failed", "Canceled"]
for _ in range(int(timeout_sec / poll_interval_sec)):
resp = client.get(f"/jobs/{job_id}")
status = resp.json()["data"]["status"]
if status in terminal_statuses:
print("terminal polling ", status, flush=True)
break
print("polling ", status, flush=True)
time.sleep(poll_interval_sec)
assert status == "Completed", status

output = pd.read_csv(f.name).set_index("task_id")
print(f"dataframe length, {len(output.index)}")
output.to_json('outputresult.json', orient='records', lines=True)
assert not output["error"].any(), "adala returned errors"
assert (
output["output"] == expected_output
).all(), "adala did not return expected output"


@pytest.mark.use_openai
@pytest.mark.use_server
Expand Down Expand Up @@ -239,6 +349,82 @@ async def test_streaming_n_concurrent_requests(async_client):
output["output"] == expected_output
).all(), "adala did not return expected output"

@pytest.mark.use_openai
@pytest.mark.use_server
@pytest.mark.asyncio
async def test_streaming_2_concurrent_requests_100000_single_producer(async_client):
client = async_client

# TODO test with n_requests > number of celery workers
n_requests = 2

f = open('jsonrecords.json')
import json
# returns JSON object as
# a dictionary
datarecord = json.load(f)

data = pd.DataFrame.from_records(
datarecord
)
batch_payload_data = data.drop("output", axis=1).to_dict(orient="records")
batch_payload_datas = [batch_payload_data[:1000], batch_payload_data[1000:2000],batch_payload_data[2000:3000],batch_payload_data[3000:4000],batch_payload_data[4000:5000],batch_payload_data[5000:6000],batch_payload_data[6000:7000],batch_payload_data[7000:8000],batch_payload_data[8000:9000],batch_payload_data[9000:10000]]
expected_output = data.set_index("task_id")["output"]

outputs = await asyncio.gather(
*[
arun_job_and_get_output(
client, SUBMIT_PAYLOAD["agent"], batch_payload_datas
)
for _ in range(n_requests)
]
)

for output in outputs:
assert not output["error"].any(), "adala returned errors"
assert (
output["output"] == expected_output
).all(), "adala did not return expected output"


@pytest.mark.use_openai
@pytest.mark.use_server
@pytest.mark.asyncio
async def test_streaming_2_concurrent_requests_100000(multi_async_client):
client = multi_async_client

# TODO test with n_requests > number of celery workers
n_requests = 2

f = open('jsonrecords.json')
import json
# returns JSON object as
# a dictionary
datarecord = json.load(f)

data = pd.DataFrame.from_records(
datarecord
)
batch_payload_data = data.drop("output", axis=1).to_dict(orient="records")
batch_payload_datas = [batch_payload_data[:1000], batch_payload_data[1000:2000],batch_payload_data[2000:3000],batch_payload_data[3000:4000],batch_payload_data[4000:5000],batch_payload_data[5000:6000],batch_payload_data[6000:7000],batch_payload_data[7000:8000],batch_payload_data[8000:9000],batch_payload_data[9000:10000]]
expected_output = data.set_index("task_id")["output"]

outputs = await asyncio.gather(
*[
arun_job_and_get_output(
client, SUBMIT_PAYLOAD["agent"], batch_payload_datas
)
for _ in range(n_requests)
]
)

for output in outputs:
assert not output["error"].any(), "adala returned errors"
assert (
output["output"] == expected_output
).all(), "adala did not return expected output"



@pytest.mark.use_openai
@pytest.mark.use_server
Expand Down
Loading