diff --git a/optimum/tpu/generation/logits_process.py b/optimum/tpu/generation/logits_process.py index 2e40c67f..9e8bd088 100644 --- a/optimum/tpu/generation/logits_process.py +++ b/optimum/tpu/generation/logits_process.py @@ -48,8 +48,6 @@ def from_config(cls, generation_config: GenerationConfig) -> "FusedLogitsWarper" Returns: a `FusedLogitsWarper` or None if neither top-k nor top-p are configured. """ - if generation_config.do_sample and generation_config.top_k == 0 and generation_config.top_p == 1.0: - raise ValueError("Multinomial sampling requires at least top-k or top-p to be specified.") return cls(generation_config.temperature, generation_config.top_k, generation_config.top_p) def __call__(self, logits: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.LongTensor]: @@ -59,9 +57,6 @@ def __call__(self, logits: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch. do_top_k = self.top_k > 0 and self.top_k < logits.shape[-1] do_top_p = self.top_p < 1.0 and self.top_p > 0.0 - if not do_top_k and not do_top_p: - return logits, None - if do_top_k: sorted_logits, sorted_indices = torch.topk(logits, self.top_k) else: diff --git a/optimum/tpu/version.py b/optimum/tpu/version.py index 4bd3c7ec..6078244a 100644 --- a/optimum/tpu/version.py +++ b/optimum/tpu/version.py @@ -15,5 +15,5 @@ from pkg_resources import parse_version -__version__ = "0.1.3" +__version__ = "0.1.4" VERSION = parse_version(__version__) diff --git a/text-generation-inference/server/text_generation_server/version.py b/text-generation-inference/server/text_generation_server/version.py index 9ee11539..16913b8c 100644 --- a/text-generation-inference/server/text_generation_server/version.py +++ b/text-generation-inference/server/text_generation_server/version.py @@ -1,5 +1,5 @@ from pkg_resources import parse_version -__version__ = "0.1.3" +__version__ = "0.1.4" VERSION = parse_version(__version__)