diff --git a/tests/test_transformer.py b/tests/test_transformer.py index 024c272b..e138f9fb 100644 --- a/tests/test_transformer.py +++ b/tests/test_transformer.py @@ -53,6 +53,7 @@ def test_patch_and_unpatch(patch_size, batch_size, C, H, W): assert torch.allclose(image_recon[i], image[i], atol=1e-6) +@pytest.mark.gpu def test_t2i_transformer_forward(): # fp16 vae does not run on cpu device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') @@ -75,6 +76,7 @@ def test_t2i_transformer_forward(): assert outputs['targets'].shape == output_shape +@pytest.mark.gpu @pytest.mark.parametrize('guidance_scale', [0.0, 3.0]) @pytest.mark.parametrize('negative_prompt', [None, 'so cool']) def test_t2i_transformer_generate(guidance_scale, negative_prompt): @@ -84,8 +86,8 @@ def test_t2i_transformer_generate(guidance_scale, negative_prompt): negative_prompt=negative_prompt, num_inference_steps=1, num_images_per_prompt=1, - height=32, - width=32, + height=64, + width=64, guidance_scale=guidance_scale, progress_bar=False, )