Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multi-level nested create/update with model full_clean() #659

96 changes: 79 additions & 17 deletions strawberry_django/mutations/resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@
)

if TYPE_CHECKING:
from django.db.models.manager import ManyToManyRelatedManager, RelatedManager
from django.db.models.manager import (
BaseManager,
ManyToManyRelatedManager,
RelatedManager,
)
from strawberry.types.info import Info


Expand Down Expand Up @@ -222,6 +226,7 @@
data: dict[str, Any],
key_attr: str | None = None,
full_clean: bool | FullCleanOptions = True,
exclude_m2m: list[str] | None = None,
) -> tuple[
Model,
dict[str, object],
Expand All @@ -237,6 +242,7 @@
fields = get_model_fields(model)
m2m: list[tuple[ManyToManyField | ForeignObjectRel, Any]] = []
direct_field_values: dict[str, object] = {}
exclude_m2m = exclude_m2m or []

if dataclasses.is_dataclass(data):
data = vars(data)
Expand All @@ -256,6 +262,8 @@
# (but only if the instance is already saved and we are updating it)
value = False # noqa: PLW2901
elif isinstance(field, (ManyToManyField, ForeignObjectRel)):
if name in exclude_m2m:
continue
# m2m will be processed later
m2m.append((field, value))
direct_field_value = False
Expand Down Expand Up @@ -309,6 +317,7 @@
key_attr: str | None = None,
full_clean: bool | FullCleanOptions = True,
pre_save_hook: Callable[[_M], None] | None = None,
exclude_m2m: list[str] | None = None,
) -> _M: ...


Expand All @@ -321,10 +330,10 @@
key_attr: str | None = None,
full_clean: bool | FullCleanOptions = True,
pre_save_hook: Callable[[_M], None] | None = None,
exclude_m2m: list[str] | None = None,
) -> list[_M]: ...


@transaction.atomic
def create(
info: Info,
model: type[_M],
Expand All @@ -333,12 +342,43 @@
key_attr: str | None = None,
full_clean: bool | FullCleanOptions = True,
pre_save_hook: Callable[[_M], None] | None = None,
exclude_m2m: list[str] | None = None,
) -> list[_M] | _M:
return _create(
info,
model._default_manager,
data,
key_attr=key_attr,
full_clean=full_clean,
pre_save_hook=pre_save_hook,
exclude_m2m=exclude_m2m,
)


@transaction.atomic
def _create(
info: Info,
manager: BaseManager,
data: dict[str, Any] | list[dict[str, Any]],
*,
key_attr: str | None = None,
full_clean: bool | FullCleanOptions = True,
pre_save_hook: Callable[[_M], None] | None = None,
exclude_m2m: list[str] | None = None,
) -> list[_M] | _M:
model = manager.model
# Before creating your instance, verify this is not a bulk create
# if so, add them one by one. Otherwise, get to work.
if isinstance(data, list):
return [
create(info, model, d, key_attr=key_attr, full_clean=full_clean)
create(
info,
model,
d,
key_attr=key_attr,
full_clean=full_clean,
exclude_m2m=exclude_m2m,
)
for d in data
]

Expand All @@ -365,6 +405,7 @@
data=data,
full_clean=full_clean,
key_attr=key_attr,
exclude_m2m=exclude_m2m,
)

# Creating the instance directly via create() without full-clean will
Expand All @@ -372,11 +413,11 @@
# full-clean() to trigger form-validation style error messages.
full_clean_options = full_clean if isinstance(full_clean, dict) else {}
if full_clean:
dummy_instance.full_clean(**full_clean_options) # type: ignore

Check warning on line 416 in strawberry_django/mutations/resolvers.py

View workflow job for this annotation

GitHub Actions / Typing

Unnecessary "# type: ignore" comment (reportUnnecessaryTypeIgnoreComment)

# Create the instance using the manager create method to respect
# manager create overrides. This also ensures support for proxy-models.
instance = model._default_manager.create(**create_kwargs)
instance = manager.create(**create_kwargs)

for field, value in m2m:
update_m2m(info, instance, field, value, key_attr)
Expand All @@ -393,6 +434,7 @@
key_attr: str | None = None,
full_clean: bool | FullCleanOptions = True,
pre_save_hook: Callable[[_M], None] | None = None,
exclude_m2m: list[str] | None = None,
) -> _M: ...


Expand All @@ -405,6 +447,7 @@
key_attr: str | None = None,
full_clean: bool | FullCleanOptions = True,
pre_save_hook: Callable[[_M], None] | None = None,
exclude_m2m: list[str] | None = None,
) -> list[_M]: ...


Expand All @@ -417,6 +460,7 @@
key_attr: str | None = None,
full_clean: bool | FullCleanOptions = True,
pre_save_hook: Callable[[_M], None] | None = None,
exclude_m2m: list[str] | None = None,
) -> _M | list[_M]:
# Unwrap lazy objects since they have a proxy __iter__ method that will make
# them iterables even if the wrapped object isn't
Expand All @@ -433,6 +477,7 @@
key_attr=key_attr,
full_clean=full_clean,
pre_save_hook=pre_save_hook,
exclude_m2m=exclude_m2m,
)
for instance in instances
]
Expand All @@ -443,6 +488,7 @@
data=data,
key_attr=key_attr,
full_clean=full_clean,
exclude_m2m=exclude_m2m,
)

