Skip to content
This repository has been archived by the owner on Nov 23, 2024. It is now read-only.

Commit

Permalink
fix: fixed a bug (references with the same name were lost due to the …
Browse files Browse the repository at this point in the history
…dict) in resolve_reference, adapted infer_purity to use this correctly
  • Loading branch information
lukarade committed Dec 20, 2023
1 parent 8e7f814 commit 0966d0d
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def check_open_like_functions(func_ref: FunctionReference) -> PurityResult:
pass # TODO: [Later] for now it is good enough to deal with open() only, but we MAYBE need to deal with the other open-like functions too

Check warning on line 274 in src/library_analyzer/processing/api/purity_analysis/_infer_purity.py

View check run for this annotation

Codecov / codecov/patch

src/library_analyzer/processing/api/purity_analysis/_infer_purity.py#L274

Added line #L274 was not covered by tests


def infer_purity(references: dict[str, ReferenceNode], function_references: dict[str, Reasons], classes: dict[str, ClassScope],
def infer_purity(references: dict[str, list[ReferenceNode]], function_references: dict[str, Reasons], classes: dict[str, ClassScope],
call_graph: CallGraphForest) -> dict[astroid.FunctionDef, PurityResult]:
"""
Infer the purity of functions.
Expand All @@ -284,7 +284,7 @@ def infer_purity(references: dict[str, ReferenceNode], function_references: dict
Parameters
----------
* references: a list of all references in the module
* references: a dict of all references in the module
* function_references: a dict of function references
* classes: a dict of all classes in the module
* call_graph: the call graph of the module
Expand All @@ -303,7 +303,7 @@ def infer_purity(references: dict[str, ReferenceNode], function_references: dict
return {key: value for key, value in purity_results.items() if not isinstance(key, str)}


def process_node(reason: Reasons, references: dict[str, ReferenceNode], function_references: dict[str, Reasons],
def process_node(reason: Reasons, references: dict[str, list[ReferenceNode]], function_references: dict[str, Reasons],
classes: dict[str, ClassScope], call_graph: CallGraphForest,
purity_results: dict[astroid.FunctionDef, PurityResult]) -> PurityResult:
"""
Expand Down Expand Up @@ -435,7 +435,7 @@ def process_node(reason: Reasons, references: dict[str, ReferenceNode], function
raise KeyError(f"Function {reason.function.name} not found in function_references") from None

Check warning on line 435 in src/library_analyzer/processing/api/purity_analysis/_infer_purity.py

View check run for this annotation

Codecov / codecov/patch

src/library_analyzer/processing/api/purity_analysis/_infer_purity.py#L434-L435

Added lines #L434 - L435 were not covered by tests


def get_purity_of_child(child: CallGraphNode, reason: Reasons, references: dict[str, ReferenceNode], function_references: dict[str, Reasons],
def get_purity_of_child(child: CallGraphNode, reason: Reasons, references: dict[str, list[ReferenceNode]], function_references: dict[str, Reasons],
classes: dict[str, ClassScope], call_graph: CallGraphForest,
purity_results: dict[astroid.FunctionDef, PurityResult]) -> None:
"""
Expand Down Expand Up @@ -472,7 +472,7 @@ def get_purity_of_child(child: CallGraphNode, reason: Reasons, references: dict[


# TODO: this is not working correctly: whenever a variable is referenced, it is marked as read/written if its is not inside the current function
def transform_reasons_to_impurity_result(reasons: Reasons, references: dict[str, ReferenceNode], classes: dict[str, ClassScope]) -> PurityResult:
def transform_reasons_to_impurity_result(reasons: Reasons, references: dict[str, list[ReferenceNode]], classes: dict[str, ClassScope]) -> PurityResult:
"""
Transform the reasons for impurity to an impurity result.
Expand All @@ -497,21 +497,23 @@ def transform_reasons_to_impurity_result(reasons: Reasons, references: dict[str,
else:
if reasons.writes:
for write in reasons.writes:
write_ref = references[write.node.name]
for sym_ref in write_ref.referenced_symbols:
if isinstance(sym_ref, GlobalVariable | ClassVariable | InstanceVariable):
impurity_reasons.add(NonLocalVariableWrite(sym_ref))
else:
raise TypeError(f"Unknown symbol reference type: {sym_ref.__class__.__name__}")
write_ref_list = references[write.node.name]
for write_ref in write_ref_list:
for sym_ref in write_ref.referenced_symbols:
if isinstance(sym_ref, GlobalVariable | ClassVariable | InstanceVariable):
impurity_reasons.add(NonLocalVariableWrite(sym_ref))
else:
raise TypeError(f"Unknown symbol reference type: {sym_ref.__class__.__name__}")

Check warning on line 506 in src/library_analyzer/processing/api/purity_analysis/_infer_purity.py

View check run for this annotation

Codecov / codecov/patch

src/library_analyzer/processing/api/purity_analysis/_infer_purity.py#L506

Added line #L506 was not covered by tests

if reasons.reads:
for read in reasons.reads:
read_ref = references[read.node.name]
for sym_ref in read_ref.referenced_symbols:
if isinstance(sym_ref, GlobalVariable | ClassVariable | InstanceVariable):
impurity_reasons.add(NonLocalVariableRead(sym_ref))
else:
raise TypeError(f"Unknown symbol reference type: {sym_ref.__class__.__name__}")
read_ref_list = references[read.node.name]
for read_ref in read_ref_list:
for sym_ref in read_ref.referenced_symbols:
if isinstance(sym_ref, GlobalVariable | ClassVariable | InstanceVariable):
impurity_reasons.add(NonLocalVariableRead(sym_ref))
else:
raise TypeError(f"Unknown symbol reference type: {sym_ref.__class__.__name__}")

Check warning on line 516 in src/library_analyzer/processing/api/purity_analysis/_infer_purity.py

View check run for this annotation

Codecov / codecov/patch

src/library_analyzer/processing/api/purity_analysis/_infer_purity.py#L516

Added line #L516 was not covered by tests

if reasons.unknown_calls:
for unknown_call in reasons.unknown_calls:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _find_name_references(
classes: dict[str, ClassScope],
functions: dict[str, list[Scope]],
parameters: dict[astroid.FunctionDef, tuple[Scope | ClassScope, set[astroid.AssignName]]],
) -> list[ReferenceNode]:
) -> dict[str, list[ReferenceNode]]:
"""Create a list of references from a list of name nodes.
Parameters
Expand All @@ -41,9 +41,9 @@ def _find_name_references(
Returns
-------
* final_references: a list containing all name references (target & value references)
* final_references: a dict containing all name references (target & value references)
"""
final_references: list[ReferenceNode] = []
final_references: dict[str, list[ReferenceNode]] = {}

# TODO: is it possible to do this in a more efficient way?
# maybe we can speed up the detection of references by using a dictionary instead of a list
Expand All @@ -55,15 +55,21 @@ def _find_name_references(
for value_ref in value_references:
if isinstance(value_ref.node, astroid.Name | MemberAccessValue):
value_ref_complete = _find_references(value_ref, target_references, classes, functions, parameters)
final_references.append(value_ref_complete)
if value_ref_complete.node.name in final_references:
final_references[value_ref_complete.node.name].append(value_ref_complete)
else:
final_references[value_ref_complete.node.name] = [value_ref_complete]

# Detect all target references: references that are used as targets (e.g., target = sth)
for target_ref in target_references:
if isinstance(target_ref.node, astroid.AssignName | astroid.Name | MemberAccessTarget):
target_ref_complete = _find_references_target(target_ref, target_references, classes)
# Remove all references that are never referenced
if target_ref_complete.referenced_symbols:
final_references.append(target_ref_complete)
if target_ref_complete.node.name in final_references:
final_references[target_ref_complete.node.name].append(target_ref_complete)
else:
final_references[target_ref_complete.node.name] = [target_ref_complete]

return final_references

Expand Down Expand Up @@ -274,7 +280,7 @@ def _find_call_reference(
classes: dict[str, ClassScope],
functions: dict[str, list[Scope]],
parameters: dict[astroid.FunctionDef, tuple[Scope | ClassScope, set[astroid.AssignName]]],
) -> list[ReferenceNode]:
) -> dict[str,list[ReferenceNode]]:
"""Find all references for a function call.
Parameters
Expand All @@ -286,9 +292,16 @@ def _find_call_reference(
Returns
-------
* final_call_references: a list of all references for a function call
* final_call_references: a dict of all references for a function call
"""
final_call_references: list[ReferenceNode] = []
def add_reference() -> None:
"""Add a reference to the final_call_references dict."""
if call_references[i].node.func.name in final_call_references:
final_call_references[call_references[i].node.func.name].append(call_references[i])
else:
final_call_references[call_references[i].node.func.name] = [call_references[i]]

final_call_references: dict[str, list[ReferenceNode]] = {}
python_builtins = dir(builtins)

call_references = [ReferenceNode(call, scope, []) for call, scope in function_calls.items()]
Expand All @@ -299,16 +312,14 @@ def _find_call_reference(
function_def = functions.get(reference.node.func.name)
symbols = [func.symbol for func in function_def if function_def] # type: ignore[union-attr] # "None" is not iterable, but we check for it
call_references[i].referenced_symbols.extend(symbols)

final_call_references.append(call_references[i])
add_reference()

# Find classes that are called (initialized)
elif reference.node.func.name in classes:
symbol = classes.get(reference.node.func.name)
if symbol:
call_references[i].referenced_symbols.append(symbol.symbol)

final_call_references.append(call_references[i])
add_reference()

# Find builtins that are called
if reference.node.func.name in python_builtins:
Expand All @@ -318,7 +329,7 @@ def _find_call_reference(
reference.node.func.name,
)
call_references[i].referenced_symbols.append(builtin_call)
final_call_references.append(call_references[i])
add_reference()

# Find function parameters that are called (passed as arguments), like:
# def f(a):
Expand All @@ -331,13 +342,13 @@ def _find_call_reference(
for child in parameters.get(func_def)[0].children: # type: ignore[index] # "None" is not index-able, but we check for it
if child.symbol.node.name == param.name:
call_references[i].referenced_symbols.append(child.symbol)
final_call_references.append(call_references[i])
add_reference()
break

return final_call_references


def resolve_references(code: str) -> tuple[dict[str, ReferenceNode], dict[str, Reasons], dict[str, ClassScope], CallGraphForest]:
def resolve_references(code: str) -> tuple[dict[str, list[ReferenceNode]], dict[str, Reasons], dict[str, ClassScope], CallGraphForest]:
"""
Resolve all references in a module.
Expand All @@ -354,7 +365,7 @@ def resolve_references(code: str) -> tuple[dict[str, ReferenceNode], dict[str, R
* call_graph: a CallGraphForest object that represents the call graph of the module
"""
module_data = get_module_data(code)
resolved_references = _find_name_references(
name_references = _find_name_references(
module_data.target_nodes,
module_data.value_nodes,
module_data.classes,
Expand All @@ -363,19 +374,38 @@ def resolve_references(code: str) -> tuple[dict[str, ReferenceNode], dict[str, R
)

if module_data.function_calls:
references_call = _find_call_reference(
call_references = _find_call_reference(
module_data.function_calls,
module_data.classes,
module_data.functions,
module_data.parameters,
)
resolved_references.extend(references_call)
else:
call_references = {}

resolved_references = {
reference.node.func.name if isinstance(reference.node, astroid.Call) else reference.node.name: reference
for reference in resolved_references # TODO: MemberAccessTarget and MemberAccessValue are not handled here
}
resolved_references = merge_dicts(call_references, name_references)

call_graph = build_call_graph(module_data.functions, module_data.function_references)

return resolved_references, module_data.function_references, module_data.classes, call_graph


def merge_dicts(d1: dict[str, list[ReferenceNode]], d2: dict[str, list[ReferenceNode]]) -> dict[str, list[ReferenceNode]]:
"""Merge two dicts of lists.
Parameters
----------
* d1: the first dict
* d2: the second dict
Returns
-------
* d3: the merged dict
"""
d3 = d1.copy()
for key, value in d2.items():
if key in d3:
d3[key].extend(value)

Check warning on line 408 in src/library_analyzer/processing/api/purity_analysis/_resolve_references.py

View check run for this annotation

Codecov / codecov/patch

src/library_analyzer/processing/api/purity_analysis/_resolve_references.py#L408

Added line #L408 was not covered by tests
else:
d3[key] = value
return d3
49 changes: 29 additions & 20 deletions tests/library_analyzer/processing/api/test_resolve_references.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ def test_resolve_references_parameters(code: str, expected: list[ReferenceTestNo
references = resolve_references(code)[0]
transformed_references: list[ReferenceTestNode] = []

for node in references:
transformed_references.append(transform_reference_node(node))
for node in references.values():
transformed_references.extend(transform_reference_nodes(node))

# assert references == expected
assert set(transformed_references) == set(expected)
Expand Down Expand Up @@ -309,11 +309,11 @@ def test_resolve_references_local_global(code: str, expected: list[ReferenceTest
references = resolve_references(code)[0]
transformed_references: list[ReferenceTestNode] = []

for node in references:
transformed_references.append(transform_reference_node(node))
for node in references.values():
transformed_references.extend(transform_reference_nodes(node))

# assert references == expected
assert set(transformed_references) == set(expected)
assert transformed_references == expected


@pytest.mark.parametrize(
Expand Down Expand Up @@ -775,11 +775,11 @@ def test_resolve_references_member_access(code: str, expected: list[ReferenceTes
references = resolve_references(code)[0]
transformed_references: list[ReferenceTestNode] = []

for node in references:
transformed_references.append(transform_reference_node(node))
for node in references.values():
transformed_references.extend(transform_reference_nodes(node))

# assert references == expected
assert transformed_references == expected
assert set(transformed_references) == set(expected)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -887,8 +887,8 @@ def test_resolve_references_conditional_statements(code: str, expected: list[Ref
references = resolve_references(code)[0]
transformed_references: list[ReferenceTestNode] = []

for node in references:
transformed_references.append(transform_reference_node(node))
for node in references.values():
transformed_references.extend(transform_reference_nodes(node))

# assert references == expected
assert set(transformed_references) == set(expected)
Expand Down Expand Up @@ -970,8 +970,8 @@ def test_resolve_references_loops(code: str, expected: list[ReferenceTestNode])
references = resolve_references(code)[0]
transformed_references: list[ReferenceTestNode] = []

for node in references:
transformed_references.append(transform_reference_node(node))
for node in references.values():
transformed_references.extend(transform_reference_nodes(node))

# assert references == expected
assert set(transformed_references) == set(expected)
Expand Down Expand Up @@ -1225,8 +1225,8 @@ def test_resolve_references_miscellaneous(code: str, expected: list[ReferenceTes
references = resolve_references(code)[0]
transformed_references: list[ReferenceTestNode] = []

for node in references:
transformed_references.append(transform_reference_node(node))
for node in references.values():
transformed_references.extend(transform_reference_nodes(node))

# assert references == expected
assert set(transformed_references) == set(expected)
Expand Down Expand Up @@ -1675,8 +1675,8 @@ def test_resolve_references_calls(code: str, expected: list[ReferenceTestNode])
transformed_references: list[ReferenceTestNode] = []

# assert references == expected
for node in references:
transformed_references.append(transform_reference_node(node))
for node in references.values():
transformed_references.extend(transform_reference_nodes(node))

assert set(transformed_references) == set(expected)

Expand Down Expand Up @@ -1768,8 +1768,8 @@ def test_resolve_references_imports(code: str, expected: list[ReferenceTestNode]
references = resolve_references(code)[0]
transformed_references: list[ReferenceTestNode] = []

for node in references:
transformed_references.append(transform_reference_node(node))
for node in references.values():
transformed_references.extend(transform_reference_nodes(node))

# assert references == expected
assert set(transformed_references) == set(expected)
Expand Down Expand Up @@ -1862,13 +1862,22 @@ def test_resolve_references_dataclasses(code: str, expected: list[ReferenceTestN
references = resolve_references(code)[0]
transformed_references: list[ReferenceTestNode] = []

for node in references:
transformed_references.append(transform_reference_node(node))
for node in references.values():
transformed_references.extend(transform_reference_nodes(node))

# assert references == expected
assert set(transformed_references) == set(expected)


def transform_reference_nodes(nodes: list[ReferenceNode]) -> list[ReferenceTestNode]:
transformed_nodes: list[ReferenceTestNode] = []

for node in nodes:
transformed_nodes.append(transform_reference_node(node))

return transformed_nodes


def transform_reference_node(node: ReferenceNode) -> ReferenceTestNode:
if isinstance(node.node, MemberAccess | MemberAccessValue | MemberAccessTarget):
expression = get_base_expression(node.node)
Expand Down

0 comments on commit 0966d0d

Please sign in to comment.