Skip to content

Commit

Permalink
[FEATURE] Use permanent processes to handle fixed commits in the comp…
Browse files Browse the repository at this point in the history
…ilation server (#489)

Introduce permanent processes to handle fixed commits in the compilation
server.

Every high level worker might create up to 5 permanent processes for
compilation of fixed version of hidet. It allow to save time on `import
hidet`.

Test manually:
 - 6th version of hidet stops old process and start new one. 
 - Comp time for individual job improved
  • Loading branch information
vadiklyutiy authored Jan 20, 2025
1 parent 7f4039b commit 712a9cd
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 53 deletions.
4 changes: 3 additions & 1 deletion apps/compile_server/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
FROM nvidia/cuda:12.2.0-devel-ubuntu22.04

COPY ./run.py /app/run.py
COPY ./requirements.txt /app/requirements.txt
WORKDIR /app

ENV TZ=America/Toronto
Expand All @@ -16,7 +17,8 @@ RUN apt-get update && apt-get install -y \
&& rm -rf /var/lib/apt/lists/* \
&& ln -s /usr/bin/python3 /usr/bin/python \
&& python -m pip install --upgrade pip \
&& python -m pip install filelock requests gunicorn flask cmake
&& python -m pip install filelock requests gunicorn flask cmake \
&& python -m pip install -r ./requirements.txt

EXPOSE 3281

Expand Down
63 changes: 40 additions & 23 deletions apps/compile_server/resources/compilation.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from typing import Dict, Any, List, Tuple
from typing import Dict, Any, Tuple
import time
import re
import sys
import os
import traceback
import threading
import requests
import subprocess
import zipfile
import logging
Expand All @@ -17,16 +15,32 @@
from hashlib import sha256
from filelock import FileLock

from .compile_worker import CompilationWorkers

'''
The compilation server will launch as many Flask applications as there are vCPUs (using gunicorn).
Each Flask application (i.e., our compilation server process) will handle the requests,
and there will be at most vCPU number of requests being processed at the same time.
Each process will maintain a pool of compilation workers with max_workers=5 (i.e., independent processes),
with a specific version of hidet that has been imported in every process.
The job will (try to) be dispatched to a worker with the same hidet version first.
If no such worker exists, then a new one will be created to replace an existing one.
Increasing the `max_workers` in `CompilationWorkers` init will (potentially) consume more memory
(thanks to fork, this problem will not get severe) and create more processes (max_workers * vCPU in total).
Reducing the max_workers will reduce the opportunity to avoid importing hidet with the same version in nearby jobs.
'''

lock = threading.Lock()
logger = logging.Logger(__name__)

pid = os.getpid()
jobs_dir = os.path.join(os.getcwd(), 'jobs')
repos_dir = os.path.join(os.getcwd(), 'repos')
commits_dir = os.path.join(os.getcwd(), 'commits')
results_dir = os.path.join(os.getcwd(), 'results')
JOBS_DIR = os.path.join(os.getcwd(), 'jobs')
REPOS_DIR = os.path.join(os.getcwd(), 'repos')
COMMITS_DIR = os.path.join(os.getcwd(), 'commits')
RESULTS_DIR = os.path.join(os.getcwd(), 'results')

compile_script = os.path.join(os.path.dirname(__file__), 'compile_worker.py')
compilation_workers = CompilationWorkers(max_workers=5)


def should_update(repo_timestamp) -> bool:
Expand All @@ -39,10 +53,10 @@ def should_update(repo_timestamp) -> bool:


def clone_github_repo(owner: str, repo: str, version: str) -> str:
repo_dir = os.path.join(repos_dir, "{}_{}".format(owner, repo))
repo_timestamp = os.path.join(repos_dir, "{}_{}_timestamp".format(owner, repo))
repo_dir = os.path.join(REPOS_DIR, "{}_{}".format(owner, repo))
repo_timestamp = os.path.join(REPOS_DIR, "{}_{}_timestamp".format(owner, repo))
os.makedirs(repo_dir, exist_ok=True)
with FileLock(os.path.join(repos_dir, '{}_{}.lock'.format(owner, repo))):
with FileLock(os.path.join(REPOS_DIR, '{}_{}.lock'.format(owner, repo))):
if not os.path.exists(os.path.join(repo_dir, '.git')):
repo = git.Repo.clone_from(
url="https://github.com/{}/{}.git".format(owner, repo),
Expand Down Expand Up @@ -76,12 +90,12 @@ def clone_github_repo(owner: str, repo: str, version: str) -> str:
repo.git.checkout(version)
commit_id = repo.head.commit.hexsha

commit_dir = os.path.join(commits_dir, commit_id)
commit_dir = os.path.join(COMMITS_DIR, commit_id)
if os.path.exists(commit_dir):
return commit_id
with FileLock(os.path.join(commits_dir, commit_id + '.lock')):
repo.git.archive(commit_id, format='zip', output=os.path.join(commits_dir, f'{commit_id}.zip'))
with zipfile.ZipFile(os.path.join(commits_dir, f'{commit_id}.zip'), 'r') as zip_ref:
with FileLock(os.path.join(COMMITS_DIR, commit_id + '.lock')):
repo.git.archive(commit_id, format='zip', output=os.path.join(COMMITS_DIR, f'{commit_id}.zip'))
with zipfile.ZipFile(os.path.join(COMMITS_DIR, f'{commit_id}.zip'), 'r') as zip_ref:
os.makedirs(commit_dir, exist_ok=True)
zip_ref.extractall(commit_dir)
# build the hidet
Expand Down Expand Up @@ -139,8 +153,8 @@ def post(self):
}

job_id: str = sha256(commit_id.encode() + workload).hexdigest()
job_path = os.path.join(jobs_dir, job_id + '.pickle')
job_response_path = os.path.join(jobs_dir, job_id + '.response')
job_path = os.path.join(JOBS_DIR, job_id + '.pickle')
job_response_path = os.path.join(JOBS_DIR, job_id + '.response')

print('[{}] Received a job: {}'.format(pid, job_id[:16]))

Expand All @@ -151,22 +165,25 @@ def post(self):
return pickle.load(f)

# write the job to the disk
job_lock = os.path.join(jobs_dir, job_id + '.lock')
job_lock = os.path.join(JOBS_DIR, job_id + '.lock')
with FileLock(job_lock):
if not os.path.exists(job_path):
with open(job_path, 'wb') as f:
pickle.dump(job, f)

version_path = os.path.join(COMMITS_DIR, commit_id)
with lock: # Only one thread can access the following code at the same time
print('[{}] Start compiling: {}'.format(pid, job_id[:16]))
ret = subprocess.run([sys.executable, compile_script, '--job_id', job_id])
print('[{}] Start compiling: {}'.format(pid, job_id[:16]), flush=True)
start_time = time.time()
compilation_workers.run_and_wait_job(job_id, version_path)
end_time = time.time()

# respond to the client
response_path = os.path.join(jobs_dir, job_id + '.response')
response_path = os.path.join(JOBS_DIR, job_id + '.response')
if not os.path.exists(response_path):
raise RuntimeError('Can not find the response file:\n{}{}'.format(ret.stderr, ret.stdout))
raise RuntimeError('Can not find the response file')
else:
print('[{}] Finish compiling: {}'.format(pid, job_id[:16]))
print(f'[{pid}] Finish compiling: {job_id[:16]} in {end_time - start_time:.2f}s', flush=True)
with open(response_path, 'rb') as f:
response: Tuple[Dict, int] = pickle.load(f)
return response
Expand Down
108 changes: 79 additions & 29 deletions apps/compile_server/resources/compile_worker.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,21 @@
from typing import Dict, Any, List, Tuple, Sequence, Union
from typing import Dict, Any, Sequence, Union
import os
import traceback
import argparse
import sys
import re
import subprocess
import zipfile
import logging
import pickle
import git
from hashlib import sha256
from filelock import FileLock
import multiprocessing
import importlib

logger = logging.Logger(__name__)

jobs_dir = os.path.join(os.getcwd(), 'jobs')
repos_dir = os.path.join(os.getcwd(), 'repos')
commits_dir = os.path.join(os.getcwd(), 'commits')
results_dir = os.path.join(os.getcwd(), 'results')
JOBS_DIR = os.path.join(os.getcwd(), 'jobs')
REPOS_DIR = os.path.join(os.getcwd(), 'repos')
COMMITS_DIR = os.path.join(os.getcwd(), 'commits')
RUSULTS_DIR = os.path.join(os.getcwd(), 'results')


def save_response(response, response_file: str):
Expand All @@ -27,14 +25,14 @@ def save_response(response, response_file: str):

def compile_job(job_id: str):
try:
job_file = os.path.join(jobs_dir, job_id + '.pickle')
job_file = os.path.join(JOBS_DIR, job_id + '.pickle')
if not os.path.exists(job_file):
# job not found
return 1

job_lock = os.path.join(jobs_dir, job_id + '.lock')
job_lock = os.path.join(JOBS_DIR, job_id + '.lock')
with FileLock(job_lock):
response_file = os.path.join(jobs_dir, job_id + '.response')
response_file = os.path.join(JOBS_DIR, job_id + '.response')
if os.path.exists(response_file):
# job already compiled
return 0
Expand All @@ -45,11 +43,8 @@ def compile_job(job_id: str):

# import the hidet from the commit
commit_id: str = job['commit_id']
commit_dir = os.path.join(commits_dir, commit_id)
sys.path.insert(0, os.path.join(commit_dir, 'python'))
import hidet # import the hidet from the commit

# load the workload
import hidet
# load the workload
workload: Dict[str, Any] = pickle.loads(job['workload'])
ir_module: Union[hidet.ir.IRModule, Sequence[hidet.ir.IRModule]] = workload['ir_module']
target: str = workload['target']
Expand All @@ -59,10 +54,10 @@ def compile_job(job_id: str):
module_string = str(ir_module)
key = module_string + target + output_kind + commit_id
hash_digest: str = sha256(key.encode()).hexdigest()
zip_file_path: str = os.path.join(results_dir, hash_digest + '.zip')
zip_file_path: str = os.path.join(RUSULTS_DIR, hash_digest + '.zip')
if not os.path.exists(zip_file_path):
output_dir: str = os.path.join(results_dir, hash_digest)
with FileLock(os.path.join(results_dir, f'{hash_digest}.lock')):
output_dir: str = os.path.join(RUSULTS_DIR, hash_digest)
with FileLock(os.path.join(RUSULTS_DIR, f'{hash_digest}.lock')):
if not os.path.exists(os.path.join(output_dir, 'lib.so')):
hidet.drivers.build_ir_module(
ir_module,
Expand All @@ -88,12 +83,67 @@ def compile_job(job_id: str):
return 0


def main():
parser = argparse.ArgumentParser()
parser.add_argument('--job_id', type=str, required=True)
args = parser.parse_args()
exit(compile_job(args.job_id))


if __name__ == '__main__':
main()
# Worker process function to handle compilation jobs using a specific version of the 'hidet' module.
def worker_process(version, job_queue, result_queue, parent_pid):
sys.path.insert(0, os.path.join(version, 'python')) # Ensure the version path is first in sys.path
print(f"[{parent_pid}] Worker loaded hidet version from {version}", flush=True)

while True:
job = job_queue.get()
if job == "STOP":
print(f"[{parent_pid}] Shutting down worker for version: {version}", flush=True)
break

# Compile
job_id = job
print(f"[{parent_pid}] Worker processing job {job_id[:16]} with hidet version {version}", flush=True)
compile_job(job_id)
result_queue.put((job_id, 'DONE'))


class CompilationWorkers:
"""
A class to manage a pool of compilation workers.
It is needed to avoid the overhead of loading the hidet module for every job.
Every worker processes a compilation with a fixed version of hidet (fixed commit hash).
One worker per version.
Only one worker is compiling at the same time. No concurrent compilation.
Concurrency compilation is processed on upper level.
"""
def __init__(self, max_workers: int = 5):
self.max_workers = max_workers
self.workers = {} # {version_path: (worker_process, job_queue)}
self.result_queue = multiprocessing.Queue()

def _get_or_create_worker(self, version_path):
# If a worker for the version exists, return it
if version_path in self.workers:
return self.workers[version_path]

# If the worker pool is full, remove the oldest worker
if len(self.workers) >= self.max_workers:
_, (worker, job_queue) = self.workers.popitem()
job_queue.put("STOP") # Send shutdown signal to the removing worker
worker.join() # Wait for it to exit

# Create a new worker for the version
job_queue = multiprocessing.Queue()
worker = multiprocessing.Process(target=worker_process,
args=(version_path, job_queue, self.result_queue, os.getpid())
)
worker.start()
self.workers[version_path] = (worker, job_queue)
return self.workers[version_path]


def run_and_wait_job(self, job_id, version_path):
# Run the job and wait until it is finished
_, job_queue = self._get_or_create_worker(version_path)
job_queue.put(job_id)
self.result_queue.get() # multiprocessing.Queue.get() waits until a new item is available

def shutdown(self):
for _, (worker, job_queue) in self.workers.items():
job_queue.put("STOP")
worker.join()
print(f"[{os.getpgid}] All compilation workers are shuted down.", flush=True)

0 comments on commit 712a9cd

Please sign in to comment.