Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions xarray_array_testing/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,15 @@ def array_strategy_fn(*, shape, dtype):
@staticmethod
def assert_equal(a, b):
npt.assert_equal(a, b)

def assert_dimension_indexers_equal(self, a, b):
assert type(a) is type(b), f"types don't match: {type(a)} vs {type(b)}"

if isinstance(a, dict):
assert a.keys() == b.keys(), f"Different dimensions: {list(a)} vs {list(b)}"

assert all(
self.xp.all(self.xp.equal(a[k], b[k])) for k in a
), "Differing indexers"
else:
npt.assert_equal(a, b)
65 changes: 56 additions & 9 deletions xarray_array_testing/reduction.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import itertools
from contextlib import nullcontext

import hypothesis.strategies as st
import numpy as np
import pytest
import xarray.testing.strategies as xrst
from hypothesis import given
from hypothesis import given, note

from xarray_array_testing.base import DuckArrayTestMixin

Expand Down Expand Up @@ -60,17 +61,63 @@ def test_variable_order_reduce(self, op, data):
@given(st.data())
def test_variable_order_reduce_index(self, op, data):
variable = data.draw(xrst.variables(array_strategy_fn=self.array_strategy_fn))
possible_dims = [..., list(variable.dims), *variable.dims] + list(
itertools.chain.from_iterable(
map(list, itertools.combinations(variable.dims, length))
for length in range(1, len(variable.dims))
)
)
dim = data.draw(st.sampled_from(possible_dims))

with self.expected_errors(op, variable=variable):
# compute using xr.Variable.<OP>()
actual = {k: v.item() for k, v in getattr(variable, op)(dim=...).items()}

# compute using xp.<OP>(array)
index = getattr(self.xp, op)(variable.data)
unraveled = np.unravel_index(index, variable.shape)
expected = dict(zip(variable.dims, unraveled))

self.assert_equal(actual, expected)
actual = getattr(variable, op)(dim=dim)
if dim is ... or isinstance(dim, list):
actual_ = {dim_: var.data for dim_, var in actual.items()}
else:
actual_ = actual.data

note(f"dim: {dim}")
if dim is not ... and not isinstance(dim, list):
# compute using xp.<OP>(array)
axis = variable.get_axis_num(dim)
indices = getattr(self.xp, op)(variable.data, axis=axis)

expected = self.xp.asarray(indices)
elif dim is ... or len(dim) == len(variable.dims):
# compute using xp.<OP>(array)
index = getattr(self.xp, op)(variable.data)

unraveled = np.unravel_index(index, variable.shape)
expected = {
k: self.xp.asarray(v) for k, v in zip(variable.dims, unraveled)
}
elif len(dim) == 1:
dim_ = dim[0]
axis = variable.get_axis_num(dim_)
index = getattr(self.xp, op)(variable.data, axis=axis)

expected = {dim_: self.xp.asarray(index)}
else:
# move the relevant dims together and flatten
dim_name = object()
stacked = variable.stack({dim_name: dim})

reduce_shape = tuple(variable.sizes[d] for d in dim)
index = getattr(self.xp, op)(stacked.data, axis=-1)

unravelled = np.unravel_index(index, reduce_shape)

expected = {
d: self.xp.asarray(idx)
for d, idx in zip(dim, unravelled, strict=True)
}

note(f"original: {variable}")
note(f"actual: {repr(actual_)}")
note(f"expected: {repr(expected)}")

self.assert_dimension_indexers_equal(actual_, expected)

@pytest.mark.parametrize(
"op",
Expand Down
Loading