diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 58723e51..68f12aa4 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -20,6 +20,8 @@ Added - Experimental support for sub-classing ``ArgumentParser`` to customize ``add_argument`` (`#661 `__). +- Support for partialmethods (`#665 + `__). Fixed ^^^^^ diff --git a/jsonargparse/_parameter_resolvers.py b/jsonargparse/_parameter_resolvers.py index 5fb98bd1..1e7679b0 100644 --- a/jsonargparse/_parameter_resolvers.py +++ b/jsonargparse/_parameter_resolvers.py @@ -8,7 +8,7 @@ from contextlib import contextmanager, suppress from contextvars import ContextVar from copy import deepcopy -from functools import partial +from functools import partial, partialmethod from importlib import import_module from types import MethodType from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union @@ -102,6 +102,10 @@ def is_method(attr) -> bool: ) +def is_partial_method(attr) -> bool: + return isinstance(attr, partialmethod) + + def is_property(attr) -> bool: return isinstance(attr, property) @@ -509,6 +513,8 @@ def get_component_and_parent( component = getattr(function_or_class, "__new__") elif is_method(attr): component = attr + elif is_partial_method(attr): + component = getattr(function_or_class, method_or_property) elif is_property(attr): component = attr.fget elif isinstance(attr, classmethod): diff --git a/jsonargparse_tests/test_parameter_resolvers.py b/jsonargparse_tests/test_parameter_resolvers.py index 368e58fe..97229fda 100644 --- a/jsonargparse_tests/test_parameter_resolvers.py +++ b/jsonargparse_tests/test_parameter_resolvers.py @@ -4,6 +4,7 @@ import inspect import xml.dom from calendar import Calendar +from functools import partialmethod from random import shuffle from typing import Any, Callable, Dict, List, Optional, Union from unittest.mock import patch @@ -33,6 +34,8 @@ def method_a(self, pma1: int, pma2: float, kma1: str = "x"): kma1: help for kma1 """ + partial_method_a = partialmethod(method_a, pma1=1, pma2=0.5) + class ClassB(ClassA): def __init__(self, pkb1: str, kb1: int = 3, kb2: str = "4", **kwargs): @@ -899,6 +902,13 @@ def test_get_params_some_ignored(): assert_params(get_params(func_given_kwargs), ["p", "p1"], help=False) +def test_partialmethod(): + ClassA.partial_method_a = partialmethod(ClassA.method_a, pma1=1, pma2=0.5) + assert_params(get_params(ClassA, "partial_method_a"), ["pma1", "pma2", "kma1"]) + with source_unavailable(): + assert_params(get_params(ClassA, "partial_method_a"), ["pma1", "pma2", "kma1"]) + + # unsupported cases