Skip to content

Commit

Permalink
clean up typing in eval and handle NoneType edge case for headers
Browse files Browse the repository at this point in the history
  • Loading branch information
alexzhang13 committed Feb 1, 2025
1 parent 104c4b6 commit 12b3a0e
Showing 1 changed file with 18 additions and 14 deletions.
32 changes: 18 additions & 14 deletions src/discord-cluster-manager/run_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import subprocess
import time
from pathlib import Path
from typing import Optional

from consts import CUDA_FLAGS, ExitCode

Expand Down Expand Up @@ -55,31 +56,34 @@ def _limit_length(text: str, max_len: int = 16384):
for i, line in enumerate(lines):
size += len(line) + 1
if size + 100 > max_len:
lines = lines[:i] + [f"[...] {len(lines)-i} lines omitted"]
lines = lines[:i] + [f"[...] {len(lines) - i} lines omitted"]
return "\n".join(lines)
return text


def _create_files(files: dict[str, str]):
def _create_files(files: Optional[dict[str, str]]):
"""
Create text files
Args:
files: A dictionary mapping file names to their contents.
Raises:
AssertionError, if the file is not within the current working directory.
"""
if files is None:
return

for name, content in files.items():
assert Path(name).resolve().is_relative_to(Path.cwd())
Path(name).write_text(content)


def compile_cuda_script( # # noqa: C901
files: list[str],
arch: int = None,
include_dirs: list[str] = None,
defines: dict[str, str] = None,
libraries: list[str] = None,
flags: list[str] = None,
arch: Optional[int] = None,
include_dirs: Optional[list[str]] = None,
defines: Optional[dict[str, str]] = None,
libraries: Optional[list[str]] = None,
flags: Optional[list[str]] = None,
verbose: bool = False,
) -> CompileResult:
"""
Expand Down Expand Up @@ -227,12 +231,12 @@ def run_program(args: list[str], seed: int) -> RunResult:

def run_cuda_script( # # noqa: C901
sources: dict[str, str],
headers: dict[str, str] = None,
arch: int = None,
include_dirs: list[str] = None,
defines: dict[str, str] = None,
libraries: list[str] = None,
flags: list[str] = None,
headers: Optional[dict[str, str]] = None,
arch: Optional[int] = None,
defines: Optional[dict[str, str]] = None,
include_dirs: Optional[list[str]] = None,
libraries: Optional[list[str]] = None,
flags: Optional[list[str]] = None,
seed: int = 42,
) -> tuple[CompileResult, RunResult]:
"""
Expand Down Expand Up @@ -283,7 +287,7 @@ def run_cuda_script( # # noqa: C901
# cleaning up all source files _before_ we let the user code run, just in
# case there's something in there that the user isn't supposed to snoop
finally:
tmp_files = list(sources.keys()) + list(headers.keys())
tmp_files = list(sources.keys()) + list((headers or {}).keys())
for f in tmp_files:
if os.path.exists(f):
os.remove(f)
Expand Down

0 comments on commit 12b3a0e

Please sign in to comment.