From f2e2185f777dbe8f01eb0516d2364e8359d456ca Mon Sep 17 00:00:00 2001 From: Maxime Gasse Date: Fri, 10 May 2024 15:24:13 -0400 Subject: [PATCH] black format --- core/src/browsergym/core/action/parsers.py | 10 ++++++++-- core/tests/test_actions_highlevel.py | 18 +++++++++++++++--- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/core/src/browsergym/core/action/parsers.py b/core/src/browsergym/core/action/parsers.py index 607ac790..f094d5c8 100644 --- a/core/src/browsergym/core/action/parsers.py +++ b/core/src/browsergym/core/action/parsers.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from typing import Any + @dataclass class NamedArgument: name: str @@ -12,6 +13,7 @@ class NamedArgument: def __repr__(self): return f"{self.name}={repr(self.value)}" + def _build_highlevel_action_parser() -> pp.ParserElement: """ Returns: @@ -53,7 +55,9 @@ def literal_eval(toks): list_items = pp.DelimitedList(element, allow_trailing_delim=True).set_name(None) list << pp.Group(LBRACK + pp.Optional(list_items) + RBRACK, aslist=True) - _tuple << pp.Group(LPAREN + pp.Optional(list_items) + RPAREN, aslist=True).set_parse_action(lambda tokens: tuple(tokens[0])) + _tuple << pp.Group(LPAREN + pp.Optional(list_items) + RPAREN, aslist=True).set_parse_action( + lambda tokens: tuple(tokens[0]) + ) dict_item = pp.Group(string + COLON + element, aslist=True).set_name("dict item") dict_items = pp.DelimitedList(dict_item, allow_trailing_delim=True).set_name(None) @@ -61,7 +65,9 @@ def literal_eval(toks): arg = element list_args = pp.DelimitedList(arg, allow_trailing_delim=True).set_name(None) - named_arg = (pp.pyparsing_common.identifier() + pp.Literal("=") + element).set_parse_action(lambda tokens: NamedArgument(name=tokens[0], value=tokens[2])) + named_arg = (pp.pyparsing_common.identifier() + pp.Literal("=") + element).set_parse_action( + lambda tokens: NamedArgument(name=tokens[0], value=tokens[2]) + ) list_named_args = pp.DelimitedList(named_arg, allow_trailing_delim=True).set_name(None) function_call = pp.pyparsing_common.identifier() + pp.Group( LPAREN + pp.Optional(list_args) + pp.Optional(list_named_args) + RPAREN, aslist=True diff --git a/core/tests/test_actions_highlevel.py b/core/tests/test_actions_highlevel.py index 8ba5943c..76995b4d 100644 --- a/core/tests/test_actions_highlevel.py +++ b/core/tests/test_actions_highlevel.py @@ -62,11 +62,19 @@ def test_action_parser(): function_calls = parser.parse_string('a(x=12, y = 12.2, other = "text")', parseAll=True) _, function_args = function_calls[0] - assert function_args == [NamedArgument(name='x', value=12), NamedArgument(name="y", value=12.2), NamedArgument(name="other", value="text")] + assert function_args == [ + NamedArgument(name="x", value=12), + NamedArgument(name="y", value=12.2), + NamedArgument(name="other", value="text"), + ] function_calls = parser.parse_string('a(12, y = 12.2, other = "text")', parseAll=True) _, function_args = function_calls[0] - assert function_args == [12, NamedArgument(name="y", value=12.2), NamedArgument(name="other", value="text")] + assert function_args == [ + 12, + NamedArgument(name="y", value=12.2), + NamedArgument(name="other", value="text"), + ] with pytest.raises(ParseException): function_calls = parser.parse_string('a(x = 12, 12.2, other = "text")', parseAll=True) @@ -97,7 +105,11 @@ def test_action_parser(): function_calls = parser.parse_string('fun(12, x="val", y={"aaa": 23})', parseAll=True) function_name, function_args = function_calls[0] assert function_name == "fun" - assert function_args == [12, NamedArgument(name="x", value="val"), NamedArgument(name="y", value={"aaa": 23})] + assert function_args == [ + 12, + NamedArgument(name="x", value="val"), + NamedArgument(name="y", value={"aaa": 23}), + ] def test_valid_action():