Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 32 additions & 66 deletions xarray_array_testing/indexing.py
Original file line number Diff line number Diff line change
@@ -1,73 +1,36 @@
from contextlib import nullcontext

import hypothesis.extra.numpy as npst
import hypothesis.strategies as st
import xarray as xr
import xarray.testing.strategies as xrst
from hypothesis import given

from xarray_array_testing.base import DuckArrayTestMixin
from xarray_array_testing.strategies import orthogonal_indexers, vectorized_indexers


def scalar_indexer(size):
return st.integers(min_value=-size, max_value=size - 1)
def broadcast_orthogonal_indexers(indexers, sizes, *, xp):
def _broadcasting_shape(index, total):
return tuple(1 if i != index else -1 for i in range(total))

def _as_array(indexer, size):
if isinstance(indexer, slice):
return xp.asarray(range(*indexer.indices(size)), dtype="int64")
elif isinstance(indexer, int):
return xp.asarray(indexer, dtype="int64")
else:
return indexer

def integer_array_indexer(size):
dtypes = npst.integer_dtypes()

return npst.arrays(
dtypes, size, elements={"min_value": -size, "max_value": size - 1}
)


def indexers(size, indexer_types):
indexer_strategy_fns = {
"scalars": scalar_indexer,
"slices": st.slices,
"integer_arrays": integer_array_indexer,
indexer_arrays = {
dim: _as_array(indexer, sizes[dim]) for dim, indexer in indexers.items()
}

bad_types = set(indexer_types) - indexer_strategy_fns.keys()
if bad_types:
raise ValueError(f"unknown indexer strategies: {sorted(bad_types)}")

# use the order of definition to prefer simpler strategies over more complex
# ones
indexer_strategies = [
strategy_fn(size)
for name, strategy_fn in indexer_strategy_fns.items()
if name in indexer_types
]
return st.one_of(*indexer_strategies)


@st.composite
def orthogonal_indexers(draw, sizes, indexer_types):
# TODO: make use of `flatmap` and `builds` instead of `composite`
possible_indexers = {
dim: indexers(size, indexer_types) for dim, size in sizes.items()
}
concrete_indexers = draw(xrst.unique_subset_of(possible_indexers))
return {dim: draw(indexer) for dim, indexer in concrete_indexers.items()}


@st.composite
def vectorized_indexers(draw, sizes):
max_size = max(sizes.values())
shape = draw(st.integers(min_value=1, max_value=max_size))
dtypes = npst.integer_dtypes()

indexers = {
dim: npst.arrays(
dtypes, shape, elements={"min_value": -size, "max_value": size - 1}
broadcasted = xp.broadcast_arrays(
*(
xp.reshape(indexer, _broadcasting_shape(index, total=len(indexers)))
for index, indexer in enumerate(indexer_arrays.values())
)
for dim, size in sizes.items()
}
)

return {
dim: xr.Variable("points", draw(indexer)) for dim, indexer in indexers.items()
}
return dict(zip(indexer_arrays.keys(), broadcasted))


class IndexingTests(DuckArrayTestMixin):
Expand All @@ -81,19 +44,22 @@ def expected_errors(op, **parameters):

@given(st.data())
def test_variable_isel_orthogonal(self, data):
indexer_types = data.draw(
st.lists(self.orthogonal_indexer_types, min_size=1, unique=True)
)
variable = data.draw(xrst.variables(array_strategy_fn=self.array_strategy_fn))
idx = data.draw(orthogonal_indexers(variable.sizes, indexer_types))
idx = data.draw(orthogonal_indexers(sizes=variable.sizes, min_dims=1))

with self.expected_errors(
"isel_orthogonal", variable=variable, indexer_types=indexer_types
):
with self.expected_errors("isel_orthogonal", variable=variable, indexers=idx):
actual = variable.isel(idx).data

raw_indexers = {dim: idx.get(dim, slice(None)) for dim in variable.dims}
expected = variable.data[*raw_indexers.values()]
sorted_dims = sorted(idx.keys(), key=variable.dims.index, reverse=True)
expected = variable.data
for dim in sorted_dims:
indexer = idx[dim]
axis = variable.get_axis_num(dim)
if isinstance(indexer, slice):
indexer = self.xp.asarray(
range(*indexer.indices(variable.sizes[dim])), dtype="int64"
)
expected = self.xp.take(expected, indexer, axis=axis)

assert isinstance(
actual, self.array_type("orthogonal_indexing")
Expand All @@ -103,7 +69,7 @@ def test_variable_isel_orthogonal(self, data):
@given(st.data())
def test_variable_isel_vectorized(self, data):
variable = data.draw(xrst.variables(array_strategy_fn=self.array_strategy_fn))
idx = data.draw(vectorized_indexers(variable.sizes))
idx = data.draw(vectorized_indexers(sizes=variable.sizes, min_dims=1))

with self.expected_errors("isel_vectorized", variable=variable):
actual = variable.isel(idx).data
Expand Down
264 changes: 264 additions & 0 deletions xarray_array_testing/strategies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
from collections.abc import Hashable
from itertools import compress

import hypothesis.extra.numpy as npst
import hypothesis.strategies as st
import numpy as np
import xarray as xr
from xarray.testing.strategies import unique_subset_of


def _basic_indexers(size):
return st.one_of(
st.integers(min_value=-size, max_value=size - 1),
st.slices(size),
)


def _outer_array_indexers(size, max_size):
return npst.arrays(
dtype=np.int64,
shape=st.integers(min_value=1, max_value=min(size, max_size)),
elements=st.integers(min_value=-size, max_value=size - 1),
)


# vendored from `xarray`, should be included in `xarray>=2026.01.0`
@st.composite
def basic_indexers(
draw,
/,
*,
sizes: dict[Hashable, int],
min_dims: int = 1,
max_dims: int | None = None,
) -> dict[Hashable, int | slice]:
"""Generate basic indexers using ``hypothesis.extra.numpy.basic_indices``.

Parameters
----------
draw : callable
sizes : dict[Hashable, int]
Dictionary mapping dimension names to their sizes.
min_dims : int, optional
Minimum number of dimensions to index.
max_dims : int or None, optional
Maximum number of dimensions to index.

Returns
-------
sizes : mapping of hashable to int or slice
Indexers as a dict with keys randomly selected from ``sizes.keys()``.

See Also
--------
hypothesis.strategies.slices
"""
selected_dims = draw(unique_subset_of(sizes, min_size=min_dims, max_size=max_dims))

# Generate one basic index (int or slice) per selected dimension
idxr = {
dim: draw(
st.one_of(
st.integers(min_value=-size, max_value=size - 1),
st.slices(size),
)
)
for dim, size in selected_dims.items()
}
return idxr


@st.composite
def outer_array_indexers(
draw,
/,
*,
sizes: dict[Hashable, int],
min_dims: int = 0,
max_dims: int | None = None,
max_size: int = 10,
) -> dict[Hashable, np.ndarray]:
"""Generate outer array indexers (vectorized/orthogonal indexing).

Parameters
----------
draw : callable
The Hypothesis draw function (automatically provided by @st.composite).
sizes : dict[Hashable, int]
Dictionary mapping dimension names to their sizes.
min_dims : int, optional
Minimum number of dimensions to index
max_dims : int or None, optional
Maximum number of dimensions to index

Returns
-------
sizes : mapping of hashable to np.ndarray
Indexers as a dict with keys randomly selected from ``sizes.keys()``.
Values are 1D numpy arrays of integer indices for each dimension.

See Also
--------
hypothesis.extra.numpy.arrays
"""
selected_dims = draw(unique_subset_of(sizes, min_size=min_dims, max_size=max_dims))
idxr = {
dim: draw(
npst.arrays(
dtype=np.int64,
shape=st.integers(min_value=1, max_value=min(size, max_size)),
elements=st.integers(min_value=-size, max_value=size - 1),
)
)
for dim, size in selected_dims.items()
}
return idxr


@st.composite
def orthogonal_indexers(
draw,
/,
*,
sizes: dict[Hashable, int],
min_dims: int = 2,
max_dims: int | None = None,
max_size: int = 10,
) -> dict[Hashable, int | slice | np.ndarray]:
"""Generate orthogonal indexers (vectorized/orthogonal indexing).

Parameters
----------
draw : callable
The Hypothesis draw function (automatically provided by @st.composite).
sizes : dict[Hashable, int]
Dictionary mapping dimension names to their sizes.
min_dims : int, optional
Minimum number of dimensions to index
max_dims : int or None, optional
Maximum number of dimensions to index
max_size : int, optional
Maximum size of array indexers

Returns
-------
sizes : mapping of hashable to indexer
Indexers as a dict with keys randomly selected from ``sizes.keys()``.
Values are integers, slices, or 1D numpy arrays of integer indices for
each dimension.

See Also
--------
hypothesis.extra.numpy.arrays
"""
selected_dims = draw(unique_subset_of(sizes, min_size=min_dims, max_size=max_dims))

return {
dim: draw(
st.one_of(
_basic_indexers(size),
_outer_array_indexers(size, max_size),
)
)
for dim, size in selected_dims.items()
}


@st.composite
def vectorized_indexers(
draw,
/,
*,
sizes: dict[Hashable, int],
min_dims: int = 2,
max_dims: int | None = None,
min_ndim: int = 1,
max_ndim: int = 3,
min_size: int = 1,
max_size: int = 5,
) -> dict[Hashable, xr.Variable]:
"""Generate vectorized (fancy) indexers where all arrays are broadcastable.

In vectorized indexing, all array indexers must have compatible shapes
that can be broadcast together, and the result shape is determined by
broadcasting the indexer arrays.

Parameters
----------
draw : callable
The Hypothesis draw function (automatically provided by @st.composite).
sizes : dict[Hashable, int]
Dictionary mapping dimension names to their sizes.
min_dims : int, optional
Minimum number of dimensions to index. Default is 2, so that we always have a "trajectory".
Use ``outer_array_indexers`` for the ``min_dims==1`` case.
max_dims : int or None, optional
Maximum number of dimensions to index.
min_ndim : int, optional
Minimum number of dimensions for the result arrays.
max_ndim : int, optional
Maximum number of dimensions for the result arrays.
min_size : int, optional
Minimum size for each dimension in the result arrays.
max_size : int, optional
Maximum size for each dimension in the result arrays.

Returns
-------
sizes : mapping of hashable to Variable
Indexers as a dict with keys randomly selected from sizes.keys().
Values are DataArrays of integer indices that are all broadcastable
to a common shape.

See Also
--------
hypothesis.extra.numpy.arrays
"""
selected_dims = draw(unique_subset_of(sizes, min_size=min_dims, max_size=max_dims))

# Generate a common broadcast shape for all arrays
# Use min_ndim to max_ndim dimensions for the result shape
result_shape = draw(
st.lists(
st.integers(min_value=min_size, max_value=max_size),
min_size=min_ndim,
max_size=max_ndim,
)
)
result_ndim = len(result_shape)

# Create dimension names for the vectorized result
vec_dims = tuple(f"vec_{i}" for i in range(result_ndim))

# Generate array indexers for each selected dimension
# All arrays must be broadcastable to the same result_shape
idxr = {}
for dim, size in selected_dims.items():
array_shape = draw(
npst.broadcastable_shapes(
shape=tuple(result_shape),
min_dims=min_ndim,
max_dims=result_ndim,
)
)

# For xarray broadcasting, drop dimensions where size differs from result_shape
# (numpy broadcasts size-1, but xarray requires matching sizes or missing dims)
# Right-align array_shape with result_shape for comparison
aligned_dims = vec_dims[-len(array_shape) :] if array_shape else ()
aligned_result = result_shape[-len(array_shape) :] if array_shape else []
keep_mask = [s == r for s, r in zip(array_shape, aligned_result, strict=True)]
filtered_shape = tuple(compress(array_shape, keep_mask))
filtered_dims = tuple(compress(aligned_dims, keep_mask))

# Generate array of valid indices for this dimension
indices = draw(
npst.arrays(
dtype=np.int64,
shape=filtered_shape,
elements=st.integers(min_value=-size, max_value=size - 1),
)
)
idxr[dim] = xr.Variable(data=indices, dims=filtered_dims)
return idxr
Loading