Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Live Mode #17

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
10 changes: 10 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,13 @@ RESERVATIONS_GPU='["r_gpu1", "r_gpu2"]'
MAX_TIME_GPU="1:00:00"
SUBMISSION_SSH_KEY="~/.ssh/id_rsa"
FORWARD_PORTS='["8888:8888"]'

# Simulator
TRAINED_MODEL_URI = "/path/to/trained/model"

#WEBSOCKET
WEBSOCKET_PORT=8765
WEBSOCKET_URL="127.0.0.1"

# Publisher
PUBLISHER_PYTHON_FILE="path/to/publisher.py"
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ __pycache__/
*$py.class
test.py
**cache**
mlex_store/*
*.pkl

# output dir
results/
Expand Down
17 changes: 15 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ dependencies = [
"dash==2.9.3",
"dash-core-components==2.0.0",
"dash-bootstrap-components==1.0.2",
"dash-extensions==0.0.52",
"dash-html-components==2.0.0",
"dash-iconify==0.1.2",
"plotly==5.14.1",
Expand All @@ -30,7 +31,8 @@ dependencies = [
"pandas",
"numpy",
"python-dotenv",
"prefect-client==2.14.21"
"prefect-client==2.14.21",
"tiled[client]==0.1.0a118",
]

[project.optional-dependencies]
Expand All @@ -41,7 +43,13 @@ dev = [
"flake8",
"pre-commit",
"pytest-mock",
"tiled[all]",
"tiled[all]==0.1.0a118",
]

simulator = [
"websocket",
"pika",
"aio_pika",
]

[project.urls]
Expand All @@ -50,3 +58,8 @@ Issues = "https://github.com/mlexchange/mlex_latent_explorer/issues/"

[tool.isort]
profile = "black"

[tool.pytest.ini_options]
pythonpath = [
"src"
]
Empty file added simulator/__init__.py
Empty file.
84 changes: 84 additions & 0 deletions simulator/data_simulator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import logging
import os
import time
from datetime import datetime

import pytz
from dotenv import load_dotenv
from tiled.client import from_uri

from src.utils_prefect import schedule_prefect_flow

load_dotenv(".env")

DATA_TILED_URI = os.getenv("DEFAULT_TILED_URI")
DATA_TILED_API_KEY = os.getenv("RESULT_TILED_API_KEY")
FLOW_NAME = os.getenv("FLOW_NAME", "")
PREFECT_TAGS = ["latent-space-explorer-live"]
WRITE_DIR = os.getenv("WRITE_DIR")
RESULT_TILED_URI = os.getenv("RESULT_TILED_URI")
RESULT_TILED_API_KEY = os.getenv("RESULT_TILED_API_KEY", None)
TRAINED_MODEL_URI = os.getenv("TRAINED_MODEL_URI")
TIMEZONE = os.getenv("TIMEZONE", "UTC")
PUBLISHER_PYTHON_FILE = os.getenv("PUBLISHER_PYTHON_FILE")
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

flow = {
"flow_type": "conda",
"params_list": [
{
"conda_env_name": "mlex_dimension_reduction_pca",
"python_file_name": "mlex_dimension_reduction_pca/pca_run.py",
"params": {
"io_parameters": {
"uid_retrieve": "",
"data_uris": [],
"data_tiled_api_key": DATA_TILED_API_KEY,
"data_type": "tiled",
"root_uri": DATA_TILED_URI,
"output_dir": f"{WRITE_DIR}/feature_vectors",
"result_tiled_uri": RESULT_TILED_URI,
"result_tiled_api_key": RESULT_TILED_API_KEY,
"load_model_path": TRAINED_MODEL_URI,
},
"model_parameters": {
"n_components": 2,
},
},
},
{
"conda_env_name": "mlex_rabbitmq_publisher",
"python_file_name": PUBLISHER_PYTHON_FILE,
"params": {"io_parameters": {"uid_retrieve": ""}},
},
],
}


def get_data_list(tiled_uri, tiled_api_key=None):
client = from_uri(tiled_uri, api_key=tiled_api_key)
data_list = client.keys()[0:10]
return data_list


if __name__ == "__main__":
data_list = get_data_list(DATA_TILED_URI, DATA_TILED_API_KEY)

for data_uri in data_list:
logger.info(f"Sending URI {data_uri} for processing.")

new_flow = flow.copy()
new_flow["params_list"][0]["params"]["io_parameters"]["data_uris"] = [data_uri]
current_time = datetime.now(pytz.timezone(TIMEZONE)).strftime(
"%Y/%m/%d %H:%M:%S"
)
job_name = f"Live model training for {data_uri}"
schedule_prefect_flow(
FLOW_NAME,
new_flow,
flow_run_name=f"{job_name} {current_time}",
tags=PREFECT_TAGS + ["train"],
)

time.sleep(10)
63 changes: 63 additions & 0 deletions simulator/websocket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import asyncio
import os

import aio_pika
import websockets
from dotenv import load_dotenv

load_dotenv(".env")

WEBSOCKET_PORT = os.getenv("WEBSOCKET_PORT", 8765)
WEBSOCKET_URL = os.getenv("WEBSOCKET_URL", "localhost")

# Set of connected WebSocket clients
clients = set()


async def register(websocket):
clients.add(websocket)


async def unregister(websocket):
clients.remove(websocket)


async def handler(websocket, path):
await register(websocket)
try:
async for message in websocket:
pass
finally:
await unregister(websocket)


async def main():
# Set up RabbitMQ connection
connection = await aio_pika.connect_robust(f"amqp://guest:guest@{WEBSOCKET_URL}/")
async with connection:
# Creating channel
channel = await connection.channel()
# Declaring queue
queue = await channel.declare_queue("latent_space_explorer", auto_delete=True)

# Start the WebSocket server
start_server = websockets.serve(handler, WEBSOCKET_URL, WEBSOCKET_PORT)

# Run the WebSocket server and the RabbitMQ client concurrently
await asyncio.gather(
start_server,
forward_messages(queue),
)


async def forward_messages(queue):
async with queue.iterator() as queue_iter:
async for message in queue_iter:
async with message.process():
# Forward the message from RabbitMQ to all WebSocket clients
for websocket in clients:
await websocket.send(message.body.decode())


# Run the main function until it completes
asyncio.run(main())
14 changes: 14 additions & 0 deletions simulator/websocket_listener.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import asyncio

import websockets


async def client():
uri = "ws://localhost:8765"
async with websockets.connect(uri) as websocket:
async for message in websocket:
print(f"Received message: {message}")


# Run the client until it completes
asyncio.run(client())
Loading
Loading