Skip to content

Commit

Permalink
Add cuda check for fp16 tests
Browse files Browse the repository at this point in the history
  • Loading branch information
hmorimitsu committed Jan 30, 2024
1 parent 30ec867 commit 466e91b
Showing 1 changed file with 30 additions and 29 deletions.
59 changes: 30 additions & 29 deletions tests/ptlflow/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,35 +100,36 @@ def test_forward() -> None:


def test_forward_fp16() -> None:
model_names = ptlflow.models_dict.keys()
for mname in model_names:
if mname in EXCLUDE_MODELS_FP16:
continue

print(mname)
model_ref = ptlflow.get_model_reference(mname)
parser = model_ref.add_model_specific_args()
args = parser.parse_args([])

if mname in MODEL_ARGS_FP16:
for name, val in MODEL_ARGS_FP16[mname].items():
setattr(args, name, val)

model = ptlflow.get_model(mname, args=args)
model = model.eval()
model = model.half()

s = make_divisible(256, model.output_stride)
num_images = 2
if mname in ["videoflow_bof", "videoflow_mof"]:
num_images = 3
inputs = {"images": torch.rand(1, num_images, 3, s, s)}

if torch.cuda.is_available():
model = model.cuda()
inputs["images"] = inputs["images"].cuda().half()

model(inputs)
if torch.cuda.is_available():
model_names = ptlflow.models_dict.keys()
for mname in model_names:
if mname in EXCLUDE_MODELS_FP16:
continue

print(mname)
model_ref = ptlflow.get_model_reference(mname)
parser = model_ref.add_model_specific_args()
args = parser.parse_args([])

if mname in MODEL_ARGS_FP16:
for name, val in MODEL_ARGS_FP16[mname].items():
setattr(args, name, val)

model = ptlflow.get_model(mname, args=args)
model = model.eval()
model = model.half()

s = make_divisible(256, model.output_stride)
num_images = 2
if mname in ["videoflow_bof", "videoflow_mof"]:
num_images = 3
inputs = {"images": torch.rand(1, num_images, 3, s, s)}

if torch.cuda.is_available():
model = model.cuda()
inputs["images"] = inputs["images"].cuda().half()

model(inputs)


@pytest.mark.skip(
Expand Down

0 comments on commit 466e91b

Please sign in to comment.