Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Rename Component.async_run to Component.run_async for better readablility #8370

Merged
merged 1 commit into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions haystack/core/component/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def _positional_to_kwargs(cls_type, args) -> Dict[str, Any]:

@staticmethod
def _parse_and_set_output_sockets(instance: Any):
has_async_run = hasattr(instance, "async_run")
has_async_run = hasattr(instance, "run_async")

# If `component.set_output_types()` was called in the component constructor,
# `__haystack_output__` is already populated, no need to do anything.
Expand All @@ -200,10 +200,10 @@ def _parse_and_set_output_sockets(instance: Any):
# won't share this data.

run_output_types = getattr(instance.run, "_output_types_cache", {})
async_run_output_types = getattr(instance.async_run, "_output_types_cache", {}) if has_async_run else {}
async_run_output_types = getattr(instance.run_async, "_output_types_cache", {}) if has_async_run else {}

if has_async_run and run_output_types != async_run_output_types:
raise ComponentError("Output type specifications of 'run' and 'async_run' methods must be the same")
raise ComponentError("Output type specifications of 'run' and 'run_async' methods must be the same")
output_types_cache = run_output_types

instance.__haystack_output__ = Sockets(instance, deepcopy(output_types_cache), OutputSocket)
Expand Down Expand Up @@ -243,7 +243,7 @@ def inner(method, sockets):
inner(getattr(component_cls, "run"), instance.__haystack_input__)

# Ensure that the sockets are the same for the async method, if it exists.
async_run = getattr(component_cls, "async_run", None)
async_run = getattr(component_cls, "run_async", None)
if async_run is not None:
run_sockets = Sockets(instance, {}, InputSocket)
async_run_sockets = Sockets(instance, {}, InputSocket)
Expand All @@ -254,7 +254,7 @@ def inner(method, sockets):
async_run_sig = inner(async_run, async_run_sockets)

if async_run_sockets != run_sockets or run_sig != async_run_sig:
raise ComponentError("Parameters of 'run' and 'async_run' methods must be the same")
raise ComponentError("Parameters of 'run' and 'run_async' methods must be the same")

def __call__(cls, *args, **kwargs):
"""
Expand All @@ -279,9 +279,9 @@ def __call__(cls, *args, **kwargs):

# Before returning, we have the chance to modify the newly created
# Component instance, so we take the chance and set up the I/O sockets
has_async_run = hasattr(instance, "async_run")
if has_async_run and not inspect.iscoroutinefunction(instance.async_run):
raise ComponentError(f"Method 'async_run' of component '{cls.__name__}' must be a coroutine")
has_async_run = hasattr(instance, "run_async")
if has_async_run and not inspect.iscoroutinefunction(instance.run_async):
raise ComponentError(f"Method 'run_async' of component '{cls.__name__}' must be a coroutine")
instance.__haystack_supports_async__ = has_async_run

ComponentMeta._parse_and_set_input_sockets(cls, instance)
Expand Down Expand Up @@ -478,8 +478,8 @@ class available here, we temporarily store the output types as an attribute of
sockets at instance creation time.
"""
method_name = run_method.__name__
if method_name not in ("run", "async_run"):
raise ComponentError("'output_types' decorator can only be used on 'run' and `async_run` methods")
if method_name not in ("run", "run_async"):
raise ComponentError("'output_types' decorator can only be used on 'run' and 'run_async' methods")

setattr(
run_method,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
---
features:
- |
Extend core component machinery to support an optional asynchronous `async_run` method in components.
Extend core component machinery to support an optional asynchronous `run_async` method in components.
If it's present, it should have the same parameters (and output types) as the run method and must be
implemented as a coroutine.
18 changes: 9 additions & 9 deletions test/core/component/test_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def run(self, input_value: int):
return {"output_value": input_value}

@component.output_types(output_value=int)
async def async_run(self, input_value: int):
async def run_async(self, input_value: int):
return {"output_value": input_value}

# Verifies also instantiation works with no issues
Expand Down Expand Up @@ -130,7 +130,7 @@ def run(self, value: int):
return {"value": 1}

@component.output_types(value=int)
def async_run(self, value: int):
def run_async(self, value: int):
return {"value": 1}

with pytest.raises(ComponentError, match=r"must be a coroutine"):
Expand All @@ -145,15 +145,15 @@ def run(self, value: int):
return {"value": 1}

@component.output_types(value=int)
async def async_run(self, value: int):
async def run_async(self, value: int):
yield {"value": 1}

with pytest.raises(ComponentError, match=r"must be a coroutine"):
comp = MockComponent()


def test_parameters_mismatch_run_and_async_run():
err_msg = r"Parameters of 'run' and 'async_run' methods must be the same"
err_msg = r"Parameters of 'run' and 'run_async' methods must be the same"

@component
class MockComponentMismatchingInputTypes:
Expand All @@ -162,7 +162,7 @@ def run(self, value: int):
return {"value": 1}

@component.output_types(value=int)
async def async_run(self, value: str):
async def run_async(self, value: str):
return {"value": "1"}

with pytest.raises(ComponentError, match=err_msg):
Expand All @@ -175,7 +175,7 @@ def run(self, value: int, **kwargs):
return {"value": 1}

@component.output_types(value=int)
async def async_run(self, value: int):
async def run_async(self, value: int):
return {"value": "1"}

with pytest.raises(ComponentError, match=err_msg):
Expand All @@ -188,7 +188,7 @@ def run(self, value: int, another: str):
return {"value": 1}

@component.output_types(value=int)
async def async_run(self, another: str, value: int):
async def run_async(self, another: str, value: int):
return {"value": "1"}

with pytest.raises(ComponentError, match=err_msg):
Expand Down Expand Up @@ -323,7 +323,7 @@ def run(self, value: int):
return {"value": 1}

@component.output_types(value=str)
async def async_run(self, value: int):
async def run_async(self, value: int):
return {"value": "1"}

with pytest.raises(ComponentError, match=r"Output type specifications .* must be the same"):
Expand All @@ -337,7 +337,7 @@ class MockComponent:
def run(self, value: int):
return {"value": 1}

async def async_run(self, value: int):
async def run_async(self, value: int):
return {"value": "1"}

with pytest.raises(ComponentError, match=r"Output type specifications .* must be the same"):
Expand Down