diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..c39f42e --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,32 @@ +# Changelog + +- [0.0.4](#004) + +## 0.0.4 + +### [Fix incorrect state mutability identification](https://github.com/OpenZeppelin/sgp/pull/9) + +This PR fixed incorrect state mutability identification for the various structures, such as: + +**`const` state mutability for the state variable declaration** + +```solidity +contract Example { + uint256 constant x = 1; +} +``` + +**`payable` state mutability for the function return parameter variable declaration** + +```solidity +function test() public returns(address payable) {} +``` + +Now, these cases are handled correctly after the bug in the version `0.0.3`. + +### [CHANGELOG.md](/CHANGELOG.md) added + +### [Onboard `ruff](https://github.com/OpenZeppelin/sgp/pull/10) + +- Use `ruff` for the `py` files check and formatting. +- Added a `ruff_helper.sh` script diff --git a/pyproject.toml b/pyproject.toml index 9c65a70..edd34dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "openzeppelin-solidity-grammar-parser" -version = "0.0.3" +version = "0.0.4" authors = [{ name = "Georgii Plotnikov", email = "accembler@gmail.com" }] description = "Solidity ANTLR4 grammar Python parser" readme = "README.md" @@ -49,3 +49,6 @@ exclude = ''' )/ ) ''' + +[tool.ruff] +exclude = ["./sgp/parser/"] diff --git a/requirements.txt b/requirements.txt index 614c4c0..e0ac19c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ coverage==7.3.1 simplejson==3.19.1 typing==3.7.4.3 typing_extensions==4.8.0 +ruff==0.2.1 diff --git a/ruff_helper.sh b/ruff_helper.sh new file mode 100755 index 0000000..c5fcf2b --- /dev/null +++ b/ruff_helper.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +source ./bin/activate +pip install ruff + +ruff check . +ruff format . diff --git a/sgp/ast_node_types.py b/sgp/ast_node_types.py index c2ff46a..3178b41 100644 --- a/sgp/ast_node_types.py +++ b/sgp/ast_node_types.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional, Tuple, Union, Dict +from typing import List, Optional, Tuple, Union class Position: diff --git a/sgp/main.py b/sgp/main.py index 732a1b3..e9b25c3 100644 --- a/sgp/main.py +++ b/sgp/main.py @@ -1,5 +1,6 @@ import sgp_parser + def main(): input = """ // SPDX-License-Identifier: MIT @@ -19,5 +20,6 @@ def main(): except Exception as e: print(e) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/sgp/sgp_error_listener.py b/sgp/sgp_error_listener.py index eaf7b11..2d9f669 100644 --- a/sgp/sgp_error_listener.py +++ b/sgp/sgp_error_listener.py @@ -2,6 +2,7 @@ from typing_extensions import override from antlr4.error.ErrorListener import ErrorListener + class SGPError: def __init__(self, message, line, column): self.message = message @@ -11,6 +12,7 @@ def __init__(self, message, line, column): def __str__(self): return f"{self.message} ({self.line}:{self.column})" + class SGPErrorListener(ErrorListener): def __init__(self): super().__init__() diff --git a/sgp/sgp_parser.py b/sgp/sgp_parser.py index b46c6ae..e7d7bb7 100644 --- a/sgp/sgp_parser.py +++ b/sgp/sgp_parser.py @@ -1,6 +1,5 @@ import os import simplejson -from typing import Dict from antlr4.CommonTokenStream import CommonTokenStream from antlr4.InputStream import InputStream as ANTLRInputStream @@ -17,14 +16,14 @@ class ParserError(Exception): """ - An exception raised when the parser encounters an error. + An exception raised when the parser encounters an error. """ def __init__(self, errors) -> None: """ Parameters ---------- - errors : List[Dict[str, Any]] - A list of errors encountered by the parser. + errors : List[Dict[str, Any]] - A list of errors encountered by the parser. """ super().__init__() error = errors[0] @@ -50,7 +49,7 @@ def parse( Returns ------- - SourceUnit - The root of an AST of the Solidity source string. + SourceUnit - The root of an AST of the Solidity source string. """ input_stream = ANTLRInputStream(input_string) @@ -69,7 +68,7 @@ def parse( ast_builder = SGPVisitor(options) try: source_unit: SourceUnit = ast_builder.visit(source_unit) - except Exception as e: + except Exception: raise Exception("AST was not generated") else: if source_unit is None: diff --git a/sgp/sgp_visitor.py b/sgp/sgp_visitor.py index b805cde..a271a94 100644 --- a/sgp/sgp_visitor.py +++ b/sgp/sgp_visitor.py @@ -1392,7 +1392,9 @@ def visitAssemblyCall(self, ctx: SP.AssemblyCallContext) -> AssemblyCall: return self._add_meta(node, ctx) - def visitAssemblyLiteral(self, ctx: SP.AssemblyLiteralContext) -> Union[ + def visitAssemblyLiteral( + self, ctx: SP.AssemblyLiteralContext + ) -> Union[ StringLiteral, BooleanLiteral, DecimalNumber, diff --git a/sgp/tokens.py b/sgp/tokens.py index 059eeb9..11eb58e 100644 --- a/sgp/tokens.py +++ b/sgp/tokens.py @@ -1,8 +1,10 @@ from typing import List, Dict, Any + def rsplit(input_string: str, value: str) -> List[str]: index = input_string.rfind(value) - return [input_string[:index], input_string[index + 1:]] + return [input_string[:index], input_string[index + 1 :]] + def normalize_token_type(value: str) -> str: if value.endswith("'"): @@ -11,68 +13,80 @@ def normalize_token_type(value: str) -> str: value = value[1:] return value + def get_token_type(value: str) -> str: TYPE_TOKENS = [ - 'var', - 'bool', - 'address', - 'string', - 'Int', - 'Uint', - 'Byte', - 'Fixed', - 'UFixed', + "var", + "bool", + "address", + "string", + "Int", + "Uint", + "Byte", + "Fixed", + "UFixed", ] - if value in ['Identifier', 'from']: - return 'Identifier' - elif value in ['TrueLiteral', 'FalseLiteral']: - return 'Boolean' - elif value == 'VersionLiteral': - return 'Version' - elif value == 'StringLiteral': - return 'String' + if value in ["Identifier", "from"]: + return "Identifier" + elif value in ["TrueLiteral", "FalseLiteral"]: + return "Boolean" + elif value == "VersionLiteral": + return "Version" + elif value == "StringLiteral": + return "String" elif value in TYPE_TOKENS: - return 'Type' - elif value == 'NumberUnit': - return 'Subdenomination' - elif value == 'DecimalNumber': - return 'Numeric' - elif value == 'HexLiteral': - return 'Hex' - elif value == 'ReservedKeyword': - return 'Reserved' + return "Type" + elif value == "NumberUnit": + return "Subdenomination" + elif value == "DecimalNumber": + return "Numeric" + elif value == "HexLiteral": + return "Hex" + elif value == "ReservedKeyword": + return "Reserved" elif not value.isalnum(): - return 'Punctuator' + return "Punctuator" else: - return 'Keyword' + return "Keyword" + def get_token_type_map(tokens: str) -> Dict[int, str]: - lines = tokens.split('\n') + lines = tokens.split("\n") token_map = {} for line in lines: - value, key = rsplit(line, '=') + value, key = rsplit(line, "=") token_map[int(key)] = normalize_token_type(value) return token_map -#TODO: sort it out -def build_token_list(tokens_arg: List[Dict[str, Any]], options: Dict[str, Any]) -> List[Dict[str, Any]]: + +# TODO: sort it out +def build_token_list( + tokens_arg: List[Dict[str, Any]], options: Dict[str, Any] +) -> List[Dict[str, Any]]: token_types = get_token_type_map(tokens_arg) result = [] for token in tokens_arg: - type_str = get_token_type(token_types[token['type']]) - node = {'type': type_str, 'value': token['text']} + type_str = get_token_type(token_types[token["type"]]) + node = {"type": type_str, "value": token["text"]} - if options.get('range', False): - node['range'] = [token['startIndex'], token['stopIndex'] + 1] + if options.get("range", False): + node["range"] = [token["startIndex"], token["stopIndex"] + 1] - if options.get('loc', False): - node['loc'] = { - 'start': {'line': token['line'], 'column': token['charPositionInLine']}, - 'end': {'line': token['line'], 'column': token['charPositionInLine'] + len(token['text']) if token['text'] else 0} + if options.get("loc", False): + node["loc"] = { + "start": {"line": token["line"], "column": token["charPositionInLine"]}, + "end": { + "line": token["line"], + "column": ( + token["charPositionInLine"] + len(token["text"]) + if token["text"] + else 0 + ), + }, } result.append(node) diff --git a/sgp/utils.py b/sgp/utils.py index 5da9129..c4cb94c 100644 --- a/sgp/utils.py +++ b/sgp/utils.py @@ -1,6 +1,3 @@ -from typing import Any - - def string_from_snake_to_camel_case(input: str) -> str: """ Convert a string from snake_case to camelCase. diff --git a/test/test_misc/test_misc.py b/test/test_misc/test_misc.py index 819fd38..7711eca 100644 --- a/test/test_misc/test_misc.py +++ b/test/test_misc/test_misc.py @@ -5,27 +5,7 @@ class TestMisc(unittest.TestCase): def test_misc(self) -> None: - input = """contract Example4 { - /// @custom:storage-location erc7201:example.main - struct MainStorage { - uint256 x; - uint256 y; - } - - // keccak256(abi.encode(uint256(keccak256("example.main")) - 1)) & ~bytes32(uint256(0xff)); - bytes32 private constant MAIN_STORAGE_LOCATION = - 0x183a6125c38840424c4a85fa12bab2ab606c4b6d0e7cc73c0c06ba5300eab500; - - uint256 constant x = 1; - - function _getMainStorage() private pure returns (MainStorage storage $) { - assembly { - $.slot := MAIN_STORAGE_LOCATION - } - } -}""" + input = """function add_your_solidity_code_here {}""" ast = parse(input) self.assertIsNotNone(ast) - self.assertTrue(ast.children[0].children[1].variables[0].is_declared_const) - self.assertTrue(ast.children[0].children[2].variables[0].is_declared_const) diff --git a/test/test_parsing/test_parsing.py b/test/test_parsing/test_parsing.py index 3da1618..d35264a 100644 --- a/test/test_parsing/test_parsing.py +++ b/test/test_parsing/test_parsing.py @@ -6,8 +6,8 @@ from sgp.sgp_parser import parse from sgp.utils import string_from_snake_to_camel_case -class TestParsing(unittest.TestCase): +class TestParsing(unittest.TestCase): def test_parsing(self): current_directory = pathlib.Path(__file__).parent.resolve() @@ -23,12 +23,10 @@ def test_parsing(self): res = parse(test_content, dump_json=True) ast_actual = simplejson.dumps( - res, - default=lambda obj: { - string_from_snake_to_camel_case(k): v - for k, v in obj.__dict__.items() - }, - ) - + res, + default=lambda obj: { + string_from_snake_to_camel_case(k): v for k, v in obj.__dict__.items() + }, + ) + self.assertEqual(ast_expected, ast_actual) - \ No newline at end of file