-
Notifications
You must be signed in to change notification settings - Fork 2
/
main.py
31 lines (25 loc) · 991 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from utils.utils import get_config, get_evaluator, get_guidance, get_network
from pipeline import BasePipeline
import torch
import logger
if __name__ == '__main__':
# Please tsee utils/config.py for the complete argument lists
args = get_config()
## prepare core modules based on configs ##
# Unconditional generative model
network = get_network(args)
# guidance method encoded by prediction model
guider = get_guidance(args, network)
# evaluator for generated samples
try:
evaluator = get_evaluator(args)
except NotImplementedError:
evaluator = None
pipeline = BasePipeline(args, network, guider, evaluator)
samples = pipeline.sample(args.num_samples)
logger.log_samples(samples)
# release torch occupied gpu memory
torch.cuda.empty_cache()
metrics = evaluator.evaluate(samples)
if metrics is not None: # avoid rewriting metrics to json
logger.log_metrics(metrics, save_json=True)