diff --git a/stdlib/@tests/test_cases/sqlite3/check_aggregations.py b/stdlib/@tests/test_cases/sqlite3/check_aggregations.py new file mode 100644 index 000000000000..89b890f9a744 --- /dev/null +++ b/stdlib/@tests/test_cases/sqlite3/check_aggregations.py @@ -0,0 +1,142 @@ +import sqlite3 +import sys + + +class WindowSumInt: + def __init__(self) -> None: + self.count = 0 + + def step(self, param: int) -> None: + self.count += param + + def value(self) -> int: + return self.count + + def inverse(self, param: int) -> None: + self.count -= param + + def finalize(self) -> int: + return self.count + + +con = sqlite3.connect(":memory:") +cur = con.execute("CREATE TABLE test(x, y)") +values = [("a", 4), ("b", 5), ("c", 3), ("d", 8), ("e", 1)] +cur.executemany("INSERT INTO test VALUES(?, ?)", values) + +if sys.version_info >= (3, 11): + con.create_window_function("sumint", 1, WindowSumInt) + +con.create_aggregate("sumint", 1, WindowSumInt) +cur.execute( + """ + SELECT x, sumint(y) OVER ( + ORDER BY x ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING + ) AS sum_y + FROM test ORDER BY x +""" +) +con.close() + + +def _create_window_function() -> WindowSumInt: + return WindowSumInt() + + +# A callable should work as well. +if sys.version_info >= (3, 11): + con.create_window_function("sumint", 1, _create_window_function) + con.create_aggregate("sumint", 1, _create_window_function) + +# With num_args set to 1, the callable should not be called with more than one. + + +class WindowSumIntMultiArgs: + def __init__(self) -> None: + self.count = 0 + + def step(self, *args: int) -> None: + self.count += sum(args) + + def value(self) -> int: + return self.count + + def inverse(self, *args: int) -> None: + self.count -= sum(args) + + def finalize(self) -> int: + return self.count + + +if sys.version_info >= (3, 11): + con.create_window_function("sumint", 1, WindowSumIntMultiArgs) + con.create_window_function("sumint", 2, WindowSumIntMultiArgs) + +con.create_aggregate("sumint", 1, WindowSumIntMultiArgs) +con.create_aggregate("sumint", 2, WindowSumIntMultiArgs) + + +# Test case: Fixed parameter aggregates (the common case in practice) +class FixedTwoParamAggregate: + def __init__(self) -> None: + self.total = 0 + + def step(self, a: int, b: int) -> None: + self.total += a + b + + def finalize(self) -> int: + return self.total + + +con.create_aggregate("sum2", 2, FixedTwoParamAggregate) + + +class FixedThreeParamWindowAggregate: + def __init__(self) -> None: + self.total = 0 + + def step(self, a: int, b: int, c: int) -> None: + self.total += a + b + c + + def inverse(self, a: int, b: int, c: int) -> None: + self.total -= a + b + c + + def value(self) -> int: + return self.total + + def finalize(self) -> int: + return self.total + + +if sys.version_info >= (3, 11): + con.create_window_function("sum3", 3, FixedThreeParamWindowAggregate) + + +# What do protocols still catch? + + +# Missing required method +class MissingStep: + def __init__(self) -> None: + self.total = 0 + + def finalize(self) -> int: + return self.total + + +con.create_aggregate("bad", 2, MissingStep) # type: ignore[arg-type] # missing step method + + +# Invalid return type from finalize (not a valid SQLite type) +class BadFinalizeReturn: + def __init__(self) -> None: + self.items: list[int] = [] + + def step(self, x: int) -> None: + self.items.append(x) + + def finalize(self) -> list[int]: # list is not a valid SQLite type + return self.items + + +con.create_aggregate("bad2", 1, BadFinalizeReturn) # type: ignore[arg-type] # bad return type diff --git a/stdlib/@tests/test_cases/check_sqlite3.py b/stdlib/@tests/test_cases/sqlite3/check_connection.py similarity index 100% rename from stdlib/@tests/test_cases/check_sqlite3.py rename to stdlib/@tests/test_cases/sqlite3/check_connection.py diff --git a/stdlib/sqlite3/__init__.pyi b/stdlib/sqlite3/__init__.pyi index 04b978b1b54c..0a848a63a779 100644 --- a/stdlib/sqlite3/__init__.pyi +++ b/stdlib/sqlite3/__init__.pyi @@ -216,6 +216,8 @@ if sys.version_info < (3, 10): _CursorT = TypeVar("_CursorT", bound=Cursor) _SqliteData: TypeAlias = str | ReadableBuffer | int | float | None +_SQLType_contra = TypeVar("_SQLType_contra", bound=_SqliteData, contravariant=True) + # Data that is passed through adapters can be of any type accepted by an adapter. _AdaptedInputData: TypeAlias = _SqliteData | Any # The Mapping must really be a dict, but making it invariant is too annoying. @@ -225,28 +227,29 @@ _IsolationLevel: TypeAlias = Literal["DEFERRED", "EXCLUSIVE", "IMMEDIATE"] | Non _RowFactoryOptions: TypeAlias = type[Row] | Callable[[Cursor, Row], object] | None @type_check_only -class _AnyParamWindowAggregateClass(Protocol): - def step(self, *args: Any) -> object: ... - def inverse(self, *args: Any) -> object: ... - def value(self) -> _SqliteData: ... +class _SingleParamAggregateProtocol(Protocol): + def step(self, param: _SqliteData, /) -> object: ... def finalize(self) -> _SqliteData: ... @type_check_only -class _WindowAggregateClass(Protocol): - step: Callable[..., object] - inverse: Callable[..., object] - def value(self) -> _SqliteData: ... +class _AnyParamAggregateProtocol(Protocol): + @property + def step(self) -> Callable[..., object]: ... def finalize(self) -> _SqliteData: ... @type_check_only -class _AggregateProtocol(Protocol): - def step(self, value: int, /) -> object: ... - def finalize(self) -> int: ... +class _SingleParamWindowAggregateClass(Protocol[_SQLType_contra]): + def step(self, param: _SQLType_contra, /) -> object: ... + def inverse(self, param: _SQLType_contra, /) -> object: ... + def value(self) -> _SqliteData: ... + def finalize(self) -> _SqliteData: ... @type_check_only -class _SingleParamWindowAggregateClass(Protocol): - def step(self, param: Any, /) -> object: ... - def inverse(self, param: Any, /) -> object: ... +class _AnyParamWindowAggregateClass(Protocol): + @property + def step(self) -> Callable[..., object]: ... + @property + def inverse(self) -> Callable[..., object]: ... def value(self) -> _SqliteData: ... def finalize(self) -> _SqliteData: ... @@ -334,22 +337,27 @@ class Connection: def blobopen(self, table: str, column: str, row: int, /, *, readonly: bool = False, name: str = "main") -> Blob: ... def commit(self) -> None: ... - def create_aggregate(self, name: str, n_arg: int, aggregate_class: Callable[[], _AggregateProtocol]) -> None: ... + @overload + def create_aggregate( + self, name: str, n_arg: Literal[1], aggregate_class: Callable[[], _SingleParamAggregateProtocol] + ) -> None: ... + @overload + def create_aggregate(self, name: str, n_arg: int, aggregate_class: Callable[[], _AnyParamAggregateProtocol]) -> None: ... + if sys.version_info >= (3, 11): # num_params determines how many params will be passed to the aggregate class. We provide an overload # for the case where num_params = 1, which is expected to be the common case. @overload def create_window_function( - self, name: str, num_params: Literal[1], aggregate_class: Callable[[], _SingleParamWindowAggregateClass] | None, / - ) -> None: ... - # And for num_params = -1, which means the aggregate must accept any number of parameters. - @overload - def create_window_function( - self, name: str, num_params: Literal[-1], aggregate_class: Callable[[], _AnyParamWindowAggregateClass] | None, / + self, + name: str, + num_params: Literal[1], + aggregate_class: Callable[[], _SingleParamWindowAggregateClass[_SQLType_contra]] | None, + /, ) -> None: ... @overload def create_window_function( - self, name: str, num_params: int, aggregate_class: Callable[[], _WindowAggregateClass] | None, / + self, name: str, num_params: int, aggregate_class: Callable[[], _AnyParamWindowAggregateClass] | None, / ) -> None: ... def create_collation(self, name: str, callback: Callable[[str, str], int | SupportsIndex] | None, /) -> None: ...