diff --git a/nodebook/nodebookcore.py b/nodebook/nodebookcore.py index 98b5cdc..ddae374 100644 --- a/nodebook/nodebookcore.py +++ b/nodebook/nodebookcore.py @@ -46,6 +46,9 @@ def visit_FunctionDef(self, node): self.locals.add(node.name) self.generic_visit(node) + def visit_arg(self, node): + self.locals.add(node.arg) + def visit_AugAssign(self, node): target = node.target while (type(target) is ast.Subscript): diff --git a/setup.py b/setup.py index 11d6a77..adc7229 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name='nodebook', - version='0.2.0', + version='0.2.1', author='Kevin Zielnicki', author_email='kzielnicki@stitchfix.com', license='Stitch Fix 2017', diff --git a/tests/test_nodebookcore.py b/tests/test_nodebookcore.py index cfd9646..470121c 100644 --- a/tests/test_nodebookcore.py +++ b/tests/test_nodebookcore.py @@ -42,6 +42,16 @@ def test_multiline(self, rf): assert rf.locals == {'pd', 'y'} assert rf.imports == {'pandas'} + def test_function(self, rf): + code_tree = ast.parse( + "def add(x,y):\n" + " return x+y\n" + ) + rf.visit(code_tree) + assert rf.inputs == set() + assert rf.locals == {'add', 'x', 'y'} + assert rf.imports == set() + class TestNodebook(object): @pytest.fixture() diff --git a/tests/test_pickledict.py b/tests/test_pickledict.py index ef35fdd..90b7088 100644 --- a/tests/test_pickledict.py +++ b/tests/test_pickledict.py @@ -33,6 +33,12 @@ def test_df(self, mydict): df = pd.DataFrame({'a': [0, 1, 2], 'b': ['foo', 'bar', 'baz']}) mydict['test_df'] = df assert mydict['test_df'].equals(df) + + def test_func(self, mydict): + def add(a, b): + return a + b + mydict['test_func'] = add + assert mydict['test_func'](3,5) == 8 def test_immutability(self, mydict): l = [1, 2, 3]