From ef3dc7dc8ec1c9d8753acdeb5f3661ad48067908 Mon Sep 17 00:00:00 2001 From: Max Muoto Date: Mon, 29 Dec 2025 19:24:12 -0600 Subject: [PATCH 1/4] Improve sqlite aggregration protocols --- .../test_cases/sqlite3/check_aggregations.py | 142 ++++++++++++++++++ .../check_connection.py} | 0 stdlib/sqlite3/__init__.pyi | 53 ++++--- 3 files changed, 172 insertions(+), 23 deletions(-) create mode 100644 stdlib/@tests/test_cases/sqlite3/check_aggregations.py rename stdlib/@tests/test_cases/{check_sqlite3.py => sqlite3/check_connection.py} (100%) diff --git a/stdlib/@tests/test_cases/sqlite3/check_aggregations.py b/stdlib/@tests/test_cases/sqlite3/check_aggregations.py new file mode 100644 index 000000000000..89b890f9a744 --- /dev/null +++ b/stdlib/@tests/test_cases/sqlite3/check_aggregations.py @@ -0,0 +1,142 @@ +import sqlite3 +import sys + + +class WindowSumInt: + def __init__(self) -> None: + self.count = 0 + + def step(self, param: int) -> None: + self.count += param + + def value(self) -> int: + return self.count + + def inverse(self, param: int) -> None: + self.count -= param + + def finalize(self) -> int: + return self.count + + +con = sqlite3.connect(":memory:") +cur = con.execute("CREATE TABLE test(x, y)") +values = [("a", 4), ("b", 5), ("c", 3), ("d", 8), ("e", 1)] +cur.executemany("INSERT INTO test VALUES(?, ?)", values) + +if sys.version_info >= (3, 11): + con.create_window_function("sumint", 1, WindowSumInt) + +con.create_aggregate("sumint", 1, WindowSumInt) +cur.execute( + """ + SELECT x, sumint(y) OVER ( + ORDER BY x ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING + ) AS sum_y + FROM test ORDER BY x +""" +) +con.close() + + +def _create_window_function() -> WindowSumInt: + return WindowSumInt() + + +# A callable should work as well. +if sys.version_info >= (3, 11): + con.create_window_function("sumint", 1, _create_window_function) + con.create_aggregate("sumint", 1, _create_window_function) + +# With num_args set to 1, the callable should not be called with more than one. + + +class WindowSumIntMultiArgs: + def __init__(self) -> None: + self.count = 0 + + def step(self, *args: int) -> None: + self.count += sum(args) + + def value(self) -> int: + return self.count + + def inverse(self, *args: int) -> None: + self.count -= sum(args) + + def finalize(self) -> int: + return self.count + + +if sys.version_info >= (3, 11): + con.create_window_function("sumint", 1, WindowSumIntMultiArgs) + con.create_window_function("sumint", 2, WindowSumIntMultiArgs) + +con.create_aggregate("sumint", 1, WindowSumIntMultiArgs) +con.create_aggregate("sumint", 2, WindowSumIntMultiArgs) + + +# Test case: Fixed parameter aggregates (the common case in practice) +class FixedTwoParamAggregate: + def __init__(self) -> None: + self.total = 0 + + def step(self, a: int, b: int) -> None: + self.total += a + b + + def finalize(self) -> int: + return self.total + + +con.create_aggregate("sum2", 2, FixedTwoParamAggregate) + + +class FixedThreeParamWindowAggregate: + def __init__(self) -> None: + self.total = 0 + + def step(self, a: int, b: int, c: int) -> None: + self.total += a + b + c + + def inverse(self, a: int, b: int, c: int) -> None: + self.total -= a + b + c + + def value(self) -> int: + return self.total + + def finalize(self) -> int: + return self.total + + +if sys.version_info >= (3, 11): + con.create_window_function("sum3", 3, FixedThreeParamWindowAggregate) + + +# What do protocols still catch? + + +# Missing required method +class MissingStep: + def __init__(self) -> None: + self.total = 0 + + def finalize(self) -> int: + return self.total + + +con.create_aggregate("bad", 2, MissingStep) # type: ignore[arg-type] # missing step method + + +# Invalid return type from finalize (not a valid SQLite type) +class BadFinalizeReturn: + def __init__(self) -> None: + self.items: list[int] = [] + + def step(self, x: int) -> None: + self.items.append(x) + + def finalize(self) -> list[int]: # list is not a valid SQLite type + return self.items + + +con.create_aggregate("bad2", 1, BadFinalizeReturn) # type: ignore[arg-type] # bad return type diff --git a/stdlib/@tests/test_cases/check_sqlite3.py b/stdlib/@tests/test_cases/sqlite3/check_connection.py similarity index 100% rename from stdlib/@tests/test_cases/check_sqlite3.py rename to stdlib/@tests/test_cases/sqlite3/check_connection.py diff --git a/stdlib/sqlite3/__init__.pyi b/stdlib/sqlite3/__init__.pyi index 04b978b1b54c..a3f22f0de06d 100644 --- a/stdlib/sqlite3/__init__.pyi +++ b/stdlib/sqlite3/__init__.pyi @@ -216,6 +216,7 @@ if sys.version_info < (3, 10): _CursorT = TypeVar("_CursorT", bound=Cursor) _SqliteData: TypeAlias = str | ReadableBuffer | int | float | None +_SQLType = TypeVar("_SQLType", bound=_SqliteData) # Data that is passed through adapters can be of any type accepted by an adapter. _AdaptedInputData: TypeAlias = _SqliteData | Any # The Mapping must really be a dict, but making it invariant is too annoying. @@ -225,28 +226,29 @@ _IsolationLevel: TypeAlias = Literal["DEFERRED", "EXCLUSIVE", "IMMEDIATE"] | Non _RowFactoryOptions: TypeAlias = type[Row] | Callable[[Cursor, Row], object] | None @type_check_only -class _AnyParamWindowAggregateClass(Protocol): - def step(self, *args: Any) -> object: ... - def inverse(self, *args: Any) -> object: ... - def value(self) -> _SqliteData: ... - def finalize(self) -> _SqliteData: ... +class _SingleParamAggregateProtocol(Protocol[_SQLType]): + def step(self, param: _SQLType, /) -> object: ... + def finalize(self) -> _SQLType: ... @type_check_only -class _WindowAggregateClass(Protocol): - step: Callable[..., object] - inverse: Callable[..., object] - def value(self) -> _SqliteData: ... +class _AnyParamAggregateProtocol(Protocol): + @property + def step(self) -> Callable[..., object]: ... def finalize(self) -> _SqliteData: ... @type_check_only -class _AggregateProtocol(Protocol): - def step(self, value: int, /) -> object: ... - def finalize(self) -> int: ... +class _SingleParamWindowAggregateClass(Protocol[_SQLType]): + def step(self, param: _SQLType, /) -> object: ... + def inverse(self, param: _SQLType, /) -> object: ... + def value(self) -> _SQLType: ... + def finalize(self) -> _SQLType: ... @type_check_only -class _SingleParamWindowAggregateClass(Protocol): - def step(self, param: Any, /) -> object: ... - def inverse(self, param: Any, /) -> object: ... +class _AnyParamWindowAggregateClass(Protocol): + @property + def step(self) -> Callable[..., object]: ... + @property + def inverse(self) -> Callable[..., object]: ... def value(self) -> _SqliteData: ... def finalize(self) -> _SqliteData: ... @@ -334,22 +336,27 @@ class Connection: def blobopen(self, table: str, column: str, row: int, /, *, readonly: bool = False, name: str = "main") -> Blob: ... def commit(self) -> None: ... - def create_aggregate(self, name: str, n_arg: int, aggregate_class: Callable[[], _AggregateProtocol]) -> None: ... + @overload + def create_aggregate( + self, name: str, n_arg: Literal[1], aggregate_class: Callable[[], _SingleParamAggregateProtocol[_SQLType]] + ) -> None: ... + @overload + def create_aggregate(self, name: str, n_arg: int, aggregate_class: Callable[[], _AnyParamAggregateProtocol]) -> None: ... + if sys.version_info >= (3, 11): # num_params determines how many params will be passed to the aggregate class. We provide an overload # for the case where num_params = 1, which is expected to be the common case. @overload def create_window_function( - self, name: str, num_params: Literal[1], aggregate_class: Callable[[], _SingleParamWindowAggregateClass] | None, / - ) -> None: ... - # And for num_params = -1, which means the aggregate must accept any number of parameters. - @overload - def create_window_function( - self, name: str, num_params: Literal[-1], aggregate_class: Callable[[], _AnyParamWindowAggregateClass] | None, / + self, + name: str, + num_params: Literal[1], + aggregate_class: Callable[[], _SingleParamWindowAggregateClass[_SQLType]] | None, + /, ) -> None: ... @overload def create_window_function( - self, name: str, num_params: int, aggregate_class: Callable[[], _WindowAggregateClass] | None, / + self, name: str, num_params: int, aggregate_class: Callable[[], _AnyParamWindowAggregateClass] | None, / ) -> None: ... def create_collation(self, name: str, callback: Callable[[str, str], int | SupportsIndex] | None, /) -> None: ... From 9fd27e4b58d03a26975ae069bfa7b60ecd0f35ed Mon Sep 17 00:00:00 2001 From: Max Muoto Date: Mon, 29 Dec 2025 19:38:58 -0600 Subject: [PATCH 2/4] unnecessarily strict return types --- .../test_cases/sqlite3/check_aggregations.py | 42 +++++++++++++++++ stdlib/sqlite3/__init__.pyi | 46 +++++++++++++++---- 2 files changed, 78 insertions(+), 10 deletions(-) diff --git a/stdlib/@tests/test_cases/sqlite3/check_aggregations.py b/stdlib/@tests/test_cases/sqlite3/check_aggregations.py index 89b890f9a744..4f63c77ec06a 100644 --- a/stdlib/@tests/test_cases/sqlite3/check_aggregations.py +++ b/stdlib/@tests/test_cases/sqlite3/check_aggregations.py @@ -75,6 +75,48 @@ def finalize(self) -> int: con.create_aggregate("sumint", 1, WindowSumIntMultiArgs) con.create_aggregate("sumint", 2, WindowSumIntMultiArgs) +# n_arg=-1 requires *args to handle any number of arguments +if sys.version_info >= (3, 11): + con.create_window_function("sumint_varargs", -1, WindowSumIntMultiArgs) + +con.create_aggregate("sumint_varargs", -1, WindowSumIntMultiArgs) + + +# n_arg=-1 should reject fixed-arity methods +class FixedArityAggregate: + def __init__(self) -> None: + self.total = 0 + + def step(self, a: int, b: int) -> None: + self.total += a + b + + def finalize(self) -> int: + return self.total + + +con.create_aggregate("bad_varargs", -1, FixedArityAggregate) # type: ignore[arg-type] + + +class FixedArityWindowAggregate: + def __init__(self) -> None: + self.total = 0 + + def step(self, a: int, b: int) -> None: + self.total += a + b + + def inverse(self, a: int, b: int) -> None: + self.total -= a + b + + def value(self) -> int: + return self.total + + def finalize(self) -> int: + return self.total + + +if sys.version_info >= (3, 11): + con.create_window_function("bad_varargs", -1, FixedArityWindowAggregate) # type: ignore[arg-type] + # Test case: Fixed parameter aggregates (the common case in practice) class FixedTwoParamAggregate: diff --git a/stdlib/sqlite3/__init__.pyi b/stdlib/sqlite3/__init__.pyi index a3f22f0de06d..a4d52987aec5 100644 --- a/stdlib/sqlite3/__init__.pyi +++ b/stdlib/sqlite3/__init__.pyi @@ -217,6 +217,8 @@ if sys.version_info < (3, 10): _CursorT = TypeVar("_CursorT", bound=Cursor) _SqliteData: TypeAlias = str | ReadableBuffer | int | float | None _SQLType = TypeVar("_SQLType", bound=_SqliteData) +_SQLType_contra = TypeVar("_SQLType_contra", bound=_SqliteData, contravariant=True) + # Data that is passed through adapters can be of any type accepted by an adapter. _AdaptedInputData: TypeAlias = _SqliteData | Any # The Mapping must really be a dict, but making it invariant is too annoying. @@ -226,9 +228,9 @@ _IsolationLevel: TypeAlias = Literal["DEFERRED", "EXCLUSIVE", "IMMEDIATE"] | Non _RowFactoryOptions: TypeAlias = type[Row] | Callable[[Cursor, Row], object] | None @type_check_only -class _SingleParamAggregateProtocol(Protocol[_SQLType]): - def step(self, param: _SQLType, /) -> object: ... - def finalize(self) -> _SQLType: ... +class _SingleParamAggregateProtocol(Protocol): + def step(self, param: _SqliteData, /) -> object: ... + def finalize(self) -> _SqliteData: ... @type_check_only class _AnyParamAggregateProtocol(Protocol): @@ -237,11 +239,16 @@ class _AnyParamAggregateProtocol(Protocol): def finalize(self) -> _SqliteData: ... @type_check_only -class _SingleParamWindowAggregateClass(Protocol[_SQLType]): - def step(self, param: _SQLType, /) -> object: ... - def inverse(self, param: _SQLType, /) -> object: ... - def value(self) -> _SQLType: ... - def finalize(self) -> _SQLType: ... +class _AnyArgsAggregateProtocol(Protocol): + def step(self, *args: _SqliteData) -> object: ... + def finalize(self) -> _SqliteData: ... + +@type_check_only +class _SingleParamWindowAggregateClass(Protocol[_SQLType_contra]): + def step(self, param: _SQLType_contra, /) -> object: ... + def inverse(self, param: _SQLType_contra, /) -> object: ... + def value(self) -> _SqliteData: ... + def finalize(self) -> _SqliteData: ... @type_check_only class _AnyParamWindowAggregateClass(Protocol): @@ -252,6 +259,13 @@ class _AnyParamWindowAggregateClass(Protocol): def value(self) -> _SqliteData: ... def finalize(self) -> _SqliteData: ... +@type_check_only +class _AnyArgsWindowAggregateClass(Protocol): + def step(self, *args: _SqliteData) -> object: ... + def inverse(self, *args: _SqliteData) -> object: ... + def value(self) -> _SqliteData: ... + def finalize(self) -> _SqliteData: ... + # These classes are implemented in the C module _sqlite3. At runtime, they're imported # from there into sqlite3.dbapi2 and from that module to here. However, they # consider themselves to live in the sqlite3.* namespace, so we'll define them here. @@ -338,7 +352,11 @@ class Connection: def commit(self) -> None: ... @overload def create_aggregate( - self, name: str, n_arg: Literal[1], aggregate_class: Callable[[], _SingleParamAggregateProtocol[_SQLType]] + self, name: str, n_arg: Literal[1], aggregate_class: Callable[[], _SingleParamAggregateProtocol] + ) -> None: ... + @overload + def create_aggregate( + self, name: str, n_arg: Literal[-1], aggregate_class: Callable[[], _AnyArgsAggregateProtocol] ) -> None: ... @overload def create_aggregate(self, name: str, n_arg: int, aggregate_class: Callable[[], _AnyParamAggregateProtocol]) -> None: ... @@ -351,7 +369,15 @@ class Connection: self, name: str, num_params: Literal[1], - aggregate_class: Callable[[], _SingleParamWindowAggregateClass[_SQLType]] | None, + aggregate_class: Callable[[], _SingleParamWindowAggregateClass[_SQLType_contra]] | None, + /, + ) -> None: ... + @overload + def create_window_function( + self, + name: str, + num_params: Literal[-1], + aggregate_class: Callable[[], _AnyArgsWindowAggregateClass] | None, /, ) -> None: ... @overload From 604b5410d5190e565262f016449ba5339d8a3cf7 Mon Sep 17 00:00:00 2001 From: Max Muoto Date: Mon, 29 Dec 2025 19:39:41 -0600 Subject: [PATCH 3/4] remove unused typevar --- stdlib/sqlite3/__init__.pyi | 1 - 1 file changed, 1 deletion(-) diff --git a/stdlib/sqlite3/__init__.pyi b/stdlib/sqlite3/__init__.pyi index a4d52987aec5..877617e610d4 100644 --- a/stdlib/sqlite3/__init__.pyi +++ b/stdlib/sqlite3/__init__.pyi @@ -216,7 +216,6 @@ if sys.version_info < (3, 10): _CursorT = TypeVar("_CursorT", bound=Cursor) _SqliteData: TypeAlias = str | ReadableBuffer | int | float | None -_SQLType = TypeVar("_SQLType", bound=_SqliteData) _SQLType_contra = TypeVar("_SQLType_contra", bound=_SqliteData, contravariant=True) # Data that is passed through adapters can be of any type accepted by an adapter. From 8adf710b66bcf48445fa28ed64cb4433da3d05c1 Mon Sep 17 00:00:00 2001 From: Max Muoto Date: Mon, 29 Dec 2025 19:58:31 -0600 Subject: [PATCH 4/4] remove unmatched -1 protcols --- .../test_cases/sqlite3/check_aggregations.py | 42 ------------------- stdlib/sqlite3/__init__.pyi | 24 ----------- 2 files changed, 66 deletions(-) diff --git a/stdlib/@tests/test_cases/sqlite3/check_aggregations.py b/stdlib/@tests/test_cases/sqlite3/check_aggregations.py index 4f63c77ec06a..89b890f9a744 100644 --- a/stdlib/@tests/test_cases/sqlite3/check_aggregations.py +++ b/stdlib/@tests/test_cases/sqlite3/check_aggregations.py @@ -75,48 +75,6 @@ def finalize(self) -> int: con.create_aggregate("sumint", 1, WindowSumIntMultiArgs) con.create_aggregate("sumint", 2, WindowSumIntMultiArgs) -# n_arg=-1 requires *args to handle any number of arguments -if sys.version_info >= (3, 11): - con.create_window_function("sumint_varargs", -1, WindowSumIntMultiArgs) - -con.create_aggregate("sumint_varargs", -1, WindowSumIntMultiArgs) - - -# n_arg=-1 should reject fixed-arity methods -class FixedArityAggregate: - def __init__(self) -> None: - self.total = 0 - - def step(self, a: int, b: int) -> None: - self.total += a + b - - def finalize(self) -> int: - return self.total - - -con.create_aggregate("bad_varargs", -1, FixedArityAggregate) # type: ignore[arg-type] - - -class FixedArityWindowAggregate: - def __init__(self) -> None: - self.total = 0 - - def step(self, a: int, b: int) -> None: - self.total += a + b - - def inverse(self, a: int, b: int) -> None: - self.total -= a + b - - def value(self) -> int: - return self.total - - def finalize(self) -> int: - return self.total - - -if sys.version_info >= (3, 11): - con.create_window_function("bad_varargs", -1, FixedArityWindowAggregate) # type: ignore[arg-type] - # Test case: Fixed parameter aggregates (the common case in practice) class FixedTwoParamAggregate: diff --git a/stdlib/sqlite3/__init__.pyi b/stdlib/sqlite3/__init__.pyi index 877617e610d4..0a848a63a779 100644 --- a/stdlib/sqlite3/__init__.pyi +++ b/stdlib/sqlite3/__init__.pyi @@ -237,11 +237,6 @@ class _AnyParamAggregateProtocol(Protocol): def step(self) -> Callable[..., object]: ... def finalize(self) -> _SqliteData: ... -@type_check_only -class _AnyArgsAggregateProtocol(Protocol): - def step(self, *args: _SqliteData) -> object: ... - def finalize(self) -> _SqliteData: ... - @type_check_only class _SingleParamWindowAggregateClass(Protocol[_SQLType_contra]): def step(self, param: _SQLType_contra, /) -> object: ... @@ -258,13 +253,6 @@ class _AnyParamWindowAggregateClass(Protocol): def value(self) -> _SqliteData: ... def finalize(self) -> _SqliteData: ... -@type_check_only -class _AnyArgsWindowAggregateClass(Protocol): - def step(self, *args: _SqliteData) -> object: ... - def inverse(self, *args: _SqliteData) -> object: ... - def value(self) -> _SqliteData: ... - def finalize(self) -> _SqliteData: ... - # These classes are implemented in the C module _sqlite3. At runtime, they're imported # from there into sqlite3.dbapi2 and from that module to here. However, they # consider themselves to live in the sqlite3.* namespace, so we'll define them here. @@ -354,10 +342,6 @@ class Connection: self, name: str, n_arg: Literal[1], aggregate_class: Callable[[], _SingleParamAggregateProtocol] ) -> None: ... @overload - def create_aggregate( - self, name: str, n_arg: Literal[-1], aggregate_class: Callable[[], _AnyArgsAggregateProtocol] - ) -> None: ... - @overload def create_aggregate(self, name: str, n_arg: int, aggregate_class: Callable[[], _AnyParamAggregateProtocol]) -> None: ... if sys.version_info >= (3, 11): @@ -372,14 +356,6 @@ class Connection: /, ) -> None: ... @overload - def create_window_function( - self, - name: str, - num_params: Literal[-1], - aggregate_class: Callable[[], _AnyArgsWindowAggregateClass] | None, - /, - ) -> None: ... - @overload def create_window_function( self, name: str, num_params: int, aggregate_class: Callable[[], _AnyParamWindowAggregateClass] | None, / ) -> None: ...