From 9ae6c2c711a4d112406f26a5c6888b5349aec314 Mon Sep 17 00:00:00 2001 From: Dan Buch Date: Tue, 25 Mar 2025 21:15:58 -0400 Subject: [PATCH 1/3] Introduce `include` function and supporting plumbing that was previously implemented in a branch of `replicate/cog`. --- replicate/__init__.py | 3 + replicate/include.py | 164 +++++++++++++++++++++++ tests/test_include.py | 293 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 460 insertions(+) create mode 100644 replicate/include.py create mode 100644 tests/test_include.py diff --git a/replicate/__init__.py b/replicate/__init__.py index 0e6838d7..2a511740 100644 --- a/replicate/__init__.py +++ b/replicate/__init__.py @@ -1,4 +1,5 @@ from replicate.client import Client +from replicate.include import include as _include from replicate.pagination import async_paginate as _async_paginate from replicate.pagination import paginate as _paginate @@ -21,3 +22,5 @@ predictions = default_client.predictions trainings = default_client.trainings webhooks = default_client.webhooks + +include = _include diff --git a/replicate/include.py b/replicate/include.py new file mode 100644 index 00000000..8619af86 --- /dev/null +++ b/replicate/include.py @@ -0,0 +1,164 @@ +import os +import sys +from contextlib import contextmanager +from contextvars import ContextVar +from dataclasses import dataclass +from typing import Any, Callable, Dict, Literal, Optional, Tuple + +import replicate + +from .exceptions import ModelError +from .model import Model +from .prediction import Prediction +from .run import _has_output_iterator_array_type +from .version import Version + +__all__ = ["include"] + + +_RUN_STATE: ContextVar[Literal["load", "setup", "run"] | None] = ContextVar( + "run_state", + default=None, +) +_RUN_TOKEN: ContextVar[str | None] = ContextVar("run_token", default=None) + + +@contextmanager +def run_state(state: Literal["load", "setup", "run"]) -> Any: + """ + Internal context manager for execution state. + """ + s = _RUN_STATE.set(state) + try: + yield + finally: + _RUN_STATE.reset(s) + + +@contextmanager +def run_token(token: str) -> Any: + """ + Sets the API token for the current context. + """ + t = _RUN_TOKEN.set(token) + try: + yield + finally: + _RUN_TOKEN.reset(t) + + +def _find_api_token() -> str: + token = os.environ.get("REPLICATE_API_TOKEN") + if token: + print("Using Replicate API token from environment", file=sys.stderr) + return token + + token = _RUN_TOKEN.get() + + if not token: + raise ValueError("No run token found") + + return token + + +@dataclass +class Run: + """ + Represents a running prediction with access to its version. + """ + + prediction: Prediction + version: Version + + def wait(self) -> Any: + """ + Wait for the prediction to complete and return its output. + """ + self.prediction.wait() + + if self.prediction.status == "failed": + raise ModelError(self.prediction) + + if _has_output_iterator_array_type(self.version): + return "".join(self.prediction.output) + + return self.prediction.output + + def logs(self) -> Optional[str]: + """ + Fetch and return the logs from the prediction. + """ + self.prediction.reload() + + return self.prediction.logs + + +@dataclass +class Function: + """ + A wrapper for a Replicate model that can be called as a function. + """ + + function_ref: str + + def _client(self) -> replicate.Client: + return replicate.Client(api_token=_find_api_token()) + + def _split_function_ref(self) -> Tuple[str, str, Optional[str]]: + owner, name = self.function_ref.split("/") + name, version = name.split(":") if ":" in name else (name, None) + return owner, name, version + + def _model(self) -> Model: + client = self._client() + model_owner, model_name, _ = self._split_function_ref() + return client.models.get(f"{model_owner}/{model_name}") + + def _version(self) -> Version: + client = self._client() + model_owner, model_name, model_version = self._split_function_ref() + model = client.models.get(f"{model_owner}/{model_name}") + version = ( + model.versions.get(model_version) if model_version else model.latest_version + ) + return version + + def __call__(self, **inputs: Dict[str, Any]) -> Any: + run = self.start(**inputs) + return run.wait() + + def start(self, **inputs: Dict[str, Any]) -> Run: + """ + Start a prediction with the specified inputs. + """ + version = self._version() + prediction = self._client().predictions.create(version=version, input=inputs) + print(f"Running {self.function_ref}: https://replicate.com/p/{prediction.id}") + + return Run(prediction, version) + + @property + def default_example(self) -> Optional[Prediction]: + """ + Get the default example for this model. + """ + return self._model().default_example + + @property + def openapi_schema(self) -> dict[Any, Any]: + """ + Get the OpenAPI schema for this model version. + """ + return self._version().openapi_schema + + +def include(function_ref: str) -> Callable[..., Any]: + """ + Include a Replicate model as a function. + + This function can only be called at the top level. + """ + if _RUN_STATE.get() != "load": + raise RuntimeError("You may only call replicate.include at the top level.") + + return Function(function_ref) diff --git a/tests/test_include.py b/tests/test_include.py new file mode 100644 index 00000000..dd2baa7f --- /dev/null +++ b/tests/test_include.py @@ -0,0 +1,293 @@ +import os +import unittest.mock as mock + +import pytest + +from replicate.exceptions import ModelError +from replicate.include import ( + Function, + Run, + include, + run_state, + run_token, +) + + +@pytest.fixture +def client(): + with mock.patch("replicate.Client") as client_class: + client_instance = mock.MagicMock() + client_class.return_value = client_instance + yield client_class, client_instance + + +@pytest.fixture +def model(): + model_obj = mock.MagicMock() + yield model_obj + + +@pytest.fixture +def version(): + version_obj = mock.MagicMock() + version_obj.openapi_schema = { + "components": {"schemas": {"Output": {"type": "string"}}} + } + version_obj.cog_version = "0.4.0" + yield version_obj + + +@pytest.fixture +def prediction(): + pred = mock.MagicMock() + pred.status = "succeeded" + pred.output = "test output" + pred.id = "pred123" + yield pred + + +@pytest.fixture +def iterator_version(): + iter_version = mock.MagicMock() + iter_version.openapi_schema = { + "components": { + "schemas": {"Output": {"type": "array", "x-cog-array-type": "iterator"}} + } + } + iter_version.cog_version = "0.4.0" + yield iter_version + + +def test_run_state_context_manager(): + with pytest.raises(RuntimeError): + include("owner/model:version") + + with run_state("load"): + include("owner/model:version") + + with run_state("load"): + include("owner/model:version") + with run_state("setup"): + with pytest.raises(RuntimeError): + include("owner/model:version") + + +def test_run_token_context_manager(client): + client_class, _ = client + + fn = Function("owner/model:version") + + with mock.patch.dict(os.environ, {}, clear=True): + with pytest.raises(ValueError, match="No run token found"): + fn._client() + + with run_token("test-token"): + fn._client() + client_class.assert_called_with(api_token="test-token") + + with run_token("another-token"): + fn._client() + client_class.assert_called_with(api_token="another-token") + + +def test_find_api_token_from_env(monkeypatch, client): + client_class, _ = client + monkeypatch.setenv("REPLICATE_API_TOKEN", "env-token") + with mock.patch("sys.stderr"): + fn = Function("owner/model:version") + fn._client() + client_class.assert_called_with(api_token="env-token") + + +def test_find_api_token_from_context(client): + client_class, _ = client + with run_token("context-token"): + fn = Function("owner/model:version") + fn._client() + client_class.assert_called_with(api_token="context-token") + + +def test_find_api_token_raises_error(): + with mock.patch.dict(os.environ, {}, clear=True): + fn = Function("owner/model:version") + with pytest.raises(ValueError, match="No run token found"): + fn._client() + + +def test_include_outside_load_state(): + with pytest.raises(RuntimeError, match="You may only call .* at the top level"): + include("owner/model:version") + + +def test_include_in_load_state(): + with run_state("load"): + fn = include("owner/model:version") + assert isinstance(fn, Function) + assert fn.function_ref == "owner/model:version" + + +def test_function_split_function_ref(): + fn = Function("owner/model:version") + assert fn._split_function_ref() == ("owner", "model", "version") + + fn = Function("owner/model") + assert fn._split_function_ref() == ("owner", "model", None) + + +def test_function_client(client): + client_class, client_instance = client + + with run_token("test-token"): + fn = Function("owner/model:version") + client = fn._client() + + client_class.assert_called_once_with(api_token="test-token") + assert client == client_instance + + +def test_function_model(client, model): + _, client_instance = client + client_instance.models.get.return_value = model + + with run_token("test-token"): + fn = Function("owner/model:version") + result = fn._model() + + client_instance.models.get.assert_called_once_with("owner/model") + assert result == model + + +def test_function_version_with_version_id(client, model, version): + _, client_instance = client + client_instance.models.get.return_value = model + model.versions.get.return_value = version + + with run_token("test-token"): + fn = Function("owner/model:version") + result = fn._version() + + client_instance.models.get.assert_called_once_with("owner/model") + model.versions.get.assert_called_once_with("version") + assert result == version + + +def test_function_version_with_latest(client, model, version): + _, client_instance = client + client_instance.models.get.return_value = model + model.latest_version = version + + with run_token("test-token"): + fn = Function("owner/model") + result = fn._version() + + client_instance.models.get.assert_called_once_with("owner/model") + assert result == version + + +@mock.patch.object(Function, "start") +@mock.patch.object(Function, "_version") +def test_function_call(version_patch, start_patch): + run_obj = mock.MagicMock() + start_patch.return_value = run_obj + + with run_token("test-token"): + fn = Function("owner/model:version") + fn(prompt="Hello", temperature=0.7) + + start_patch.assert_called_once_with(prompt="Hello", temperature=0.7) + run_obj.wait.assert_called_once() + + +def test_function_start(client, model, version, prediction, capsys): + _, client_instance = client + + client_instance.models.get.return_value = model + model.versions.get.return_value = version + client_instance.predictions.create.return_value = prediction + + with run_token("test-token"): + fn = Function("owner/model:version") + run = fn.start(prompt="Hello", temperature=0.7) + + client_instance.predictions.create.assert_called_once_with( + version=version, input={"prompt": "Hello", "temperature": 0.7} + ) + + assert isinstance(run, Run) + assert run.prediction == prediction + assert run.version == version + + captured = capsys.readouterr() + assert "https://replicate.com/p/pred123" in captured.out + + +def test_function_default_example(client, model): + _, client_instance = client + example_obj = mock.MagicMock() + client_instance.models.get.return_value = model + model.default_example = example_obj + + with run_token("test-token"): + fn = Function("owner/model:version") + example = fn.default_example + + assert example == example_obj + + +def test_function_openapi_schema(client, model, version): + _, client_instance = client + client_instance.models.get.return_value = model + model.versions.get.return_value = version + + with run_token("test-token"): + fn = Function("owner/model:version") + schema = fn.openapi_schema + + assert schema == version.openapi_schema + + +def test_run_wait_success(prediction, version): + with mock.patch( + "replicate.include._has_output_iterator_array_type", return_value=False + ): + run = Run(prediction=prediction, version=version) + result = run.wait() + + prediction.wait.assert_called_once() + assert result == "test output" + + +def test_run_wait_failure(version): + failed_prediction = mock.MagicMock() + failed_prediction.status = "failed" + + run = Run(prediction=failed_prediction, version=version) + with pytest.raises(ModelError): + run.wait() + + failed_prediction.wait.assert_called_once() + + +def test_run_wait_iterator_output(iterator_version): + iter_prediction = mock.MagicMock() + iter_prediction.status = "succeeded" + iter_prediction.output = ["Hello", " ", "world"] + + with mock.patch( + "replicate.include._has_output_iterator_array_type", return_value=True + ): + run = Run(prediction=iter_prediction, version=iterator_version) + result = run.wait() + + iter_prediction.wait.assert_called_once() + assert result == "Hello world" + + +def test_run_logs(prediction, version): + prediction.logs = "log content" + + run = Run(prediction=prediction, version=version) + logs = run.logs() + + prediction.reload.assert_called_once() + assert logs == "log content" From da11ccd6a41f4eab2a422653a17423392a5ab9d4 Mon Sep 17 00:00:00 2001 From: Dan Buch Date: Tue, 25 Mar 2025 22:29:20 -0400 Subject: [PATCH 2/3] Allow threaded code within `include` context managers --- replicate/include.py | 72 ++++++++++++++++++++++++++++++++----------- tests/test_include.py | 48 +++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+), 18 deletions(-) diff --git a/replicate/include.py b/replicate/include.py index 8619af86..d5fdbb75 100644 --- a/replicate/include.py +++ b/replicate/include.py @@ -1,7 +1,7 @@ import os import sys +import threading from contextlib import contextmanager -from contextvars import ContextVar from dataclasses import dataclass from typing import Any, Callable, Dict, Literal, Optional, Tuple @@ -13,38 +13,75 @@ from .run import _has_output_iterator_array_type from .version import Version -__all__ = ["include"] +__all__ = ["get_run_state", "get_run_token", "include", "run_state", "run_token"] -_RUN_STATE: ContextVar[Literal["load", "setup", "run"] | None] = ContextVar( - "run_state", - default=None, -) -_RUN_TOKEN: ContextVar[str | None] = ContextVar("run_token", default=None) +_run_state: Optional[Literal["load", "setup", "run"]] = None +_run_token: Optional[str] = None + +_state_stack = [] +_token_stack = [] + +_state_lock = threading.RLock() +_token_lock = threading.RLock() + + +def get_run_state() -> Optional[Literal["load", "setup", "run"]]: + """ + Get the current run state. + """ + return _run_state + + +def get_run_token() -> Optional[str]: + """ + Get the current API token. + """ + return _run_token @contextmanager def run_state(state: Literal["load", "setup", "run"]) -> Any: """ - Internal context manager for execution state. + Context manager for setting the current run state. """ - s = _RUN_STATE.set(state) + global _run_state + + if threading.current_thread() is not threading.main_thread(): + raise RuntimeError("Only the main thread can modify run state") + + with _state_lock: + _state_stack.append(_run_state) + + _run_state = state + try: yield finally: - _RUN_STATE.reset(s) + with _state_lock: + _run_state = _state_stack.pop() @contextmanager def run_token(token: str) -> Any: """ - Sets the API token for the current context. + Context manager for setting the current API token. """ - t = _RUN_TOKEN.set(token) + global _run_token + + if threading.current_thread() is not threading.main_thread(): + raise RuntimeError("Only the main thread can modify API token") + + with _token_lock: + _token_stack.append(_run_token) + + _run_token = token + try: yield finally: - _RUN_TOKEN.reset(t) + with _token_lock: + _run_token = _token_stack.pop() def _find_api_token() -> str: @@ -53,12 +90,11 @@ def _find_api_token() -> str: print("Using Replicate API token from environment", file=sys.stderr) return token - token = _RUN_TOKEN.get() - - if not token: + current_token = get_run_token() + if current_token is None: raise ValueError("No run token found") - return token + return current_token @dataclass @@ -158,7 +194,7 @@ def include(function_ref: str) -> Callable[..., Any]: This function can only be called at the top level. """ - if _RUN_STATE.get() != "load": + if get_run_state() != "load": raise RuntimeError("You may only call replicate.include at the top level.") return Function(function_ref) diff --git a/tests/test_include.py b/tests/test_include.py index dd2baa7f..8b788716 100644 --- a/tests/test_include.py +++ b/tests/test_include.py @@ -1,4 +1,5 @@ import os +import threading import unittest.mock as mock import pytest @@ -7,6 +8,8 @@ from replicate.include import ( Function, Run, + get_run_state, + get_run_token, include, run_state, run_token, @@ -291,3 +294,48 @@ def test_run_logs(prediction, version): prediction.reload.assert_called_once() assert logs == "log content" + + +def test_thread_safety_concepts(): + with run_state("load"), run_token("test-token"): + assert get_run_state() == "load" + assert get_run_token() == "test-token" + + results = [] + + def worker_thread_fn(): + thread_sees_state = get_run_state() == "load" + thread_sees_token = get_run_token() == "test-token" + + can_modify = True + try: + with run_state("setup"): + pass + can_modify = True + except RuntimeError: + can_modify = False + + results.append( + { + "reads_state": thread_sees_state, + "reads_token": thread_sees_token, + "can_modify": can_modify, + } + ) + + threads = [] + for _ in range(3): + t = threading.Thread(target=worker_thread_fn) + threads.append(t) + t.start() + + for t in threads: + t.join() + + for result in results: + assert result["reads_state"] + assert result["reads_token"] + assert not result["can_modify"] + + assert get_run_state() is None + assert get_run_token() is None From f9e21f4eb377fb2e268718eb7bb3d4ce40689c2a Mon Sep 17 00:00:00 2001 From: Dan Buch Date: Tue, 25 Mar 2025 22:36:13 -0400 Subject: [PATCH 3/3] Automatically clear `REPLICATE_API_TOKEN` from test env --- tests/test_include.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/tests/test_include.py b/tests/test_include.py index 8b788716..1beaf532 100644 --- a/tests/test_include.py +++ b/tests/test_include.py @@ -16,6 +16,22 @@ ) +@pytest.fixture(autouse=True) +def no_api_token_in_env(): + """ + Remove REPLICATE_API_TOKEN from environment during tests and restore it after. + """ + original_token = os.environ.get("REPLICATE_API_TOKEN") + + if "REPLICATE_API_TOKEN" in os.environ: + del os.environ["REPLICATE_API_TOKEN"] + + yield + + if original_token is not None: + os.environ["REPLICATE_API_TOKEN"] = original_token + + @pytest.fixture def client(): with mock.patch("replicate.Client") as client_class: @@ -111,10 +127,11 @@ def test_find_api_token_from_context(client): def test_find_api_token_raises_error(): - with mock.patch.dict(os.environ, {}, clear=True): - fn = Function("owner/model:version") - with pytest.raises(ValueError, match="No run token found"): - fn._client() + assert "REPLICATE_API_TOKEN" not in os.environ + + fn = Function("owner/model:version") + with pytest.raises(ValueError, match="No run token found"): + fn._client() def test_include_outside_load_state():