Skip to content

Commit

Permalink
fix measure
Browse files Browse the repository at this point in the history
  • Loading branch information
zhcm committed Jun 17, 2024
1 parent 560b400 commit 59bd277
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
2 changes: 1 addition & 1 deletion stereo/modeling/trainer_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def save_ckpt(self, current_epoch):
ckpt_list = glob.glob(os.path.join(self.args.ckpt_dir, 'checkpoint_epoch_*.pth'))
ckpt_list.sort(key=os.path.getmtime)
if len(ckpt_list) >= self.cfgs.TRAINER.MAX_CKPT_SAVE_NUM:
for cur_file_idx in range(0, len(ckpt_list) - self.args.max_ckpt_save_num + 1):
for cur_file_idx in range(0, len(ckpt_list) - self.cfgs.TRAINER.MAX_CKPT_SAVE_NUM + 1):
os.remove(ckpt_list[cur_file_idx])
ckpt_name = os.path.join(self.args.ckpt_dir, 'checkpoint_epoch_%d.pth' % current_epoch)
common_utils.save_checkpoint(self.model, self.optimizer, self.scheduler, self.scaler,
Expand Down
7 changes: 4 additions & 3 deletions tools/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,29 @@
import argparse
import sys
import thop
import numpy as np
from easydict import EasyDict
from tqdm import tqdm

sys.path.insert(0, './')
from stereo.utils import common_utils
from stereo.modeling import build_network
from stereo.modeling import build_trainer


def parse_config():
parser = argparse.ArgumentParser(description='arg parser')
parser.add_argument('--dist_mode', action='store_true', default=False, help='torchrun ddp multi gpu')
parser.add_argument('--cfg_file', type=str, default=None, help='specify the config for training')

args = parser.parse_args()
yaml_config = common_utils.config_loader(args.cfg_file)
cfgs = EasyDict(yaml_config)
args.run_mode = 'measure'
return args, cfgs


def main():
args, cfgs = parse_config()
model = build_network(model_cfg=cfgs.MODEL).cuda()
model = build_trainer(args, cfgs, local_rank=0, global_rank=0, logger=None, tb_writer=None).model

shape = [1, 3, 544, 960]
infer_time(model, shape)
Expand Down

0 comments on commit 59bd277

Please sign in to comment.