diff --git a/src/huggingface_hub/hub_mixin.py b/src/huggingface_hub/hub_mixin.py index 2b9a895602..75715dfed8 100644 --- a/src/huggingface_hub/hub_mixin.py +++ b/src/huggingface_hub/hub_mixin.py @@ -230,12 +230,21 @@ def __init_subclass__( tags.append("model_hub_mixin") # Initialize MixinInfo if not existent - if not hasattr(cls, "_hub_mixin_info"): - cls._hub_mixin_info = MixinInfo( - model_card_template=model_card_template, - model_card_data=ModelCardData(), - ) - info = cls._hub_mixin_info + info = MixinInfo(model_card_template=model_card_template, model_card_data=ModelCardData()) + + # If parent class has a MixinInfo, inherit from it as a copy + if hasattr(cls, "_hub_mixin_info"): + # Inherit model card template from parent class if not explicitly set + if model_card_template == DEFAULT_MODEL_CARD: + info.model_card_template = cls._hub_mixin_info.model_card_template + + # Inherit from parent model card data + info.model_card_data = ModelCardData(**cls._hub_mixin_info.model_card_data.to_dict()) + + # Inherit other info + info.docs_url = cls._hub_mixin_info.docs_url + info.repo_url = cls._hub_mixin_info.repo_url + cls._hub_mixin_info = info if languages is not None: warnings.warn( @@ -269,6 +278,8 @@ def __init_subclass__( else: info.model_card_data.tags = tags + info.model_card_data.tags = sorted(set(info.model_card_data.tags)) + # Handle encoders/decoders for args cls._hub_mixin_coders = coders or {} cls._hub_mixin_jsonable_custom_types = tuple(cls._hub_mixin_coders.keys()) diff --git a/tests/test_hub_mixin_pytorch.py b/tests/test_hub_mixin_pytorch.py index 09f4a67b47..b9646b9f4f 100644 --- a/tests/test_hub_mixin_pytorch.py +++ b/tests/test_hub_mixin_pytorch.py @@ -111,12 +111,20 @@ def __init__(self, config: Namespace): super().__init__() self.config = config + class DummyModelWithTag1(nn.Module, PyTorchModelHubMixin, tags=["tag1"]): + """Used to test tags not shared between sibling classes (only inheritance).""" + + class DummyModelWithTag2(nn.Module, PyTorchModelHubMixin, tags=["tag2"]): + """Used to test tags not shared between sibling classes (only inheritance).""" + else: DummyModel = None DummyModelWithModelCard = None DummyModelNoConfig = None DummyModelWithConfigAndKwargs = None DummyModelWithModelCardAndCustomKwargs = None + DummyModelWithTag1 = None + DummyModelWithTag2 = None @requires("torch") @@ -451,3 +459,21 @@ def test_config_with_custom_coders(self): assert isinstance(reloaded.config, Namespace) assert reloaded.config.a == 1 assert reloaded.config.b == 2 + + def test_inheritance_and_sibling_classes(self): + """ + Test tags are not shared between sibling classes. + + Regression test for #2394. + See https://github.com/huggingface/huggingface_hub/pull/2394. + """ + assert DummyModelWithTag1._hub_mixin_info.model_card_data.tags == [ + "model_hub_mixin", + "pytorch_model_hub_mixin", + "tag1", + ] + assert DummyModelWithTag2._hub_mixin_info.model_card_data.tags == [ + "model_hub_mixin", + "pytorch_model_hub_mixin", + "tag2", + ]