Skip to content

Commit

Permalink
Merge pull request #67 from pyiron/converter
Browse files Browse the repository at this point in the history
Update units check
  • Loading branch information
samwaseda authored Feb 26, 2025
2 parents 52bff21 + 6d4614e commit 47ebd67
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
14 changes: 12 additions & 2 deletions semantikon/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _get_converter(func):


def _get_ret_units(output, ureg, names):
if output is None:
if output == {}:
return None
ret = _to_units_container(output.get("units", None), ureg)
names = {key: 1.0 * value.units for key, value in names.items()}
Expand All @@ -130,6 +130,16 @@ def _get_output_units(output, ureg, names):
return _get_ret_units(output, ureg, names)


def _is_dimensionless(output):
if output is None:
return True
if isinstance(output, tuple):
return all([_is_dimensionless(oo) for oo in output])
if output.to_base_units().magnitude == 1.0 and output.dimensionless:
return True
return False


def units(func):
"""
Decorator to convert the output of a function to a Quantity object with
Expand All @@ -155,7 +165,7 @@ def wrapper(*args, **kwargs):
output_units = _get_output_units(parse_output_args(func), ureg, names)
except AttributeError:
output_units = None
if output_units is None:
if _is_dimensionless(output_units):
return func(*args, **kwargs)
elif isinstance(output_units, tuple):
return tuple(
Expand Down
12 changes: 12 additions & 0 deletions tests/unit/test_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ def get_speed_relative(
return distance / time


@units
def return_dict(
distance: u(float, units="meter"), time: u(float, units="second")
) -> dict:
return {"distance": distance, "time": time}


class TestUnits(unittest.TestCase):
def test_relative(self):
self.assertEqual(get_speed_relative(1, 1), 1)
Expand Down Expand Up @@ -170,6 +177,11 @@ def get_speed_use_list(
0.001 * ureg.meter / ureg.second,
)

def test_return_dict(self):
self.assertEqual(return_dict(1, 1), {"distance": 1, "time": 1})
ureg = UnitRegistry()
self.assertIsInstance(return_dict(1 * ureg.meter, 1 * ureg.second), dict)


if __name__ == "__main__":
unittest.main()

0 comments on commit 47ebd67

Please sign in to comment.