Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Onboard ruff #10

Merged
merged 2 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Changelog

- [0.0.4](#004)

## 0.0.4

### [Fix incorrect state mutability identification](https://github.com/OpenZeppelin/sgp/pull/9)

This PR fixed incorrect state mutability identification for the various structures, such as:

**`const` state mutability for the state variable declaration**

```solidity
contract Example {
uint256 constant x = 1;
}
```

**`payable` state mutability for the function return parameter variable declaration**

```solidity
function test() public returns(address payable) {}
```

Now, these cases are handled correctly after the bug in the version `0.0.3`.

### [CHANGELOG.md](/CHANGELOG.md) added

### [Onboard `ruff](https://github.com/OpenZeppelin/sgp/pull/10)

- Use `ruff` for the `py` files check and formatting.
- Added a `ruff_helper.sh` script
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "openzeppelin-solidity-grammar-parser"
version = "0.0.3"
version = "0.0.4"
authors = [{ name = "Georgii Plotnikov", email = "[email protected]" }]
description = "Solidity ANTLR4 grammar Python parser"
readme = "README.md"
Expand Down Expand Up @@ -49,3 +49,6 @@ exclude = '''
)/
)
'''

[tool.ruff]
exclude = ["./sgp/parser/"]
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ coverage==7.3.1
simplejson==3.19.1
typing==3.7.4.3
typing_extensions==4.8.0
ruff==0.2.1
7 changes: 7 additions & 0 deletions ruff_helper.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#!/bin/bash

source ./bin/activate
pip install ruff

ruff check .
ruff format .
2 changes: 1 addition & 1 deletion sgp/ast_node_types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List, Optional, Tuple, Union, Dict
from typing import List, Optional, Tuple, Union


class Position:
Expand Down
4 changes: 3 additions & 1 deletion sgp/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sgp_parser


def main():
input = """
// SPDX-License-Identifier: MIT
Expand All @@ -19,5 +20,6 @@ def main():
except Exception as e:
print(e)

if __name__ == '__main__':

if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions sgp/sgp_error_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing_extensions import override
from antlr4.error.ErrorListener import ErrorListener


class SGPError:
def __init__(self, message, line, column):
self.message = message
Expand All @@ -11,6 +12,7 @@ def __init__(self, message, line, column):
def __str__(self):
return f"{self.message} ({self.line}:{self.column})"


class SGPErrorListener(ErrorListener):
def __init__(self):
super().__init__()
Expand Down
9 changes: 4 additions & 5 deletions sgp/sgp_parser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
import simplejson
from typing import Dict

from antlr4.CommonTokenStream import CommonTokenStream
from antlr4.InputStream import InputStream as ANTLRInputStream
Expand All @@ -17,14 +16,14 @@

class ParserError(Exception):
"""
An exception raised when the parser encounters an error.
An exception raised when the parser encounters an error.
"""

def __init__(self, errors) -> None:
"""
Parameters
----------
errors : List[Dict[str, Any]] - A list of errors encountered by the parser.
errors : List[Dict[str, Any]] - A list of errors encountered by the parser.
"""
super().__init__()
error = errors[0]
Expand All @@ -50,7 +49,7 @@ def parse(

Returns
-------
SourceUnit - The root of an AST of the Solidity source string.
SourceUnit - The root of an AST of the Solidity source string.
"""

input_stream = ANTLRInputStream(input_string)
Expand All @@ -69,7 +68,7 @@ def parse(
ast_builder = SGPVisitor(options)
try:
source_unit: SourceUnit = ast_builder.visit(source_unit)
except Exception as e:
except Exception:
raise Exception("AST was not generated")
else:
if source_unit is None:
Expand Down
4 changes: 3 additions & 1 deletion sgp/sgp_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1392,7 +1392,9 @@ def visitAssemblyCall(self, ctx: SP.AssemblyCallContext) -> AssemblyCall:

return self._add_meta(node, ctx)

def visitAssemblyLiteral(self, ctx: SP.AssemblyLiteralContext) -> Union[
def visitAssemblyLiteral(
self, ctx: SP.AssemblyLiteralContext
) -> Union[
StringLiteral,
BooleanLiteral,
DecimalNumber,
Expand Down
96 changes: 55 additions & 41 deletions sgp/tokens.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import List, Dict, Any


def rsplit(input_string: str, value: str) -> List[str]:
index = input_string.rfind(value)
return [input_string[:index], input_string[index + 1:]]
return [input_string[:index], input_string[index + 1 :]]


def normalize_token_type(value: str) -> str:
if value.endswith("'"):
Expand All @@ -11,68 +13,80 @@ def normalize_token_type(value: str) -> str:
value = value[1:]
return value


def get_token_type(value: str) -> str:
TYPE_TOKENS = [
'var',
'bool',
'address',
'string',
'Int',
'Uint',
'Byte',
'Fixed',
'UFixed',
"var",
"bool",
"address",
"string",
"Int",
"Uint",
"Byte",
"Fixed",
"UFixed",
]

if value in ['Identifier', 'from']:
return 'Identifier'
elif value in ['TrueLiteral', 'FalseLiteral']:
return 'Boolean'
elif value == 'VersionLiteral':
return 'Version'
elif value == 'StringLiteral':
return 'String'
if value in ["Identifier", "from"]:
return "Identifier"
elif value in ["TrueLiteral", "FalseLiteral"]:
return "Boolean"
elif value == "VersionLiteral":
return "Version"
elif value == "StringLiteral":
return "String"
elif value in TYPE_TOKENS:
return 'Type'
elif value == 'NumberUnit':
return 'Subdenomination'
elif value == 'DecimalNumber':
return 'Numeric'
elif value == 'HexLiteral':
return 'Hex'
elif value == 'ReservedKeyword':
return 'Reserved'
return "Type"
elif value == "NumberUnit":
return "Subdenomination"
elif value == "DecimalNumber":
return "Numeric"
elif value == "HexLiteral":
return "Hex"
elif value == "ReservedKeyword":
return "Reserved"
elif not value.isalnum():
return 'Punctuator'
return "Punctuator"
else:
return 'Keyword'
return "Keyword"


def get_token_type_map(tokens: str) -> Dict[int, str]:
lines = tokens.split('\n')
lines = tokens.split("\n")
token_map = {}

for line in lines:
value, key = rsplit(line, '=')
value, key = rsplit(line, "=")
token_map[int(key)] = normalize_token_type(value)

return token_map

#TODO: sort it out
def build_token_list(tokens_arg: List[Dict[str, Any]], options: Dict[str, Any]) -> List[Dict[str, Any]]:

# TODO: sort it out
def build_token_list(
tokens_arg: List[Dict[str, Any]], options: Dict[str, Any]
) -> List[Dict[str, Any]]:
token_types = get_token_type_map(tokens_arg)
result = []

for token in tokens_arg:
type_str = get_token_type(token_types[token['type']])
node = {'type': type_str, 'value': token['text']}
type_str = get_token_type(token_types[token["type"]])
node = {"type": type_str, "value": token["text"]}

if options.get('range', False):
node['range'] = [token['startIndex'], token['stopIndex'] + 1]
if options.get("range", False):
node["range"] = [token["startIndex"], token["stopIndex"] + 1]

if options.get('loc', False):
node['loc'] = {
'start': {'line': token['line'], 'column': token['charPositionInLine']},
'end': {'line': token['line'], 'column': token['charPositionInLine'] + len(token['text']) if token['text'] else 0}
if options.get("loc", False):
node["loc"] = {
"start": {"line": token["line"], "column": token["charPositionInLine"]},
"end": {
"line": token["line"],
"column": (
token["charPositionInLine"] + len(token["text"])
if token["text"]
else 0
),
},
}

result.append(node)
Expand Down
3 changes: 0 additions & 3 deletions sgp/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
from typing import Any


def string_from_snake_to_camel_case(input: str) -> str:
"""
Convert a string from snake_case to camelCase.
Expand Down
22 changes: 1 addition & 21 deletions test/test_misc/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,7 @@

class TestMisc(unittest.TestCase):
def test_misc(self) -> None:
input = """contract Example4 {
/// @custom:storage-location erc7201:example.main
struct MainStorage {
uint256 x;
uint256 y;
}

// keccak256(abi.encode(uint256(keccak256("example.main")) - 1)) & ~bytes32(uint256(0xff));
bytes32 private constant MAIN_STORAGE_LOCATION =
0x183a6125c38840424c4a85fa12bab2ab606c4b6d0e7cc73c0c06ba5300eab500;

uint256 constant x = 1;

function _getMainStorage() private pure returns (MainStorage storage $) {
assembly {
$.slot := MAIN_STORAGE_LOCATION
}
}
}"""
input = """function add_your_solidity_code_here {}"""

ast = parse(input)
self.assertIsNotNone(ast)
self.assertTrue(ast.children[0].children[1].variables[0].is_declared_const)
self.assertTrue(ast.children[0].children[2].variables[0].is_declared_const)
16 changes: 7 additions & 9 deletions test/test_parsing/test_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from sgp.sgp_parser import parse
from sgp.utils import string_from_snake_to_camel_case

class TestParsing(unittest.TestCase):

class TestParsing(unittest.TestCase):
def test_parsing(self):
current_directory = pathlib.Path(__file__).parent.resolve()

Expand All @@ -23,12 +23,10 @@ def test_parsing(self):

res = parse(test_content, dump_json=True)
ast_actual = simplejson.dumps(
res,
default=lambda obj: {
string_from_snake_to_camel_case(k): v
for k, v in obj.__dict__.items()
},
)

res,
default=lambda obj: {
string_from_snake_to_camel_case(k): v for k, v in obj.__dict__.items()
},
)

self.assertEqual(ast_expected, ast_actual)

Loading