Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 142 additions & 0 deletions stdlib/@tests/test_cases/sqlite3/check_aggregations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import sqlite3
import sys


class WindowSumInt:
def __init__(self) -> None:
self.count = 0

def step(self, param: int) -> None:
self.count += param

def value(self) -> int:
return self.count

def inverse(self, param: int) -> None:
self.count -= param

def finalize(self) -> int:
return self.count


con = sqlite3.connect(":memory:")
cur = con.execute("CREATE TABLE test(x, y)")
values = [("a", 4), ("b", 5), ("c", 3), ("d", 8), ("e", 1)]
cur.executemany("INSERT INTO test VALUES(?, ?)", values)

if sys.version_info >= (3, 11):
con.create_window_function("sumint", 1, WindowSumInt)

con.create_aggregate("sumint", 1, WindowSumInt)
cur.execute(
"""
SELECT x, sumint(y) OVER (
ORDER BY x ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING
) AS sum_y
FROM test ORDER BY x
"""
)
con.close()


def _create_window_function() -> WindowSumInt:
return WindowSumInt()


# A callable should work as well.
if sys.version_info >= (3, 11):
con.create_window_function("sumint", 1, _create_window_function)
con.create_aggregate("sumint", 1, _create_window_function)

# With num_args set to 1, the callable should not be called with more than one.


class WindowSumIntMultiArgs:
def __init__(self) -> None:
self.count = 0

def step(self, *args: int) -> None:
self.count += sum(args)

def value(self) -> int:
return self.count

def inverse(self, *args: int) -> None:
self.count -= sum(args)

def finalize(self) -> int:
return self.count


if sys.version_info >= (3, 11):
con.create_window_function("sumint", 1, WindowSumIntMultiArgs)
con.create_window_function("sumint", 2, WindowSumIntMultiArgs)

con.create_aggregate("sumint", 1, WindowSumIntMultiArgs)
con.create_aggregate("sumint", 2, WindowSumIntMultiArgs)


# Test case: Fixed parameter aggregates (the common case in practice)
class FixedTwoParamAggregate:
def __init__(self) -> None:
self.total = 0

def step(self, a: int, b: int) -> None:
self.total += a + b

def finalize(self) -> int:
return self.total


con.create_aggregate("sum2", 2, FixedTwoParamAggregate)


class FixedThreeParamWindowAggregate:
def __init__(self) -> None:
self.total = 0

def step(self, a: int, b: int, c: int) -> None:
self.total += a + b + c

def inverse(self, a: int, b: int, c: int) -> None:
self.total -= a + b + c

def value(self) -> int:
return self.total

def finalize(self) -> int:
return self.total


if sys.version_info >= (3, 11):
con.create_window_function("sum3", 3, FixedThreeParamWindowAggregate)


# What do protocols still catch?


# Missing required method
class MissingStep:
def __init__(self) -> None:
self.total = 0

def finalize(self) -> int:
return self.total


con.create_aggregate("bad", 2, MissingStep) # type: ignore[arg-type] # missing step method


# Invalid return type from finalize (not a valid SQLite type)
class BadFinalizeReturn:
def __init__(self) -> None:
self.items: list[int] = []

def step(self, x: int) -> None:
self.items.append(x)

def finalize(self) -> list[int]: # list is not a valid SQLite type
return self.items


con.create_aggregate("bad2", 1, BadFinalizeReturn) # type: ignore[arg-type] # bad return type
52 changes: 30 additions & 22 deletions stdlib/sqlite3/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,8 @@ if sys.version_info < (3, 10):

_CursorT = TypeVar("_CursorT", bound=Cursor)
_SqliteData: TypeAlias = str | ReadableBuffer | int | float | None
_SQLType_contra = TypeVar("_SQLType_contra", bound=_SqliteData, contravariant=True)

