diff --git a/compiler_opt/rl/compilation_runner.py b/compiler_opt/rl/compilation_runner.py index a3f5c0b4..26231244 100644 --- a/compiler_opt/rl/compilation_runner.py +++ b/compiler_opt/rl/compilation_runner.py @@ -80,6 +80,19 @@ def __exit__(self, exc, value, tb): pass +def get_workdir_context(): + """Return a context which manages how the temperory directories are handled. + + When the flag keep_temps is specified temporary directories are stored in + keep_temps. + """ + if _KEEP_TEMPS.value is not None: + tempdir_context = NonTemporaryDirectory(dir=_KEEP_TEMPS.value) + else: + tempdir_context = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with + return tempdir_context + + def _overwrite_trajectory_reward(sequence_example: tf.train.SequenceExample, reward: float) -> tf.train.SequenceExample: """Overwrite the reward in the trace (sequence_example) with the given one. @@ -401,10 +414,7 @@ def collect_data(self, compilation_runner.ProcessKilledException is passed through. ValueError if example under default policy and ml policy does not match. """ - if _KEEP_TEMPS.present: - tempdir_context = NonTemporaryDirectory(dir=_KEEP_TEMPS.value) - else: - tempdir_context = tempfile.TemporaryDirectory() + tempdir_context = get_workdir_context() with tempdir_context as tempdir: final_cmd_line = loaded_module_spec.build_command_line(tempdir) diff --git a/compiler_opt/rl/env.py b/compiler_opt/rl/env.py index 904fd388..0b40f1b9 100644 --- a/compiler_opt/rl/env.py +++ b/compiler_opt/rl/env.py @@ -24,13 +24,13 @@ import contextlib import io import os -import tempfile from typing import Callable, Generator, List, Optional, Tuple, Type import numpy as np from compiler_opt.rl import corpus from compiler_opt.rl import log_reader +from compiler_opt.rl import compilation_runner class StepType(Enum): @@ -47,6 +47,7 @@ class TimeStep: score_default: Optional[dict[str, float]] context: Optional[str] module_name: str + working_dir: str obs_id: Optional[int] step_type: StepType @@ -115,10 +116,12 @@ class ClangProcess: """ def __init__(self, proc: subprocess.Popen, - get_scores_fn: Callable[[], dict[str, float]], module_name): + get_scores_fn: Callable[[], dict[str, float]], module_name: str, + working_dir: str): self._proc = proc self._get_scores_fn = get_scores_fn self._module_name = module_name + self._working_dir = working_dir def get_scores(self, timeout: Optional[int] = None): self._proc.wait(timeout=timeout) @@ -133,10 +136,11 @@ def __init__( proc: subprocess.Popen, get_scores_fn: Callable[[], dict[str, float]], module_name: str, + working_dir: str, reader_pipe: io.BufferedReader, writer_pipe: io.BufferedWriter, ): - super().__init__(proc, get_scores_fn, module_name) + super().__init__(proc, get_scores_fn, module_name, working_dir) self._reader_pipe = reader_pipe self._writer_pipe = writer_pipe self._obs_gen = log_reader.read_log_from_file(self._reader_pipe) @@ -150,6 +154,7 @@ def __init__( score_default=None, context=None, module_name=module_name, + working_dir=working_dir, obs_id=None, step_type=StepType.LAST, ) @@ -180,6 +185,7 @@ def _get_step_type() -> StepType: score_default=None, context=obs.context, module_name=self._module_name, + working_dir=self._working_dir, obs_id=obs.observation_id, step_type=_get_step_type(), ) @@ -235,7 +241,8 @@ def clang_session( Yields: Either the constructed InteractiveClang or DefaultClang object. """ - with tempfile.TemporaryDirectory() as td: + tempdir_context = compilation_runner.get_workdir_context() + with tempdir_context as td: task_working_dir = os.path.join(td, '__task_working_dir__') os.mkdir(task_working_dir) task = task_type() @@ -264,6 +271,7 @@ def _get_scores() -> dict[str, float]: proc, _get_scores, module.name, + task_working_dir, reader_pipe, writer_pipe, ) @@ -272,6 +280,7 @@ def _get_scores() -> dict[str, float]: proc, _get_scores, module.name, + task_working_dir, ) finally: diff --git a/compiler_opt/rl/env_test.py b/compiler_opt/rl/env_test.py index 87577b3e..f6d3c63b 100644 --- a/compiler_opt/rl/env_test.py +++ b/compiler_opt/rl/env_test.py @@ -19,6 +19,9 @@ import ctypes from unittest import mock import subprocess +import os +import tempfile +from absl.testing import flagsaver from typing import Dict, List, Optional @@ -161,6 +164,30 @@ def test_interactive_clang_session(self, mock_popen): self.assertEqual(obs.context, f'context_{idx}') mock_popen.assert_called_once() + @mock.patch('subprocess.Popen') + def test_interactive_clang_temp_dir(self, mock_popen): + mock_popen.side_effect = mock_interactive_clang + working_dir = None + + with env.clang_session( + _CLANG_PATH, _MOCK_MODULE, MockTask, interactive=True) as clang_session: + for _ in range(_NUM_STEPS): + obs = clang_session.get_observation() + working_dir = obs.working_dir + self.assertEqual(os.path.exists(working_dir), True) + self.assertEqual(os.path.exists(working_dir), False) + + with tempfile.TemporaryDirectory() as td: + with flagsaver.flagsaver((env.compilation_runner._KEEP_TEMPS, td)): # pylint: disable=protected-access + with env.clang_session( + _CLANG_PATH, _MOCK_MODULE, MockTask, + interactive=True) as clang_session: + for _ in range(_NUM_STEPS): + obs = clang_session.get_observation() + working_dir = obs.working_dir + self.assertEqual(os.path.exists(working_dir), True) + self.assertEqual(os.path.exists(working_dir), True) + class MLGOEnvironmentTest(tf.test.TestCase):