Skip to content

Commit

Permalink
Hot-fix: do not share tags between ModelHubMixin siblings (#2394)
Browse files Browse the repository at this point in the history
* Hot-fix: do not share tags between ModelHubMixin sibligs

* reference regression test
  • Loading branch information
Wauplin authored Jul 16, 2024
1 parent de160a4 commit 05fdb76
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 6 deletions.
23 changes: 17 additions & 6 deletions src/huggingface_hub/hub_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,12 +230,21 @@ def __init_subclass__(
tags.append("model_hub_mixin")

# Initialize MixinInfo if not existent
if not hasattr(cls, "_hub_mixin_info"):
cls._hub_mixin_info = MixinInfo(
model_card_template=model_card_template,
model_card_data=ModelCardData(),
)
info = cls._hub_mixin_info
info = MixinInfo(model_card_template=model_card_template, model_card_data=ModelCardData())

# If parent class has a MixinInfo, inherit from it as a copy
if hasattr(cls, "_hub_mixin_info"):
# Inherit model card template from parent class if not explicitly set
if model_card_template == DEFAULT_MODEL_CARD:
info.model_card_template = cls._hub_mixin_info.model_card_template

# Inherit from parent model card data
info.model_card_data = ModelCardData(**cls._hub_mixin_info.model_card_data.to_dict())

# Inherit other info
info.docs_url = cls._hub_mixin_info.docs_url
info.repo_url = cls._hub_mixin_info.repo_url
cls._hub_mixin_info = info

if languages is not None:
warnings.warn(
Expand Down Expand Up @@ -269,6 +278,8 @@ def __init_subclass__(
else:
info.model_card_data.tags = tags

info.model_card_data.tags = sorted(set(info.model_card_data.tags))

# Handle encoders/decoders for args
cls._hub_mixin_coders = coders or {}
cls._hub_mixin_jsonable_custom_types = tuple(cls._hub_mixin_coders.keys())
Expand Down
26 changes: 26 additions & 0 deletions tests/test_hub_mixin_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,20 @@ def __init__(self, config: Namespace):
super().__init__()
self.config = config

class DummyModelWithTag1(nn.Module, PyTorchModelHubMixin, tags=["tag1"]):
"""Used to test tags not shared between sibling classes (only inheritance)."""

class DummyModelWithTag2(nn.Module, PyTorchModelHubMixin, tags=["tag2"]):
"""Used to test tags not shared between sibling classes (only inheritance)."""

else:
DummyModel = None
DummyModelWithModelCard = None
DummyModelNoConfig = None
DummyModelWithConfigAndKwargs = None
DummyModelWithModelCardAndCustomKwargs = None
DummyModelWithTag1 = None
DummyModelWithTag2 = None


@requires("torch")
Expand Down Expand Up @@ -451,3 +459,21 @@ def test_config_with_custom_coders(self):
assert isinstance(reloaded.config, Namespace)
assert reloaded.config.a == 1
assert reloaded.config.b == 2

def test_inheritance_and_sibling_classes(self):
"""
Test tags are not shared between sibling classes.
Regression test for #2394.
See https://github.com/huggingface/huggingface_hub/pull/2394.
"""
assert DummyModelWithTag1._hub_mixin_info.model_card_data.tags == [
"model_hub_mixin",
"pytorch_model_hub_mixin",
"tag1",
]
assert DummyModelWithTag2._hub_mixin_info.model_card_data.tags == [
"model_hub_mixin",
"pytorch_model_hub_mixin",
"tag2",
]

0 comments on commit 05fdb76

Please sign in to comment.