# Data that is passed through adapters can be of any type accepted by an adapter.
_AdaptedInputData: TypeAlias = _SqliteData | Any
# The Mapping must really be a dict, but making it invariant is too annoying.
Expand All @@ -225,28 +227,29 @@ _IsolationLevel: TypeAlias = Literal["DEFERRED", "EXCLUSIVE", "IMMEDIATE"] | Non
_RowFactoryOptions: TypeAlias = type[Row] | Callable[[Cursor, Row], object] | None

@type_check_only
class _AnyParamWindowAggregateClass(Protocol):
def step(self, *args: Any) -> object: ...
def inverse(self, *args: Any) -> object: ...
def value(self) -> _SqliteData: ...
class _SingleParamAggregateProtocol(Protocol):
def step(self, param: _SqliteData, /) -> object: ...
def finalize(self) -> _SqliteData: ...

@type_check_only
class _WindowAggregateClass(Protocol):
step: Callable[..., object]
inverse: Callable[..., object]
def value(self) -> _SqliteData: ...
class _AnyParamAggregateProtocol(Protocol):
@property
def step(self) -> Callable[..., object]: ...
def finalize(self) -> _SqliteData: ...

@type_check_only
class _AggregateProtocol(Protocol):
def step(self, value: int, /) -> object: ...
def finalize(self) -> int: ...
class _SingleParamWindowAggregateClass(Protocol[_SQLType_contra]):
def step(self, param: _SQLType_contra, /) -> object: ...
def inverse(self, param: _SQLType_contra, /) -> object: ...
def value(self) -> _SqliteData: ...
def finalize(self) -> _SqliteData: ...

@type_check_only
class _SingleParamWindowAggregateClass(Protocol):
def step(self, param: Any, /) -> object: ...
def inverse(self, param: Any, /) -> object: ...
class _AnyParamWindowAggregateClass(Protocol):
@property
def step(self) -> Callable[..., object]: ...
@property
def inverse(self) -> Callable[..., object]: ...
def value(self) -> _SqliteData: ...
def finalize(self) -> _SqliteData: ...

Expand Down Expand Up @@ -334,22 +337,27 @@ class Connection:
def blobopen(self, table: str, column: str, row: int, /, *, readonly: bool = False, name: str = "main") -> Blob: ...

def commit(self) -> None: ...
def create_aggregate(self, name: str, n_arg: int, aggregate_class: Callable[[], _AggregateProtocol]) -> None: ...
@overload
def create_aggregate(
self, name: str, n_arg: Literal[1], aggregate_class: Callable[[], _SingleParamAggregateProtocol]
) -> None: ...
@overload
def create_aggregate(self, name: str, n_arg: int, aggregate_class: Callable[[], _AnyParamAggregateProtocol]) -> None: ...

if sys.version_info >= (3, 11):
# num_params determines how many params will be passed to the aggregate class. We provide an overload
# for the case where num_params = 1, which is expected to be the common case.
@overload
def create_window_function(
self, name: str, num_params: Literal[1], aggregate_class: Callable[[], _SingleParamWindowAggregateClass] | None, /
) -> None: ...
# And for num_params = -1, which means the aggregate must accept any number of parameters.
@overload
def create_window_function(
self, name: str, num_params: Literal[-1], aggregate_class: Callable[[], _AnyParamWindowAggregateClass] | None, /
Copy link
Contributor Author

@max-muoto max-muoto Dec 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't necessarily enforce usage of *args in the any params case.

self,
name: str,
num_params: Literal[1],
aggregate_class: Callable[[], _SingleParamWindowAggregateClass[_SQLType_contra]] | None,
/,
) -> None: ...
@overload
def create_window_function(
self, name: str, num_params: int, aggregate_class: Callable[[], _WindowAggregateClass] | None, /
self, name: str, num_params: int, aggregate_class: Callable[[], _AnyParamWindowAggregateClass] | None, /
) -> None: ...

def create_collation(self, name: str, callback: Callable[[str, str], int | SupportsIndex] | None, /) -> None: ...
Expand Down