-
Notifications
You must be signed in to change notification settings - Fork 2
/
test.py
85 lines (67 loc) · 2.54 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import pytest
import time
import os
import sys
import shutil
import subprocess
import numpy as np
import lib
from loguru import logger
from pathlib import Path
def test_datasets_dtype():
for d in lib.DATA_DIR.iterdir():
if (d/'X_num.npy').exists():
assert np.load(d/'X_num.npy').dtype == np.float32
if (d/'X_bin.npy').exists():
assert np.load(d/'X_bin.npy').dtype == np.float32
if (d/'X_cat.npy').exists():
assert np.load(d/'X_cat.npy').dtype == np.int64
info = lib.load_json(d/'info.json')
task_type = lib.TaskType(info['task_type'])
if task_type in (lib.TaskType.BINCLASS, lib.TaskType.MULTICLASS):
assert np.load(d/'Y.npy').dtype == np.int64
else:
assert np.load(d/'Y.npy').dtype == np.float32
def test_all_runs_start_successfull(tmp_path: Path):
print()
for tuning_config in lib.EXP_DIR.glob('**/tuning.toml'):
if tuning_config.parent.name in [
'cooking-time',
'delivery-eta',
'homesite-insurance',
'maps-routing',
'weather',
]:
continue
# if tuning_config.parent.parent.name in [
# "xgboost_",
# "catboost_",
# "lightgbm_",
# "mlp",
# "mlp-plr",
# "resnet",
# "snn",
# "dcn2",
# "ft_transformer",
# ]:
# continue
# All algorithms are using cuda devices
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
shutil.copy(tuning_config, tmp_path/tuning_config.name)
process = subprocess.Popen(f"python bin/go.py {str(tmp_path)}/tuning.toml --force".split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if tuning_config.parent.parent.name in ['xgboost_', 'catboost_', 'lightgbm_']:
# All boostings start training with this
wait_on = 'training...'
else:
# This appears in nn logs
wait_on = 'new best epoch!'
stop = False
while not stop:
if process.stdout is not None:
s = process.stdout.readline().decode().lower()
stop = wait_on.lower() in s
if stop:
logger.info(f'{tuning_config.relative_to(lib.EXP_DIR)} OK, killing')
process.kill()
process.poll()
assert process.returncode is None or process.returncode == 0, f"{tuning_config.name} fails {process.stderr.read()}"