Skip to content

Commit

Permalink
extracted branch trimming into own transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-sanche committed Sep 13, 2024
1 parent 61490e5 commit a45a6f6
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 8 deletions.
40 changes: 32 additions & 8 deletions .cross_sync/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,36 @@ def visit_AsyncFor(self, node):
return self.generic_visit(node)


class StripAsyncConditionalBranches(ast.NodeTransformer):
"""
Visits all if statements in an AST, and removes branches marked with CrossSync.is_async
"""

def visit_If(self, node):
"""
remove CrossSync.is_async branches from top-level if statements
"""
kept_branch = None
# check for CrossSync.is_async
if self._is_async_check(node.test):
kept_branch = node.orelse
# check for not CrossSync.is_async
elif isinstance(node.test, ast.UnaryOp) and isinstance(node.test.op, ast.Not) and self._is_async_check(node.test.operand):
kept_branch = node.body
if kept_branch is not None:
# only keep the statements in the kept branch
return [self.visit(n) for n in kept_branch]
else:
# keep the entire if statement
return self.visit(node)

def _is_async_check(self, node) -> bool:
"""
Check for CrossSync.is_async nodes
"""
return isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name) and node.value.id == "CrossSync" and node.attr == "is_async"


class CrossSyncFileProcessor(ast.NodeTransformer):
"""
Visits a file, looking for __CROSS_SYNC_OUTPUT__ annotations
Expand Down Expand Up @@ -228,6 +258,8 @@ def visit_Module(self, node):
converted = self.generic_visit(node)
# strip out CrossSync.rm_aio calls
converted = RmAioFunctions().visit(converted)
# strip out CrossSync.is_async branches
converted = StripAsyncConditionalBranches().visit(converted)
# replace CrossSync statements
converted = SymbolReplacer({"CrossSync": "CrossSync._Sync_Impl"}).visit(converted)
return converted
Expand All @@ -251,14 +283,6 @@ def visit_ClassDef(self, node):
continue
return self.generic_visit(node) if node else None

def visit_If(self, node):
"""
remove CrossSync.is_async branches from top-level if statements
"""
if isinstance(node.test, ast.Attribute) and isinstance(node.test.value, ast.Name) and node.test.value.id == "CrossSync" and node.test.attr == "is_async":
return [self.generic_visit(n) for n in node.orelse]
return self.generic_visit(node)

def visit_Assign(self, node):
"""
strip out __CROSS_SYNC_OUTPUT__ assignments
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
tests:
- description: "top level conditional"
before: |
if CrossSync.is_async:
print("async")
else:
print("sync")
transformers: [StripAsyncConditionalBranches]
after: |
print("sync")
- description: "nested conditional"
before: |
if CrossSync.is_async:
print("async")
else:
print("hello")
if CrossSync.is_async:
print("async")
else:
print("world")
transformers: [StripAsyncConditionalBranches]
after: |
print("hello")
print("world")
- description: "conditional within class"
before: |
class MyClass:
def my_method(self):
if CrossSync.is_async:
return "async result"
else:
return "sync result"
transformers: [StripAsyncConditionalBranches]
after: |
class MyClass:
def my_method(self):
return "sync result"
- description: "multiple branches"
before: |
if CrossSync.is_async:
print("async branch 1")
elif some_condition:
print("other condition")
elif CrossSync.is_async:
print("async branch 2")
else:
print("sync branch")
transformers: [StripAsyncConditionalBranches]
after: |
if some_condition:
print("other condition")
else:
print("sync branch")
- description: "negated conditionals"
before: |
if not CrossSync.is_async:
print("sync code")
else:
print("async code")
transformers: [StripAsyncConditionalBranches]
after: |
print("sync code")
- description: "is check"
before: |
if CrossSync.is_async is True:
print("async code")
else:
print("sync code")
transformers: [StripAsyncConditionalBranches]
after: |
print("sync code")
1 change: 1 addition & 0 deletions tests/system/cross_sync/test_cross_sync_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
SymbolReplacer,
AsyncToSync,
RmAioFunctions,
StripAsyncConditionalBranches,
CrossSyncFileProcessor,
)

Expand Down

0 comments on commit a45a6f6

Please sign in to comment.