Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add utilities for auto-generating param Python types #636

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 73 additions & 2 deletions param/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import re
import datetime as dt
import collections
import numbers
import typing

from .parameterized import (
Parameterized, Parameter, String, ParameterizedFunction, ParamOverrides,
Expand Down Expand Up @@ -688,7 +690,6 @@ def _force(self,obj,objtype=None):
return gen


import numbers
def _is_number(obj):
if isinstance(obj, numbers.Number): return True
# The extra check is for classes that behave like numbers, such as those
Expand Down Expand Up @@ -794,6 +795,10 @@ def __init__(self, default=0.0, bounds=None, softbounds=None,
self.step = step
self._validate(default)

@property
def pytype(self):
return typing.Union[numbers.Number, None] if self.allow_None else numbers.Number

def __get__(self, obj, objtype):
"""
Same as the superclass's __get__, but if the value was
Expand Down Expand Up @@ -923,6 +928,10 @@ class Integer(Number):
def __init__(self, default=0, **params):
Number.__init__(self, default=default, **params)

@property
def pytype(self):
return typing.Union[int, None] if self.allow_None else int

def _validate_value(self, val, allow_None):
if callable(val):
return
Expand Down Expand Up @@ -960,6 +969,10 @@ def __init__(self, default=False, bounds=(0,1), **params):
self.bounds = bounds
super(Boolean, self).__init__(default=default, **params)

@property
def pytype(self):
return typing.Union[bool, None] if self.allow_None else bool

def _validate_value(self, val, allow_None):
if allow_None:
if not isinstance(val, bool) and val is not None:
Expand Down Expand Up @@ -994,6 +1007,14 @@ def __init__(self, default=(0,0), length=None, **params):
self.length = length
self._validate(default)

@property
def pytype(self):
if self.length:
pytype = typing.Tuple[(typing.Any,)*self.length]
else:
ptype = typing.Tuple[typing.Any, ...]
return typing.Union[pytype, None] if self.allow_None else pytype

def _validate_value(self, val, allow_None):
if val is None and allow_None:
return
Expand Down Expand Up @@ -1031,6 +1052,14 @@ def deserialize(cls, value):
class NumericTuple(Tuple):
"""A numeric tuple Parameter (e.g. (4.5,7.6,3)) with a fixed tuple length."""

@property
def pytype(self):
if self.length:
pytype = typing.Tuple[(numbers.Number,)*self.length]
else:
ptype = typing.Tuple[numbers.Number, ...]
return typing.Union[pytype, None] if self.allow_None else pytype

def _validate_value(self, val, allow_None):
super(NumericTuple, self)._validate_value(val, allow_None)
if allow_None and val is None:
Expand All @@ -1048,6 +1077,11 @@ class XYCoordinates(NumericTuple):
def __init__(self, default=(0.0, 0.0), **params):
super(XYCoordinates,self).__init__(default=default, length=2, **params)

@property
def pytype(self):
pytype = typing.Tuple[numbers.Number, numbers.Number]
return typing.Union[pytype, None] if self.allow_None else pytype


class Callable(Parameter):
"""
Expand Down Expand Up @@ -1393,6 +1427,17 @@ def __init__(self, default=[], class_=None, item_type=None,
**params)
self._validate(default)

@property
def pytype(self):
if isinstance(self.item_type, tuple):
item_type = typing.Union[self.item_type]
elif self.item_type is not None:
item_type = self.item_type
else:
item_type = typing.Any
list_type = typing.List[item_type]
return typing.Union[list_type, None] if self.allow_None else list_type

def _validate(self, val):
"""
Checks that the value is numeric and that it is within the hard
Expand Down Expand Up @@ -1466,16 +1511,27 @@ class Dict(ClassSelector):
def __init__(self, default=None, **params):
super(Dict, self).__init__(dict, default=default, **params)

@property
def pytype(self):
dict_type = typing.Dict[typing.Hashable, typing.Any]
return typing.Union[dict_type, None] if self.allow_None else dict_type



class Array(ClassSelector):
"""
Parameter whose value is a numpy array.
"""

def __init__(self, default=None, **params):
from numpy import ndarray

super(Array, self).__init__(ndarray, allow_None=True, default=default, **params)

@property
def pytype(self):
from numpy import ndarray
return ndarray

@classmethod
def serialize(cls, value):
if value is None:
Expand Down Expand Up @@ -1519,6 +1575,11 @@ def __init__(self, default=None, rows=None, columns=None, ordered=None, **params
super(DataFrame,self).__init__(pdDFrame, default=default, **params)
self._validate(self.default)

@property
def pytype(self):
from pandas import DataFrame
return DataFrame

def _length_bounds_check(self, bounds, length, name):
message = '{name} length {length} does not match declared bounds of {bounds}'
if not isinstance(bounds, tuple):
Expand Down Expand Up @@ -1595,6 +1656,11 @@ def __init__(self, default=None, rows=None, allow_None=False, **params):
**params)
self._validate(self.default)

@property
def pytype(self):
from pandas import Series
return Series

def _length_bounds_check(self, bounds, length, name):
message = '{name} length {length} does not match declared bounds of {bounds}'
if not isinstance(bounds, tuple):
Expand Down Expand Up @@ -2165,6 +2231,7 @@ def deserialize(cls, value):
# As JSON has no tuple representation
return tuple(deserialized)


class CalendarDateRange(Range):
"""
A date range specified as (start_date, end_date).
Expand Down Expand Up @@ -2241,6 +2308,10 @@ def __init__(self,default=False,bounds=(0,1),**params):
# back to False while triggered callbacks are executing
super(Event, self).__init__(default=default,**params)

@property
def pytype(self):
return bool

def _reset_event(self, obj, val):
val = False
if obj is None:
Expand Down
9 changes: 9 additions & 0 deletions param/parameterized.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import random
import numbers
import operator
import typing

# Allow this file to be used standalone if desired, albeit without JSON serialization
try:
Expand Down Expand Up @@ -1067,6 +1068,10 @@ class hierarchy (see ParameterizedMetaclass).
self.watchers = {}
self.per_instance = per_instance

@property
def pytype(self):
return typing.Any

@classmethod
def serialize(cls, value):
"Given the parameter value, return a Python value suitable for serialization"
Expand Down Expand Up @@ -1327,6 +1332,10 @@ def __init__(self, default="", regex=None, allow_None=False, **kwargs):
self.allow_None = (default is None or allow_None)
self._validate(default)

@property
def pytype(self):
return str

def _validate_regex(self, val, regex):
if (val is None and self.allow_None):
return
Expand Down
99 changes: 99 additions & 0 deletions param/typer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import ast
import inspect

import param


def walk_modules(pkg):
modules = [pkg]
for name, submodule in inspect.getmembers(pkg, inspect.ismodule):
if submodule.__package__.startswith(pkg.__name__):
modules.extend(walk_modules(submodule))
return modules

def walk_parameterized(module):
parameterizeds = []
for name, cls in inspect.getmembers(module, inspect.isclass):
if issubclass(cls, param.Parameterized) and cls.__module__ == module.__name__ and walk_parameters(cls):
parameterizeds.append(cls)
return parameterizeds

def walk_parameters(parameterized):
return [obj for obj in parameterized.param.objects().values() if obj.owner is parameterized]

class ExtractClassDefs(ast.NodeTransformer):

def __init__(self, parameterizeds):
self.nodes = {p: None for p in parameterizeds}
self._lookup = {p.__name__: p for p in parameterizeds}

def visit_ClassDef(self, node):
p = self._lookup.get(node.name)
if p is not None:
self.nodes[p] = node
return node

class ExtractParameterDefs(ast.NodeTransformer):

def __init__(self, params):
self.nodes = {p: None for p in params}
self._lookup = {p.name: p for p in params}

def record(self, node):
targets = [node.target] if isinstance(node, ast.AnnAssign) else node.targets
for t in targets:
if not isinstance(t, ast.Name):
continue
p = self._lookup.get(t.id)
if p is not None:
self.nodes[p] = node

def visit_AnnAssign(self, node):
self.record(node)
return node

def visit_Assign(self, node):
self.record(node)
return node

def extra_param_assigns(parameterized, cls_def):
params = walk_parameters(parameterized)
param_defs = ExtractParameterDefs(params)
param_defs.visit(cls_def)
return param_defs.nodes

TYPES = {
bool: 'bool',
int: 'int',
float: 'float',
bytes: 'bytes'
}

def format_type(pytype):
if pytype in TYPES:
return TYPES[pytype]
return str(pytype)

def add_types(module):
with open(module.__file__) as f:
code = f.read()
parsed = ast.parse(code, module.__file__)

parameterized = walk_parameterized(module)

class_defs = ExtractClassDefs(parameterized)
class_defs.visit(parsed)

parameter_definitions = {}
for pzd in parameterized:
cls_def = class_defs.nodes[pzd]
param_defs = extra_param_assigns(pzd, cls_def)
for p, pdef in param_defs.items():
if pdef is None:
continue
src_code = ast.get_source_segment(code, pdef, padded=True)
transformed = src_code.split('\n')
pytype = format_type(p.pytype)
transformed[0] = transformed[0].replace(f'{p.name} =', f'{p.name}: {pytype} =')
code = code.replace(src_code, '\n'.join(transformed))
return code