Skip to content

Commit

Permalink
Update from_string method
Browse files Browse the repository at this point in the history
  • Loading branch information
Amnah199 committed Jan 9, 2025
1 parent 94f9fd1 commit a825868
Showing 1 changed file with 22 additions and 22 deletions.
44 changes: 22 additions & 22 deletions haystack/components/routers/conditional_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from warnings import warn

from jinja2 import Environment, TemplateSyntaxError, meta
from jinja2.nativetypes import NativeEnvironment, NativeTemplate
from jinja2.nativetypes import NativeEnvironment, NativeTemplate, Template
from jinja2.sandbox import SandboxedEnvironment

from haystack import component, default_from_dict, default_to_dict, logging
Expand All @@ -17,7 +17,7 @@

logger = logging.getLogger(__name__)

haystack_dataclass_types = (ByteStream, Document, ChatMessage, Answer, SparseEmbedding, StreamingChunk)
haystack_dataclass_types = (ByteStream, ChatMessage, Document, Answer, SparseEmbedding, StreamingChunk)


class NoRouteSelectedException(Exception):
Expand All @@ -28,7 +28,7 @@ class RouteConditionException(Exception):
"""Exception raised when there is an error parsing or evaluating the condition expression in ConditionalRouter."""


class NativeSandboxedTemplate(NativeTemplate):
class NativeSandboxedTemplate(NativeTemplate, Template):
"""
A template class that returns native Python objects and also respects the sandbox security checks.
"""
Expand All @@ -42,25 +42,25 @@ class NativeSandboxedEnvironment(SandboxedEnvironment, NativeEnvironment):
"""

# We tell the environment to use our custom template class by default.
template_class = NativeSandboxedTemplate

def from_string(self, source, template_class=None):
def from_string(self, source):
"""
Override from_string to ensure the sandbox logic + native logic are used together.
"""
if template_class is None:
template_class = self.template_class
template_class = NativeSandboxedTemplate

return SandboxedEnvironment.from_string(self, source, template_class=template_class)

def is_safe_attribute(self, obj):
def is_safe_attribute(self, obj, attr="", value=""):
"""
Whitelist Haystack dataclasses so the sandbox won't block them.
"""
if isinstance(obj, haystack_dataclass_types):
return True

if not isinstance(obj, haystack_dataclass_types):
return False

# Otherwise, fallback to the default sandbox behavior
return super().is_safe_attribute(obj)
return SandboxedEnvironment.is_safe_attribute(self, obj, attr, value)


@component
Expand Down Expand Up @@ -236,14 +236,6 @@ def __init__( # pylint: disable=too-many-positional-arguments
self._custom_env = NativeSandboxedEnvironment()
self._env.filters.update(self.custom_filters)

# Add custom types to the custom environment
self._custom_env.globals["Document"] = Document
self._custom_env.globals["ChatMessage"] = ChatMessage
self._custom_env.globals["ByteStream"] = ByteStream
self._custom_env.globals["Answer"] = Answer
self._custom_env.globals["SparseEmbedding"] = SparseEmbedding
self._custom_env.globals["StreamingChunk"] = StreamingChunk

self._validate_routes(routes)
# Inspect the routes to determine input and output types.
input_types: Set[str] = set() # let's just store the name, type will always be Any
Expand Down Expand Up @@ -359,13 +351,21 @@ def run(self, **kwargs):
t_output = self._custom_env.from_string(route["output"])
output = t_output.render(**kwargs)

# Check if output is a list/sequence and validate accordingly
if isinstance(output, (list, tuple)):
if all(self._custom_env.is_safe_attribute(item) for item in output):
pass
elif self._custom_env.is_safe_attribute(output):
pass

# We suppress the exception in case the output is already a string, otherwise
# we try to evaluate it and would fail.
# This must be done cause the output could be different literal structures.
# This doesn't support any user types.
with contextlib.suppress(Exception):
if not self._unsafe and isinstance(output, str):
output = ast.literal_eval(output)
else:
with contextlib.suppress(Exception):
if not self._unsafe and isinstance(output, str):
output = ast.literal_eval(output)
except Exception as e:
msg = f"Error evaluating condition for route '{route}': {e}"
raise RouteConditionException(msg) from e
Expand Down

0 comments on commit a825868

Please sign in to comment.