Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix has many through relationship not working #903

Open
wants to merge 9 commits into
base: 2.0
Choose a base branch
from
2 changes: 1 addition & 1 deletion src/masoniteorm/query/QueryBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1969,7 +1969,7 @@ def _register_relationships_to_model(
def _map_related(self, related_result, related):
if related.__class__.__name__ == 'MorphTo':
return related_result
elif related.__class__.__name__ == 'HasOneThrough':
elif related.__class__.__name__ in ['HasOneThrough', 'HasManyThrough']:
return related_result.group_by(related.local_key)

return related_result.group_by(related.foreign_key)
Expand Down
226 changes: 144 additions & 82 deletions src/masoniteorm/relationships/HasManyThrough.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .BaseRelationship import BaseRelationship
from ..collection import Collection
from .BaseRelationship import BaseRelationship


class HasManyThrough(BaseRelationship):
Expand Down Expand Up @@ -57,33 +57,46 @@ def __get__(self, instance, owner):
if attribute in instance._relationships:
return instance._relationships[attribute]

result = self.apply_query(
result = self.apply_related_query(
self.distant_builder, self.intermediary_builder, instance
circulon marked this conversation as resolved.
Show resolved Hide resolved
)
return result
else:
return self

def apply_query(self, distant_builder, intermediary_builder, owner):
"""Apply the query and return a dictionary to be hydrated.
Used during accessing a relationship on a model
def apply_related_query(self, distant_builder, intermediary_builder, owner):
"""
Apply the query to return a Collection of data for the distant models to be hydrated with.

Arguments:
query {oject} -- The relationship object
owner {object} -- The current model oject.
Method is used when accessing a relationship on a model if its not
already eager loaded

Returns:
dict -- A dictionary of data which will be hydrated.
Arguments
distant_builder (QueryBuilder): QueryBuilder attached to the distant table
intermediate_builder (QueryBuilder): QueryBuilder attached to the intermediate (linking) table
owner (Any): the model this relationship is starting from

Returns
Collection: Collection of dicts which will be used for hydrating models.
"""
# select * from `countries` inner join `ports` on `ports`.`country_id` = `countries`.`country_id` where `ports`.`port_id` is null and `countries`.`deleted_at` is null and `ports`.`deleted_at` is null
result = distant_builder.join(
f"{self.intermediary_builder.get_table_name()}",
f"{self.intermediary_builder.get_table_name()}.{self.foreign_key}",
"=",
f"{distant_builder.get_table_name()}.{self.other_owner_key}",
).where(f"{self.intermediary_builder.get_table_name()}.{self.local_owner_key}", getattr(owner, self.other_owner_key)).get()

return result
dist_table = distant_builder.get_table_name()
int_table = intermediary_builder.get_table_name()
circulon marked this conversation as resolved.
Show resolved Hide resolved

return (
self.distant_builder.select(f"{dist_table}.*, {int_table}.{self.local_key}")
.join(
f"{int_table}",
f"{int_table}.{self.foreign_key}",
"=",
f"{dist_table}.{self.other_owner_key}",
)
.where(
f"{int_table}.{self.local_key}",
getattr(owner, self.local_owner_key),
)
.get()
)

def relate(self, related_model):
return self.distant_builder.join(
Expand All @@ -104,51 +117,144 @@ def make_builder(self, eagers=None):

return builder

def get_related(self, query, relation, eagers=None, callback=None):
builder = self.distant_builder
def register_related(self, key, model, collection):
"""
Attach the related model to source models attribute

Arguments
key (str): The attribute name
model (Any): The model instance
collection (Collection): The data for the related models

Returns
None
"""
related = collection.get(getattr(model, self.local_owner_key), None)
if related and not isinstance(related, Collection):
related = Collection(related)

model.add_relation({key: related if related else None})

def get_related(self, current_builder, relation, eagers=None, callback=None):
"""
Get a Collection to hydrate the models for the distant table with
Used when eager loading the model attribute

Arguments
current_builder (QueryBuilder): The source models QueryBuilder object
relation (HasManyThrough): this relationship object
eagers (Any):
callback (Any):

Returns
Collection the collection of dicts to hydrate the distant models with
"""

dist_table = self.distant_builder.get_table_name()
int_table = self.intermediary_builder.get_table_name()

if callback:
callback(builder)
callback(current_builder)

(
self.distant_builder.select(
f"{dist_table}.*, {int_table}.{self.local_key}"
).join(
f"{int_table}",
f"{int_table}.{self.foreign_key}",
"=",
f"{dist_table}.{self.other_owner_key}",
)
)

if isinstance(relation, Collection):
return builder.where_in(
f"{builder.get_table_name()}.{self.foreign_key}",
Collection(relation._get_value(self.local_key)).unique(),
return self.distant_builder.where_in(
f"{int_table}.{self.local_key}",
Collection(relation._get_value(self.local_owner_key)).unique(),
).get()
else:
return builder.where(
f"{builder.get_table_name()}.{self.foreign_key}",
return self.distant_builder.where(
f"{int_table}.{self.local_key}",
getattr(relation, self.local_owner_key),
circulon marked this conversation as resolved.
Show resolved Hide resolved
).get()

def get_with_count_query(self, builder, callback):
query = self.distant_builder
def attach(self, current_model, related_record):
raise NotImplementedError(
"HasOneThrough relationship does not implement the attach method"
)

def attach_related(self, current_model, related_record):
raise NotImplementedError(
"HasOneThrough relationship does not implement the attach_related method"
)

if not builder._columns:
builder = builder.select("*")
def query_has(self, current_builder, method="where_exists"):
dist_table = self.distant_builder.get_table_name()
int_table = self.intermediary_builder.get_table_name()

circulon marked this conversation as resolved.
Show resolved Hide resolved
return_query = builder.add_select(
getattr(current_builder, method)(
self.distant_builder.join(
f"{int_table}",
f"{int_table}.{self.foreign_key}",
"=",
f"{dist_table}.{self.other_owner_key}",
).where_column(
f"{int_table}.{self.local_key}",
f"{current_builder.get_table_name()}.{self.local_owner_key}",
)
)

return self.distant_builder

def query_where_exists(self, current_builder, callback, method="where_exists"):
dist_table = self.distant_builder.get_table_name()
int_table = self.intermediary_builder.get_table_name()

getattr(current_builder, method)(
self.distant_builder.join(
f"{int_table}",
f"{int_table}.{self.foreign_key}",
"=",
f"{dist_table}.{self.other_owner_key}",
circulon marked this conversation as resolved.
Show resolved Hide resolved
)
.where_column(
f"{int_table}.{self.local_key}",
f"{current_builder.get_table_name()}.{self.local_owner_key}",
)
.when(callback, lambda q: (callback(q)))
)

def get_with_count_query(self, current_builder, callback):
dist_table = self.distant_builder.get_table_name()
int_table = self.intermediary_builder.get_table_name()

if not current_builder._columns:
current_builder.select("*")

return_query = current_builder.add_select(
f"{self.attribute}_count",
lambda q: (
(
q.count("*")
.join(
f"{self.intermediary_builder.get_table_name()}",
f"{self.intermediary_builder.get_table_name()}.{self.foreign_key}",
f"{int_table}",
f"{int_table}.{self.foreign_key}",
"=",
f"{query.get_table_name()}.{self.other_owner_key}",
circulon marked this conversation as resolved.
Show resolved Hide resolved
f"{dist_table}.{self.other_owner_key}",
)
.where_column(
f"{builder.get_table_name()}.{self.local_owner_key}",
f"{self.intermediary_builder.get_table_name()}.{self.local_key}",
f"{int_table}.{self.local_key}",
f"{current_builder.get_table_name()}.{self.local_owner_key}",
)
.table(query.get_table_name())
.table(dist_table)
.when(
callback,
lambda q: (
q.where_in(
self.foreign_key,
callback(query.select(self.other_owner_key)),
callback(
self.distant_builder.select(self.other_owner_key)
),
)
),
)
Expand All @@ -157,47 +263,3 @@ def get_with_count_query(self, builder, callback):
)

return return_query

def attach(self, current_model, related_record):
raise NotImplementedError(
"HasOneThrough relationship does not implement the attach method"
)

def attach_related(self, current_model, related_record):
raise NotImplementedError(
"HasOneThrough relationship does not implement the attach_related method"
)

def query_has(self, current_query_builder, method="where_exists"):
related_builder = self.get_builder()

getattr(current_query_builder, method)(
self.distant_builder.where_column(
f"{current_query_builder.get_table_name()}.{self.local_owner_key}",
f"{self.intermediary_builder.get_table_name()}.{self.local_key}",
).join(
f"{self.intermediary_builder.get_table_name()}",
f"{self.intermediary_builder.get_table_name()}.{self.foreign_key}",
"=",
f"{self.distant_builder.get_table_name()}.{self.other_owner_key}",
)
)

return related_builder

def query_where_exists(
self, current_query_builder, callback, method="where_exists"
):
query = self.distant_builder

getattr(current_query_builder, method)(
query.join(
f"{self.intermediary_builder.get_table_name()}",
f"{self.intermediary_builder.get_table_name()}.{self.foreign_key}",
"=",
f"{query.get_table_name()}.{self.other_owner_key}",
).where_column(
f"{current_query_builder.get_table_name()}.{self.local_owner_key}",
f"{self.intermediary_builder.get_table_name()}.{self.local_key}",
)
).when(callback, lambda q: (callback(q)))
12 changes: 6 additions & 6 deletions tests/mysql/relationships/test_has_many_through.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ def test_has_query(self):

self.assertEqual(
sql,
"""SELECT * FROM `inbound_shipments` WHERE EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `inbound_shipments`.`from_port_id` = `ports`.`port_id`)""",
"""SELECT * FROM `inbound_shipments` WHERE EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id`)""",
)

def test_or_has(self):
sql = InboundShipment.where("name", "Joe").or_has("from_country").to_sql()

self.assertEqual(
sql,
"""SELECT * FROM `inbound_shipments` WHERE `inbound_shipments`.`name` = 'Joe' OR EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `inbound_shipments`.`from_port_id` = `ports`.`port_id`)""",
"""SELECT * FROM `inbound_shipments` WHERE `inbound_shipments`.`name` = 'Joe' OR EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id`)""",
)

def test_where_has_query(self):
Expand All @@ -49,7 +49,7 @@ def test_where_has_query(self):

self.assertEqual(
sql,
"""SELECT * FROM `inbound_shipments` WHERE EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `inbound_shipments`.`from_port_id` = `ports`.`port_id`) AND `inbound_shipments`.`name` = 'USA'""",
"""SELECT * FROM `inbound_shipments` WHERE EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id` AND `countries`.`name` = 'USA')""",
)

def test_or_where_has(self):
Expand All @@ -61,15 +61,15 @@ def test_or_where_has(self):

self.assertEqual(
sql,
"""SELECT * FROM `inbound_shipments` WHERE `inbound_shipments`.`name` = 'Joe' OR EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `inbound_shipments`.`from_port_id` = `ports`.`port_id`) AND `inbound_shipments`.`name` = 'USA'""",
"""SELECT * FROM `inbound_shipments` WHERE `inbound_shipments`.`name` = 'Joe' OR EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id` AND `countries`.`name` = 'USA')""",
)

def test_doesnt_have(self):
sql = InboundShipment.doesnt_have("from_country").to_sql()

self.assertEqual(
sql,
"""SELECT * FROM `inbound_shipments` WHERE NOT EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `inbound_shipments`.`from_port_id` = `ports`.`port_id`)""",
"""SELECT * FROM `inbound_shipments` WHERE NOT EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id`)""",
)

def test_or_where_doesnt_have(self):
Expand All @@ -83,5 +83,5 @@ def test_or_where_doesnt_have(self):

self.assertEqual(
sql,
"""SELECT * FROM `inbound_shipments` WHERE `inbound_shipments`.`name` = 'Joe' OR NOT EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `inbound_shipments`.`from_port_id` = `ports`.`port_id`) AND `inbound_shipments`.`name` = 'USA'""",
"""SELECT * FROM `inbound_shipments` WHERE `inbound_shipments`.`name` = 'Joe' OR NOT EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id` AND `countries`.`name` = 'USA')""",
)
Loading
Loading