diff --git a/utilization/chat_templates.py b/utilization/chat_templates.py index abec6411..aba8727b 100644 --- a/utilization/chat_templates.py +++ b/utilization/chat_templates.py @@ -5,9 +5,9 @@ def add_space( msg: str, - auto_leading_space: bool, - remove_space_between: bool, context: str, + auto_leading_space: bool = True, + remove_space_between: bool = True, starts: Optional[List[str]] = None, ends: Optional[List[str]] = None ) -> str: @@ -15,8 +15,8 @@ def add_space( context_ends_special = False msg_starts_special = False else: - context_ends_special = any(context.endswith(e) for e in ends) - msg_starts_special = any(msg.startswith(s) for s in starts) + context_ends_special = any(context.endswith(e) for e in ends if len(e) > 0) + msg_starts_special = any(msg.startswith(s) for s in starts if len(s) > 0) if (auto_leading_space and msg and context)\ and not (context[-1].isspace() or msg[0].isspace())\ and not (context_ends_special and msg_starts_special): @@ -30,7 +30,14 @@ def smart_space(parts: List[str], auto_leading_space: bool, remove_space_between rendered = "" for part in parts: if part: - rendered += add_space(part, auto_leading_space, remove_space_between, rendered, starts, ends) + rendered += add_space( + part, + rendered, + auto_leading_space=auto_leading_space, + remove_space_between=remove_space_between, + starts=starts, + ends=ends + ) return rendered diff --git a/utilization/model/model_utils/conversation.py b/utilization/model/model_utils/conversation.py index 55efe4b6..94c34ed0 100644 --- a/utilization/model/model_utils/conversation.py +++ b/utilization/model/model_utils/conversation.py @@ -159,7 +159,7 @@ def _get_segs(self, conversations: List["Conversation"], max_turns: int = 1) -> for seg in (system, examples, source, target): if len(seg) > 0: if len(result) > 0: - seg = add_space(seg, True, result[-1]) + seg = add_space(seg, result[-1]) elif self.final_lstrip: seg = seg.lstrip() result += (seg,) diff --git a/utilization/utils/generation_args.py b/utilization/utils/generation_args.py index 78493873..3de81a8a 100644 --- a/utilization/utils/generation_args.py +++ b/utilization/utils/generation_args.py @@ -98,6 +98,8 @@ def resolve_generation_args( # overrides if key in extra_generation_args: extra = extra_generation_args.pop(key) + if value is None and not details.nullable: + continue if callable(extra): overrided = extra(value, details) for new_key, new_value in overrided.items():