diff --git a/xarray_array_testing/base.py b/xarray_array_testing/base.py index 945e7e3..fc69ba0 100644 --- a/xarray_array_testing/base.py +++ b/xarray_array_testing/base.py @@ -24,3 +24,15 @@ def array_strategy_fn(*, shape, dtype): @staticmethod def assert_equal(a, b): npt.assert_equal(a, b) + + def assert_dimension_indexers_equal(self, a, b): + assert type(a) is type(b), f"types don't match: {type(a)} vs {type(b)}" + + if isinstance(a, dict): + assert a.keys() == b.keys(), f"Different dimensions: {list(a)} vs {list(b)}" + + assert all( + self.xp.all(self.xp.equal(a[k], b[k])) for k in a + ), "Differing indexers" + else: + npt.assert_equal(a, b) diff --git a/xarray_array_testing/reduction.py b/xarray_array_testing/reduction.py index 1fa0f23..6833c58 100644 --- a/xarray_array_testing/reduction.py +++ b/xarray_array_testing/reduction.py @@ -1,10 +1,11 @@ +import itertools from contextlib import nullcontext import hypothesis.strategies as st import numpy as np import pytest import xarray.testing.strategies as xrst -from hypothesis import given +from hypothesis import given, note from xarray_array_testing.base import DuckArrayTestMixin @@ -60,17 +61,63 @@ def test_variable_order_reduce(self, op, data): @given(st.data()) def test_variable_order_reduce_index(self, op, data): variable = data.draw(xrst.variables(array_strategy_fn=self.array_strategy_fn)) + possible_dims = [..., list(variable.dims), *variable.dims] + list( + itertools.chain.from_iterable( + map(list, itertools.combinations(variable.dims, length)) + for length in range(1, len(variable.dims)) + ) + ) + dim = data.draw(st.sampled_from(possible_dims)) with self.expected_errors(op, variable=variable): # compute using xr.Variable.() - actual = {k: v.item() for k, v in getattr(variable, op)(dim=...).items()} - - # compute using xp.(array) - index = getattr(self.xp, op)(variable.data) - unraveled = np.unravel_index(index, variable.shape) - expected = dict(zip(variable.dims, unraveled)) - - self.assert_equal(actual, expected) + actual = getattr(variable, op)(dim=dim) + if dim is ... or isinstance(dim, list): + actual_ = {dim_: var.data for dim_, var in actual.items()} + else: + actual_ = actual.data + + note(f"dim: {dim}") + if dim is not ... and not isinstance(dim, list): + # compute using xp.(array) + axis = variable.get_axis_num(dim) + indices = getattr(self.xp, op)(variable.data, axis=axis) + + expected = self.xp.asarray(indices) + elif dim is ... or len(dim) == len(variable.dims): + # compute using xp.(array) + index = getattr(self.xp, op)(variable.data) + + unraveled = np.unravel_index(index, variable.shape) + expected = { + k: self.xp.asarray(v) for k, v in zip(variable.dims, unraveled) + } + elif len(dim) == 1: + dim_ = dim[0] + axis = variable.get_axis_num(dim_) + index = getattr(self.xp, op)(variable.data, axis=axis) + + expected = {dim_: self.xp.asarray(index)} + else: + # move the relevant dims together and flatten + dim_name = object() + stacked = variable.stack({dim_name: dim}) + + reduce_shape = tuple(variable.sizes[d] for d in dim) + index = getattr(self.xp, op)(stacked.data, axis=-1) + + unravelled = np.unravel_index(index, reduce_shape) + + expected = { + d: self.xp.asarray(idx) + for d, idx in zip(dim, unravelled, strict=True) + } + + note(f"original: {variable}") + note(f"actual: {repr(actual_)}") + note(f"expected: {repr(expected)}") + + self.assert_dimension_indexers_equal(actual_, expected) @pytest.mark.parametrize( "op",