Skip to content

Commit

Permalink
Make parameters fully pickleable (#158)
Browse files Browse the repository at this point in the history
* Make parameters fully pickleable

* Make Arguments use slots all the way, and test it

* Address comment

* Add a comment ased on comments

* Move argument_indices back to the class

* Address comment
  • Loading branch information
pckroon authored and tBuLi committed Jul 6, 2018
1 parent ca2ce1b commit 999ef0b
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 9 deletions.
33 changes: 24 additions & 9 deletions symfit/core/argument.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from collections import defaultdict
import numbers
import warnings

from sympy.core.symbol import Symbol


class Argument(Symbol):
"""
Base class for ``symfit`` symbols. This helps make ``symfit`` symbols distinguishable from ``sympy`` symbols.
Base class for :mod:`symfit` symbols. This helps make :mod:`symfit` symbols
distinguishable from :mod:`sympy` symbols.
If no name is explicitly provided a name will be generated.
Expand All @@ -19,6 +22,11 @@ class Argument(Symbol):
print(y.name)
>> 'y'
"""
__slots__ = ['_argument_index', '_argument_name']
# TODO: Make sure this also survives a pickle/unpickle to a fresh(!)
# interpreter.
_argument_indices = defaultdict(int)

def __new__(cls, name=None, *args, **assumptions):
assumptions['real'] = True
# Generate a dummy name
Expand All @@ -31,13 +39,13 @@ def __new__(cls, name=None, *args, **assumptions):
DeprecationWarning, stacklevel=2
)

name = '{}_{}'.format(cls._argument_name, cls._argument_index)
name = '{}_{}'.format(cls._argument_name, cls._argument_indices[cls])
instance = super(Argument, cls).__new__(cls, name, **assumptions)
instance._argument_index = cls._argument_index
cls._argument_index += 1
return instance
else:
return super(Argument, cls).__new__(cls, name, **assumptions)
instance = super(Argument, cls).__new__(cls, name, **assumptions)
instance._argument_index = cls._argument_indices[cls]
cls._argument_indices[cls] += 1
return instance

def __init__(self, name=None, *args, **assumptions):
# TODO: A more careful look at Symbol.__init__ is needed! However, it
Expand All @@ -46,6 +54,12 @@ def __init__(self, name=None, *args, **assumptions):
self.name = name
super(Argument, self).__init__()

def __getstate__(self):
state = super(Argument, self).__getstate__()
state.update({slot: getattr(self, slot) for slot in self.__slots__
if hasattr(self, slot)})
return state


class Parameter(Argument):
"""
Expand All @@ -56,7 +70,8 @@ class Parameter(Argument):
be generated.
"""
# Parameter index to be assigned to generated nameless parameters
_argument_index = 0
__slots__ = ['min', 'max', 'fixed', 'value']

_argument_name = 'par'

def __new__(cls, name=None, *args, **kwargs):
Expand Down Expand Up @@ -95,5 +110,5 @@ def __init__(self, name=None, value=1.0, min=None, max=None, fixed=False, **assu
class Variable(Argument):
""" Variable type."""
# Variable index to be assigned to generated nameless variables
_argument_index = 0
_argument_name = 'var'
_argument_name = 'var'
__slots__ = ()
36 changes: 36 additions & 0 deletions tests/test_argument.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import division, print_function
import pickle
import unittest
import sys
import sympy
Expand Down Expand Up @@ -75,6 +76,41 @@ def test_symbol_add(self):
new = x + y
self.assertIsInstance(new, sympy.Add)

def test_pickle(self):
"""
Make sure attributes are preserved when pickling
"""
A = Parameter('A', min=0., max=1e3, fixed=True)
new_A = pickle.loads(pickle.dumps(A))
self.assertEqual((A.min, A.value, A.max, A.fixed, A.name),
(new_A.min, new_A.value, new_A.max, new_A.fixed, new_A.name))

A = Parameter(min=0., max=1e3, fixed=True)
new_A = pickle.loads(pickle.dumps(A))
self.assertEqual((A.min, A.value, A.max, A.fixed, A.name),
(new_A.min, new_A.value, new_A.max, new_A.fixed, new_A.name))

def test_slots(self):
"""
Make sure Parameters and Variables don't have a __dict__
"""
P = Parameter('P')

# If you only have __slots__ you can't set arbitrary attributes, but
# you *should* be able to set those that are in your __slots__
try:
P.min = 0
except AttributeError:
self.fail()

with self.assertRaises(AttributeError):
P.foo = None

V = Variable('V')
with self.assertRaises(AttributeError):
V.bar = None


if __name__ == '__main__':
try:
unittest.main(warnings='ignore')
Expand Down

0 comments on commit 999ef0b

Please sign in to comment.