Skip to content

Commit

Permalink
Merge pull request #10 from OpenZeppelin/onboard-ruff
Browse files Browse the repository at this point in the history
Onboard `ruff`
  • Loading branch information
0xGeorgii authored Feb 8, 2024
2 parents 31a439a + 5f3d751 commit eb88b2c
Show file tree
Hide file tree
Showing 13 changed files with 120 additions and 83 deletions.
32 changes: 32 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Changelog

- [0.0.4](#004)

## 0.0.4

### [Fix incorrect state mutability identification](https://github.com/OpenZeppelin/sgp/pull/9)

This PR fixed incorrect state mutability identification for the various structures, such as:

**`const` state mutability for the state variable declaration**

```solidity
contract Example {
uint256 constant x = 1;
}
```

**`payable` state mutability for the function return parameter variable declaration**

```solidity
function test() public returns(address payable) {}
```

Now, these cases are handled correctly after the bug in the version `0.0.3`.

### [CHANGELOG.md](/CHANGELOG.md) added

### [Onboard `ruff](https://github.com/OpenZeppelin/sgp/pull/10)

- Use `ruff` for the `py` files check and formatting.
- Added a `ruff_helper.sh` script
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "openzeppelin-solidity-grammar-parser"
version = "0.0.3"
version = "0.0.4"
authors = [{ name = "Georgii Plotnikov", email = "[email protected]" }]
description = "Solidity ANTLR4 grammar Python parser"
readme = "README.md"
Expand Down Expand Up @@ -49,3 +49,6 @@ exclude = '''
)/
)
'''

[tool.ruff]
exclude = ["./sgp/parser/"]
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ coverage==7.3.1
simplejson==3.19.1
typing==3.7.4.3
typing_extensions==4.8.0
ruff==0.2.1
7 changes: 7 additions & 0 deletions ruff_helper.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#!/bin/bash

source ./bin/activate
pip install ruff

ruff check .
ruff format .
2 changes: 1 addition & 1 deletion sgp/ast_node_types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List, Optional, Tuple, Union, Dict
from typing import List, Optional, Tuple, Union


class Position:
Expand Down
4 changes: 3 additions & 1 deletion sgp/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sgp_parser


def main():
input = """
// SPDX-License-Identifier: MIT
Expand All @@ -19,5 +20,6 @@ def main():
except Exception as e:
print(e)

if __name__ == '__main__':

if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions sgp/sgp_error_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing_extensions import override
from antlr4.error.ErrorListener import ErrorListener


class SGPError:
def __init__(self, message, line, column):
self.message = message
Expand All @@ -11,6 +12,7 @@ def __init__(self, message, line, column):
def __str__(self):
return f"{self.message} ({self.line}:{self.column})"


class SGPErrorListener(ErrorListener):
def __init__(self):
super().__init__()
Expand Down
9 changes: 4 additions & 5 deletions sgp/sgp_parser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
import simplejson
from typing import Dict

from antlr4.CommonTokenStream import CommonTokenStream
from antlr4.InputStream import InputStream as ANTLRInputStream
Expand All @@ -17,14 +16,14 @@

class ParserError(Exception):
"""
An exception raised when the parser encounters an error.
An exception raised when the parser encounters an error.
"""

def __init__(self, errors) -> None:
"""
Parameters
----------
errors : List[Dict[str, Any]] - A list of errors encountered by the parser.
errors : List[Dict[str, Any]] - A list of errors encountered by the parser.
"""
super().__init__()
error = errors[0]
Expand All @@ -50,7 +49,7 @@ def parse(
Returns
-------
SourceUnit - The root of an AST of the Solidity source string.
SourceUnit - The root of an AST of the Solidity source string.
"""

input_stream = ANTLRInputStream(input_string)
Expand All @@ -69,7 +68,7 @@ def parse(
ast_builder = SGPVisitor(options)
try:
source_unit: SourceUnit = ast_builder.visit(source_unit)
except Exception as e:
except Exception:
raise Exception("AST was not generated")
else:
if source_unit is None:
Expand Down
4 changes: 3 additions & 1 deletion sgp/sgp_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1392,7 +1392,9 @@ def visitAssemblyCall(self, ctx: SP.AssemblyCallContext) -> AssemblyCall:

return self._add_meta(node, ctx)

def visitAssemblyLiteral(self, ctx: SP.AssemblyLiteralContext) -> Union[
def visitAssemblyLiteral(
self, ctx: SP.AssemblyLiteralContext
) -> Union[
StringLiteral,
BooleanLiteral,
DecimalNumber,
Expand Down
96 changes: 55 additions & 41 deletions sgp/tokens.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import List, Dict, Any


def rsplit(input_string: str, value: str) -> List[str]:
index = input_string.rfind(value)
return [input_string[:index], input_string[index + 1:]]
return [input_string[:index], input_string[index + 1 :]]


def normalize_token_type(value: str) -> str:
if value.endswith("'"):
Expand All @@ -11,68 +13,80 @@ def normalize_token_type(value: str) -> str:
value = value[1:]
return value


def get_token_type(value: str) -> str:
TYPE_TOKENS = [
'var',
'bool',
'address',
'string',
'Int',
'Uint',
'Byte',
'Fixed',
'UFixed',
"var",
"bool",
"address",
"string",
"Int",
"Uint",
"Byte",
"Fixed",
"UFixed",
]

if value in ['Identifier', 'from']:
return 'Identifier'
elif value in ['TrueLiteral', 'FalseLiteral']:
return 'Boolean'
elif value == 'VersionLiteral':
return 'Version'
elif value == 'StringLiteral':
return 'String'
if value in ["Identifier", "from"]:
return "Identifier"
elif value in ["TrueLiteral", "FalseLiteral"]:
return "Boolean"
elif value == "VersionLiteral":
return "Version"
elif value == "StringLiteral":
return "String"
elif value in TYPE_TOKENS:
return 'Type'
elif value == 'NumberUnit':
return 'Subdenomination'
elif value == 'DecimalNumber':
return 'Numeric'
elif value == 'HexLiteral':
return 'Hex'
elif value == 'ReservedKeyword':
return 'Reserved'
return "Type"
elif value == "NumberUnit":
return "Subdenomination"
elif value == "DecimalNumber":
return "Numeric"
elif value == "HexLiteral":
return "Hex"
elif value == "ReservedKeyword":
return "Reserved"
elif not value.isalnum():
return 'Punctuator'
return "Punctuator"
else:
return 'Keyword'
return "Keyword"


def get_token_type_map(tokens: str) -> Dict[int, str]:
lines = tokens.split('\n')
lines = tokens.split("\n")
token_map = {}

for line in lines:
value, key = rsplit(line, '=')
value, key = rsplit(line, "=")
token_map[int(key)] = normalize_token_type(value)

return token_map

#TODO: sort it out
def build_token_list(tokens_arg: List[Dict[str, Any]], options: Dict[str, Any]) -> List[Dict[str, Any]]:

# TODO: sort it out
def build_token_list(
tokens_arg: List[Dict[str, Any]], options: Dict[str, Any]
) -> List[Dict[str, Any]]:
token_types = get_token_type_map(tokens_arg)
result = []

for token in tokens_arg:
type_str = get_token_type(token_types[token['type']])
node = {'type': type_str, 'value': token['text']}
type_str = get_token_type(token_types[token["type"]])
node = {"type": type_str, "value": token["text"]}

if options.get('range', False):
node['range'] = [token['startIndex'], token['stopIndex'] + 1]
if options.get("range", False):
node["range"] = [token["startIndex"], token["stopIndex"] + 1]

if options.get('loc', False):
node['loc'] = {
'start': {'line': token['line'], 'column': token['charPositionInLine']},
'end': {'line': token['line'], 'column': token['charPositionInLine'] + len(token['text']) if token['text'] else 0}
if options.get("loc", False):
node["loc"] = {
"start": {"line": token["line"], "column": token["charPositionInLine"]},
"end": {
"line": token["line"],
"column": (
token["charPositionInLine"] + len(token["text"])
if token["text"]
else 0
),
},
}

result.append(node)
Expand Down
3 changes: 0 additions & 3 deletions sgp/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
from typing import Any


def string_from_snake_to_camel_case(input: str) -> str:
"""
Convert a string from snake_case to camelCase.
Expand Down
22 changes: 1 addition & 21 deletions test/test_misc/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,7 @@

class TestMisc(unittest.TestCase):
def test_misc(self) -> None:
input = """contract Example4 {
/// @custom:storage-location erc7201:example.main
struct MainStorage {
uint256 x;
uint256 y;
}
// keccak256(abi.encode(uint256(keccak256("example.main")) - 1)) & ~bytes32(uint256(0xff));
bytes32 private constant MAIN_STORAGE_LOCATION =
0x183a6125c38840424c4a85fa12bab2ab606c4b6d0e7cc73c0c06ba5300eab500;
uint256 constant x = 1;
function _getMainStorage() private pure returns (MainStorage storage $) {
assembly {
$.slot := MAIN_STORAGE_LOCATION
}
}
}"""
input = """function add_your_solidity_code_here {}"""

ast = parse(input)
self.assertIsNotNone(ast)
self.assertTrue(ast.children[0].children[1].variables[0].is_declared_const)
self.assertTrue(ast.children[0].children[2].variables[0].is_declared_const)
16 changes: 7 additions & 9 deletions test/test_parsing/test_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from sgp.sgp_parser import parse
from sgp.utils import string_from_snake_to_camel_case

class TestParsing(unittest.TestCase):

class TestParsing(unittest.TestCase):
def test_parsing(self):
current_directory = pathlib.Path(__file__).parent.resolve()

Expand All @@ -23,12 +23,10 @@ def test_parsing(self):

res = parse(test_content, dump_json=True)
ast_actual = simplejson.dumps(
res,
default=lambda obj: {
string_from_snake_to_camel_case(k): v
for k, v in obj.__dict__.items()
},
)

res,
default=lambda obj: {
string_from_snake_to_camel_case(k): v for k, v in obj.__dict__.items()
},
)

self.assertEqual(ast_expected, ast_actual)

0 comments on commit eb88b2c

Please sign in to comment.