diff --git a/src/ess/reflectometry/tools.py b/src/ess/reflectometry/tools.py index 341fa4f0..f880e30b 100644 --- a/src/ess/reflectometry/tools.py +++ b/src/ess/reflectometry/tools.py @@ -254,19 +254,39 @@ def combine_curves( if len({c.coords['Q'].unit for c in curves}) != 1: raise ValueError('The Q-coordinates must have the same unit for each curve') - r = _interpolate_on_qgrid(map(sc.values, curves), qgrid).values - v = _interpolate_on_qgrid(map(sc.variances, curves), qgrid).values + r = _interpolate_on_qgrid(map(sc.values, curves), qgrid) + v = _interpolate_on_qgrid(map(sc.variances, curves), qgrid) - v[v == 0] = np.nan + v = sc.where(v == 0, sc.scalar(np.nan, unit=v.unit), v) inv_v = 1.0 / v - r_avg = np.nansum(r * inv_v, axis=0) / np.nansum(inv_v, axis=0) - v_avg = 1 / np.nansum(inv_v, axis=0) - return sc.DataArray( + r_avg = sc.nansum(r * inv_v, dim='curves') / sc.nansum(inv_v, dim='curves') + v_avg = 1 / sc.nansum(inv_v, dim='curves') + + out = sc.DataArray( data=sc.array( dims='Q', - values=r_avg, - variances=v_avg, + values=r_avg.values, + variances=v_avg.values, unit=next(iter(curves)).data.unit, ), coords={'Q': qgrid}, ) + if any('Q_resolution' in c.coords for c in curves): + # This might need to be revisited. The question about how to combine curves + # with different Q-resolution is not completely resolved. + # However, in practice the difference in Q-resolution between different curves + # is small so it's not likely to make a big difference. + q_res = ( + sc.DataArray( + data=c.coords.get( + 'Q_resolution', sc.full_like(c.coords['Q'], value=np.nan) + ), + coords={'Q': c.coords['Q']}, + ) + for c in curves + ) + qs = _interpolate_on_qgrid(q_res, qgrid) + out.coords['Q_resolution'] = sc.nansum(qs * inv_v, dim='curves') / sc.nansum( + sc.where(sc.isnan(qs), sc.scalar(0.0, unit=inv_v.unit), inv_v), dim='curves' + ) + return out diff --git a/tests/tools_test.py b/tests/tools_test.py index b447fcf4..48141f65 100644 --- a/tests/tools_test.py +++ b/tests/tools_test.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +import pytest import scipp as sc from scipp.testing import assert_allclose @@ -126,3 +127,26 @@ def test_combined_curves(): ], ), ) + + +@pytest.mark.filterwarnings("ignore:invalid value encountered in divide") +def test_combined_curves_resolution(): + qgrid = sc.linspace('Q', 0, 1, 26) + data = sc.concat( + ( + sc.ones(dims=['Q'], shape=[10], with_variances=True), + 0.5 * sc.ones(dims=['Q'], shape=[15], with_variances=True), + ), + dim='Q', + ) + data.variances[:] = 0.1 + curves = ( + curve(data, 0, 0.3), + curve(0.5 * data, 0.2, 0.7), + curve(0.25 * data, 0.6, 1.0), + ) + curves[0].coords['Q_resolution'] = sc.midpoints(curves[0].coords['Q']) / 5 + combined = combine_curves(curves, qgrid) + assert 'Q_resolution' in combined.coords + assert combined.coords['Q_resolution'][0] == curves[0].coords['Q_resolution'][1] + assert sc.isnan(combined.coords['Q_resolution'][-1])