|
1 | 1 | import itertools
|
2 | 2 | import random
|
3 | 3 | from typing import Dict, List, Optional, Tuple
|
4 |
| -from unittest.mock import patch |
| 4 | +from unittest.mock import Mock, patch |
5 | 5 |
|
6 | 6 | import pytest
|
7 | 7 | import torch
|
@@ -703,3 +703,28 @@ def test_sampling_params(sampling_params: List[SamplingParams]):
|
703 | 703 |
|
704 | 704 | assert tokens1[0] == tokens2[1]
|
705 | 705 | assert tokens1[1] == tokens2[0]
|
| 706 | + |
| 707 | + |
| 708 | +@pytest.mark.parametrize("device", CUDA_DEVICES) |
| 709 | +def test_sampler_include_gpu_probs_tensor(device: str): |
| 710 | + set_random_seed(42) |
| 711 | + torch.set_default_device(device) |
| 712 | + batch_size = random.randint(1, 256) |
| 713 | + _, fake_logits, sampler = _prepare_test(batch_size) |
| 714 | + sampler.include_gpu_probs_tensor = True |
| 715 | + sampler.should_modify_greedy_probs_inplace = False |
| 716 | + |
| 717 | + sampling_params = SamplingParams(temperature=0) |
| 718 | + |
| 719 | + mock_inplace = Mock() |
| 720 | + with patch( |
| 721 | + "vllm.model_executor.layers.sampler._modify_greedy_probs_inplace", |
| 722 | + mock_inplace): |
| 723 | + |
| 724 | + sampler_output = _do_sample(batch_size, fake_logits, sampler, |
| 725 | + sampling_params, device) |
| 726 | + mock_inplace.assert_not_called() |
| 727 | + |
| 728 | + assert sampler_output.sampled_token_probs is not None |
| 729 | + assert sampler_output.logprobs is not None |
| 730 | + assert sampler_output.sampled_token_ids is not None |
0 commit comments