Skip to content

Commit

Permalink
Fix model type detection in tests
Browse files Browse the repository at this point in the history
Co-authored-by: Alexander Borzunov <[email protected]>
  • Loading branch information
mryab and borzunov committed Oct 8, 2023
1 parent ea1239d commit 400685b
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions tests/test_optimized_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,6 @@ def _reorder_cache_from_llama_to_bloom(
return (key_states, value_states)


@pytest.mark.skipif(
all(model_name not in MODEL_NAME for model_name in ("falcon", "llama")),
reason="This test is applicable only to Falcon and LLaMa models",
)
@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
@pytest.mark.forked
def test_falcon(device):
Expand All @@ -196,10 +192,12 @@ def test_falcon(device):
block = config.block_class(config).to(dtype)
block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)

if "falcon" in MODEL_NAME:
if config.model_type == "falcon":
unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype)
elif "llama" in MODEL_NAME:
elif config.model_type == "llama":
unopt_block = UnoptimizedWrappedLlamaBlock(config).to(dtype)
else:
pytest.skip(f"This test is not applicable to {config.model_type} models")

unopt_block = convert_block(
unopt_block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True
Expand Down

0 comments on commit 400685b

Please sign in to comment.