35
35
from .book_plugin import BookPlugin
36
36
from .mod_plugin import DefaultRenderedTemplates , ModPlugin , ModPluginWithBook
37
37
from .specs import HEXDOC_PROJECT_NAME , PluginSpec
38
- from .types import HookReturns
38
+ from .types import HookReturn , HookReturns
39
39
40
40
_T = TypeVar ("_T" )
41
41
@@ -117,19 +117,19 @@ def init_plugins(self):
117
117
118
118
def _init_book_plugins (self ):
119
119
caller = self ._hook_caller (PluginSpec .hexdoc_book_plugin )
120
- for plugin in flatten (caller .try_call ()):
120
+ for plugin in flatten_hook_returns (caller .try_call ()):
121
121
self .book_plugins [plugin .modid ] = plugin
122
122
123
123
def _init_mod_plugins (self ):
124
124
caller = self ._hook_caller (PluginSpec .hexdoc_mod_plugin )
125
- for plugin in flatten (
125
+ for plugin in flatten_hook_returns (
126
126
caller .try_call (
127
127
branch = self .branch ,
128
128
props = self .props ,
129
129
)
130
130
):
131
131
self .mod_plugins [plugin .modid ] = plugin
132
- self .item_image_types += flatten ([ plugin .item_image_types ()] )
132
+ self .item_image_types += flatten_hook_return ( plugin .item_image_types ())
133
133
134
134
def register (self , plugin : Any , name : str | None = None ):
135
135
self .inner .register (plugin , name )
@@ -207,7 +207,7 @@ def validate_format_tree(
207
207
def update_context (self , context : dict [str , Any ]) -> Iterator [ValidationContext ]:
208
208
caller = self ._hook_caller (PluginSpec .hexdoc_update_context )
209
209
if returns := caller .try_call (context = context ):
210
- yield from flatten (returns )
210
+ yield from flatten_hook_returns (returns )
211
211
212
212
def update_jinja_env (self , env : SandboxedEnvironment , modids : Sequence [str ]):
213
213
for modid in modids :
@@ -222,7 +222,7 @@ def update_template_args(self, template_args: dict[str, Any]):
222
222
223
223
def load_resources (self , modid : str ) -> Iterator [ModuleType ]:
224
224
plugin = self .mod_plugin (modid )
225
- for package in flatten ([ plugin .resource_dirs ()] ):
225
+ for package in flatten_hook_return ( plugin .resource_dirs ()):
226
226
yield import_package (package )
227
227
228
228
def load_tagged_unions (self ) -> Iterator [ModuleType ]:
@@ -253,7 +253,7 @@ def _package_loaders_for(self, modids: Iterable[str]):
253
253
package_name = import_package (package ).__name__ ,
254
254
package_path = package_path ,
255
255
)
256
- for package , package_path in flatten ([ result ] )
256
+ for package , package_path in flatten_hook_return ( result )
257
257
]
258
258
)
259
259
@@ -289,7 +289,7 @@ def _import_from_hook(
289
289
** kwargs : _P .kwargs ,
290
290
) -> Iterator [ModuleType ]:
291
291
packages = self ._hook_caller (__spec )(* args , ** kwargs )
292
- for package in flatten (packages ):
292
+ for package in flatten_hook_returns (packages ):
293
293
yield import_package (package )
294
294
295
295
@overload
@@ -305,12 +305,16 @@ def _hook_caller(self, spec: Callable[_P, _R | None]) -> TypedHookCaller[_P, _R]
305
305
return TypedHookCaller (None , caller )
306
306
307
307
308
- def flatten (values : list [ list [ _T ] | _T ] | None ) -> Iterator [_T ]:
308
+ def flatten_hook_returns (values : HookReturns [ _T ] | None ) -> Iterator [_T ]:
309
309
for value in values or []:
310
- if isinstance (value , list ):
311
- yield from value
312
- else :
313
- yield value
310
+ yield from flatten_hook_return (value )
311
+
312
+
313
+ def flatten_hook_return (values : HookReturn [_T ] | None ) -> Iterator [_T ]:
314
+ if isinstance (values , list ):
315
+ yield from values
316
+ else :
317
+ yield values # type: ignore
314
318
315
319
316
320
def import_package (package : Package ) -> ModuleType :
0 commit comments