Skip to content

Commit fbe67b7

Browse files
Limiting FILE_HEADER (#114)
Why === Now that we're generating modules, not all imports apply in all places, and it imports are not able to be pruned in `__init__.py`, so let's not just import a bunch of stuff we don't need. What changed ============ Breaking `FILE_HEADER` up into three different variants. Test plan ========= _Describe what you did to test this change to a level of detail that allows your reviewer to test it_
1 parent d85becf commit fbe67b7

File tree

1 file changed

+28
-2
lines changed

1 file changed

+28
-2
lines changed

replit_river/codegen/client.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,32 @@ def ensure_literal_type(value: TypeExpression) -> TypeName:
109109

110110
_NON_ALNUM_RE = re.compile(r"[^a-zA-Z0-9_]+")
111111

112+
# Literal is here because HandshakeType can be Literal[None]
113+
ROOT_FILE_HEADER = dedent(
114+
"""\
115+
# Code generated by river.codegen. DO NOT EDIT.
116+
from pydantic import BaseModel
117+
from typing import Literal
118+
119+
import replit_river as river
120+
121+
"""
122+
)
123+
124+
SERVICE_FILE_HEADER = dedent(
125+
"""\
126+
# Code generated by river.codegen. DO NOT EDIT.
127+
from collections.abc import AsyncIterable, AsyncIterator
128+
from typing import Any
129+
130+
from pydantic import TypeAdapter
131+
132+
from replit_river.error_schema import RiverError
133+
import replit_river as river
134+
135+
"""
136+
)
137+
112138
FILE_HEADER = dedent(
113139
"""\
114140
# ruff: noqa
@@ -709,7 +735,7 @@ def generate_common_client(
709735
handshake_chunks: Sequence[str],
710736
modules: list[Tuple[ModuleName, ClassName]],
711737
) -> FileContents:
712-
chunks: list[str] = [FILE_HEADER]
738+
chunks: list[str] = [ROOT_FILE_HEADER]
713739
chunks.extend(
714740
[
715741
f"from .{model_name} import {class_name}"
@@ -1072,7 +1098,7 @@ async def {name}(
10721098
]
10731099

10741100
emitted_files[RenderedPath(str(Path(f"{schema_name}/__init__.py")))] = FileContents(
1075-
"\n".join([FILE_HEADER] + rendered_imports + in_root + current_chunks)
1101+
"\n".join([SERVICE_FILE_HEADER] + rendered_imports + in_root + current_chunks)
10761102
)
10771103
return (
10781104
ModuleName(schema_name),

0 commit comments

Comments
 (0)