-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun.py
45 lines (34 loc) · 1.24 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import argparse
import json
import logging
# For old version of tensorflow and rdkit
# if you don't use tensorflow and kgcn, please comment out this line
import tensorflow as tf
from rdkit import RDLogger
from lib.calculators import CalculatorFactory
from lib.config import Config
from lib.data_providers import MoleculeLoader
from lib.filters import FilterFactory
from lib.helpers import Sketcher
from lib.models import MonteCarloTreeSearch
RDLogger.logger().setLevel(RDLogger.CRITICAL)
parser = argparse.ArgumentParser()
parser.add_argument("config", type=str)
args = parser.parse_args()
config = Config.load(args.config)
logging.basicConfig(format="%(message)s", level=config.logging)
molecule_loader = MoleculeLoader(file_path=config.dataset, threshold=config.threshold)
reward_calculator = CalculatorFactory.create(
config.reward_calculator, config.reward_weights, config
)
filters = [FilterFactory.create(filter_) for filter_ in config.filters]
model = MonteCarloTreeSearch(
data_provider=molecule_loader,
calculator=reward_calculator,
filters=filters,
config=config,
)
for molecules in model.start(config.generate, config.monte_carlo_iterations):
if molecules is None:
continue
print(json.dumps(molecules, indent=4))