Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Simplify convert_math_to_symbolic #5

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 12 additions & 15 deletions src/pycalphad_xml/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,21 @@
from pathlib import Path
this_dir = Path(__file__).parent

def _stringify_node_text(node):
return ''.join(node.xpath('./text()')).replace('\n', '').replace(' ', '').strip()

def convert_math_to_symbolic(math_nodes):

def convert_math_to_symbolic(math_node):
result = 0.0
interval_nodes = [x for x in math_nodes if (not isinstance(x, str)) and x.tag == 'Interval']
string_nodes = [x for x in math_nodes if isinstance(x, str)]
for math_node in string_nodes:
# +0 is a hack, for how the function works
result += _sympify_string(math_node+'+0')
result += convert_intervals_to_piecewise(interval_nodes)
interval_nodes = [x for x in math_node if x.tag == 'Interval']
# +0 is a hack, for how the function works
result += _sympify_string(_stringify_node_text(math_node)+'+0')
result += _convert_intervals_to_piecewise(interval_nodes)
result = result.xreplace({Symbol('T'): v.T, Symbol('P'): v.P})
return result


def convert_intervals_to_piecewise(interval_nodes):
def _convert_intervals_to_piecewise(interval_nodes):
exprs = []
conds = []
for interval_node in interval_nodes:
Expand All @@ -33,7 +34,7 @@ def convert_intervals_to_piecewise(interval_nodes):
variable = interval_node.attrib['in']
lower = float(interval_node.attrib.get('lower', '-inf'))
upper = float(interval_node.attrib.get('upper', 'inf'))
math_expr = convert_math_to_symbolic([''.join(interval_node.itertext()).replace('\n', '').replace(' ', '').strip()])
math_expr = _sympify_string(_stringify_node_text(interval_node)+'+0')
if upper != float('inf'):
cond = And(lower <= getattr(v, variable, Symbol(variable)), upper > getattr(v, variable))
else:
Expand Down Expand Up @@ -181,9 +182,7 @@ def parse_model(dbf, phase_name, model_node, parameters):
constituent_array = [[str(c) for c in sorted(lx)] for lx in constituent_array]

# Parameter value
# Interval _and_ text (if any) to be able to handle intervals or scalar expressions
param_nodes = param_node.xpath('./Interval') + [''.join(param_node.xpath('./text()')).strip()]
function_obj = convert_math_to_symbolic(param_nodes)
function_obj = convert_math_to_symbolic(param_node)

# TODO: Reference

Expand Down Expand Up @@ -254,9 +253,7 @@ def read_xml(dbf, fd):
dbf.species.add(v.Species(species, constituent_dict, charge=species_charge))
elif child.tag == 'Expr':
function_name = str(child.attrib['id'])
# Interval _and_ text (if any) to be able to handle intervals or scalar expressions
expr_nodes = child.xpath('./Interval') + [''.join(child.xpath('./text()')).strip()]
function_obj = convert_math_to_symbolic(expr_nodes)
function_obj = convert_math_to_symbolic(child)
_setitem_raise_duplicates(dbf.symbols, function_name, function_obj)
elif child.tag == 'Phase':
model_nodes = child.xpath('./Model')
Expand Down