Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BC] Refactor _KEEP_TEMPS for reusability #376

Merged
merged 20 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions compiler_opt/rl/compilation_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,18 @@ def __exit__(self, exc, value, tb):
pass


def get_workdir_context():
"""Return a context which manages how the temperory directories are handled.

When the flag keep_temps is specified temporary directories are stored in
keep_temps."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: The """ string terminator should be on a separate line.

if _KEEP_TEMPS.value is not None:
tempdir_context = NonTemporaryDirectory(dir=_KEEP_TEMPS.value)
else:
tempdir_context = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with
return tempdir_context


def _overwrite_trajectory_reward(sequence_example: tf.train.SequenceExample,
reward: float) -> tf.train.SequenceExample:
"""Overwrite the reward in the trace (sequence_example) with the given one.
Expand Down Expand Up @@ -401,10 +413,7 @@ def collect_data(self,
compilation_runner.ProcessKilledException is passed through.
ValueError if example under default policy and ml policy does not match.
"""
if _KEEP_TEMPS.present:
tempdir_context = NonTemporaryDirectory(dir=_KEEP_TEMPS.value)
else:
tempdir_context = tempfile.TemporaryDirectory()
tempdir_context = get_workdir_context()

with tempdir_context as tempdir:
final_cmd_line = loaded_module_spec.build_command_line(tempdir)
Expand Down
17 changes: 13 additions & 4 deletions compiler_opt/rl/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
import contextlib
import io
import os
import tempfile
from typing import Callable, Generator, List, Optional, Tuple, Type

import numpy as np

from compiler_opt.rl import corpus
from compiler_opt.rl import log_reader
from compiler_opt.rl import compilation_runner


class StepType(Enum):
Expand All @@ -47,6 +47,7 @@ class TimeStep:
score_default: Optional[dict[str, float]]
context: Optional[str]
module_name: str
working_dir: str
obs_id: Optional[int]
step_type: StepType

Expand Down Expand Up @@ -115,10 +116,12 @@ class ClangProcess:
"""

def __init__(self, proc: subprocess.Popen,
get_scores_fn: Callable[[], dict[str, float]], module_name):
get_scores_fn: Callable[[], dict[str, float]], module_name: str,
working_dir: str):
self._proc = proc
self._get_scores_fn = get_scores_fn
self._module_name = module_name
self._working_dir = working_dir

def get_scores(self, timeout: Optional[int] = None):
self._proc.wait(timeout=timeout)
Expand All @@ -133,10 +136,11 @@ def __init__(
proc: subprocess.Popen,
get_scores_fn: Callable[[], dict[str, float]],
module_name: str,
working_dir: str,
reader_pipe: io.BufferedReader,
writer_pipe: io.BufferedWriter,
):
super().__init__(proc, get_scores_fn, module_name)
super().__init__(proc, get_scores_fn, module_name, working_dir)
self._reader_pipe = reader_pipe
self._writer_pipe = writer_pipe
self._obs_gen = log_reader.read_log_from_file(self._reader_pipe)
Expand All @@ -150,6 +154,7 @@ def __init__(
score_default=None,
context=None,
module_name=module_name,
working_dir=working_dir,
obs_id=None,
step_type=StepType.LAST,
)
Expand Down Expand Up @@ -180,6 +185,7 @@ def _get_step_type() -> StepType:
score_default=None,
context=obs.context,
module_name=self._module_name,
working_dir=self._working_dir,
obs_id=obs.observation_id,
step_type=_get_step_type(),
)
Expand Down Expand Up @@ -235,7 +241,8 @@ def clang_session(
Yields:
Either the constructed InteractiveClang or DefaultClang object.
"""
with tempfile.TemporaryDirectory() as td:
tempdir_context = compilation_runner.get_workdir_context()
with tempdir_context as td:
task_working_dir = os.path.join(td, '__task_working_dir__')
os.mkdir(task_working_dir)
task = task_type()
Expand Down Expand Up @@ -264,6 +271,7 @@ def _get_scores() -> dict[str, float]:
proc,
_get_scores,
module.name,
task_working_dir,
reader_pipe,
writer_pipe,
)
Expand All @@ -272,6 +280,7 @@ def _get_scores() -> dict[str, float]:
proc,
_get_scores,
module.name,
task_working_dir,
)

finally:
Expand Down
27 changes: 27 additions & 0 deletions compiler_opt/rl/env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
import ctypes
from unittest import mock
import subprocess
import os
import tempfile
from absl.testing import flagsaver

from typing import Dict, List, Optional

Expand Down Expand Up @@ -161,6 +164,30 @@ def test_interactive_clang_session(self, mock_popen):
self.assertEqual(obs.context, f'context_{idx}')
mock_popen.assert_called_once()

@mock.patch('subprocess.Popen')
def test_interactive_clang_temp_dir(self, mock_popen):
mock_popen.side_effect = mock_interactive_clang
working_dir = None

with env.clang_session(
_CLANG_PATH, _MOCK_MODULE, MockTask, interactive=True) as clang_session:
for _ in range(_NUM_STEPS):
obs = clang_session.get_observation()
working_dir = obs.working_dir
self.assertEqual(os.path.exists(working_dir), True)
self.assertEqual(os.path.exists(working_dir), False)

with tempfile.TemporaryDirectory() as td:
with flagsaver.flagsaver((env.compilation_runner._KEEP_TEMPS, td)): # pylint: disable=protected-access
with env.clang_session(
_CLANG_PATH, _MOCK_MODULE, MockTask,
interactive=True) as clang_session:
for _ in range(_NUM_STEPS):
obs = clang_session.get_observation()
working_dir = obs.working_dir
self.assertEqual(os.path.exists(working_dir), True)
self.assertEqual(os.path.exists(working_dir), True)


class MLGOEnvironmentTest(tf.test.TestCase):

Expand Down