Skip to content

Commit af2c5bf

Browse files
committed
Split plugin flatten into two functions to fix type issues
1 parent cff18a8 commit af2c5bf

File tree

2 files changed

+19
-15
lines changed

2 files changed

+19
-15
lines changed

src/hexdoc/cli/utils/info.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
Properties,
1313
)
1414
from hexdoc.plugin import ModPlugin, PluginManager
15-
from hexdoc.plugin.manager import flatten, import_package
15+
from hexdoc.plugin.manager import flatten_hook_return, import_package
1616

1717
logger = logging.getLogger(__name__)
1818

@@ -53,7 +53,7 @@ def _plugins(pm: PluginManager):
5353

5454

5555
def _jinja_template_roots(mod_plugin: ModPlugin):
56-
for package, folder in flatten([mod_plugin.jinja_template_root() or []]):
56+
for package, folder in flatten_hook_return(mod_plugin.jinja_template_root()):
5757
module_path = _get_package_path(package)
5858
folder_path = module_path / folder
5959
yield _relative_path(folder_path)

src/hexdoc/plugin/manager.py

+17-13
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from .book_plugin import BookPlugin
3636
from .mod_plugin import DefaultRenderedTemplates, ModPlugin, ModPluginWithBook
3737
from .specs import HEXDOC_PROJECT_NAME, PluginSpec
38-
from .types import HookReturns
38+
from .types import HookReturn, HookReturns
3939

4040
_T = TypeVar("_T")
4141

@@ -117,19 +117,19 @@ def init_plugins(self):
117117

118118
def _init_book_plugins(self):
119119
caller = self._hook_caller(PluginSpec.hexdoc_book_plugin)
120-
for plugin in flatten(caller.try_call()):
120+
for plugin in flatten_hook_returns(caller.try_call()):
121121
self.book_plugins[plugin.modid] = plugin
122122

123123
def _init_mod_plugins(self):
124124
caller = self._hook_caller(PluginSpec.hexdoc_mod_plugin)
125-
for plugin in flatten(
125+
for plugin in flatten_hook_returns(
126126
caller.try_call(
127127
branch=self.branch,
128128
props=self.props,
129129
)
130130
):
131131
self.mod_plugins[plugin.modid] = plugin
132-
self.item_image_types += flatten([plugin.item_image_types()])
132+
self.item_image_types += flatten_hook_return(plugin.item_image_types())
133133

134134
def register(self, plugin: Any, name: str | None = None):
135135
self.inner.register(plugin, name)
@@ -207,7 +207,7 @@ def validate_format_tree(
207207
def update_context(self, context: dict[str, Any]) -> Iterator[ValidationContext]:
208208
caller = self._hook_caller(PluginSpec.hexdoc_update_context)
209209
if returns := caller.try_call(context=context):
210-
yield from flatten(returns)
210+
yield from flatten_hook_returns(returns)
211211

212212
def update_jinja_env(self, env: SandboxedEnvironment, modids: Sequence[str]):
213213
for modid in modids:
@@ -222,7 +222,7 @@ def update_template_args(self, template_args: dict[str, Any]):
222222

223223
def load_resources(self, modid: str) -> Iterator[ModuleType]:
224224
plugin = self.mod_plugin(modid)
225-
for package in flatten([plugin.resource_dirs()]):
225+
for package in flatten_hook_return(plugin.resource_dirs()):
226226
yield import_package(package)
227227

228228
def load_tagged_unions(self) -> Iterator[ModuleType]:
@@ -253,7 +253,7 @@ def _package_loaders_for(self, modids: Iterable[str]):
253253
package_name=import_package(package).__name__,
254254
package_path=package_path,
255255
)
256-
for package, package_path in flatten([result])
256+
for package, package_path in flatten_hook_return(result)
257257
]
258258
)
259259

@@ -289,7 +289,7 @@ def _import_from_hook(
289289
**kwargs: _P.kwargs,
290290
) -> Iterator[ModuleType]:
291291
packages = self._hook_caller(__spec)(*args, **kwargs)
292-
for package in flatten(packages):
292+
for package in flatten_hook_returns(packages):
293293
yield import_package(package)
294294

295295
@overload
@@ -305,12 +305,16 @@ def _hook_caller(self, spec: Callable[_P, _R | None]) -> TypedHookCaller[_P, _R]
305305
return TypedHookCaller(None, caller)
306306

307307

308-
def flatten(values: list[list[_T] | _T] | None) -> Iterator[_T]:
308+
def flatten_hook_returns(values: HookReturns[_T] | None) -> Iterator[_T]:
309309
for value in values or []:
310-
if isinstance(value, list):
311-
yield from value
312-
else:
313-
yield value
310+
yield from flatten_hook_return(value)
311+
312+
313+
def flatten_hook_return(values: HookReturn[_T] | None) -> Iterator[_T]:
314+
if isinstance(values, list):
315+
yield from values
316+
else:
317+
yield values # type: ignore
314318

315319

316320
def import_package(package: Package) -> ModuleType:

0 commit comments

Comments
 (0)