if pre_save_hook is not None:
Expand Down Expand Up @@ -554,15 +600,22 @@
use_remove = True
if isinstance(field, ManyToManyField):
manager = cast("RelatedManager", getattr(instance, field.attname))
reverse_field_name = field.remote_field.related_name # type: ignore
else:
assert isinstance(field, (ManyToManyRel, ManyToOneRel))
accessor_name = field.get_accessor_name()
reverse_field_name = field.field.name
assert accessor_name
manager = cast("RelatedManager", getattr(instance, accessor_name))
if field.one_to_many:
# remove if field is nullable, otherwise delete
use_remove = field.remote_field.null is True

# Create a data dict containing the reference to the instance and exclude it from
# nested m2m creation (to break circular references)
ref_instance_data = {reverse_field_name: instance}
exclude_m2m = [reverse_field_name]

to_add = []
to_remove = []
to_delete = []
Expand Down Expand Up @@ -621,14 +674,17 @@

existing.discard(obj)
else:
if key_attr not in data: # we have a Input Type
obj, _ = manager.get_or_create(**data)
else:
data.pop(key_attr)
obj = manager.create(**data)

if full_clean:
obj.full_clean(**full_clean_options)
# If we've reached here, the key_attr should be UNSET or missing. So
# let's remove it if it is there.
data.pop(key_attr, None)
obj = _create(
info,
manager,
data | ref_instance_data,
key_attr=key_attr,
full_clean=full_clean,
exclude_m2m=exclude_m2m,
)
existing.discard(obj)

for remaining in existing:
Expand Down Expand Up @@ -656,11 +712,17 @@
data.pop(key_attr, None)
to_add.append(obj)
elif data:
if key_attr not in data:
manager.get_or_create(**data)
else:
data.pop(key_attr)
manager.create(**data)
# If we've reached here, the key_attr should be UNSET or missing. So
# let's remove it if it is there.
bellini666 marked this conversation as resolved.
Show resolved Hide resolved
data.pop(key_attr, None)
_create(
info,
manager,
data | ref_instance_data,
key_attr=key_attr,
full_clean=full_clean,
exclude_m2m=exclude_m2m,
)
else:
raise AssertionError

Expand Down
12 changes: 12 additions & 0 deletions tests/projects/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,12 @@ class MilestoneIssueInput:
name: strawberry.auto


@strawberry_django.partial(Issue)
class MilestoneIssueInputPartial:
name: strawberry.auto
tags: Optional[list[TagInputPartial]]


@strawberry_django.partial(Project)
class ProjectInputPartial(NodeInputPartial):
name: strawberry.auto
Expand All @@ -351,6 +357,7 @@ class MilestoneInput:
@strawberry_django.partial(Milestone)
class MilestoneInputPartial(NodeInputPartial):
name: strawberry.auto
issues: Optional[list[MilestoneIssueInputPartial]]


@strawberry.type
Expand Down Expand Up @@ -519,6 +526,11 @@ class Mutation:
argument_name="input",
key_attr="name",
)
create_project_with_milestones: ProjectType = mutations.create(
ProjectInputPartial,
handle_django_errors=True,
argument_name="input",
)
update_project: ProjectType = mutations.update(
ProjectInputPartial,
handle_django_errors=True,
Expand Down
9 changes: 9 additions & 0 deletions tests/projects/snapshots/schema.gql
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ input CreateProjectInput {

union CreateProjectPayload = ProjectType | OperationInfo

union CreateProjectWithMilestonesPayload = ProjectType | OperationInfo

input CreateQuizInput {
title: String!
fullCleanOptions: Boolean! = false
Expand Down Expand Up @@ -365,12 +367,18 @@ input MilestoneInput {
input MilestoneInputPartial {
id: GlobalID
name: String
issues: [MilestoneIssueInputPartial!]
}

input MilestoneIssueInput {
name: String!
}

input MilestoneIssueInputPartial {
name: String
tags: [TagInputPartial!]
}

input MilestoneOrder {
name: Ordering
project: ProjectOrder
Expand Down Expand Up @@ -431,6 +439,7 @@ type Mutation {
updateIssueWithKeyAttr(input: IssueInputPartialWithoutId!): UpdateIssueWithKeyAttrPayload!
deleteIssue(input: NodeInput!): DeleteIssuePayload!
deleteIssueWithKeyAttr(input: MilestoneIssueInput!): DeleteIssueWithKeyAttrPayload!
createProjectWithMilestones(input: ProjectInputPartial!): CreateProjectWithMilestonesPayload!
updateProject(input: ProjectInputPartial!): UpdateProjectPayload!
createMilestone(input: MilestoneInput!): CreateMilestonePayload!
createProject(
Expand Down
11 changes: 11 additions & 0 deletions tests/projects/snapshots/schema_with_inheritance.gql
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,12 @@ input MilestoneFilter {
input MilestoneInputPartial {
id: GlobalID
name: String
issues: [MilestoneIssueInputPartial!]
}

input MilestoneIssueInputPartial {
name: String
tags: [TagInputPartial!]
}

input MilestoneOrder {
Expand Down Expand Up @@ -401,6 +407,11 @@ input StrFilterLookup {
iRegex: String
}

input TagInputPartial {
id: GlobalID
name: String
}

type TagType implements Node & Named {
"""The Globally Unique ID of this object"""
id: GlobalID!
Expand Down
Loading
Loading