Skip to content

Commit

Permalink
Support LocalEnv (#53)
Browse files Browse the repository at this point in the history
* Support LocalEnv

* resolve issues

* resolve issues

* resolve issues

* resolve issues

* resolve issues

* resolve issues
  • Loading branch information
SunsetWolf authored Jul 9, 2024
1 parent c85d11c commit 14de9cf
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 9 deletions.
48 changes: 42 additions & 6 deletions rdagent/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
"""
import os
from abc import abstractmethod
from pathlib import Path
from typing import Generic, TypeVar

import sys
import docker
import subprocess
from abc import abstractmethod
from pydantic import BaseModel
from typing import Generic, TypeVar, Optional, Dict
from pathlib import Path

ASpecificBaseModel = TypeVar("ASpecificBaseModel", bound=BaseModel)

Expand Down Expand Up @@ -62,15 +63,50 @@ def run(self, entry: str | None, local_path: str | None = None, env: dict | None


class LocalConf(BaseModel):
py_entry: str # where you can find your python path
py_bin: str
default_entry: str


class LocalEnv(Env[LocalConf]):
"""
Sometimes local environment may be more convinient for testing
"""
def prepare(self):
if not (Path("~/.qlib/qlib_data/cn_data").expanduser().resolve().exists()):
self.run(
entry="python -m qlib.run.get_data qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn",
)
else:
print("Data already exists. Download skipped.")

def run(self,
entry: str | None = None,
local_path: Optional[str] = None,
env: dict | None = None) -> str:
if env is None:
env = {}

if entry is None:
entry = self.conf.default_entry

command = str(Path(self.conf.py_bin).joinpath(entry)).split(" ")

cwd = None
if local_path:
cwd = Path(local_path).resolve()
print(command)
result = subprocess.run(
command,
cwd=cwd,
env={**os.environ, **env},
capture_output=True,
text=True
)

if result.returncode != 0:
raise RuntimeError(f"Error while running the command: {result.stderr}")

conf: LocalConf
return result.stdout


## Docker Environment -----
Expand Down
18 changes: 15 additions & 3 deletions test/utils/test_env.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import os
import subprocess
import sys
import unittest
from pathlib import Path

sys.path.append(str(Path(__file__).resolve().parent.parent))
from rdagent.utils.env import QTDockerEnv, LocalEnv, LocalConf
import shutil

from rdagent.utils.env import QTDockerEnv

DIRNAME = Path(__file__).absolute().resolve().parent

Expand All @@ -23,6 +21,20 @@ def tearDown(self):
# shutil.rmtree(mlrun_p)
...

# NOTE: Since I don't know the exact environment in which it will be used, here's just an example.
# NOTE: Because you need to download the data during the prepare process. So you need to have pyqlib in your environment.
# def test_local(self):
# local_conf = LocalConf(
# py_bin="/home/v-linlanglv/miniconda3/envs/RD-Agent-310/bin",
# default_entry="qrun conf.yaml",
# )
# qle = LocalEnv(conf=local_conf)
# qle.prepare()
# conf_path = str(DIRNAME / "env_tpl" / "conf.yaml")
# qle.run(entry="qrun " + conf_path)
# mlrun_p = DIRNAME / "env_tpl" / "mlruns"
# self.assertTrue(mlrun_p.exists(), f"Expected output file {mlrun_p} not found")

def test_docker(self):
"""
We will mount `env_tpl` into the docker image.
Expand Down

0 comments on commit 14de9cf

Please sign in to comment.