diff --git a/replit_river/codegen/client.py b/replit_river/codegen/client.py index 20c9cb6..ee1b803 100644 --- a/replit_river/codegen/client.py +++ b/replit_river/codegen/client.py @@ -80,6 +80,7 @@ Literal, Optional, Mapping, + NotRequired, Union, Tuple, TypedDict, @@ -507,7 +508,9 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]: encoder_names.add(encoder_name) typeddict_encoder.append(f"{encoder_name}(x[{repr(name)}])") if name not in type.required: - typeddict_encoder.append(f"if x[{repr(name)}] else None") + typeddict_encoder.append( + f"if {repr(name)} in x and x[{repr(name)}] else None" + ) elif isinstance(prop, RiverIntersectionType): encoder_name = TypeName( f"encode_{ensure_literal_type(type_name)}" @@ -541,7 +544,12 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]: items = cast(RiverConcreteType, prop).items assert items, "Somehow items was none" if is_literal(cast(RiverType, items)): - typeddict_encoder.append(f"x[{repr(name)}]") + if name in prop.required: + typeddict_encoder.append(f"x[{repr(name)}]") + else: + typeddict_encoder.append( + f"x.get({repr(safe_name)})" + ) else: match type_name: case ListTypeExpr(inner_type_name): @@ -606,12 +614,24 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]: ) else: if name not in type.required: - value = "" - if base_model != "TypedDict": - value = " = None" - current_chunks.append( - f" {name}: Optional[{render_type_expr(type_name)}]{value}" - ) + if base_model == "TypedDict": + current_chunks.append( + reindent( + " ", + f"""\ + {name}: NotRequired[Optional[{render_type_expr(type_name)}]] + """, + ) + ) + else: + current_chunks.append( + reindent( + " ", + f"""\ + {name}: Optional[{render_type_expr(type_name)}] = None + """, + ) + ) else: current_chunks.append( f" {name}: {render_type_expr(type_name)}"