diff --git a/modules/pymol/cmd.py b/modules/pymol/cmd.py index 6c89f72cd..09ae92f9e 100644 --- a/modules/pymol/cmd.py +++ b/modules/pymol/cmd.py @@ -202,7 +202,7 @@ def as_pathstr(path): # for extending the language - from .commanding import extend, extendaa, alias + from .commanding import declare_command, extend, extendaa, alias # for documentation etc diff --git a/modules/pymol/commanding.py b/modules/pymol/commanding.py index 4f52321fd..470994f33 100644 --- a/modules/pymol/commanding.py +++ b/modules/pymol/commanding.py @@ -20,6 +20,15 @@ import urllib.request as urllib2 from io import FileIO as file + import inspect + import glob + import shlex + from enum import Enum + from functools import wraps + from pathlib import Path + from textwrap import dedent + from typing import List + import re import os import time @@ -529,6 +538,97 @@ def delete(name, *, _self=cmd): if _self._raising(r,_self): raise pymol.CmdException return r + + class Selection(str): + pass + + + def _parse_bool(value: str): + if isinstance(value, str): + if value.lower() in ["yes", "1", "true", "on", "y"]: + return True + elif value.lower() in ["no", "0", "false", "off", "n"]: + return False + else: + raise Exception("Invalid boolean value: %s" % value) + elif isinstance(value, bool): + return value + else: + raise Exception(f"Unsuported boolean flag {value}") + + def _parse_list_str(value): + return shlex.split(value) + + def _parse_list_int(value): + return list(map(int, shlex.split(value))) + + def _parse_list_float(value): + return list(map(float, shlex.split(value))) + + def declare_command(name, function=None, _self=cmd): + if function is None: + name, function = name.__name__, name + + # new style commands should have annotations + annotations = [a for a in function.__annotations__ if a != "return"] + if function.__code__.co_argcount != len(annotations): + raise Exception("Messy annotations") + + # docstring text, if present, should be dedented + if function.__doc__ is not None: + function.__doc__ = dedent(function.__doc__).strip() + + + # Analysing arguments + spec = inspect.getfullargspec(function) + kwargs_ = {} + args_ = spec.args[:] + defaults = list(spec.defaults or []) + + args2_ = args_[:] + while args_ and defaults: + kwargs_[args_.pop(-1)] = defaults.pop(-1) + + funcs = {} + for idx, (var, func) in enumerate(spec.annotations.items()): + funcs[var] = func + + # Inner function that will be callable every time the command is executed + @wraps(function) + def inner(*args, **kwargs): + frame = traceback.format_stack()[-2] + caller = frame.split("\"", maxsplit=2)[1] + + # It was called from command line or pml script, so parse arguments + if caller.endswith("pymol/parser.py"): + kwargs = {**kwargs_, **kwargs, **dict(zip(args2_, args))} + kwargs.pop("_self", None) + for arg in kwargs.copy(): + if funcs[arg] == bool: + funcs[arg] = _parse_bool + elif funcs[arg] == List[str]: + funcs[arg] = _parse_list_str + elif funcs[arg] == List[int]: + funcs[arg] = _parse_list_int + elif funcs[arg] == List[float]: + funcs[arg] = _parse_list_float + else: + # Assume it's a literal supported type + pass + # Convert the argument to the correct type + kwargs[arg] = funcs[arg](kwargs[arg]) + return function(**kwargs) + + # It was called from Python, so pass the arguments as is + else: + return function(*args, **kwargs) + + name = function.__name__ + _self.keyword[name] = [inner, 0, 0, ",", parsing.STRICT] + _self.kwhash.append(name) + _self.help_sc.append(name) + return inner + def extend(name, function=None, _self=cmd): ''' diff --git a/testing/tests/api/commanding.py b/testing/tests/api/commanding.py index e09ec789e..e9d671224 100644 --- a/testing/tests/api/commanding.py +++ b/testing/tests/api/commanding.py @@ -1,10 +1,17 @@ from __future__ import print_function import sys +import pytest + import pymol import __main__ from pymol import cmd, testing, stored +from typing import List + + + + class TestCommanding(testing.PyMOLTestCase): def testAlias(self): @@ -171,3 +178,100 @@ def testRun(self, namespace, mod, rw): self.assertTrue(stored.tmp) if mod: self.assertEqual(rw, hasattr(sys.modules[mod], varname)) + +def test_declare_command_casting(): + from pathlib import Path + + @cmd.declare_command + def func(a: int, b: Path): + assert isinstance(a, int) and a == 1 + assert isinstance(b, (Path, str)) and "/tmp" == str(b) + func(1, "/tmp") + cmd.do('func 1, /tmp') + + +def test_declare_command_default(capsys): + from pymol.commanding import Selection + @cmd.declare_command + def func(a: Selection = "sele"): + assert a == "sele" + func() + cmd.do("func") + out, err = capsys.readouterr() + assert out == '' + +def test_declare_command_docstring(): + @cmd.declare_command + def func(): + """docstring""" + assert func.__doc__ == "docstring" + + @cmd.declare_command + def func(): + """ + docstring + Test: + --foo + """ + assert func.__doc__ == "docstring\nTest:\n --foo" + + +def test_declare_command_type_return(capsys): + @cmd.declare_command + def func() -> int: + return 1 + + assert func() == 1 + out, err = capsys.readouterr() + assert out == '' + + @cmd.declare_command + def func(): + return 1 + assert func() == 1 + +def test_declare_command_list_str(capsys): + @cmd.declare_command + def func(a: List[str]): + print(a[-1]) + + func(["a", "b", "c"]) + cmd.do('func a b c') + out, err = capsys.readouterr() + assert out == 'c\nc\n' + +def test_declare_command_list_int(capsys): + @cmd.declare_command + def func(a: List[int]): + print(a[-1] ** 2) + return a[-1] ** 2 + + assert func([1, 2, 3]) == 9 + cmd.do('func 1 2 3') + out, err = capsys.readouterr() + assert out == '9\n9\n' + + +def test_declare_command_list_float(capsys): + @cmd.declare_command + def func(a: List[float]): + print(a[-1]**2) + return a[-1]**2 + + assert func([1.1, 2.0, 3.0]) == 9.0 + cmd.do('func 1 2 3') + out, err = capsys.readouterr() + assert out == '9.0\n9.0\n' + + +def test_declare_command_bool(capsys): + @cmd.declare_command + def func(a: bool, b: bool): + assert a + assert not b + + func(True, False) + + cmd.do("func yes, no") + out, err = capsys.readouterr() + assert out == '' and err == '' \ No newline at end of file