From 3864036041b39aaf28193d3e00822ed59cb6a65c Mon Sep 17 00:00:00 2001 From: Maxime Gasse Date: Fri, 10 May 2024 16:09:18 -0400 Subject: [PATCH] tuple and named arguments in high-level action parser (#10) --- core/src/browsergym/core/action/parsers.py | 22 ++++++++++++++-- core/tests/test_actions_highlevel.py | 30 ++++++++++++++++++---- 2 files changed, 45 insertions(+), 7 deletions(-) diff --git a/core/src/browsergym/core/action/parsers.py b/core/src/browsergym/core/action/parsers.py index 63f606c6..f094d5c8 100644 --- a/core/src/browsergym/core/action/parsers.py +++ b/core/src/browsergym/core/action/parsers.py @@ -1,6 +1,18 @@ import ast import pyparsing as pp +from dataclasses import dataclass +from typing import Any + + +@dataclass +class NamedArgument: + name: str + value: Any + + def __repr__(self): + return f"{self.name}={repr(self.value)}" + def _build_highlevel_action_parser() -> pp.ParserElement: """ @@ -38,10 +50,14 @@ def literal_eval(toks): number = pp.pyparsing_common.number() dict = pp.Forward().set_name("dict") # will be defined later list = pp.Forward().set_name("list") # will be defined later - element = (string | number | dict | list | TRUE | FALSE | NONE).set_name("element") + _tuple = pp.Forward().set_name("tuple") # will be defined later + element = (string | number | dict | list | _tuple | TRUE | FALSE | NONE).set_name("element") list_items = pp.DelimitedList(element, allow_trailing_delim=True).set_name(None) list << pp.Group(LBRACK + pp.Optional(list_items) + RBRACK, aslist=True) + _tuple << pp.Group(LPAREN + pp.Optional(list_items) + RPAREN, aslist=True).set_parse_action( + lambda tokens: tuple(tokens[0]) + ) dict_item = pp.Group(string + COLON + element, aslist=True).set_name("dict item") dict_items = pp.DelimitedList(dict_item, allow_trailing_delim=True).set_name(None) @@ -49,7 +65,9 @@ def literal_eval(toks): arg = element list_args = pp.DelimitedList(arg, allow_trailing_delim=True).set_name(None) - named_arg = (pp.pyparsing_common.identifier() + pp.Literal("=")).suppress() + element + named_arg = (pp.pyparsing_common.identifier() + pp.Literal("=") + element).set_parse_action( + lambda tokens: NamedArgument(name=tokens[0], value=tokens[2]) + ) list_named_args = pp.DelimitedList(named_arg, allow_trailing_delim=True).set_name(None) function_call = pp.pyparsing_common.identifier() + pp.Group( LPAREN + pp.Optional(list_args) + pp.Optional(list_named_args) + RPAREN, aslist=True diff --git a/core/tests/test_actions_highlevel.py b/core/tests/test_actions_highlevel.py index 779746c2..76995b4d 100644 --- a/core/tests/test_actions_highlevel.py +++ b/core/tests/test_actions_highlevel.py @@ -17,7 +17,7 @@ from browsergym.utils.obs import flatten_dom_to_str from browsergym.core.action.highlevel import HighLevelActionSet -from browsergym.core.action.parsers import highlevel_action_parser +from browsergym.core.action.parsers import highlevel_action_parser, NamedArgument from browsergym.core.constants import BROWSERGYM_ID_ATTRIBUTE as BID_ATTR @@ -56,21 +56,32 @@ def test_action_parser(): function_calls = parser.parse_string(" a ( ) b() \n \tc()", parseAll=True) assert [function_name for function_name, _ in function_calls] == ["a", "b", "c"] - function_calls = parser.parse_string('a(12, 12.2, "text")', parseAll=True) + function_calls = parser.parse_string('a(12, 12.2, "text", (1, 2, 3), ["a", 23])', parseAll=True) _, function_args = function_calls[0] - assert function_args == [12, 12.2, "text"] + assert function_args == [12, 12.2, "text", (1, 2, 3), ["a", 23]] function_calls = parser.parse_string('a(x=12, y = 12.2, other = "text")', parseAll=True) _, function_args = function_calls[0] - assert function_args == [12, 12.2, "text"] + assert function_args == [ + NamedArgument(name="x", value=12), + NamedArgument(name="y", value=12.2), + NamedArgument(name="other", value="text"), + ] function_calls = parser.parse_string('a(12, y = 12.2, other = "text")', parseAll=True) _, function_args = function_calls[0] - assert function_args == [12, 12.2, "text"] + assert function_args == [ + 12, + NamedArgument(name="y", value=12.2), + NamedArgument(name="other", value="text"), + ] with pytest.raises(ParseException): function_calls = parser.parse_string('a(x = 12, 12.2, other = "text")', parseAll=True) + with pytest.raises(ParseException): + function_calls = parser.parse_string('a(12, 12.2, 1 = "text")', parseAll=True) + with pytest.raises(ParseException): function_calls = parser.parse_string("a(1-)", parseAll=True) @@ -91,6 +102,15 @@ def test_action_parser(): assert function_name == "a" assert function_args == ["# not comment"] + function_calls = parser.parse_string('fun(12, x="val", y={"aaa": 23})', parseAll=True) + function_name, function_args = function_calls[0] + assert function_name == "fun" + assert function_args == [ + 12, + NamedArgument(name="x", value="val"), + NamedArgument(name="y", value={"aaa": 23}), + ] + def test_valid_action(): action_set = HighLevelActionSet()