Skip to content

Commit

Permalink
Merge pull request #72 from nikkirad/nikrad
Browse files Browse the repository at this point in the history
Recognize where clauses
  • Loading branch information
nikkirad authored Jun 20, 2023
2 parents 0c164bd + 851c35b commit 1d6b2ce
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 60 deletions.
24 changes: 19 additions & 5 deletions sphinxcontrib/chapeldomain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,11 @@
([\w$.]*\.)? # class name(s)
([\w\+\-/\*$\<\=\>\!]+) \s* # function or method name
(?:\((.*?)\))? # opt: arguments
(\s+(?:const\s)? \w+| # or return intent
\s* : \s* [^:]+| # or return type
\s+(?:const\s)? \w+\s* : \s* [^:]+ # or return intent and type
(\s+(?:const\s)? (?:\w+?)| # or return intent
\s* : \s* (?:[^:]+?)| # or return type
\s+(?:const\s)? \w+\s* : \s* (?:[^:]+?) # or return intent and type
)?
(\s+where\s+.* # Where clause
)?
$""", re.VERBOSE)

Expand Down Expand Up @@ -221,7 +223,7 @@ def _get_proc_like_prefix(self, sig):
if sig_match is None:
return ChapelObject.get_signature_prefix(self, sig)

prefixes, _, _, _, _ = sig_match.groups()
prefixes, _, _, _, _, _ = sig_match.groups()
if prefixes:
return prefixes.strip() + ' '
elif self.objtype.startswith('iter'):
Expand Down Expand Up @@ -285,14 +287,20 @@ def handle_signature(self, sig, signode):
raise ValueError('Signature does not parse: {0}'.format(sig))
func_prefix, name_prefix, name, retann = sig_match.groups()
arglist = None
where_clause = None
else:
sig_match = chpl_sig_pattern.match(sig)
if sig_match is None:
raise ValueError('Signature does not parse: {0}'.format(sig))

func_prefix, name_prefix, name, arglist, retann = \
func_prefix, name_prefix, name, arglist, retann, where_clause = \
sig_match.groups()

# check if where clause is valid
if where_clause is not None and not self._is_proc_like():
raise ValueError('A where clause has been used on'
' a non-proc-like directive.')

modname = self.options.get(
'module', self.env.temp_data.get('chpl:module'))
classname = self.env.temp_data.get('chpl:class')
Expand Down Expand Up @@ -346,13 +354,19 @@ def handle_signature(self, sig, signode):
signode += addnodes.desc_type(retann, retann)
if anno:
signode += addnodes.desc_annotation(' ' + anno, ' ' + anno)
if where_clause:
signode += addnodes.desc_annotation(' ' + where_clause,
' ' + where_clause)
return fullname, name_prefix

self._pseudo_parse_arglist(signode, arglist)
if retann:
signode += addnodes.desc_type(retann, retann)
if anno:
signode += addnodes.desc_annotation(' ' + anno, ' ' + anno)
if where_clause:
signode += addnodes.desc_annotation(' ' + where_clause,
' ' + where_clause)
return fullname, name_prefix

def get_index_text(self, modname, name):
Expand Down
131 changes: 76 additions & 55 deletions test/test_chapeldomain.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,19 +643,20 @@ class SigPatternTests(PatternTestCase):
longMessage = True
pattern = chpl_sig_pattern

def check_sig(self, sig, func_prefix, name_prefix, name, arglist, retann):
def check_sig(self, sig, func_prefix, name_prefix, name, arglist, retann, where_clause):
"""Verify signature results in appropriate matches."""
fail_msg = 'sig: {0}'.format(sig)

match = self.pattern.match(sig)
self.assertIsNotNone(match, msg=fail_msg)

(actual_func_prefix, actual_name_prefix, actual_name, actual_arglist, actual_retann) = match.groups()
(actual_func_prefix, actual_name_prefix, actual_name, actual_arglist, actual_retann, actual_where_clause) = match.groups()
self.assertEqual(func_prefix, actual_func_prefix, msg=fail_msg)
self.assertEqual(name_prefix, actual_name_prefix, msg=fail_msg)
self.assertEqual(name, actual_name, msg=fail_msg)
self.assertEqual(arglist, actual_arglist, msg=fail_msg)
self.assertEqual(retann, actual_retann, msg=fail_msg)
self.assertEqual(where_clause, actual_where_clause, msg=fail_msg)

def test_does_not_match(self):
"""Verify various signatures that should not match."""
Expand Down Expand Up @@ -685,7 +686,7 @@ def test_no_parens(self):
'**',
]
for sig in test_cases:
self.check_sig(sig, None, None, sig, None, None)
self.check_sig(sig, None, None, sig, None, None, None)

def test_no_args(self):
"""Verify various functions with no args parse correctly."""
Expand All @@ -704,7 +705,7 @@ def test_no_args(self):
('x ()', 'x'),
]
for sig, name in test_cases:
self.check_sig(sig, None, None, name, '', None)
self.check_sig(sig, None, None, name, '', None, None)

def test_with_args(self):
"""Verify function signatures with arguments parse correctly."""
Expand All @@ -720,26 +721,26 @@ def test_with_args(self):
('++++++++++++++++++++ ( +++ )', '++++++++++++++++++++', ' +++ '),
]
for sig, name, arglist in test_cases:
self.check_sig(sig, None, None, name, arglist, None)
self.check_sig(sig, None, None, name, arglist, None, None)

def test_with_return_type(self):
"""Verify function signatures with return types parse correctly."""
test_cases = [
('x(): int', 'x', '', ': int'),
('x(): MyMod.MyClass', 'x', '', ': MyMod.MyClass'),
('x(): int(32)', 'x', '', ': int(32)'),
('x():int(32)', 'x', '', ':int(32)'),
('x(y:int(64)):int(32)', 'x', 'y:int(64)', ':int(32)'),
('x(y:int(64), d: domain(r=2, i=int, s=true)): [{1..5}] real', 'x', 'y:int(64), d: domain(r=2, i=int, s=true)', ': [{1..5}] real'),
('x(): domain(1)', 'x', '', ': domain(1)'),
('x(): [{1..n}] BigNum', 'x', '', ': [{1..n}] BigNum'),
('x(): nil', 'x', '', ': nil'),
('x() ref', 'x', '', ' ref'),
('x() const', 'x', '', ' const'),
('x(ref x:int(32)) const', 'x', 'ref x:int(32)', ' const'),
('x(): int', 'x', '', ': int', None),
('x(): MyMod.MyClass', 'x', '', ': MyMod.MyClass', None),
('x(): int(32)', 'x', '', ': int(32)', None),
('x():int(32)', 'x', '', ':int(32)', None),
('x(y:int(64)):int(32)', 'x', 'y:int(64)', ':int(32)', None),
('x(y:int(64), d: domain(r=2, i=int, s=true)): [{1..5}] real', 'x', 'y:int(64), d: domain(r=2, i=int, s=true)', ': [{1..5}] real', None),
('x(): domain(1)', 'x', '', ': domain(1)', None),
('x(): [{1..n}] BigNum', 'x', '', ': [{1..n}] BigNum', None),
('x(): nil', 'x', '', ': nil', None),
('x() ref', 'x', '', ' ref', None),
('x() const', 'x', '', ' const', None),
('x(ref x:int(32)) const', 'x', 'ref x:int(32)', ' const', None),
]
for sig, name, arglist, retann in test_cases:
self.check_sig(sig, None, None, name, arglist, retann)
for sig, name, arglist, retann, where_clause in test_cases:
self.check_sig(sig, None, None, name, arglist, retann, where_clause)

def test_with_class_names(self):
"""Verify function signatures with class names parse correctly."""
Expand All @@ -754,7 +755,7 @@ def test_with_class_names(self):
('MyMod.MyClass.foo()', 'MyMod.MyClass.', 'foo', ''),
]
for sig, class_name, name, arglist in test_cases:
self.check_sig(sig, None, class_name, name, arglist, None)
self.check_sig(sig, None, class_name, name, arglist, None, None)

def test_with_prefixes(self):
"""Verify functions with prefixes parse correctly."""
Expand All @@ -766,51 +767,71 @@ def test_with_prefixes(self):
('inline operator +', 'inline operator ', '+', None),
]
for sig, prefix, name, arglist in test_cases:
self.check_sig(sig, prefix, None, name, arglist, None)
self.check_sig(sig, prefix, None, name, arglist, None, None)

def test_with_where_clause(self):
"""Verify functions with where clauses parse correctly."""
test_cases = [
('proc processArr(arr: [1..n] int, f: proc (int) int) where n > 0', 'proc ', 'processArr', 'arr: [1..n] int, f: proc (int) int', None, ' where n > 0'),
('proc processArr(arr: []) where arr.elemType == int', 'proc ', 'processArr', 'arr: []', None, ' where arr.elemType == int'),
('proc processDom(dom: domain) where dom.rank == 2', 'proc ', 'processDom', 'dom: domain', None, ' where dom.rank == 2'),
('proc processRec(r: MyRecord) where r.x > 0', 'proc ', 'processRec', 'r: MyRecord', None, ' where r.x > 0'),
('proc processRange(r: [1..n] int) where n > 0', 'proc ', 'processRange', 'r: [1..n] int', None, ' where n > 0'),
('proc processRange(r: range) where r.low > 1', 'proc ', 'processRange', 'r: range', None, ' where r.low > 1'),
('operator + (a: int, b: int) where a > 0', 'operator ', '+', 'a: int, b: int', None, ' where a > 0'),
]
for sig, prefix, name, arglist, retann, where_clause in test_cases:
self.check_sig(sig, prefix, None, name, arglist, retann, where_clause)

def test_with_all(self):
"""Verify fully specified signatures parse correctly."""
test_cases = [
('proc foo() ref', 'proc ', None, 'foo', '', ' ref'),
('iter foo() ref', 'iter ', None, 'foo', '', ' ref'),
('inline proc Vector.pop() ref', 'inline proc ', 'Vector.', 'pop', '', ' ref'),
('inline proc range.first', 'inline proc ', 'range.', 'first', None, None),
('iter Math.fib(n: int(64)): GMP.BigInt', 'iter ', 'Math.', 'fib', 'n: int(64)', ': GMP.BigInt'),
('proc My.Mod.With.Deep.NameSpace.1.2.3.432.foo()', 'proc ', 'My.Mod.With.Deep.NameSpace.1.2.3.432.', 'foo', '', None),
('these() ref', None, None, 'these', '', ' ref'),
('size', None, None, 'size', None, None),
('proc Util.toVector(type eltType, cap=4, offset=0): Containers.Vector', 'proc ', 'Util.', 'toVector', 'type eltType, cap=4, offset=0', ': Containers.Vector'),
('proc MyClass$.lock$(combo$): sync bool', 'proc ', 'MyClass$.', 'lock$', 'combo$', ': sync bool'),
('proc MyClass$.lock$(combo$): sync myBool$', 'proc ', 'MyClass$.', 'lock$', 'combo$', ': sync myBool$'),
('proc type currentTime(): int(64)', 'proc type ', None, 'currentTime', '', ': int(64)'),
('proc param int.someNum(): int(64)', 'proc param ', 'int.', 'someNum', '', ': int(64)'),
('proc MyRs(seed: int(64)): int(64)', 'proc ', None, 'MyRs', 'seed: int(64)', ': int(64)'),
('proc foo where a > b', 'proc ', None, 'foo', None, None, ' where a > b'),
('proc foo() where a > b', 'proc ', None, 'foo', '', None, ' where a > b'),
('proc foo:int where a > b', 'proc ', None, 'foo', None, ':int', ' where a > b'),
('proc foo():int where a > b', 'proc ', None, 'foo', '', ':int', ' where a > b'),
('proc foo ref where a > b', 'proc ', None, 'foo', None, ' ref', ' where a > b'),
('proc foo() ref where a > b', 'proc ', None, 'foo', '', ' ref', ' where a > b'),
('proc foo ref: int where a > b', 'proc ', None, 'foo', None, ' ref: int', ' where a > b'),
('proc foo() ref: int where a > b', 'proc ', None, 'foo', '', ' ref: int', ' where a > b'),
('proc foo() ref', 'proc ', None, 'foo', '', ' ref', None),
('iter foo() ref', 'iter ', None, 'foo', '', ' ref', None),
('inline proc Vector.pop() ref', 'inline proc ', 'Vector.', 'pop', '', ' ref', None),
('inline proc range.first', 'inline proc ', 'range.', 'first', None, None, None),
('iter Math.fib(n: int(64)): GMP.BigInt', 'iter ', 'Math.', 'fib', 'n: int(64)', ': GMP.BigInt', None),
('proc My.Mod.With.Deep.NameSpace.1.2.3.432.foo()', 'proc ', 'My.Mod.With.Deep.NameSpace.1.2.3.432.', 'foo', '', None, None),
('these() ref', None, None, 'these', '', ' ref', None),
('size', None, None, 'size', None, None, None),
('proc Util.toVector(type eltType, cap=4, offset=0): Containers.Vector', 'proc ', 'Util.', 'toVector', 'type eltType, cap=4, offset=0', ': Containers.Vector', None),
('proc MyClass$.lock$(combo$): sync bool', 'proc ', 'MyClass$.', 'lock$', 'combo$', ': sync bool', None),
('proc MyClass$.lock$(combo$): sync myBool$', 'proc ', 'MyClass$.', 'lock$', 'combo$', ': sync myBool$', None),
('proc type currentTime(): int(64)', 'proc type ', None, 'currentTime', '', ': int(64)', None),
('proc param int.someNum(): int(64)', 'proc param ', 'int.', 'someNum', '', ': int(64)', None),
('proc MyRs(seed: int(64)): int(64)', 'proc ', None, 'MyRs', 'seed: int(64)', ': int(64)', None),
('proc RandomStream(seed: int(64) = SeedGenerator.currentTime, param parSafe: bool = true)',
'proc ', None, 'RandomStream', 'seed: int(64) = SeedGenerator.currentTime, param parSafe: bool = true', None),
('class X', 'class ', None, 'X', None, None),
('class MyClass:YourClass', 'class ', None, 'MyClass', None, ':YourClass'),
('class M.C : A, B, C', 'class ', 'M.', 'C', None, ': A, B, C'),
('record R', 'record ', None, 'R', None, None),
('record MyRec:SuRec', 'record ', None, 'MyRec', None, ':SuRec'),
('record N.R : X, Y, Z', 'record ', 'N.', 'R', None, ': X, Y, Z'),
'proc ', None, 'RandomStream', 'seed: int(64) = SeedGenerator.currentTime, param parSafe: bool = true', None, None),
('class X', 'class ', None, 'X', None, None, None),
('class MyClass:YourClass', 'class ', None, 'MyClass', None, ':YourClass', None),
('class M.C : A, B, C', 'class ', 'M.', 'C', None, ': A, B, C', None),
('record R', 'record ', None, 'R', None, None, None),
('record MyRec:SuRec', 'record ', None, 'MyRec', None, ':SuRec', None),
('record N.R : X, Y, Z', 'record ', 'N.', 'R', None, ': X, Y, Z', None),
('proc rcRemote(replicatedVar: [?D] ?MYTYPE, remoteLoc: locale) ref: MYTYPE',
'proc ', None, 'rcRemote', 'replicatedVar: [?D] ?MYTYPE, remoteLoc: locale', ' ref: MYTYPE'),
'proc ', None, 'rcRemote', 'replicatedVar: [?D] ?MYTYPE, remoteLoc: locale', ' ref: MYTYPE', None),
('proc rcLocal(replicatedVar: [?D] ?MYTYPE) ref: MYTYPE',
'proc ', None, 'rcLocal', 'replicatedVar: [?D] ?MYTYPE', ' ref: MYTYPE'),
('proc specialArg(const ref x: int)', 'proc ', None, 'specialArg', 'const ref x: int', None),
('proc specialReturn() const ref', 'proc ', None, 'specialReturn', '', ' const ref'),
('proc constRefArgAndReturn(const ref x: int) const ref', 'proc ', None, 'constRefArgAndReturn', 'const ref x: int', ' const ref'),
('operator string.+(s0: string, s1: string) : string', 'operator ', 'string.', '+', 's0: string, s1: string', ' : string'),
('operator *(s: string, n: integral) : string', 'operator ', None, '*', 's: string, n: integral', ' : string'),
('inline operator string.==(param s0: string, param s1: string) param', 'inline operator ', 'string.', '==', 'param s0: string, param s1: string', ' param'),
('operator bytes.=(ref lhs: bytes, rhs: bytes) : void ', 'operator ', 'bytes.', '=', 'ref lhs: bytes, rhs: bytes', ' : void '),
'proc ', None, 'rcLocal', 'replicatedVar: [?D] ?MYTYPE', ' ref: MYTYPE', None),
('proc specialArg(const ref x: int)', 'proc ', None, 'specialArg', 'const ref x: int', None, None),
('proc specialReturn() const ref', 'proc ', None, 'specialReturn', '', ' const ref', None),
('proc constRefArgAndReturn(const ref x: int) const ref', 'proc ', None, 'constRefArgAndReturn', 'const ref x: int', ' const ref', None),
('operator string.+(s0: string, s1: string) : string', 'operator ', 'string.', '+', 's0: string, s1: string', ' : string', None),
('operator *(s: string, n: integral) : string', 'operator ', None, '*', 's: string, n: integral', ' : string', None),
('inline operator string.==(param s0: string, param s1: string) param', 'inline operator ', 'string.', '==', 'param s0: string, param s1: string', ' param', None),
('operator bytes.=(ref lhs: bytes, rhs: bytes) : void ', 'operator ', 'bytes.', '=', 'ref lhs: bytes, rhs: bytes', ' : void ', None),
# can't handle this pattern, ":" is set as punctuation, and casts don't seem to be doc'd anyway
# ('operator :(x: bytes)', 'operator ', None, ':', 'x: bytes', None),

]
for sig, prefix, class_name, name, arglist, retann in test_cases:
self.check_sig(sig, prefix, class_name, name, arglist, retann)

for sig, prefix, class_name, name, arglist, retann, where_clause in test_cases:
self.check_sig(sig, prefix, class_name, name, arglist, retann, where_clause)

class AttrSigPatternTests(PatternTestCase):
"""Verify chpl_attr_sig_pattern regex."""
Expand Down

0 comments on commit 1d6b2ce

Please sign in to comment.