From 5883369b8f59adb08aae555bf6c0fc7a56786657 Mon Sep 17 00:00:00 2001 From: mportdata Date: Sat, 27 Dec 2025 23:30:54 +0000 Subject: [PATCH 01/24] added flags for enabling and disabling all telemetry - test_spans currently all pass --- src/google/adk/agents/base_agent.py | 1119 +++---- src/google/adk/agents/llm_agent.py | 1203 ++++--- .../adk/flows/llm_flows/base_llm_flow.py | 1909 +++++------ src/google/adk/flows/llm_flows/functions.py | 1342 ++++---- .../models/gemini_context_cache_manager.py | 842 ++--- src/google/adk/models/google_llm.py | 910 +++--- src/google/adk/runners.py | 2794 +++++++++-------- src/google/adk/telemetry/tracing.py | 530 ++-- src/google/adk/utils/telemetry_utils.py | 63 + 9 files changed, 5427 insertions(+), 5285 deletions(-) create mode 100644 src/google/adk/utils/telemetry_utils.py diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index e15f9af981..11cab7ea04 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -44,13 +44,14 @@ from ..telemetry.tracing import tracer from ..utils.context_utils import Aclosing from ..utils.feature_decorator import experimental +from ..utils.telemetry_utils import is_telemetry_enabled from .base_agent_config import BaseAgentConfig from .callback_context import CallbackContext if TYPE_CHECKING: - from .invocation_context import InvocationContext + from .invocation_context import InvocationContext -logger = logging.getLogger('google_adk.' + __name__) +logger = logging.getLogger("google_adk." + __name__) _SingleAgentCallback: TypeAlias = Callable[ [CallbackContext], @@ -67,32 +68,32 @@ list[_SingleAgentCallback], ] -SelfAgent = TypeVar('SelfAgent', bound='BaseAgent') +SelfAgent = TypeVar("SelfAgent", bound="BaseAgent") @experimental class BaseAgentState(BaseModel): - """Base class for all agent states.""" + """Base class for all agent states.""" - model_config = ConfigDict( - extra='forbid', - ) + model_config = ConfigDict( + extra="forbid", + ) -AgentState = TypeVar('AgentState', bound=BaseAgentState) +AgentState = TypeVar("AgentState", bound=BaseAgentState) class BaseAgent(BaseModel): - """Base class for all agents in Agent Development Kit.""" + """Base class for all agents in Agent Development Kit.""" - model_config = ConfigDict( - arbitrary_types_allowed=True, - extra='forbid', - ) - """The pydantic model config.""" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) + """The pydantic model config.""" - config_type: ClassVar[type[BaseAgentConfig]] = BaseAgentConfig - """The config type for this agent. + config_type: ClassVar[type[BaseAgentConfig]] = BaseAgentConfig + """The config type for this agent. Sub-classes should override this to specify their own config type. @@ -107,22 +108,22 @@ class MyAgent(BaseAgent): ``` """ - name: str - """The agent's name. + name: str + """The agent's name. Agent name must be a Python identifier and unique within the agent tree. Agent name cannot be "user", since it's reserved for end-user's input. """ - description: str = '' - """Description about the agent's capability. + description: str = "" + """Description about the agent's capability. The model uses this to determine whether to delegate control to the agent. One-line description is enough and preferred. """ - parent_agent: Optional[BaseAgent] = Field(default=None, init=False) - """The parent agent of this agent. + parent_agent: Optional[BaseAgent] = Field(default=None, init=False) + """The parent agent of this agent. Note that an agent can ONLY be added as sub-agent once. @@ -130,11 +131,13 @@ class MyAgent(BaseAgent): instances with identical config, but with different name and add them to the agent tree. """ - sub_agents: list[BaseAgent] = Field(default_factory=list) - """The sub-agents of this agent.""" - - before_agent_callback: Optional[BeforeAgentCallback] = None - """Callback or list of callbacks to be invoked before the agent run. + sub_agents: list[BaseAgent] = Field(default_factory=list) + """The sub-agents of this agent.""" + disable_telemetry: bool = False + """Whether to disable telemetry for this agent. + """ + before_agent_callback: Optional[BeforeAgentCallback] = None + """Callback or list of callbacks to be invoked before the agent run. When a list of callbacks is provided, the callbacks will be called in the order they are listed until a callback does not return None. @@ -147,8 +150,8 @@ class MyAgent(BaseAgent): When the content is present, the agent run will be skipped and the provided content will be returned to user. """ - after_agent_callback: Optional[AfterAgentCallback] = None - """Callback or list of callbacks to be invoked after the agent run. + after_agent_callback: Optional[AfterAgentCallback] = None + """Callback or list of callbacks to be invoked after the agent run. When a list of callbacks is provided, the callbacks will be called in the order they are listed until a callback does not return None. @@ -162,536 +165,550 @@ class MyAgent(BaseAgent): will be appended to event history as an additional agent response. """ - def _load_agent_state( - self, - ctx: InvocationContext, - state_type: Type[AgentState], - ) -> Optional[AgentState]: - """Loads the agent state from the invocation context. - - Args: - ctx: The invocation context. - state_type: The type of the agent state. - - Returns: - The current state if exists; otherwise, None. - """ - if ctx.agent_states is None or self.name not in ctx.agent_states: - return None - else: - return state_type.model_validate(ctx.agent_states.get(self.name)) - - def _create_agent_state_event( - self, - ctx: InvocationContext, - ) -> Event: - """Returns an event with current agent state set in the invocation context. - - Args: - ctx: The invocation context. - - Returns: - An event with the current agent state set in the invocation context. - """ - event_actions = EventActions() - if (agent_state := ctx.agent_states.get(self.name)) is not None: - event_actions.agent_state = agent_state - if ctx.end_of_agents.get(self.name): - event_actions.end_of_agent = True - return Event( - invocation_id=ctx.invocation_id, - author=self.name, - branch=ctx.branch, - actions=event_actions, - ) - - def clone( - self: SelfAgent, update: Mapping[str, Any] | None = None - ) -> SelfAgent: - """Creates a copy of this agent instance. - - Args: - update: Optional mapping of new values for the fields of the cloned agent. - The keys of the mapping are the names of the fields to be updated, and - the values are the new values for those fields. - For example: {"name": "cloned_agent"} - - Returns: - A new agent instance with identical configuration as the original - agent except for the fields specified in the update. - """ - if update is not None and 'parent_agent' in update: - raise ValueError( - 'Cannot update `parent_agent` field in clone. Parent agent is set' - ' only when the parent agent is instantiated with the sub-agents.' - ) - - # Only allow updating fields that are defined in the agent class. - allowed_fields = set(self.__class__.model_fields) - if update is not None: - invalid_fields = set(update) - allowed_fields - if invalid_fields: - raise ValueError( - f'Cannot update nonexistent fields in {self.__class__.__name__}:' - f' {invalid_fields}' + def _load_agent_state( + self, + ctx: InvocationContext, + state_type: Type[AgentState], + ) -> Optional[AgentState]: + """Loads the agent state from the invocation context. + + Args: + ctx: The invocation context. + state_type: The type of the agent state. + + Returns: + The current state if exists; otherwise, None. + """ + if ctx.agent_states is None or self.name not in ctx.agent_states: + return None + else: + return state_type.model_validate(ctx.agent_states.get(self.name)) + + def _create_agent_state_event( + self, + ctx: InvocationContext, + ) -> Event: + """Returns an event with current agent state set in the invocation context. + + Args: + ctx: The invocation context. + + Returns: + An event with the current agent state set in the invocation context. + """ + event_actions = EventActions() + if (agent_state := ctx.agent_states.get(self.name)) is not None: + event_actions.agent_state = agent_state + if ctx.end_of_agents.get(self.name): + event_actions.end_of_agent = True + return Event( + invocation_id=ctx.invocation_id, + author=self.name, + branch=ctx.branch, + actions=event_actions, ) - cloned_agent = self.model_copy(update=update) - - # If any field is stored as list and not provided in the update, need to - # shallow copy it for the cloned agent to avoid sharing the same list object - # with the original agent. - for field_name in cloned_agent.__class__.model_fields: - if field_name == 'sub_agents': - continue - if update is not None and field_name in update: - continue - field = getattr(cloned_agent, field_name) - if isinstance(field, list): - setattr(cloned_agent, field_name, field.copy()) - - if update is None or 'sub_agents' not in update: - # If `sub_agents` is not provided in the update, need to recursively clone - # the sub-agents to avoid sharing the sub-agents with the original agent. - cloned_agent.sub_agents = [] - for sub_agent in self.sub_agents: - cloned_sub_agent = sub_agent.clone() - cloned_sub_agent.parent_agent = cloned_agent - cloned_agent.sub_agents.append(cloned_sub_agent) - else: - for sub_agent in cloned_agent.sub_agents: - sub_agent.parent_agent = cloned_agent - - # Remove the parent agent from the cloned agent to avoid sharing the parent - # agent with the cloned agent. - cloned_agent.parent_agent = None - return cloned_agent - - @final - async def run_async( - self, - parent_context: InvocationContext, - ) -> AsyncGenerator[Event, None]: - """Entry method to run an agent via text-based conversation. - - Args: - parent_context: InvocationContext, the invocation context of the parent - agent. - - Yields: - Event: the events generated by the agent. - """ - - with tracer.start_as_current_span(f'invoke_agent {self.name}') as span: - ctx = self._create_invocation_context(parent_context) - tracing.trace_agent_invocation(span, self, ctx) - if event := await self._handle_before_agent_callback(ctx): - yield event - if ctx.end_invocation: - return - - async with Aclosing(self._run_async_impl(ctx)) as agen: - async for event in agen: - yield event - - if ctx.end_invocation: - return - - if event := await self._handle_after_agent_callback(ctx): - yield event - - @final - async def run_live( - self, - parent_context: InvocationContext, - ) -> AsyncGenerator[Event, None]: - """Entry method to run an agent via video/audio-based conversation. - - Args: - parent_context: InvocationContext, the invocation context of the parent - agent. - - Yields: - Event: the events generated by the agent. - """ - - with tracer.start_as_current_span(f'invoke_agent {self.name}') as span: - ctx = self._create_invocation_context(parent_context) - tracing.trace_agent_invocation(span, self, ctx) - if event := await self._handle_before_agent_callback(ctx): - yield event - if ctx.end_invocation: - return - - async with Aclosing(self._run_live_impl(ctx)) as agen: - async for event in agen: - yield event - - if event := await self._handle_after_agent_callback(ctx): - yield event - - async def _run_async_impl( - self, ctx: InvocationContext - ) -> AsyncGenerator[Event, None]: - """Core logic to run this agent via text-based conversation. - - Args: - ctx: InvocationContext, the invocation context for this agent. - - Yields: - Event: the events generated by the agent. - """ - raise NotImplementedError( - f'_run_async_impl for {type(self)} is not implemented.' - ) - yield # AsyncGenerator requires having at least one yield statement + def clone(self: SelfAgent, update: Mapping[str, Any] | None = None) -> SelfAgent: + """Creates a copy of this agent instance. + + Args: + update: Optional mapping of new values for the fields of the cloned agent. + The keys of the mapping are the names of the fields to be updated, and + the values are the new values for those fields. + For example: {"name": "cloned_agent"} + + Returns: + A new agent instance with identical configuration as the original + agent except for the fields specified in the update. + """ + if update is not None and "parent_agent" in update: + raise ValueError( + "Cannot update `parent_agent` field in clone. Parent agent is set" + " only when the parent agent is instantiated with the sub-agents." + ) + + # Only allow updating fields that are defined in the agent class. + allowed_fields = set(self.__class__.model_fields) + if update is not None: + invalid_fields = set(update) - allowed_fields + if invalid_fields: + raise ValueError( + f"Cannot update nonexistent fields in {self.__class__.__name__}:" + f" {invalid_fields}" + ) + + cloned_agent = self.model_copy(update=update) + + # If any field is stored as list and not provided in the update, need to + # shallow copy it for the cloned agent to avoid sharing the same list object + # with the original agent. + for field_name in cloned_agent.__class__.model_fields: + if field_name == "sub_agents": + continue + if update is not None and field_name in update: + continue + field = getattr(cloned_agent, field_name) + if isinstance(field, list): + setattr(cloned_agent, field_name, field.copy()) + + if update is None or "sub_agents" not in update: + # If `sub_agents` is not provided in the update, need to recursively clone + # the sub-agents to avoid sharing the sub-agents with the original agent. + cloned_agent.sub_agents = [] + for sub_agent in self.sub_agents: + cloned_sub_agent = sub_agent.clone() + cloned_sub_agent.parent_agent = cloned_agent + cloned_agent.sub_agents.append(cloned_sub_agent) + else: + for sub_agent in cloned_agent.sub_agents: + sub_agent.parent_agent = cloned_agent + + # Remove the parent agent from the cloned agent to avoid sharing the parent + # agent with the cloned agent. + cloned_agent.parent_agent = None + return cloned_agent + + @final + async def run_async( + self, + parent_context: InvocationContext, + ) -> AsyncGenerator[Event, None]: + """Entry method to run an agent via text-based conversation. + + Args: + parent_context: InvocationContext, the invocation context of the parent + agent. + + Yields: + Event: the events generated by the agent. + """ + + ctx = self._create_invocation_context(parent_context) + if is_telemetry_enabled(self): + with tracer.start_as_current_span(f"invoke_agent {self.name}") as span: + tracing.trace_agent_invocation(span, self, ctx) + async with Aclosing( + self._run_callbacks_and_impl(ctx, mode="async") + ) as agen: + async for event in agen: + yield event + else: + async with Aclosing( + self._run_callbacks_and_impl(ctx, mode="async") + ) as agen: + async for event in agen: + yield event + + @final + async def run_live( + self, + parent_context: InvocationContext, + ) -> AsyncGenerator[Event, None]: + """Entry method to run an agent via video/audio-based conversation. + + Args: + parent_context: InvocationContext, the invocation context of the parent + agent. + + Yields: + Event: the events generated by the agent. + """ + + ctx = self._create_invocation_context(parent_context) + if is_telemetry_enabled(self): + with tracer.start_as_current_span(f"invoke_agent {self.name}") as span: + tracing.trace_agent_invocation(span, self, ctx) + async for event in self._run_callbacks_and_impl(ctx, mode="live"): + yield event + else: + async for event in self._run_callbacks_and_impl(ctx, mode="live"): + yield event + + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + """Core logic to run this agent via text-based conversation. + + Args: + ctx: InvocationContext, the invocation context for this agent. + + Yields: + Event: the events generated by the agent. + """ + raise NotImplementedError( + f"_run_async_impl for {type(self)} is not implemented." + ) + yield # AsyncGenerator requires having at least one yield statement - async def _run_live_impl( - self, ctx: InvocationContext - ) -> AsyncGenerator[Event, None]: - """Core logic to run this agent via video/audio-based conversation. + async def _run_live_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + """Core logic to run this agent via video/audio-based conversation. - Args: - ctx: InvocationContext, the invocation context for this agent. + Args: + ctx: InvocationContext, the invocation context for this agent. - Yields: - Event: the events generated by the agent. - """ - raise NotImplementedError( - f'_run_live_impl for {type(self)} is not implemented.' - ) - yield # AsyncGenerator requires having at least one yield statement - - @property - def root_agent(self) -> BaseAgent: - """Gets the root agent of this agent.""" - root_agent = self - while root_agent.parent_agent is not None: - root_agent = root_agent.parent_agent - return root_agent - - def find_agent(self, name: str) -> Optional[BaseAgent]: - """Finds the agent with the given name in this agent and its descendants. - - Args: - name: The name of the agent to find. - - Returns: - The agent with the matching name, or None if no such agent is found. - """ - if self.name == name: - return self - return self.find_sub_agent(name) - - def find_sub_agent(self, name: str) -> Optional[BaseAgent]: - """Finds the agent with the given name in this agent's descendants. - - Args: - name: The name of the agent to find. - - Returns: - The agent with the matching name, or None if no such agent is found. - """ - for sub_agent in self.sub_agents: - if result := sub_agent.find_agent(name): - return result - return None - - def _create_invocation_context( - self, parent_context: InvocationContext - ) -> InvocationContext: - """Creates a new invocation context for this agent.""" - invocation_context = parent_context.model_copy(update={'agent': self}) - return invocation_context - - @property - def canonical_before_agent_callbacks(self) -> list[_SingleAgentCallback]: - """The resolved self.before_agent_callback field as a list of _SingleAgentCallback. - - This method is only for use by Agent Development Kit. - """ - if not self.before_agent_callback: - return [] - if isinstance(self.before_agent_callback, list): - return self.before_agent_callback - return [self.before_agent_callback] - - @property - def canonical_after_agent_callbacks(self) -> list[_SingleAgentCallback]: - """The resolved self.after_agent_callback field as a list of _SingleAgentCallback. - - This method is only for use by Agent Development Kit. - """ - if not self.after_agent_callback: - return [] - if isinstance(self.after_agent_callback, list): - return self.after_agent_callback - return [self.after_agent_callback] - - async def _handle_before_agent_callback( - self, ctx: InvocationContext - ) -> Optional[Event]: - """Runs the before_agent_callback if it exists. - - Args: - ctx: InvocationContext, the invocation context for this agent. - - Returns: - Optional[Event]: an event if callback provides content or changed state. - """ - callback_context = CallbackContext(ctx) - - # Run callbacks from the plugins. - before_agent_callback_content = ( - await ctx.plugin_manager.run_before_agent_callback( - agent=self, callback_context=callback_context + Yields: + Event: the events generated by the agent. + """ + raise NotImplementedError( + f"_run_live_impl for {type(self)} is not implemented." ) - ) - - # If no overrides are provided from the plugins, further run the canonical - # callbacks. - if ( - not before_agent_callback_content - and self.canonical_before_agent_callbacks - ): - for callback in self.canonical_before_agent_callbacks: - before_agent_callback_content = callback( - callback_context=callback_context + yield # AsyncGenerator requires having at least one yield statement + + async def _run_callbacks_and_impl( + self, ctx: InvocationContext, mode: str = "async" + ) -> AsyncGenerator[Event, None]: + """Runs the before and after agent callbacks around the core agent logic. + Args: + ctx: InvocationContext, the invocation context for this agent. + mode: str, either 'async' or 'live', indicating which core agent logic to run. + Yields: + Event: the events generated by the agent. + """ + if event := await self._handle_before_agent_callback(ctx): + yield event + if ctx.end_invocation: + return + if mode.lower() == "async": + async with Aclosing(self._run_async_impl(ctx)) as agen: + async for event in agen: + yield event + elif mode.lower() == "live": + async with Aclosing(self._run_live_impl(ctx)) as agen: + async for event in agen: + yield event + else: + raise ValueError(f"Invalid mode: {mode}. Must be either 'async' or 'live'.") + + if ctx.end_invocation: + return + + if event := await self._handle_after_agent_callback(ctx): + yield event + + @property + def root_agent(self) -> BaseAgent: + """Gets the root agent of this agent.""" + root_agent = self + while root_agent.parent_agent is not None: + root_agent = root_agent.parent_agent + return root_agent + + def find_agent(self, name: str) -> Optional[BaseAgent]: + """Finds the agent with the given name in this agent and its descendants. + + Args: + name: The name of the agent to find. + + Returns: + The agent with the matching name, or None if no such agent is found. + """ + if self.name == name: + return self + return self.find_sub_agent(name) + + def find_sub_agent(self, name: str) -> Optional[BaseAgent]: + """Finds the agent with the given name in this agent's descendants. + + Args: + name: The name of the agent to find. + + Returns: + The agent with the matching name, or None if no such agent is found. + """ + for sub_agent in self.sub_agents: + if result := sub_agent.find_agent(name): + return result + return None + + def _create_invocation_context( + self, parent_context: InvocationContext + ) -> InvocationContext: + """Creates a new invocation context for this agent.""" + invocation_context = parent_context.model_copy(update={"agent": self}) + return invocation_context + + @property + def canonical_before_agent_callbacks(self) -> list[_SingleAgentCallback]: + """The resolved self.before_agent_callback field as a list of _SingleAgentCallback. + + This method is only for use by Agent Development Kit. + """ + if not self.before_agent_callback: + return [] + if isinstance(self.before_agent_callback, list): + return self.before_agent_callback + return [self.before_agent_callback] + + @property + def canonical_after_agent_callbacks(self) -> list[_SingleAgentCallback]: + """The resolved self.after_agent_callback field as a list of _SingleAgentCallback. + + This method is only for use by Agent Development Kit. + """ + if not self.after_agent_callback: + return [] + if isinstance(self.after_agent_callback, list): + return self.after_agent_callback + return [self.after_agent_callback] + + async def _handle_before_agent_callback( + self, ctx: InvocationContext + ) -> Optional[Event]: + """Runs the before_agent_callback if it exists. + + Args: + ctx: InvocationContext, the invocation context for this agent. + + Returns: + Optional[Event]: an event if callback provides content or changed state. + """ + callback_context = CallbackContext(ctx) + + # Run callbacks from the plugins. + before_agent_callback_content = ( + await ctx.plugin_manager.run_before_agent_callback( + agent=self, callback_context=callback_context + ) ) - if inspect.isawaitable(before_agent_callback_content): - before_agent_callback_content = await before_agent_callback_content + + # If no overrides are provided from the plugins, further run the canonical + # callbacks. + if not before_agent_callback_content and self.canonical_before_agent_callbacks: + for callback in self.canonical_before_agent_callbacks: + before_agent_callback_content = callback( + callback_context=callback_context + ) + if inspect.isawaitable(before_agent_callback_content): + before_agent_callback_content = await before_agent_callback_content + if before_agent_callback_content: + break + + # Process the override content if exists, and further process the state + # change if exists. if before_agent_callback_content: - break - - # Process the override content if exists, and further process the state - # change if exists. - if before_agent_callback_content: - ret_event = Event( - invocation_id=ctx.invocation_id, - author=self.name, - branch=ctx.branch, - content=before_agent_callback_content, - actions=callback_context._event_actions, - ) - ctx.end_invocation = True - return ret_event - - if callback_context.state.has_delta(): - return Event( - invocation_id=ctx.invocation_id, - author=self.name, - branch=ctx.branch, - actions=callback_context._event_actions, - ) - - return None - - async def _handle_after_agent_callback( - self, invocation_context: InvocationContext - ) -> Optional[Event]: - """Runs the after_agent_callback if it exists. - - Args: - invocation_context: InvocationContext, the invocation context for this - agent. - - Returns: - Optional[Event]: an event if callback provides content or changed state. - """ - - callback_context = CallbackContext(invocation_context) - - # Run callbacks from the plugins. - after_agent_callback_content = ( - await invocation_context.plugin_manager.run_after_agent_callback( - agent=self, callback_context=callback_context + ret_event = Event( + invocation_id=ctx.invocation_id, + author=self.name, + branch=ctx.branch, + content=before_agent_callback_content, + actions=callback_context._event_actions, + ) + ctx.end_invocation = True + return ret_event + + if callback_context.state.has_delta(): + return Event( + invocation_id=ctx.invocation_id, + author=self.name, + branch=ctx.branch, + actions=callback_context._event_actions, + ) + + return None + + async def _handle_after_agent_callback( + self, invocation_context: InvocationContext + ) -> Optional[Event]: + """Runs the after_agent_callback if it exists. + + Args: + invocation_context: InvocationContext, the invocation context for this + agent. + + Returns: + Optional[Event]: an event if callback provides content or changed state. + """ + + callback_context = CallbackContext(invocation_context) + + # Run callbacks from the plugins. + after_agent_callback_content = ( + await invocation_context.plugin_manager.run_after_agent_callback( + agent=self, callback_context=callback_context + ) ) - ) - # If no overrides are provided from the plugins, further run the canonical - # callbacks. - if ( - not after_agent_callback_content - and self.canonical_after_agent_callbacks - ): - for callback in self.canonical_after_agent_callbacks: - after_agent_callback_content = callback( - callback_context=callback_context - ) - if inspect.isawaitable(after_agent_callback_content): - after_agent_callback_content = await after_agent_callback_content + # If no overrides are provided from the plugins, further run the canonical + # callbacks. + if not after_agent_callback_content and self.canonical_after_agent_callbacks: + for callback in self.canonical_after_agent_callbacks: + after_agent_callback_content = callback( + callback_context=callback_context + ) + if inspect.isawaitable(after_agent_callback_content): + after_agent_callback_content = await after_agent_callback_content + if after_agent_callback_content: + break + + # Process the override content if exists, and further process the state + # change if exists. if after_agent_callback_content: - break - - # Process the override content if exists, and further process the state - # change if exists. - if after_agent_callback_content: - ret_event = Event( - invocation_id=invocation_context.invocation_id, - author=self.name, - branch=invocation_context.branch, - content=after_agent_callback_content, - actions=callback_context._event_actions, - ) - return ret_event - - if callback_context.state.has_delta(): - return Event( - invocation_id=invocation_context.invocation_id, - author=self.name, - branch=invocation_context.branch, - content=after_agent_callback_content, - actions=callback_context._event_actions, - ) - return None - - @override - def model_post_init(self, __context: Any) -> None: - self.__set_parent_agent_for_sub_agents() - - @field_validator('name', mode='after') - @classmethod - def validate_name(cls, value: str): - if not value.isidentifier(): - raise ValueError( - f'Found invalid agent name: `{value}`.' - ' Agent name must be a valid identifier. It should start with a' - ' letter (a-z, A-Z) or an underscore (_), and can only contain' - ' letters, digits (0-9), and underscores.' - ) - if value == 'user': - raise ValueError( - "Agent name cannot be `user`. `user` is reserved for end-user's" - ' input.' - ) - return value - - @field_validator('sub_agents', mode='after') - @classmethod - def validate_sub_agents_unique_names( - cls, value: list[BaseAgent] - ) -> list[BaseAgent]: - """Validates that all sub-agents have unique names. - - Args: - value: The list of sub-agents to validate. - - Returns: - The validated list of sub-agents. - - """ - if not value: - return value - - seen_names: set[str] = set() - duplicates: set[str] = set() - - for sub_agent in value: - name = sub_agent.name - if name in seen_names: - duplicates.add(name) - else: - seen_names.add(name) - - if duplicates: - duplicate_names_str = ', '.join( - f'`{name}`' for name in sorted(duplicates) - ) - logger.warning( - 'Found duplicate sub-agent names: %s. ' - 'All sub-agents must have unique names.', - duplicate_names_str, - ) - - return value - - def __set_parent_agent_for_sub_agents(self) -> BaseAgent: - for sub_agent in self.sub_agents: - if sub_agent.parent_agent is not None: - raise ValueError( - f'Agent `{sub_agent.name}` already has a parent agent, current' - f' parent: `{sub_agent.parent_agent.name}`, trying to add:' - f' `{self.name}`' - ) - sub_agent.parent_agent = self - return self - - @final - @classmethod - @experimental - def from_config( - cls: Type[SelfAgent], - config: BaseAgentConfig, - config_abs_path: str, - ) -> SelfAgent: - """Creates an agent from a config. - - If sub-classes uses a custom agent config, override `_from_config_kwargs` - method to return an updated kwargs for agent constructor. - - Args: - config: The config to create the agent from. - config_abs_path: The absolute path to the config file that contains the - agent config. - - Returns: - The created agent. - """ - kwargs = cls.__create_kwargs(config, config_abs_path) - kwargs = cls._parse_config(config, config_abs_path, kwargs) - return cls(**kwargs) - - @classmethod - @experimental - def _parse_config( - cls: Type[SelfAgent], - config: BaseAgentConfig, - config_abs_path: str, - kwargs: Dict[str, Any], - ) -> Dict[str, Any]: - """Parses the config and returns updated kwargs to construct the agent. - - Sub-classes should override this method to use a custom agent config class. - - Args: - config: The config to parse. - config_abs_path: The absolute path to the config file that contains the - agent config. - kwargs: The keyword arguments used for agent constructor. - - Returns: - The updated keyword arguments used for agent constructor. - """ - return kwargs - - @classmethod - def __create_kwargs( - cls, - config: BaseAgentConfig, - config_abs_path: str, - ) -> Dict[str, Any]: - """Creates kwargs for the fields of BaseAgent.""" - - from .config_agent_utils import resolve_agent_reference - from .config_agent_utils import resolve_callbacks - - kwargs: Dict[str, Any] = { - 'name': config.name, - 'description': config.description, - } - if config.sub_agents: - sub_agents = [] - for sub_agent_config in config.sub_agents: - sub_agent = resolve_agent_reference(sub_agent_config, config_abs_path) - sub_agents.append(sub_agent) - kwargs['sub_agents'] = sub_agents - - if config.before_agent_callbacks: - kwargs['before_agent_callback'] = resolve_callbacks( - config.before_agent_callbacks - ) - if config.after_agent_callbacks: - kwargs['after_agent_callback'] = resolve_callbacks( - config.after_agent_callbacks - ) - return kwargs + ret_event = Event( + invocation_id=invocation_context.invocation_id, + author=self.name, + branch=invocation_context.branch, + content=after_agent_callback_content, + actions=callback_context._event_actions, + ) + return ret_event + + if callback_context.state.has_delta(): + return Event( + invocation_id=invocation_context.invocation_id, + author=self.name, + branch=invocation_context.branch, + content=after_agent_callback_content, + actions=callback_context._event_actions, + ) + return None + + @override + def model_post_init(self, __context: Any) -> None: + self.__set_parent_agent_for_sub_agents() + + @field_validator("name", mode="after") + @classmethod + def validate_name(cls, value: str): + if not value.isidentifier(): + raise ValueError( + f"Found invalid agent name: `{value}`." + " Agent name must be a valid identifier. It should start with a" + " letter (a-z, A-Z) or an underscore (_), and can only contain" + " letters, digits (0-9), and underscores." + ) + if value == "user": + raise ValueError( + "Agent name cannot be `user`. `user` is reserved for end-user's" + " input." + ) + return value + + @field_validator("sub_agents", mode="after") + @classmethod + def validate_sub_agents_unique_names( + cls, value: list[BaseAgent] + ) -> list[BaseAgent]: + """Validates that all sub-agents have unique names. + + Args: + value: The list of sub-agents to validate. + + Returns: + The validated list of sub-agents. + + """ + if not value: + return value + + seen_names: set[str] = set() + duplicates: set[str] = set() + + for sub_agent in value: + name = sub_agent.name + if name in seen_names: + duplicates.add(name) + else: + seen_names.add(name) + + if duplicates: + duplicate_names_str = ", ".join(f"`{name}`" for name in sorted(duplicates)) + logger.warning( + "Found duplicate sub-agent names: %s. " + "All sub-agents must have unique names.", + duplicate_names_str, + ) + + return value + + def __set_parent_agent_for_sub_agents(self) -> BaseAgent: + for sub_agent in self.sub_agents: + if sub_agent.parent_agent is not None: + raise ValueError( + f"Agent `{sub_agent.name}` already has a parent agent, current" + f" parent: `{sub_agent.parent_agent.name}`, trying to add:" + f" `{self.name}`" + ) + sub_agent.parent_agent = self + return self + + @final + @classmethod + @experimental + def from_config( + cls: Type[SelfAgent], + config: BaseAgentConfig, + config_abs_path: str, + ) -> SelfAgent: + """Creates an agent from a config. + + If sub-classes uses a custom agent config, override `_from_config_kwargs` + method to return an updated kwargs for agent constructor. + + Args: + config: The config to create the agent from. + config_abs_path: The absolute path to the config file that contains the + agent config. + + Returns: + The created agent. + """ + kwargs = cls.__create_kwargs(config, config_abs_path) + kwargs = cls._parse_config(config, config_abs_path, kwargs) + return cls(**kwargs) + + @classmethod + @experimental + def _parse_config( + cls: Type[SelfAgent], + config: BaseAgentConfig, + config_abs_path: str, + kwargs: Dict[str, Any], + ) -> Dict[str, Any]: + """Parses the config and returns updated kwargs to construct the agent. + + Sub-classes should override this method to use a custom agent config class. + + Args: + config: The config to parse. + config_abs_path: The absolute path to the config file that contains the + agent config. + kwargs: The keyword arguments used for agent constructor. + + Returns: + The updated keyword arguments used for agent constructor. + """ + return kwargs + + @classmethod + def __create_kwargs( + cls, + config: BaseAgentConfig, + config_abs_path: str, + ) -> Dict[str, Any]: + """Creates kwargs for the fields of BaseAgent.""" + + from .config_agent_utils import resolve_agent_reference + from .config_agent_utils import resolve_callbacks + + kwargs: Dict[str, Any] = { + "name": config.name, + "description": config.description, + } + if config.sub_agents: + sub_agents = [] + for sub_agent_config in config.sub_agents: + sub_agent = resolve_agent_reference(sub_agent_config, config_abs_path) + sub_agents.append(sub_agent) + kwargs["sub_agents"] = sub_agents + + if config.before_agent_callbacks: + kwargs["before_agent_callback"] = resolve_callbacks( + config.before_agent_callbacks + ) + if config.after_agent_callbacks: + kwargs["after_agent_callback"] = resolve_callbacks( + config.after_agent_callbacks + ) + return kwargs diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index 5abaef589f..82542b0a9d 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -55,6 +55,7 @@ from ..tools.tool_context import ToolContext from ..utils.context_utils import Aclosing from ..utils.feature_decorator import experimental +from ..utils.telemetry_utils import is_telemetry_enabled from .base_agent import BaseAgent from .base_agent import BaseAgentState from .base_agent_config import BaseAgentConfig @@ -63,7 +64,7 @@ from .llm_agent_config import LlmAgentConfig from .readonly_context import ReadonlyContext -logger = logging.getLogger('google_adk.' + __name__) +logger = logging.getLogger("google_adk." + __name__) _SingleBeforeModelCallback: TypeAlias = Callable[ [CallbackContext, LlmRequest], @@ -125,9 +126,7 @@ list[_SingleOnToolErrorCallback], ] -InstructionProvider: TypeAlias = Callable[ - [ReadonlyContext], Union[str, Awaitable[str]] -] +InstructionProvider: TypeAlias = Callable[[ReadonlyContext], Union[str, Awaitable[str]]] ToolUnion: TypeAlias = Union[Callable, BaseTool, BaseToolset] @@ -138,62 +137,62 @@ async def _convert_tool_union_to_tools( model: Union[str, BaseLlm], multiple_tools: bool = False, ) -> list[BaseTool]: - from ..tools.google_search_tool import GoogleSearchTool - from ..tools.vertex_ai_search_tool import VertexAiSearchTool - - # Wrap google_search tool with AgentTool if there are multiple tools because - # the built-in tools cannot be used together with other tools. - # TODO(b/448114567): Remove once the workaround is no longer needed. - if multiple_tools and isinstance(tool_union, GoogleSearchTool): - from ..tools.google_search_agent_tool import create_google_search_agent - from ..tools.google_search_agent_tool import GoogleSearchAgentTool - - search_tool = cast(GoogleSearchTool, tool_union) - if search_tool.bypass_multi_tools_limit: - return [GoogleSearchAgentTool(create_google_search_agent(model))] - - # Replace VertexAiSearchTool with DiscoveryEngineSearchTool if there are - # multiple tools because the built-in tools cannot be used together with - # other tools. - # TODO(b/448114567): Remove once the workaround is no longer needed. - if multiple_tools and isinstance(tool_union, VertexAiSearchTool): - from ..tools.discovery_engine_search_tool import DiscoveryEngineSearchTool - - vais_tool = cast(VertexAiSearchTool, tool_union) - if vais_tool.bypass_multi_tools_limit: - return [ - DiscoveryEngineSearchTool( - data_store_id=vais_tool.data_store_id, - data_store_specs=vais_tool.data_store_specs, - search_engine_id=vais_tool.search_engine_id, - filter=vais_tool.filter, - max_results=vais_tool.max_results, - ) - ] + from ..tools.google_search_tool import GoogleSearchTool + from ..tools.vertex_ai_search_tool import VertexAiSearchTool - if isinstance(tool_union, BaseTool): - return [tool_union] - if callable(tool_union): - return [FunctionTool(func=tool_union)] + # Wrap google_search tool with AgentTool if there are multiple tools because + # the built-in tools cannot be used together with other tools. + # TODO(b/448114567): Remove once the workaround is no longer needed. + if multiple_tools and isinstance(tool_union, GoogleSearchTool): + from ..tools.google_search_agent_tool import create_google_search_agent + from ..tools.google_search_agent_tool import GoogleSearchAgentTool - # At this point, tool_union must be a BaseToolset - return await tool_union.get_tools_with_prefix(ctx) + search_tool = cast(GoogleSearchTool, tool_union) + if search_tool.bypass_multi_tools_limit: + return [GoogleSearchAgentTool(create_google_search_agent(model))] + + # Replace VertexAiSearchTool with DiscoveryEngineSearchTool if there are + # multiple tools because the built-in tools cannot be used together with + # other tools. + # TODO(b/448114567): Remove once the workaround is no longer needed. + if multiple_tools and isinstance(tool_union, VertexAiSearchTool): + from ..tools.discovery_engine_search_tool import DiscoveryEngineSearchTool + + vais_tool = cast(VertexAiSearchTool, tool_union) + if vais_tool.bypass_multi_tools_limit: + return [ + DiscoveryEngineSearchTool( + data_store_id=vais_tool.data_store_id, + data_store_specs=vais_tool.data_store_specs, + search_engine_id=vais_tool.search_engine_id, + filter=vais_tool.filter, + max_results=vais_tool.max_results, + ) + ] + + if isinstance(tool_union, BaseTool): + return [tool_union] + if callable(tool_union): + return [FunctionTool(func=tool_union)] + + # At this point, tool_union must be a BaseToolset + return await tool_union.get_tools_with_prefix(ctx) class LlmAgent(BaseAgent): - """LLM-based Agent.""" + """LLM-based Agent.""" - model: Union[str, BaseLlm] = '' - """The model to use for the agent. + model: Union[str, BaseLlm] = "" + """The model to use for the agent. When not set, the agent will inherit the model from its ancestor. """ - config_type: ClassVar[Type[BaseAgentConfig]] = LlmAgentConfig - """The config type for this agent.""" + config_type: ClassVar[Type[BaseAgentConfig]] = LlmAgentConfig + """The config type for this agent.""" - instruction: Union[str, InstructionProvider] = '' - """Dynamic instructions for the LLM model, guiding the agent's behavior. + instruction: Union[str, InstructionProvider] = "" + """Dynamic instructions for the LLM model, guiding the agent's behavior. These instructions can contain placeholders like {variable_name} that will be resolved at runtime using session state and context. @@ -206,8 +205,8 @@ class LlmAgent(BaseAgent): comes first in the prompt, followed by dynamic content (instruction). """ - global_instruction: Union[str, InstructionProvider] = '' - """Instructions for all the agents in the entire agent tree. + global_instruction: Union[str, InstructionProvider] = "" + """Instructions for all the agents in the entire agent tree. DEPRECATED: This field is deprecated and will be removed in a future version. Use GlobalInstructionPlugin instead, which provides the same functionality @@ -219,8 +218,8 @@ class LlmAgent(BaseAgent): or personality. """ - static_instruction: Optional[types.ContentUnion] = None - """Static instruction content sent literally as system instruction at the beginning. + static_instruction: Optional[types.ContentUnion] = None + """Static instruction content sent literally as system instruction at the beginning. This field is for content that never changes and doesn't contain placeholders. It's sent directly to the model without any processing or variable substitution. @@ -270,11 +269,11 @@ class LlmAgent(BaseAgent): ``` """ - tools: list[ToolUnion] = Field(default_factory=list) - """Tools available to this agent.""" + tools: list[ToolUnion] = Field(default_factory=list) + """Tools available to this agent.""" - generate_content_config: Optional[types.GenerateContentConfig] = None - """The additional content generation configurations. + generate_content_config: Optional[types.GenerateContentConfig] = None + """The additional content generation configurations. NOTE: not all fields are usable, e.g. tools must be configured via `tools`, thinking_config must be configured via `planner` in LlmAgent. @@ -283,21 +282,21 @@ class LlmAgent(BaseAgent): settings, etc. """ - # LLM-based agent transfer configs - Start - disallow_transfer_to_parent: bool = False - """Disallows LLM-controlled transferring to the parent agent. + # LLM-based agent transfer configs - Start + disallow_transfer_to_parent: bool = False + """Disallows LLM-controlled transferring to the parent agent. NOTE: Setting this as True also prevents this agent from continuing to reply to the end-user, and will transfer control back to the parent agent in the next turn. This behavior prevents one-way transfer, in which end-user may be stuck with one agent that cannot transfer to other agents in the agent tree. """ - disallow_transfer_to_peers: bool = False - """Disallows LLM-controlled transferring to the peer agents.""" - # LLM-based agent transfer configs - End + disallow_transfer_to_peers: bool = False + """Disallows LLM-controlled transferring to the peer agents.""" + # LLM-based agent transfer configs - End - include_contents: Literal['default', 'none'] = 'default' - """Controls content inclusion in model requests. + include_contents: Literal["default", "none"] = "default" + """Controls content inclusion in model requests. Options: default: Model receives relevant conversation history @@ -305,36 +304,36 @@ class LlmAgent(BaseAgent): instruction and input """ - # Controlled input/output configurations - Start - input_schema: Optional[type[BaseModel]] = None - """The input schema when agent is used as a tool.""" - output_schema: Optional[type[BaseModel]] = None - """The output schema when agent replies. + # Controlled input/output configurations - Start + input_schema: Optional[type[BaseModel]] = None + """The input schema when agent is used as a tool.""" + output_schema: Optional[type[BaseModel]] = None + """The output schema when agent replies. NOTE: When this is set, agent can ONLY reply and CANNOT use any tools, such as function tools, RAGs, agent transfer, etc. """ - output_key: Optional[str] = None - """The key in session state to store the output of the agent. + output_key: Optional[str] = None + """The key in session state to store the output of the agent. Typically use cases: - Extracts agent reply for later use, such as in tools, callbacks, etc. - Connects agents to coordinate with each other. """ - # Controlled input/output configurations - End + # Controlled input/output configurations - End - # Advance features - Start - planner: Optional[BasePlanner] = None - """Instructs the agent to make a plan and execute it step by step. + # Advance features - Start + planner: Optional[BasePlanner] = None + """Instructs the agent to make a plan and execute it step by step. NOTE: To use model's built-in thinking features, set the `thinking_config` field in `google.adk.planners.built_in_planner`. """ - code_executor: Optional[BaseCodeExecutor] = None - """Allow agent to execute code blocks from model responses using the provided + code_executor: Optional[BaseCodeExecutor] = None + """Allow agent to execute code blocks from model responses using the provided CodeExecutor. Check out available code executions in `google.adk.code_executor` package. @@ -342,11 +341,11 @@ class LlmAgent(BaseAgent): NOTE: To use model's built-in code executor, use the `BuiltInCodeExecutor`. """ - # Advance features - End + # Advance features - End - # Callbacks - Start - before_model_callback: Optional[BeforeModelCallback] = None - """Callback or list of callbacks to be called before calling the LLM. + # Callbacks - Start + before_model_callback: Optional[BeforeModelCallback] = None + """Callback or list of callbacks to be called before calling the LLM. When a list of callbacks is provided, the callbacks will be called in the order they are listed until a callback does not return None. @@ -360,8 +359,8 @@ class LlmAgent(BaseAgent): The content to return to the user. When present, the model call will be skipped and the provided content will be returned to user. """ - after_model_callback: Optional[AfterModelCallback] = None - """Callback or list of callbacks to be called after calling the LLM. + after_model_callback: Optional[AfterModelCallback] = None + """Callback or list of callbacks to be called after calling the LLM. When a list of callbacks is provided, the callbacks will be called in the order they are listed until a callback does not return None. @@ -374,8 +373,8 @@ class LlmAgent(BaseAgent): The content to return to the user. When present, the actual model response will be ignored and the provided content will be returned to user. """ - on_model_error_callback: Optional[OnModelErrorCallback] = None - """Callback or list of callbacks to be called when a model call encounters an error. + on_model_error_callback: Optional[OnModelErrorCallback] = None + """Callback or list of callbacks to be called when a model call encounters an error. When a list of callbacks is provided, the callbacks will be called in the order they are listed until a callback does not return None. @@ -389,8 +388,8 @@ class LlmAgent(BaseAgent): The content to return to the user. When present, the error will be ignored and the provided content will be returned to user. """ - before_tool_callback: Optional[BeforeToolCallback] = None - """Callback or list of callbacks to be called before calling the tool. + before_tool_callback: Optional[BeforeToolCallback] = None + """Callback or list of callbacks to be called before calling the tool. When a list of callbacks is provided, the callbacks will be called in the order they are listed until a callback does not return None. @@ -404,8 +403,8 @@ class LlmAgent(BaseAgent): The tool response. When present, the returned tool response will be used and the framework will skip calling the actual tool. """ - after_tool_callback: Optional[AfterToolCallback] = None - """Callback or list of callbacks to be called after calling the tool. + after_tool_callback: Optional[AfterToolCallback] = None + """Callback or list of callbacks to be called after calling the tool. When a list of callbacks is provided, the callbacks will be called in the order they are listed until a callback does not return None. @@ -419,8 +418,8 @@ class LlmAgent(BaseAgent): Returns: When present, the returned dict will be used as tool result. """ - on_tool_error_callback: Optional[OnToolErrorCallback] = None - """Callback or list of callbacks to be called when a tool call encounters an error. + on_tool_error_callback: Optional[OnToolErrorCallback] = None + """Callback or list of callbacks to be called when a tool call encounters an error. When a list of callbacks is provided, the callbacks will be called in the order they are listed until a callback does not return None. @@ -434,521 +433,515 @@ class LlmAgent(BaseAgent): Returns: When present, the returned dict will be used as tool result. """ - # Callbacks - End - - @override - async def _run_async_impl( - self, ctx: InvocationContext - ) -> AsyncGenerator[Event, None]: - agent_state = self._load_agent_state(ctx, BaseAgentState) - - # If there is a sub-agent to resume, run it and then end the current - # agent. - if agent_state is not None and ( - agent_to_transfer := self._get_subagent_to_resume(ctx) - ): - async with Aclosing(agent_to_transfer.run_async(ctx)) as agen: - async for event in agen: - yield event - - ctx.set_agent_state(self.name, end_of_agent=True) - yield self._create_agent_state_event(ctx) - return - - should_pause = False - async with Aclosing(self._llm_flow.run_async(ctx)) as agen: - async for event in agen: - self.__maybe_save_output_to_state(event) - yield event - if ctx.should_pause_invocation(event): - # Do not pause immediately, wait until the long running tool call is - # executed. - should_pause = True - if should_pause: - return - - if ctx.is_resumable: - events = ctx._get_events(current_invocation=True, current_branch=True) - if events and any(ctx.should_pause_invocation(e) for e in events[-2:]): - return - # Only yield an end state if the last event is no longer a long running - # tool call. - ctx.set_agent_state(self.name, end_of_agent=True) - yield self._create_agent_state_event(ctx) - - @override - async def _run_live_impl( - self, ctx: InvocationContext - ) -> AsyncGenerator[Event, None]: - async with Aclosing(self._llm_flow.run_live(ctx)) as agen: - async for event in agen: - self.__maybe_save_output_to_state(event) - yield event - if ctx.end_invocation: - return - - @property - def canonical_model(self) -> BaseLlm: - """The resolved self.model field as BaseLlm. - - This method is only for use by Agent Development Kit. - """ - if isinstance(self.model, BaseLlm): - return self.model - elif self.model: # model is non-empty str - return LLMRegistry.new_llm(self.model) - else: # find model from ancestors. - ancestor_agent = self.parent_agent - while ancestor_agent is not None: - if isinstance(ancestor_agent, LlmAgent): - return ancestor_agent.canonical_model - ancestor_agent = ancestor_agent.parent_agent - raise ValueError(f'No model found for {self.name}.') - - async def canonical_instruction( - self, ctx: ReadonlyContext - ) -> tuple[str, bool]: - """The resolved self.instruction field to construct instruction for this agent. - - This method is only for use by Agent Development Kit. - - Args: - ctx: The context to retrieve the session state. - - Returns: - A tuple of (instruction, bypass_state_injection). - instruction: The resolved self.instruction field. - bypass_state_injection: Whether the instruction is based on - InstructionProvider. - """ - if isinstance(self.instruction, str): - return self.instruction, False - else: - instruction = self.instruction(ctx) - if inspect.isawaitable(instruction): - instruction = await instruction - return instruction, True - - async def canonical_global_instruction( - self, ctx: ReadonlyContext - ) -> tuple[str, bool]: - """The resolved self.instruction field to construct global instruction. - - This method is only for use by Agent Development Kit. - - Args: - ctx: The context to retrieve the session state. - - Returns: - A tuple of (instruction, bypass_state_injection). - instruction: The resolved self.global_instruction field. - bypass_state_injection: Whether the instruction is based on - InstructionProvider. - """ - # Issue deprecation warning if global_instruction is being used - if self.global_instruction: - warnings.warn( - 'global_instruction field is deprecated and will be removed in a' - ' future version. Use GlobalInstructionPlugin instead for the same' - ' functionality at the App level. See migration guide for details.', - DeprecationWarning, - stacklevel=2, - ) - - if isinstance(self.global_instruction, str): - return self.global_instruction, False - else: - global_instruction = self.global_instruction(ctx) - if inspect.isawaitable(global_instruction): - global_instruction = await global_instruction - return global_instruction, True - - async def canonical_tools( - self, ctx: ReadonlyContext = None - ) -> list[BaseTool]: - """The resolved self.tools field as a list of BaseTool based on the context. - - This method is only for use by Agent Development Kit. - """ - resolved_tools = [] - # We may need to wrap some built-in tools if there are other tools - # because the built-in tools cannot be used together with other tools. - # TODO(b/448114567): Remove once the workaround is no longer needed. - multiple_tools = len(self.tools) > 1 - for tool_union in self.tools: - resolved_tools.extend( - await _convert_tool_union_to_tools( - tool_union, ctx, self.model, multiple_tools - ) - ) - return resolved_tools - - @property - def canonical_before_model_callbacks( - self, - ) -> list[_SingleBeforeModelCallback]: - """The resolved self.before_model_callback field as a list of _SingleBeforeModelCallback. - - This method is only for use by Agent Development Kit. - """ - if not self.before_model_callback: - return [] - if isinstance(self.before_model_callback, list): - return self.before_model_callback - return [self.before_model_callback] - - @property - def canonical_after_model_callbacks(self) -> list[_SingleAfterModelCallback]: - """The resolved self.after_model_callback field as a list of _SingleAfterModelCallback. - - This method is only for use by Agent Development Kit. - """ - if not self.after_model_callback: - return [] - if isinstance(self.after_model_callback, list): - return self.after_model_callback - return [self.after_model_callback] - - @property - def canonical_on_model_error_callbacks( - self, - ) -> list[_SingleOnModelErrorCallback]: - """The resolved self.on_model_error_callback field as a list of _SingleOnModelErrorCallback. - - This method is only for use by Agent Development Kit. - """ - if not self.on_model_error_callback: - return [] - if isinstance(self.on_model_error_callback, list): - return self.on_model_error_callback - return [self.on_model_error_callback] - - @property - def canonical_before_tool_callbacks( - self, - ) -> list[BeforeToolCallback]: - """The resolved self.before_tool_callback field as a list of BeforeToolCallback. - - This method is only for use by Agent Development Kit. - """ - if not self.before_tool_callback: - return [] - if isinstance(self.before_tool_callback, list): - return self.before_tool_callback - return [self.before_tool_callback] - - @property - def canonical_after_tool_callbacks( - self, - ) -> list[AfterToolCallback]: - """The resolved self.after_tool_callback field as a list of AfterToolCallback. - - This method is only for use by Agent Development Kit. - """ - if not self.after_tool_callback: - return [] - if isinstance(self.after_tool_callback, list): - return self.after_tool_callback - return [self.after_tool_callback] - - @property - def canonical_on_tool_error_callbacks( - self, - ) -> list[OnToolErrorCallback]: - """The resolved self.on_tool_error_callback field as a list of OnToolErrorCallback. - - This method is only for use by Agent Development Kit. - """ - if not self.on_tool_error_callback: - return [] - if isinstance(self.on_tool_error_callback, list): - return self.on_tool_error_callback - return [self.on_tool_error_callback] - - @property - def _llm_flow(self) -> BaseLlmFlow: - if ( - self.disallow_transfer_to_parent - and self.disallow_transfer_to_peers - and not self.sub_agents - ): - return SingleFlow() - else: - return AutoFlow() - - def _get_subagent_to_resume( - self, ctx: InvocationContext - ) -> Optional[BaseAgent]: - """Returns the sub-agent in the llm tree to resume if it exists. - - There are 2 cases where we need to transfer to and resume a sub-agent: - 1. The last event is a transfer to agent response from the current agent. - In this case, we need to return the agent specified in the response. - - 2. The last event's author isn't the current agent, or the user is - responding to another agent's tool call. - In this case, we need to return the LAST agent being transferred to - from the current agent. - """ - events = ctx._get_events(current_invocation=True, current_branch=True) - if not events: - return None - - last_event = events[-1] - if last_event.author == self.name: - # Last event is from current agent. Return transfer_to_agent in the event - # if it exists, or None. - return self.__get_transfer_to_agent_or_none(last_event, self.name) - - # Last event is from user or another agent. - if last_event.author == 'user': - function_call_event = ctx._find_matching_function_call(last_event) - if not function_call_event: - raise ValueError( - 'No agent to transfer to for resuming agent from function response' - f' {self.name}' - ) - if function_call_event.author == self.name: - # User is responding to a tool call from the current agent. - # Current agent should continue, so no sub-agent to resume. + # Callbacks - End + + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + agent_state = self._load_agent_state(ctx, BaseAgentState) + + # If there is a sub-agent to resume, run it and then end the current + # agent. + if agent_state is not None and ( + agent_to_transfer := self._get_subagent_to_resume(ctx) + ): + async with Aclosing(agent_to_transfer.run_async(ctx)) as agen: + async for event in agen: + yield event + + ctx.set_agent_state(self.name, end_of_agent=True) + yield self._create_agent_state_event(ctx) + return + + should_pause = False + async with Aclosing(self._llm_flow.run_async(ctx)) as agen: + async for event in agen: + self.__maybe_save_output_to_state(event) + yield event + if ctx.should_pause_invocation(event): + # Do not pause immediately, wait until the long running tool call is + # executed. + should_pause = True + if should_pause: + return + + if ctx.is_resumable: + events = ctx._get_events(current_invocation=True, current_branch=True) + if events and any(ctx.should_pause_invocation(e) for e in events[-2:]): + return + # Only yield an end state if the last event is no longer a long running + # tool call. + ctx.set_agent_state(self.name, end_of_agent=True) + yield self._create_agent_state_event(ctx) + + @override + async def _run_live_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + async with Aclosing(self._llm_flow.run_live(ctx)) as agen: + async for event in agen: + self.__maybe_save_output_to_state(event) + yield event + if ctx.end_invocation: + return + + @property + def canonical_model(self) -> BaseLlm: + """The resolved self.model field as BaseLlm. + + This method is only for use by Agent Development Kit. + """ + if isinstance(self.model, BaseLlm): + return self.model + elif self.model: # model is non-empty str + return LLMRegistry.new_llm(self.model) + else: # find model from ancestors. + ancestor_agent = self.parent_agent + while ancestor_agent is not None: + if isinstance(ancestor_agent, LlmAgent): + return ancestor_agent.canonical_model + ancestor_agent = ancestor_agent.parent_agent + raise ValueError(f"No model found for {self.name}.") + + async def canonical_instruction(self, ctx: ReadonlyContext) -> tuple[str, bool]: + """The resolved self.instruction field to construct instruction for this agent. + + This method is only for use by Agent Development Kit. + + Args: + ctx: The context to retrieve the session state. + + Returns: + A tuple of (instruction, bypass_state_injection). + instruction: The resolved self.instruction field. + bypass_state_injection: Whether the instruction is based on + InstructionProvider. + """ + if isinstance(self.instruction, str): + return self.instruction, False + else: + instruction = self.instruction(ctx) + if inspect.isawaitable(instruction): + instruction = await instruction + return instruction, True + + async def canonical_global_instruction( + self, ctx: ReadonlyContext + ) -> tuple[str, bool]: + """The resolved self.instruction field to construct global instruction. + + This method is only for use by Agent Development Kit. + + Args: + ctx: The context to retrieve the session state. + + Returns: + A tuple of (instruction, bypass_state_injection). + instruction: The resolved self.global_instruction field. + bypass_state_injection: Whether the instruction is based on + InstructionProvider. + """ + # Issue deprecation warning if global_instruction is being used + if self.global_instruction: + warnings.warn( + "global_instruction field is deprecated and will be removed in a" + " future version. Use GlobalInstructionPlugin instead for the same" + " functionality at the App level. See migration guide for details.", + DeprecationWarning, + stacklevel=2, + ) + + if isinstance(self.global_instruction, str): + return self.global_instruction, False + else: + global_instruction = self.global_instruction(ctx) + if inspect.isawaitable(global_instruction): + global_instruction = await global_instruction + return global_instruction, True + + async def canonical_tools(self, ctx: ReadonlyContext = None) -> list[BaseTool]: + """The resolved self.tools field as a list of BaseTool based on the context. + + This method is only for use by Agent Development Kit. + """ + resolved_tools = [] + # We may need to wrap some built-in tools if there are other tools + # because the built-in tools cannot be used together with other tools. + # TODO(b/448114567): Remove once the workaround is no longer needed. + multiple_tools = len(self.tools) > 1 + for tool_union in self.tools: + resolved_tools.extend( + await _convert_tool_union_to_tools( + tool_union, ctx, self.model, multiple_tools + ) + ) + return resolved_tools + + @property + def canonical_before_model_callbacks( + self, + ) -> list[_SingleBeforeModelCallback]: + """The resolved self.before_model_callback field as a list of _SingleBeforeModelCallback. + + This method is only for use by Agent Development Kit. + """ + if not self.before_model_callback: + return [] + if isinstance(self.before_model_callback, list): + return self.before_model_callback + return [self.before_model_callback] + + @property + def canonical_after_model_callbacks(self) -> list[_SingleAfterModelCallback]: + """The resolved self.after_model_callback field as a list of _SingleAfterModelCallback. + + This method is only for use by Agent Development Kit. + """ + if not self.after_model_callback: + return [] + if isinstance(self.after_model_callback, list): + return self.after_model_callback + return [self.after_model_callback] + + @property + def canonical_on_model_error_callbacks( + self, + ) -> list[_SingleOnModelErrorCallback]: + """The resolved self.on_model_error_callback field as a list of _SingleOnModelErrorCallback. + + This method is only for use by Agent Development Kit. + """ + if not self.on_model_error_callback: + return [] + if isinstance(self.on_model_error_callback, list): + return self.on_model_error_callback + return [self.on_model_error_callback] + + @property + def canonical_before_tool_callbacks( + self, + ) -> list[BeforeToolCallback]: + """The resolved self.before_tool_callback field as a list of BeforeToolCallback. + + This method is only for use by Agent Development Kit. + """ + if not self.before_tool_callback: + return [] + if isinstance(self.before_tool_callback, list): + return self.before_tool_callback + return [self.before_tool_callback] + + @property + def canonical_after_tool_callbacks( + self, + ) -> list[AfterToolCallback]: + """The resolved self.after_tool_callback field as a list of AfterToolCallback. + + This method is only for use by Agent Development Kit. + """ + if not self.after_tool_callback: + return [] + if isinstance(self.after_tool_callback, list): + return self.after_tool_callback + return [self.after_tool_callback] + + @property + def canonical_on_tool_error_callbacks( + self, + ) -> list[OnToolErrorCallback]: + """The resolved self.on_tool_error_callback field as a list of OnToolErrorCallback. + + This method is only for use by Agent Development Kit. + """ + if not self.on_tool_error_callback: + return [] + if isinstance(self.on_tool_error_callback, list): + return self.on_tool_error_callback + return [self.on_tool_error_callback] + + @property + def _llm_flow(self) -> BaseLlmFlow: + if ( + self.disallow_transfer_to_parent + and self.disallow_transfer_to_peers + and not self.sub_agents + ): + return SingleFlow() + else: + return AutoFlow() + + def _get_subagent_to_resume(self, ctx: InvocationContext) -> Optional[BaseAgent]: + """Returns the sub-agent in the llm tree to resume if it exists. + + There are 2 cases where we need to transfer to and resume a sub-agent: + 1. The last event is a transfer to agent response from the current agent. + In this case, we need to return the agent specified in the response. + + 2. The last event's author isn't the current agent, or the user is + responding to another agent's tool call. + In this case, we need to return the LAST agent being transferred to + from the current agent. + """ + events = ctx._get_events(current_invocation=True, current_branch=True) + if not events: + return None + + last_event = events[-1] + if last_event.author == self.name: + # Last event is from current agent. Return transfer_to_agent in the event + # if it exists, or None. + return self.__get_transfer_to_agent_or_none(last_event, self.name) + + # Last event is from user or another agent. + if last_event.author == "user": + function_call_event = ctx._find_matching_function_call(last_event) + if not function_call_event: + raise ValueError( + "No agent to transfer to for resuming agent from function response" + f" {self.name}" + ) + if function_call_event.author == self.name: + # User is responding to a tool call from the current agent. + # Current agent should continue, so no sub-agent to resume. + return None + + # Last event is from another agent, or from user for another agent's tool + # call. We need to find the last agent we transferred to. + for event in reversed(events): + if agent := self.__get_transfer_to_agent_or_none(event, self.name): + return agent + return None - # Last event is from another agent, or from user for another agent's tool - # call. We need to find the last agent we transferred to. - for event in reversed(events): - if agent := self.__get_transfer_to_agent_or_none(event, self.name): - return agent - - return None - - def __get_agent_to_run(self, agent_name: str) -> BaseAgent: - """Find the agent to run under the root agent by name.""" - agent_to_run = self.root_agent.find_agent(agent_name) - if not agent_to_run: - available = self._get_available_agent_names() - error_msg = ( - f"Agent '{agent_name}' not found.\n" - f"Available agents: {', '.join(available)}\n\n" - 'Possible causes:\n' - ' 1. Agent not registered before being referenced\n' - ' 2. Agent name mismatch (typo or case sensitivity)\n' - ' 3. Timing issue (agent referenced before creation)\n\n' - 'Suggested fixes:\n' - ' - Verify agent is registered with root agent\n' - ' - Check agent name spelling and case\n' - ' - Ensure agents are created before being referenced' - ) - raise ValueError(error_msg) - return agent_to_run - - def _get_available_agent_names(self) -> list[str]: - """Helper to get all agent names in the tree for error reporting. - - This is a private helper method used only for error message formatting. - Traverses the agent tree starting from root_agent and collects all - agent names for display in error messages. - - Returns: - List of all agent names in the agent tree. - """ - agents = [] - - def collect_agents(agent): - agents.append(agent.name) - if hasattr(agent, 'sub_agents') and agent.sub_agents: - for sub_agent in agent.sub_agents: - collect_agents(sub_agent) - - collect_agents(self.root_agent) - return agents - - def __get_transfer_to_agent_or_none( - self, event: Event, from_agent: str - ) -> Optional[BaseAgent]: - """Returns the agent to run if the event is a transfer to agent response.""" - function_responses = event.get_function_responses() - if not function_responses: - return None - for function_response in function_responses: - if ( - function_response.name == 'transfer_to_agent' - and event.author == from_agent - and event.actions.transfer_to_agent != from_agent - ): - return self.__get_agent_to_run(event.actions.transfer_to_agent) - return None - - def __maybe_save_output_to_state(self, event: Event): - """Saves the model output to state if needed.""" - # skip if the event was authored by some other agent (e.g. current agent - # transferred to another agent) - if event.author != self.name: - logger.debug( - 'Skipping output save for agent %s: event authored by %s', - self.name, - event.author, - ) - return - if ( - self.output_key - and event.is_final_response() - and event.content - and event.content.parts - ): - - result = ''.join( - part.text - for part in event.content.parts - if part.text and not part.thought - ) - if self.output_schema: - # If the result from the final chunk is just whitespace or empty, - # it means this is an empty final chunk of a stream. - # Do not attempt to parse it as JSON. - if not result.strip(): - return - result = self.output_schema.model_validate_json(result).model_dump( - exclude_none=True - ) - event.actions.state_delta[self.output_key] = result - - @model_validator(mode='after') - def __model_validator_after(self) -> LlmAgent: - return self - - @field_validator('generate_content_config', mode='after') - @classmethod - def validate_generate_content_config( - cls, generate_content_config: Optional[types.GenerateContentConfig] - ) -> types.GenerateContentConfig: - if not generate_content_config: - return types.GenerateContentConfig() - if generate_content_config.thinking_config: - raise ValueError('Thinking config should be set via LlmAgent.planner.') - if generate_content_config.tools: - raise ValueError('All tools must be set via LlmAgent.tools.') - if generate_content_config.system_instruction: - raise ValueError( - 'System instruction must be set via LlmAgent.instruction.' - ) - if generate_content_config.response_schema: - raise ValueError( - 'Response schema must be set via LlmAgent.output_schema.' - ) - return generate_content_config - - @classmethod - @experimental - def _resolve_tools( - cls, tool_configs: list[ToolConfig], config_abs_path: str - ) -> list[Any]: - """Resolve tools from configuration. - - Args: - tool_configs: List of tool configurations (ToolConfig objects). - config_abs_path: The absolute path to the agent config file. - - Returns: - List of resolved tool objects. - """ - - resolved_tools = [] - for tool_config in tool_configs: - if '.' not in tool_config.name: - # ADK built-in tools - module = importlib.import_module('google.adk.tools') - obj = getattr(module, tool_config.name) - else: - # User-defined tools - module_path, obj_name = tool_config.name.rsplit('.', 1) - module = importlib.import_module(module_path) - obj = getattr(module, obj_name) - - if isinstance(obj, BaseTool) or isinstance(obj, BaseToolset): - logger.debug( - 'Tool %s is an instance of BaseTool/BaseToolset.', tool_config.name - ) - resolved_tools.append(obj) - elif inspect.isclass(obj) and ( - issubclass(obj, BaseTool) or issubclass(obj, BaseToolset) - ): - logger.debug( - 'Tool %s is a sub-class of BaseTool/BaseToolset.', tool_config.name - ) - resolved_tools.append( - obj.from_config(tool_config.args, config_abs_path) - ) - elif callable(obj): - if tool_config.args: - logger.debug( - 'Tool %s is a user-defined tool-generating function.', - tool_config.name, - ) - resolved_tools.append(obj(tool_config.args)) - else: - logger.debug( - 'Tool %s is a user-defined function tool.', tool_config.name - ) - resolved_tools.append(obj) - else: - raise ValueError(f'Invalid tool YAML config: {tool_config}.') - - return resolved_tools - - @override - @classmethod - @experimental - def _parse_config( - cls: Type[LlmAgent], - config: LlmAgentConfig, - config_abs_path: str, - kwargs: Dict[str, Any], - ) -> Dict[str, Any]: - from .config_agent_utils import resolve_callbacks - from .config_agent_utils import resolve_code_reference - - if config.model_code: - kwargs['model'] = resolve_code_reference(config.model_code) - elif config.model: - kwargs['model'] = config.model - if config.instruction: - kwargs['instruction'] = config.instruction - if config.static_instruction: - kwargs['static_instruction'] = config.static_instruction - if config.disallow_transfer_to_parent: - kwargs['disallow_transfer_to_parent'] = config.disallow_transfer_to_parent - if config.disallow_transfer_to_peers: - kwargs['disallow_transfer_to_peers'] = config.disallow_transfer_to_peers - if config.include_contents != 'default': - kwargs['include_contents'] = config.include_contents - if config.input_schema: - kwargs['input_schema'] = resolve_code_reference(config.input_schema) - if config.output_schema: - kwargs['output_schema'] = resolve_code_reference(config.output_schema) - if config.output_key: - kwargs['output_key'] = config.output_key - if config.tools: - kwargs['tools'] = cls._resolve_tools(config.tools, config_abs_path) - if config.before_model_callbacks: - kwargs['before_model_callback'] = resolve_callbacks( - config.before_model_callbacks - ) - if config.after_model_callbacks: - kwargs['after_model_callback'] = resolve_callbacks( - config.after_model_callbacks - ) - if config.before_tool_callbacks: - kwargs['before_tool_callback'] = resolve_callbacks( - config.before_tool_callbacks - ) - if config.after_tool_callbacks: - kwargs['after_tool_callback'] = resolve_callbacks( - config.after_tool_callbacks - ) - if config.generate_content_config: - kwargs['generate_content_config'] = config.generate_content_config - - return kwargs + def __get_agent_to_run(self, agent_name: str) -> BaseAgent: + """Find the agent to run under the root agent by name.""" + agent_to_run = self.root_agent.find_agent(agent_name) + if not agent_to_run: + available = self._get_available_agent_names() + error_msg = ( + f"Agent '{agent_name}' not found.\n" + f"Available agents: {', '.join(available)}\n\n" + "Possible causes:\n" + " 1. Agent not registered before being referenced\n" + " 2. Agent name mismatch (typo or case sensitivity)\n" + " 3. Timing issue (agent referenced before creation)\n\n" + "Suggested fixes:\n" + " - Verify agent is registered with root agent\n" + " - Check agent name spelling and case\n" + " - Ensure agents are created before being referenced" + ) + raise ValueError(error_msg) + return agent_to_run + + def _get_available_agent_names(self) -> list[str]: + """Helper to get all agent names in the tree for error reporting. + + This is a private helper method used only for error message formatting. + Traverses the agent tree starting from root_agent and collects all + agent names for display in error messages. + + Returns: + List of all agent names in the agent tree. + """ + agents = [] + + def collect_agents(agent): + agents.append(agent.name) + if hasattr(agent, "sub_agents") and agent.sub_agents: + for sub_agent in agent.sub_agents: + collect_agents(sub_agent) + + collect_agents(self.root_agent) + return agents + + def __get_transfer_to_agent_or_none( + self, event: Event, from_agent: str + ) -> Optional[BaseAgent]: + """Returns the agent to run if the event is a transfer to agent response.""" + function_responses = event.get_function_responses() + if not function_responses: + return None + for function_response in function_responses: + if ( + function_response.name == "transfer_to_agent" + and event.author == from_agent + and event.actions.transfer_to_agent != from_agent + ): + return self.__get_agent_to_run(event.actions.transfer_to_agent) + return None + + def __maybe_save_output_to_state(self, event: Event): + """Saves the model output to state if needed.""" + # skip if the event was authored by some other agent (e.g. current agent + # transferred to another agent) + if event.author != self.name: + logger.debug( + "Skipping output save for agent %s: event authored by %s", + self.name, + event.author, + ) + return + if ( + self.output_key + and event.is_final_response() + and event.content + and event.content.parts + ): + + result = "".join( + part.text + for part in event.content.parts + if part.text and not part.thought + ) + if self.output_schema: + # If the result from the final chunk is just whitespace or empty, + # it means this is an empty final chunk of a stream. + # Do not attempt to parse it as JSON. + if not result.strip(): + return + result = self.output_schema.model_validate_json(result).model_dump( + exclude_none=True + ) + event.actions.state_delta[self.output_key] = result + + @model_validator(mode="after") + def __model_validator_after(self) -> LlmAgent: + root_agent = getattr(self, "root_agent", None) or self + disable_telemetry: bool = not is_telemetry_enabled(root_agent) + if hasattr(self.model, "disable_telemetry"): + self.model.disable_telemetry = disable_telemetry + return self + + @field_validator("generate_content_config", mode="after") + @classmethod + def validate_generate_content_config( + cls, generate_content_config: Optional[types.GenerateContentConfig] + ) -> types.GenerateContentConfig: + if not generate_content_config: + return types.GenerateContentConfig() + if generate_content_config.thinking_config: + raise ValueError("Thinking config should be set via LlmAgent.planner.") + if generate_content_config.tools: + raise ValueError("All tools must be set via LlmAgent.tools.") + if generate_content_config.system_instruction: + raise ValueError("System instruction must be set via LlmAgent.instruction.") + if generate_content_config.response_schema: + raise ValueError("Response schema must be set via LlmAgent.output_schema.") + return generate_content_config + + @classmethod + @experimental + def _resolve_tools( + cls, tool_configs: list[ToolConfig], config_abs_path: str + ) -> list[Any]: + """Resolve tools from configuration. + + Args: + tool_configs: List of tool configurations (ToolConfig objects). + config_abs_path: The absolute path to the agent config file. + + Returns: + List of resolved tool objects. + """ + + resolved_tools = [] + for tool_config in tool_configs: + if "." not in tool_config.name: + # ADK built-in tools + module = importlib.import_module("google.adk.tools") + obj = getattr(module, tool_config.name) + else: + # User-defined tools + module_path, obj_name = tool_config.name.rsplit(".", 1) + module = importlib.import_module(module_path) + obj = getattr(module, obj_name) + + if isinstance(obj, BaseTool) or isinstance(obj, BaseToolset): + logger.debug( + "Tool %s is an instance of BaseTool/BaseToolset.", tool_config.name + ) + resolved_tools.append(obj) + elif inspect.isclass(obj) and ( + issubclass(obj, BaseTool) or issubclass(obj, BaseToolset) + ): + logger.debug( + "Tool %s is a sub-class of BaseTool/BaseToolset.", tool_config.name + ) + resolved_tools.append( + obj.from_config(tool_config.args, config_abs_path) + ) + elif callable(obj): + if tool_config.args: + logger.debug( + "Tool %s is a user-defined tool-generating function.", + tool_config.name, + ) + resolved_tools.append(obj(tool_config.args)) + else: + logger.debug( + "Tool %s is a user-defined function tool.", tool_config.name + ) + resolved_tools.append(obj) + else: + raise ValueError(f"Invalid tool YAML config: {tool_config}.") + + return resolved_tools + + @override + @classmethod + @experimental + def _parse_config( + cls: Type[LlmAgent], + config: LlmAgentConfig, + config_abs_path: str, + kwargs: Dict[str, Any], + ) -> Dict[str, Any]: + from .config_agent_utils import resolve_callbacks + from .config_agent_utils import resolve_code_reference + + if config.model_code: + kwargs["model"] = resolve_code_reference(config.model_code) + elif config.model: + kwargs["model"] = config.model + if config.instruction: + kwargs["instruction"] = config.instruction + if config.static_instruction: + kwargs["static_instruction"] = config.static_instruction + if config.disallow_transfer_to_parent: + kwargs["disallow_transfer_to_parent"] = config.disallow_transfer_to_parent + if config.disallow_transfer_to_peers: + kwargs["disallow_transfer_to_peers"] = config.disallow_transfer_to_peers + if config.include_contents != "default": + kwargs["include_contents"] = config.include_contents + if config.input_schema: + kwargs["input_schema"] = resolve_code_reference(config.input_schema) + if config.output_schema: + kwargs["output_schema"] = resolve_code_reference(config.output_schema) + if config.output_key: + kwargs["output_key"] = config.output_key + if config.tools: + kwargs["tools"] = cls._resolve_tools(config.tools, config_abs_path) + if config.before_model_callbacks: + kwargs["before_model_callback"] = resolve_callbacks( + config.before_model_callbacks + ) + if config.after_model_callbacks: + kwargs["after_model_callback"] = resolve_callbacks( + config.after_model_callbacks + ) + if config.before_tool_callbacks: + kwargs["before_tool_callback"] = resolve_callbacks( + config.before_tool_callbacks + ) + if config.after_tool_callbacks: + kwargs["after_tool_callback"] = resolve_callbacks( + config.after_tool_callbacks + ) + if config.generate_content_config: + kwargs["generate_content_config"] = config.generate_content_config + + return kwargs Agent: TypeAlias = LlmAgent diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 824cd26be1..90f5e1423d 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -50,17 +50,18 @@ from ...tools.google_search_tool import google_search from ...tools.tool_context import ToolContext from ...utils.context_utils import Aclosing +from ...utils.telemetry_utils import is_telemetry_enabled from .audio_cache_manager import AudioCacheManager if TYPE_CHECKING: - from ...agents.llm_agent import LlmAgent - from ...models.base_llm import BaseLlm - from ._base_llm_processor import BaseLlmRequestProcessor - from ._base_llm_processor import BaseLlmResponseProcessor + from ...agents.llm_agent import LlmAgent + from ...models.base_llm import BaseLlm + from ._base_llm_processor import BaseLlmRequestProcessor + from ._base_llm_processor import BaseLlmResponseProcessor -logger = logging.getLogger('google_adk.' + __name__) +logger = logging.getLogger("google_adk." + __name__) -_ADK_AGENT_NAME_LABEL_KEY = 'adk_agent_name' +_ADK_AGENT_NAME_LABEL_KEY = "adk_agent_name" # Timing configuration DEFAULT_REQUEST_QUEUE_TIMEOUT = 0.25 @@ -72,987 +73,1013 @@ class BaseLlmFlow(ABC): - """A basic flow that calls the LLM in a loop until a final response is generated. - - This flow ends when it transfer to another agent. - """ - - def __init__(self): - self.request_processors: list[BaseLlmRequestProcessor] = [] - self.response_processors: list[BaseLlmResponseProcessor] = [] - - # Initialize configuration and managers - self.audio_cache_manager = AudioCacheManager() - - async def run_live( - self, - invocation_context: InvocationContext, - ) -> AsyncGenerator[Event, None]: - """Runs the flow using live api.""" - llm_request = LlmRequest() - event_id = Event.new_id() - - # Preprocess before calling the LLM. - async with Aclosing( - self._preprocess_async(invocation_context, llm_request) - ) as agen: - async for event in agen: - yield event - if invocation_context.end_invocation: - return - - llm = self.__get_llm(invocation_context) - logger.debug( - 'Establishing live connection for agent: %s with llm request: %s', - invocation_context.agent.name, - llm_request, - ) - - attempt = 1 - while True: - try: - # On subsequent attempts, use the saved token to reconnect - if invocation_context.live_session_resumption_handle: - logger.info('Attempting to reconnect (Attempt %s)...', attempt) - attempt += 1 - if not llm_request.live_connect_config: - llm_request.live_connect_config = types.LiveConnectConfig() - llm_request.live_connect_config.session_resumption.handle = ( - invocation_context.live_session_resumption_handle - ) - llm_request.live_connect_config.session_resumption.transparent = True - - logger.info( - 'Establishing live connection for agent: %s', + """A basic flow that calls the LLM in a loop until a final response is generated. + + This flow ends when it transfer to another agent. + """ + + def __init__(self): + self.request_processors: list[BaseLlmRequestProcessor] = [] + self.response_processors: list[BaseLlmResponseProcessor] = [] + + # Initialize configuration and managers + self.audio_cache_manager = AudioCacheManager() + + async def run_live( + self, + invocation_context: InvocationContext, + ) -> AsyncGenerator[Event, None]: + """Runs the flow using live api.""" + llm_request = LlmRequest() + event_id = Event.new_id() + + # Preprocess before calling the LLM. + async with Aclosing( + self._preprocess_async(invocation_context, llm_request) + ) as agen: + async for event in agen: + yield event + if invocation_context.end_invocation: + return + + llm = self.__get_llm(invocation_context) + logger.debug( + "Establishing live connection for agent: %s with llm request: %s", invocation_context.agent.name, + llm_request, ) - async with llm.connect(llm_request) as llm_connection: - if llm_request.contents: - # Sends the conversation history to the model. - with tracer.start_as_current_span('send_data'): - # Combine regular contents with audio/transcription from session - logger.debug('Sending history to model: %s', llm_request.contents) - await llm_connection.send_history(llm_request.contents) - trace_send_data( - invocation_context, event_id, llm_request.contents - ) - - send_task = asyncio.create_task( - self._send_to_model(llm_connection, invocation_context) - ) - - try: - async with Aclosing( - self._receive_from_model( - llm_connection, - event_id, - invocation_context, - llm_request, + + attempt = 1 + while True: + try: + # On subsequent attempts, use the saved token to reconnect + if invocation_context.live_session_resumption_handle: + logger.info("Attempting to reconnect (Attempt %s)...", attempt) + attempt += 1 + if not llm_request.live_connect_config: + llm_request.live_connect_config = types.LiveConnectConfig() + llm_request.live_connect_config.session_resumption.handle = ( + invocation_context.live_session_resumption_handle + ) + llm_request.live_connect_config.session_resumption.transparent = ( + True + ) + + logger.info( + "Establishing live connection for agent: %s", + invocation_context.agent.name, ) - ) as agen: - async for event in agen: - # Empty event means the queue is closed. - if not event: - break - logger.debug('Receive new event: %s', event) - yield event - # send back the function response to models - if event.get_function_responses(): - logger.debug( - 'Sending back last function response event: %s', event - ) - invocation_context.live_request_queue.send_content( - event.content - ) - # We handle agent transfer here in `run_live` rather than - # in `_postprocess_live` to prevent duplication of function - # response processing. If agent transfer were handled in - # `_postprocess_live`, events yielded from child agent's - # `run_live` would bubble up to parent agent's `run_live`, - # causing `event.get_function_responses()` to be true in both - # child and parent, and `send_content()` to be called twice for - # the same function response. By handling agent transfer here, - # we ensure that only child agent processes its own function - # responses after the transfer. - if ( - event.content - and event.content.parts - and event.content.parts[0].function_response - and event.content.parts[0].function_response.name - == 'transfer_to_agent' - ): - await asyncio.sleep(DEFAULT_TRANSFER_AGENT_DELAY) - # cancel the tasks that belongs to the closed connection. - send_task.cancel() - logger.debug('Closing live connection') - await llm_connection.close() - logger.debug('Live connection closed.') - # transfer to the sub agent. - transfer_to_agent = event.actions.transfer_to_agent - if transfer_to_agent: - logger.debug('Transferring to agent: %s', transfer_to_agent) - agent_to_run = self._get_agent_to_run( - invocation_context, transfer_to_agent + async with llm.connect(llm_request) as llm_connection: + if llm_request.contents: + # Sends the conversation history to the model. + if is_telemetry_enabled(invocation_context.agent): + with tracer.start_as_current_span("send_data"): + # Combine regular contents with audio/transcription from session + logger.debug( + "Sending history to model: %s", llm_request.contents + ) + await llm_connection.send_history(llm_request.contents) + trace_send_data( + invocation_context, event_id, llm_request.contents + ) + else: + logger.debug( + "Sending history to model: %s", llm_request.contents + ) + await llm_connection.send_history(llm_request.contents) + + send_task = asyncio.create_task( + self._send_to_model(llm_connection, invocation_context) ) - async with Aclosing( - agent_to_run.run_live(invocation_context) - ) as agen: - async for item in agen: - yield item - if ( - event.content - and event.content.parts - and event.content.parts[0].function_response - and event.content.parts[0].function_response.name - == 'task_completed' - ): - # this is used for sequential agent to signal the end of the agent. - await asyncio.sleep(DEFAULT_TASK_COMPLETION_DELAY) - # cancel the tasks that belongs to the closed connection. - send_task.cancel() - return - finally: - # Clean up - if not send_task.done(): - send_task.cancel() + + try: + async with Aclosing( + self._receive_from_model( + llm_connection, + event_id, + invocation_context, + llm_request, + ) + ) as agen: + async for event in agen: + # Empty event means the queue is closed. + if not event: + break + logger.debug("Receive new event: %s", event) + yield event + # send back the function response to models + if event.get_function_responses(): + logger.debug( + "Sending back last function response event: %s", + event, + ) + invocation_context.live_request_queue.send_content( + event.content + ) + # We handle agent transfer here in `run_live` rather than + # in `_postprocess_live` to prevent duplication of function + # response processing. If agent transfer were handled in + # `_postprocess_live`, events yielded from child agent's + # `run_live` would bubble up to parent agent's `run_live`, + # causing `event.get_function_responses()` to be true in both + # child and parent, and `send_content()` to be called twice for + # the same function response. By handling agent transfer here, + # we ensure that only child agent processes its own function + # responses after the transfer. + if ( + event.content + and event.content.parts + and event.content.parts[0].function_response + and event.content.parts[0].function_response.name + == "transfer_to_agent" + ): + await asyncio.sleep(DEFAULT_TRANSFER_AGENT_DELAY) + # cancel the tasks that belongs to the closed connection. + send_task.cancel() + logger.debug("Closing live connection") + await llm_connection.close() + logger.debug("Live connection closed.") + # transfer to the sub agent. + transfer_to_agent = event.actions.transfer_to_agent + if transfer_to_agent: + logger.debug( + "Transferring to agent: %s", + transfer_to_agent, + ) + agent_to_run = self._get_agent_to_run( + invocation_context, transfer_to_agent + ) + async with Aclosing( + agent_to_run.run_live(invocation_context) + ) as agen: + async for item in agen: + yield item + if ( + event.content + and event.content.parts + and event.content.parts[0].function_response + and event.content.parts[0].function_response.name + == "task_completed" + ): + # this is used for sequential agent to signal the end of the agent. + await asyncio.sleep(DEFAULT_TASK_COMPLETION_DELAY) + # cancel the tasks that belongs to the closed connection. + send_task.cancel() + return + finally: + # Clean up + if not send_task.done(): + send_task.cancel() + try: + await send_task + except asyncio.CancelledError: + pass + except (ConnectionClosed, ConnectionClosedOK) as e: + # when the session timeout, it will just close and not throw exception. + # so this is for bad cases + logger.error("Connection closed: %s.", e) + raise + except Exception as e: + logger.error( + "An unexpected error occurred in live flow: %s", e, exc_info=True + ) + raise + + async def _send_to_model( + self, + llm_connection: BaseLlmConnection, + invocation_context: InvocationContext, + ): + """Sends data to model.""" + while True: + live_request_queue = invocation_context.live_request_queue try: - await send_task - except asyncio.CancelledError: - pass - except (ConnectionClosed, ConnectionClosedOK) as e: - # when the session timeout, it will just close and not throw exception. - # so this is for bad cases - logger.error('Connection closed: %s.', e) - raise - except Exception as e: - logger.error( - 'An unexpected error occurred in live flow: %s', e, exc_info=True - ) - raise - - async def _send_to_model( - self, - llm_connection: BaseLlmConnection, - invocation_context: InvocationContext, - ): - """Sends data to model.""" - while True: - live_request_queue = invocation_context.live_request_queue - try: - # Streamlit's execution model doesn't preemptively yield to the event - # loop. Therefore, we must explicitly introduce timeouts to allow the - # event loop to process events. - # TODO: revert back(remove timeout) once we move off streamlit. - live_request = await asyncio.wait_for( - live_request_queue.get(), timeout=DEFAULT_REQUEST_QUEUE_TIMEOUT - ) - # duplicate the live_request to all the active streams - logger.debug( - 'Sending live request %s to active streams: %s', - live_request, - invocation_context.active_streaming_tools, - ) - if invocation_context.active_streaming_tools: - for active_streaming_tool in ( - invocation_context.active_streaming_tools - ).values(): - if active_streaming_tool.stream: - active_streaming_tool.stream.send(live_request) - await asyncio.sleep(0) - except asyncio.TimeoutError: - continue - if live_request.close: - await llm_connection.close() - return - - if live_request.activity_start: - await llm_connection.send_realtime(types.ActivityStart()) - elif live_request.activity_end: - await llm_connection.send_realtime(types.ActivityEnd()) - elif live_request.blob: - # Cache input audio chunks before flushing - self.audio_cache_manager.cache_audio( - invocation_context, live_request.blob, cache_type='input' + # Streamlit's execution model doesn't preemptively yield to the event + # loop. Therefore, we must explicitly introduce timeouts to allow the + # event loop to process events. + # TODO: revert back(remove timeout) once we move off streamlit. + live_request = await asyncio.wait_for( + live_request_queue.get(), timeout=DEFAULT_REQUEST_QUEUE_TIMEOUT + ) + # duplicate the live_request to all the active streams + logger.debug( + "Sending live request %s to active streams: %s", + live_request, + invocation_context.active_streaming_tools, + ) + if invocation_context.active_streaming_tools: + for active_streaming_tool in ( + invocation_context.active_streaming_tools + ).values(): + if active_streaming_tool.stream: + active_streaming_tool.stream.send(live_request) + await asyncio.sleep(0) + except asyncio.TimeoutError: + continue + if live_request.close: + await llm_connection.close() + return + + if live_request.activity_start: + await llm_connection.send_realtime(types.ActivityStart()) + elif live_request.activity_end: + await llm_connection.send_realtime(types.ActivityEnd()) + elif live_request.blob: + # Cache input audio chunks before flushing + self.audio_cache_manager.cache_audio( + invocation_context, live_request.blob, cache_type="input" + ) + + await llm_connection.send_realtime(live_request.blob) + + if live_request.content: + await llm_connection.send_content(live_request.content) + + async def _receive_from_model( + self, + llm_connection: BaseLlmConnection, + event_id: str, + invocation_context: InvocationContext, + llm_request: LlmRequest, + ) -> AsyncGenerator[Event, None]: + """Receive data from model and process events using BaseLlmConnection.""" + + def get_author_for_event(llm_response): + """Get the author of the event. + + When the model returns transcription, the author is "user". Otherwise, the + author is the agent name(not 'model'). + + Args: + llm_response: The LLM response from the LLM call. + """ + if ( + llm_response + and llm_response.content + and llm_response.content.role == "user" + ): + return "user" + else: + return invocation_context.agent.name + + assert invocation_context.live_request_queue + try: + while True: + async with Aclosing(llm_connection.receive()) as agen: + async for llm_response in agen: + if llm_response.live_session_resumption_update: + logger.info( + "Update session resumption handle:" + f" {llm_response.live_session_resumption_update}." + ) + invocation_context.live_session_resumption_handle = ( + llm_response.live_session_resumption_update.new_handle + ) + model_response_event = Event( + id=Event.new_id(), + invocation_id=invocation_context.invocation_id, + author=get_author_for_event(llm_response), + ) + + async with Aclosing( + self._postprocess_live( + invocation_context, + llm_request, + llm_response, + model_response_event, + ) + ) as agen: + async for event in agen: + # Cache output audio chunks from model responses + # TODO: support video data + if ( + invocation_context.run_config.save_live_blob + and event.content + and event.content.parts + and event.content.parts[0].inline_data + and event.content.parts[ + 0 + ].inline_data.mime_type.startswith("audio/") + ): + audio_blob = types.Blob( + data=event.content.parts[0].inline_data.data, + mime_type=event.content.parts[ + 0 + ].inline_data.mime_type, + ) + self.audio_cache_manager.cache_audio( + invocation_context, + audio_blob, + cache_type="output", + ) + + yield event + # Give opportunity for other tasks to run. + await asyncio.sleep(0) + except ConnectionClosedOK: + pass + + async def run_async( + self, invocation_context: InvocationContext + ) -> AsyncGenerator[Event, None]: + """Runs the flow.""" + while True: + last_event = None + async with Aclosing(self._run_one_step_async(invocation_context)) as agen: + async for event in agen: + last_event = event + yield event + if not last_event or last_event.is_final_response() or last_event.partial: + if last_event and last_event.partial: + logger.warning("The last event is partial, which is not expected.") + break + + async def _run_one_step_async( + self, + invocation_context: InvocationContext, + ) -> AsyncGenerator[Event, None]: + """One step means one LLM call.""" + llm_request = LlmRequest() + + # Preprocess before calling the LLM. + async with Aclosing( + self._preprocess_async(invocation_context, llm_request) + ) as agen: + async for event in agen: + yield event + if invocation_context.end_invocation: + return + + # Resume the LLM agent based on the last event from the current branch. + # 1. User content: continue the normal flow + # 2. Function call: call the tool and get the response event. + events = invocation_context._get_events( + current_invocation=True, current_branch=True ) - await llm_connection.send_realtime(live_request.blob) - - if live_request.content: - await llm_connection.send_content(live_request.content) - - async def _receive_from_model( - self, - llm_connection: BaseLlmConnection, - event_id: str, - invocation_context: InvocationContext, - llm_request: LlmRequest, - ) -> AsyncGenerator[Event, None]: - """Receive data from model and process events using BaseLlmConnection.""" - - def get_author_for_event(llm_response): - """Get the author of the event. - - When the model returns transcription, the author is "user". Otherwise, the - author is the agent name(not 'model'). - - Args: - llm_response: The LLM response from the LLM call. - """ - if ( - llm_response - and llm_response.content - and llm_response.content.role == 'user' - ): - return 'user' - else: - return invocation_context.agent.name - - assert invocation_context.live_request_queue - try: - while True: - async with Aclosing(llm_connection.receive()) as agen: - async for llm_response in agen: - if llm_response.live_session_resumption_update: - logger.info( - 'Update session resumption handle:' - f' {llm_response.live_session_resumption_update}.' - ) - invocation_context.live_session_resumption_handle = ( - llm_response.live_session_resumption_update.new_handle - ) - model_response_event = Event( - id=Event.new_id(), - invocation_id=invocation_context.invocation_id, - author=get_author_for_event(llm_response), + # Long running tool calls should have been handled before this point. + # If there are still long running tool calls, it means the agent is paused + # before, and its branch hasn't been resumed yet. + if ( + invocation_context.is_resumable + and events + and len(events) > 1 + # TODO: here we are using the last 2 events to decide whether to pause + # the invocation. But this is just being optimistic, we should find a + # way to pause when the long running tool call is followed by more than + # one text responses. + and ( + invocation_context.should_pause_invocation(events[-1]) + or invocation_context.should_pause_invocation(events[-2]) ) + ): + return + if ( + invocation_context.is_resumable + and events + and events[-1].get_function_calls() + ): + model_response_event = events[-1] async with Aclosing( - self._postprocess_live( - invocation_context, - llm_request, - llm_response, - model_response_event, + self._postprocess_handle_function_calls_async( + invocation_context, model_response_event, llm_request ) ) as agen: - async for event in agen: - # Cache output audio chunks from model responses - # TODO: support video data - if ( - invocation_context.run_config.save_live_blob - and event.content - and event.content.parts - and event.content.parts[0].inline_data - and event.content.parts[0].inline_data.mime_type.startswith( - 'audio/' + async for event in agen: + event.id = Event.new_id() + yield event + return + + # Calls the LLM. + model_response_event = Event( + id=Event.new_id(), + invocation_id=invocation_context.invocation_id, + author=invocation_context.agent.name, + branch=invocation_context.branch, + ) + async with Aclosing( + self._call_llm_async(invocation_context, llm_request, model_response_event) + ) as agen: + async for llm_response in agen: + # Postprocess after calling the LLM. + async with Aclosing( + self._postprocess_async( + invocation_context, + llm_request, + llm_response, + model_response_event, ) - ): - audio_blob = types.Blob( - data=event.content.parts[0].inline_data.data, - mime_type=event.content.parts[0].inline_data.mime_type, - ) - self.audio_cache_manager.cache_audio( - invocation_context, audio_blob, cache_type='output' - ) + ) as agen: + async for event in agen: + # Update the mutable event id to avoid conflict + model_response_event.id = Event.new_id() + model_response_event.timestamp = ( + datetime.datetime.now().timestamp() + ) + yield event + + async def _preprocess_async( + self, invocation_context: InvocationContext, llm_request: LlmRequest + ) -> AsyncGenerator[Event, None]: + from ...agents.llm_agent import LlmAgent + + agent = invocation_context.agent + if not isinstance(agent, LlmAgent): + raise TypeError(f"Expected agent to be an LlmAgent, but got {type(agent)}") + + # Runs processors. + for processor in self.request_processors: + async with Aclosing( + processor.run_async(invocation_context, llm_request) + ) as agen: + async for event in agen: + yield event + + # Run processors for tools. + + # We may need to wrap some built-in tools if there are other tools + # because the built-in tools cannot be used together with other tools. + # TODO(b/448114567): Remove once the workaround is no longer needed. + multiple_tools = len(agent.tools) > 1 + for tool_union in agent.tools: + tool_context = ToolContext(invocation_context) + + # If it's a toolset, process it first + if isinstance(tool_union, BaseToolset): + await tool_union.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) - yield event - # Give opportunity for other tasks to run. - await asyncio.sleep(0) - except ConnectionClosedOK: - pass - - async def run_async( - self, invocation_context: InvocationContext - ) -> AsyncGenerator[Event, None]: - """Runs the flow.""" - while True: - last_event = None - async with Aclosing(self._run_one_step_async(invocation_context)) as agen: - async for event in agen: - last_event = event - yield event - if not last_event or last_event.is_final_response() or last_event.partial: - if last_event and last_event.partial: - logger.warning('The last event is partial, which is not expected.') - break - - async def _run_one_step_async( - self, - invocation_context: InvocationContext, - ) -> AsyncGenerator[Event, None]: - """One step means one LLM call.""" - llm_request = LlmRequest() - - # Preprocess before calling the LLM. - async with Aclosing( - self._preprocess_async(invocation_context, llm_request) - ) as agen: - async for event in agen: - yield event - if invocation_context.end_invocation: - return - - # Resume the LLM agent based on the last event from the current branch. - # 1. User content: continue the normal flow - # 2. Function call: call the tool and get the response event. - events = invocation_context._get_events( - current_invocation=True, current_branch=True - ) - - # Long running tool calls should have been handled before this point. - # If there are still long running tool calls, it means the agent is paused - # before, and its branch hasn't been resumed yet. - if ( - invocation_context.is_resumable - and events - and len(events) > 1 - # TODO: here we are using the last 2 events to decide whether to pause - # the invocation. But this is just being optimistic, we should find a - # way to pause when the long running tool call is followed by more than - # one text responses. - and ( - invocation_context.should_pause_invocation(events[-1]) - or invocation_context.should_pause_invocation(events[-2]) - ) - ): - return + from ...agents.llm_agent import _convert_tool_union_to_tools - if ( - invocation_context.is_resumable - and events - and events[-1].get_function_calls() - ): - model_response_event = events[-1] - async with Aclosing( - self._postprocess_handle_function_calls_async( - invocation_context, model_response_event, llm_request - ) - ) as agen: - async for event in agen: - event.id = Event.new_id() - yield event - return - - # Calls the LLM. - model_response_event = Event( - id=Event.new_id(), - invocation_id=invocation_context.invocation_id, - author=invocation_context.agent.name, - branch=invocation_context.branch, - ) - async with Aclosing( - self._call_llm_async( - invocation_context, llm_request, model_response_event - ) - ) as agen: - async for llm_response in agen: - # Postprocess after calling the LLM. - async with Aclosing( - self._postprocess_async( - invocation_context, - llm_request, - llm_response, - model_response_event, + # Then process all tools from this tool union + tools = await _convert_tool_union_to_tools( + tool_union, + ReadonlyContext(invocation_context), + agent.model, + multiple_tools, ) + for tool in tools: + await tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + async def _postprocess_async( + self, + invocation_context: InvocationContext, + llm_request: LlmRequest, + llm_response: LlmResponse, + model_response_event: Event, + ) -> AsyncGenerator[Event, None]: + """Postprocess after calling the LLM. + + Args: + invocation_context: The invocation context. + llm_request: The original LLM request. + llm_response: The LLM response from the LLM call. + model_response_event: A mutable event for the LLM response. + + Yields: + A generator of events. + """ + + # Runs processors. + async with Aclosing( + self._postprocess_run_processors_async(invocation_context, llm_response) ) as agen: - async for event in agen: - # Update the mutable event id to avoid conflict - model_response_event.id = Event.new_id() - model_response_event.timestamp = datetime.datetime.now().timestamp() - yield event - - async def _preprocess_async( - self, invocation_context: InvocationContext, llm_request: LlmRequest - ) -> AsyncGenerator[Event, None]: - from ...agents.llm_agent import LlmAgent + async for event in agen: + yield event - agent = invocation_context.agent - if not isinstance(agent, LlmAgent): - raise TypeError( - f'Expected agent to be an LlmAgent, but got {type(agent)}' - ) - - # Runs processors. - for processor in self.request_processors: - async with Aclosing( - processor.run_async(invocation_context, llm_request) - ) as agen: - async for event in agen: - yield event - - # Run processors for tools. - - # We may need to wrap some built-in tools if there are other tools - # because the built-in tools cannot be used together with other tools. - # TODO(b/448114567): Remove once the workaround is no longer needed. - multiple_tools = len(agent.tools) > 1 - for tool_union in agent.tools: - tool_context = ToolContext(invocation_context) - - # If it's a toolset, process it first - if isinstance(tool_union, BaseToolset): - await tool_union.process_llm_request( - tool_context=tool_context, llm_request=llm_request - ) + # Skip the model response event if there is no content and no error code. + # This is needed for the code executor to trigger another loop. + if ( + not llm_response.content + and not llm_response.error_code + and not llm_response.interrupted + ): + return - from ...agents.llm_agent import _convert_tool_union_to_tools - - # Then process all tools from this tool union - tools = await _convert_tool_union_to_tools( - tool_union, - ReadonlyContext(invocation_context), - agent.model, - multiple_tools, - ) - for tool in tools: - await tool.process_llm_request( - tool_context=tool_context, llm_request=llm_request + # Builds the event. + model_response_event = self._finalize_model_response_event( + llm_request, llm_response, model_response_event ) + yield model_response_event - async def _postprocess_async( - self, - invocation_context: InvocationContext, - llm_request: LlmRequest, - llm_response: LlmResponse, - model_response_event: Event, - ) -> AsyncGenerator[Event, None]: - """Postprocess after calling the LLM. - - Args: - invocation_context: The invocation context. - llm_request: The original LLM request. - llm_response: The LLM response from the LLM call. - model_response_event: A mutable event for the LLM response. - - Yields: - A generator of events. - """ + # Handles function calls. + if model_response_event.get_function_calls(): - # Runs processors. - async with Aclosing( - self._postprocess_run_processors_async(invocation_context, llm_response) - ) as agen: - async for event in agen: - yield event - - # Skip the model response event if there is no content and no error code. - # This is needed for the code executor to trigger another loop. - if ( - not llm_response.content - and not llm_response.error_code - and not llm_response.interrupted - ): - return + if is_feature_enabled(FeatureName.PROGRESSIVE_SSE_STREAMING): + # In progressive SSE streaming mode stage 1, we skip partial FC events + # Only execute FCs in the final aggregated event (partial=False) + if ( + invocation_context.run_config.streaming_mode == StreamingMode.SSE + and model_response_event.partial + ): + return - # Builds the event. - model_response_event = self._finalize_model_response_event( - llm_request, llm_response, model_response_event - ) - yield model_response_event + async with Aclosing( + self._postprocess_handle_function_calls_async( + invocation_context, model_response_event, llm_request + ) + ) as agen: + async for event in agen: + yield event - # Handles function calls. - if model_response_event.get_function_calls(): + async def _postprocess_live( + self, + invocation_context: InvocationContext, + llm_request: LlmRequest, + llm_response: LlmResponse, + model_response_event: Event, + ) -> AsyncGenerator[Event, None]: + """Postprocess after calling the LLM asynchronously. + + Args: + invocation_context: The invocation context. + llm_request: The original LLM request. + llm_response: The LLM response from the LLM call. + model_response_event: A mutable event for the LLM response. + + Yields: + A generator of events. + """ + + # Runs processors. + async with Aclosing( + self._postprocess_run_processors_async(invocation_context, llm_response) + ) as agen: + async for event in agen: + yield event - if is_feature_enabled(FeatureName.PROGRESSIVE_SSE_STREAMING): - # In progressive SSE streaming mode stage 1, we skip partial FC events - # Only execute FCs in the final aggregated event (partial=False) + # Skip the model response event if there is no content and no error code. + # This is needed for the code executor to trigger another loop. + # But don't skip control events like turn_complete or transcription events. if ( - invocation_context.run_config.streaming_mode == StreamingMode.SSE - and model_response_event.partial + not llm_response.content + and not llm_response.error_code + and not llm_response.interrupted + and not llm_response.turn_complete + and not llm_response.input_transcription + and not llm_response.output_transcription + and not llm_response.usage_metadata ): - return - - async with Aclosing( - self._postprocess_handle_function_calls_async( - invocation_context, model_response_event, llm_request - ) - ) as agen: - async for event in agen: - yield event - - async def _postprocess_live( - self, - invocation_context: InvocationContext, - llm_request: LlmRequest, - llm_response: LlmResponse, - model_response_event: Event, - ) -> AsyncGenerator[Event, None]: - """Postprocess after calling the LLM asynchronously. - - Args: - invocation_context: The invocation context. - llm_request: The original LLM request. - llm_response: The LLM response from the LLM call. - model_response_event: A mutable event for the LLM response. - - Yields: - A generator of events. - """ - - # Runs processors. - async with Aclosing( - self._postprocess_run_processors_async(invocation_context, llm_response) - ) as agen: - async for event in agen: - yield event - - # Skip the model response event if there is no content and no error code. - # This is needed for the code executor to trigger another loop. - # But don't skip control events like turn_complete or transcription events. - if ( - not llm_response.content - and not llm_response.error_code - and not llm_response.interrupted - and not llm_response.turn_complete - and not llm_response.input_transcription - and not llm_response.output_transcription - and not llm_response.usage_metadata - ): - return - - # Handle transcription events ONCE per llm_response, outside the event loop - if llm_response.input_transcription: - model_response_event.input_transcription = ( - llm_response.input_transcription - ) - model_response_event.partial = llm_response.partial - yield model_response_event - return - - if llm_response.output_transcription: - model_response_event.output_transcription = ( - llm_response.output_transcription - ) - model_response_event.partial = llm_response.partial - yield model_response_event - return - - # Flush audio caches based on control events using configurable settings - if invocation_context.run_config.save_live_blob: - flushed_events = await self._handle_control_event_flush( - invocation_context, llm_response - ) - for event in flushed_events: - yield event - if flushed_events: - return - - # Builds the event. - model_response_event = self._finalize_model_response_event( - llm_request, llm_response, model_response_event - ) - yield model_response_event - - # Handles function calls. - if model_response_event.get_function_calls(): - function_response_event = await functions.handle_function_calls_live( - invocation_context, model_response_event, llm_request.tools_dict - ) - # Always yield the function response event first - yield function_response_event - - # Check if this is a set_model_response function response - if json_response := _output_schema_processor.get_structured_model_response( - function_response_event - ): - # Create and yield a final model response event - final_event = ( - _output_schema_processor.create_final_model_response_event( - invocation_context, json_response + return + + # Handle transcription events ONCE per llm_response, outside the event loop + if llm_response.input_transcription: + model_response_event.input_transcription = llm_response.input_transcription + model_response_event.partial = llm_response.partial + yield model_response_event + return + + if llm_response.output_transcription: + model_response_event.output_transcription = ( + llm_response.output_transcription ) - ) - yield final_event - - async def _postprocess_run_processors_async( - self, invocation_context: InvocationContext, llm_response: LlmResponse - ) -> AsyncGenerator[Event, None]: - for processor in self.response_processors: - async with Aclosing( - processor.run_async(invocation_context, llm_response) - ) as agen: - async for event in agen: - yield event - - async def _postprocess_handle_function_calls_async( - self, - invocation_context: InvocationContext, - function_call_event: Event, - llm_request: LlmRequest, - ) -> AsyncGenerator[Event, None]: - if function_response_event := await functions.handle_function_calls_async( - invocation_context, function_call_event, llm_request.tools_dict - ): - auth_event = functions.generate_auth_event( - invocation_context, function_response_event - ) - if auth_event: - yield auth_event - - tool_confirmation_event = functions.generate_request_confirmation_event( - invocation_context, function_call_event, function_response_event - ) - if tool_confirmation_event: - yield tool_confirmation_event - - # Always yield the function response event first - yield function_response_event - - # Check if this is a set_model_response function response - if json_response := _output_schema_processor.get_structured_model_response( - function_response_event - ): - # Create and yield a final model response event - final_event = ( - _output_schema_processor.create_final_model_response_event( - invocation_context, json_response + model_response_event.partial = llm_response.partial + yield model_response_event + return + + # Flush audio caches based on control events using configurable settings + if invocation_context.run_config.save_live_blob: + flushed_events = await self._handle_control_event_flush( + invocation_context, llm_response ) + for event in flushed_events: + yield event + if flushed_events: + return + + # Builds the event. + model_response_event = self._finalize_model_response_event( + llm_request, llm_response, model_response_event ) - yield final_event - transfer_to_agent = function_response_event.actions.transfer_to_agent - if transfer_to_agent: - agent_to_run = self._get_agent_to_run( - invocation_context, transfer_to_agent - ) - async with Aclosing(agent_to_run.run_async(invocation_context)) as agen: - async for event in agen: - yield event - - def _get_agent_to_run( - self, invocation_context: InvocationContext, agent_name: str - ) -> BaseAgent: - root_agent = invocation_context.agent.root_agent - agent_to_run = root_agent.find_agent(agent_name) - if not agent_to_run: - raise ValueError(f'Agent {agent_name} not found in the agent tree.') - return agent_to_run - - async def _call_llm_async( - self, - invocation_context: InvocationContext, - llm_request: LlmRequest, - model_response_event: Event, - ) -> AsyncGenerator[LlmResponse, None]: - # Runs before_model_callback if it exists. - if response := await self._handle_before_model_callback( - invocation_context, llm_request, model_response_event - ): - yield response - return - - llm_request.config = llm_request.config or types.GenerateContentConfig() - llm_request.config.labels = llm_request.config.labels or {} - - # Add agent name as a label to the llm_request. This will help with slicing - # the billing reports on a per-agent basis. - if _ADK_AGENT_NAME_LABEL_KEY not in llm_request.config.labels: - llm_request.config.labels[_ADK_AGENT_NAME_LABEL_KEY] = ( - invocation_context.agent.name - ) - - # Calls the LLM. - llm = self.__get_llm(invocation_context) - - async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]: - with tracer.start_as_current_span('call_llm'): - if invocation_context.run_config.support_cfc: - invocation_context.live_request_queue = LiveRequestQueue() - responses_generator = self.run_live(invocation_context) - async with Aclosing( - self._run_and_handle_error( - responses_generator, - invocation_context, - llm_request, - model_response_event, - ) - ) as agen: - async for llm_response in agen: - # Runs after_model_callback if it exists. - if altered_llm_response := await self._handle_after_model_callback( - invocation_context, llm_response, model_response_event - ): - llm_response = altered_llm_response - # only yield partial response in SSE streaming mode - if ( - invocation_context.run_config.streaming_mode - == StreamingMode.SSE - or not llm_response.partial - ): - yield llm_response - if llm_response.turn_complete: - invocation_context.live_request_queue.close() - else: - # Check if we can make this llm call or not. If the current call - # pushes the counter beyond the max set value, then the execution is - # stopped right here, and exception is thrown. - invocation_context.increment_llm_call_count() - responses_generator = llm.generate_content_async( - llm_request, - stream=invocation_context.run_config.streaming_mode - == StreamingMode.SSE, - ) - async with Aclosing( - self._run_and_handle_error( - responses_generator, - invocation_context, - llm_request, - model_response_event, - ) - ) as agen: - async for llm_response in agen: - trace_call_llm( - invocation_context, - model_response_event.id, - llm_request, - llm_response, - ) - # Runs after_model_callback if it exists. - if altered_llm_response := await self._handle_after_model_callback( - invocation_context, llm_response, model_response_event - ): - llm_response = altered_llm_response - - yield llm_response - - async with Aclosing(_call_llm_with_tracing()) as agen: - async for event in agen: - yield event - - async def _handle_before_model_callback( - self, - invocation_context: InvocationContext, - llm_request: LlmRequest, - model_response_event: Event, - ) -> Optional[LlmResponse]: - from ...agents.llm_agent import LlmAgent + yield model_response_event + + # Handles function calls. + if model_response_event.get_function_calls(): + function_response_event = await functions.handle_function_calls_live( + invocation_context, model_response_event, llm_request.tools_dict + ) + # Always yield the function response event first + yield function_response_event + + # Check if this is a set_model_response function response + if json_response := _output_schema_processor.get_structured_model_response( + function_response_event + ): + # Create and yield a final model response event + final_event = ( + _output_schema_processor.create_final_model_response_event( + invocation_context, json_response + ) + ) + yield final_event - agent = invocation_context.agent + async def _postprocess_run_processors_async( + self, invocation_context: InvocationContext, llm_response: LlmResponse + ) -> AsyncGenerator[Event, None]: + for processor in self.response_processors: + async with Aclosing( + processor.run_async(invocation_context, llm_response) + ) as agen: + async for event in agen: + yield event - callback_context = CallbackContext( - invocation_context, event_actions=model_response_event.actions - ) + async def _postprocess_handle_function_calls_async( + self, + invocation_context: InvocationContext, + function_call_event: Event, + llm_request: LlmRequest, + ) -> AsyncGenerator[Event, None]: + if function_response_event := await functions.handle_function_calls_async( + invocation_context, function_call_event, llm_request.tools_dict + ): + auth_event = functions.generate_auth_event( + invocation_context, function_response_event + ) + if auth_event: + yield auth_event - # First run callbacks from the plugins. - callback_response = ( - await invocation_context.plugin_manager.run_before_model_callback( - callback_context=callback_context, - llm_request=llm_request, - ) - ) - if callback_response: - return callback_response - - # If no overrides are provided from the plugins, further run the canonical - # callbacks. - if not agent.canonical_before_model_callbacks: - return - for callback in agent.canonical_before_model_callbacks: - callback_response = callback( - callback_context=callback_context, llm_request=llm_request - ) - if inspect.isawaitable(callback_response): - callback_response = await callback_response - if callback_response: - return callback_response - - async def _handle_after_model_callback( - self, - invocation_context: InvocationContext, - llm_response: LlmResponse, - model_response_event: Event, - ) -> Optional[LlmResponse]: - from ...agents.llm_agent import LlmAgent + tool_confirmation_event = functions.generate_request_confirmation_event( + invocation_context, function_call_event, function_response_event + ) + if tool_confirmation_event: + yield tool_confirmation_event + + # Always yield the function response event first + yield function_response_event + + # Check if this is a set_model_response function response + if json_response := _output_schema_processor.get_structured_model_response( + function_response_event + ): + # Create and yield a final model response event + final_event = ( + _output_schema_processor.create_final_model_response_event( + invocation_context, json_response + ) + ) + yield final_event + transfer_to_agent = function_response_event.actions.transfer_to_agent + if transfer_to_agent: + agent_to_run = self._get_agent_to_run( + invocation_context, transfer_to_agent + ) + async with Aclosing(agent_to_run.run_async(invocation_context)) as agen: + async for event in agen: + yield event + + def _get_agent_to_run( + self, invocation_context: InvocationContext, agent_name: str + ) -> BaseAgent: + root_agent = invocation_context.agent.root_agent + agent_to_run = root_agent.find_agent(agent_name) + if not agent_to_run: + raise ValueError(f"Agent {agent_name} not found in the agent tree.") + return agent_to_run + + async def _call_llm_async( + self, + invocation_context: InvocationContext, + llm_request: LlmRequest, + model_response_event: Event, + ) -> AsyncGenerator[LlmResponse, None]: + # Runs before_model_callback if it exists. + if response := await self._handle_before_model_callback( + invocation_context, llm_request, model_response_event + ): + yield response + return + + llm_request.config = llm_request.config or types.GenerateContentConfig() + llm_request.config.labels = llm_request.config.labels or {} + + # Add agent name as a label to the llm_request. This will help with slicing + # the billing reports on a per-agent basis. + if _ADK_AGENT_NAME_LABEL_KEY not in llm_request.config.labels: + llm_request.config.labels[_ADK_AGENT_NAME_LABEL_KEY] = ( + invocation_context.agent.name + ) + + # Calls the LLM. + llm = self.__get_llm(invocation_context) + + async def _call_llm_body() -> AsyncGenerator[LlmResponse, None]: + if invocation_context.run_config.support_cfc: + invocation_context.live_request_queue = LiveRequestQueue() + responses_generator = self.run_live(invocation_context) + async with Aclosing( + self._run_and_handle_error( + responses_generator, + invocation_context, + llm_request, + model_response_event, + ) + ) as agen: + async for llm_response in agen: + # Runs after_model_callback if it exists. + if altered_llm_response := await self._handle_after_model_callback( + invocation_context, llm_response, model_response_event + ): + llm_response = altered_llm_response + # only yield partial response in SSE streaming mode + if ( + invocation_context.run_config.streaming_mode + == StreamingMode.SSE + or not llm_response.partial + ): + yield llm_response + if llm_response.turn_complete: + invocation_context.live_request_queue.close() + else: + # Check if we can make this llm call or not. If the current call + # pushes the counter beyond the max set value, then the execution is + # stopped right here, and exception is thrown. + invocation_context.increment_llm_call_count() + responses_generator = llm.generate_content_async( + llm_request, + stream=invocation_context.run_config.streaming_mode + == StreamingMode.SSE, + ) + async with Aclosing( + self._run_and_handle_error( + responses_generator, + invocation_context, + llm_request, + model_response_event, + ) + ) as agen: + async for llm_response in agen: + trace_call_llm( + invocation_context, + model_response_event.id, + llm_request, + llm_response, + ) + # Runs after_model_callback if it exists. + if altered_llm_response := await self._handle_after_model_callback( + invocation_context, llm_response, model_response_event + ): + llm_response = altered_llm_response + + yield llm_response + + async def _call_llm_with_optional_tracing() -> ( + AsyncGenerator[LlmResponse, None] + ): + if is_telemetry_enabled(invocation_context.agent): + with tracer.start_as_current_span("call_llm"): + async with Aclosing(_call_llm_body()) as agen: + async for r in agen: + yield r + else: + async with Aclosing(_call_llm_body()) as agen: + async for r in agen: + yield r + + async with Aclosing(_call_llm_with_optional_tracing()) as agen: + async for event in agen: + yield event - agent = invocation_context.agent + async def _handle_before_model_callback( + self, + invocation_context: InvocationContext, + llm_request: LlmRequest, + model_response_event: Event, + ) -> Optional[LlmResponse]: + from ...agents.llm_agent import LlmAgent + + agent = invocation_context.agent + + callback_context = CallbackContext( + invocation_context, event_actions=model_response_event.actions + ) - # Add grounding metadata to the response if needed. - # TODO(b/448114567): Remove this function once the workaround is no longer needed. - async def _maybe_add_grounding_metadata( - response: Optional[LlmResponse] = None, + # First run callbacks from the plugins. + callback_response = ( + await invocation_context.plugin_manager.run_before_model_callback( + callback_context=callback_context, + llm_request=llm_request, + ) + ) + if callback_response: + return callback_response + + # If no overrides are provided from the plugins, further run the canonical + # callbacks. + if not agent.canonical_before_model_callbacks: + return + for callback in agent.canonical_before_model_callbacks: + callback_response = callback( + callback_context=callback_context, llm_request=llm_request + ) + if inspect.isawaitable(callback_response): + callback_response = await callback_response + if callback_response: + return callback_response + + async def _handle_after_model_callback( + self, + invocation_context: InvocationContext, + llm_response: LlmResponse, + model_response_event: Event, ) -> Optional[LlmResponse]: - readonly_context = ReadonlyContext(invocation_context) - if (tools := invocation_context.canonical_tools_cache) is None: - tools = await agent.canonical_tools(readonly_context) - invocation_context.canonical_tools_cache = tools - - if not any(tool.name == 'google_search_agent' for tool in tools): - return response - ground_metadata = invocation_context.session.state.get( - 'temp:_adk_grounding_metadata', None - ) - if not ground_metadata: - return response - - if not response: - response = llm_response - response.grounding_metadata = ground_metadata - return response - - callback_context = CallbackContext( - invocation_context, event_actions=model_response_event.actions - ) - - # First run callbacks from the plugins. - callback_response = ( - await invocation_context.plugin_manager.run_after_model_callback( - callback_context=CallbackContext(invocation_context), - llm_response=llm_response, + from ...agents.llm_agent import LlmAgent + + agent = invocation_context.agent + + # Add grounding metadata to the response if needed. + # TODO(b/448114567): Remove this function once the workaround is no longer needed. + async def _maybe_add_grounding_metadata( + response: Optional[LlmResponse] = None, + ) -> Optional[LlmResponse]: + readonly_context = ReadonlyContext(invocation_context) + if (tools := invocation_context.canonical_tools_cache) is None: + tools = await agent.canonical_tools(readonly_context) + invocation_context.canonical_tools_cache = tools + + if not any(tool.name == "google_search_agent" for tool in tools): + return response + ground_metadata = invocation_context.session.state.get( + "temp:_adk_grounding_metadata", None + ) + if not ground_metadata: + return response + + if not response: + response = llm_response + response.grounding_metadata = ground_metadata + return response + + callback_context = CallbackContext( + invocation_context, event_actions=model_response_event.actions ) - ) - if callback_response: - return await _maybe_add_grounding_metadata(callback_response) - - # If no overrides are provided from the plugins, further run the canonical - # callbacks. - if not agent.canonical_after_model_callbacks: - return await _maybe_add_grounding_metadata() - for callback in agent.canonical_after_model_callbacks: - callback_response = callback( - callback_context=callback_context, llm_response=llm_response - ) - if inspect.isawaitable(callback_response): - callback_response = await callback_response - if callback_response: - return await _maybe_add_grounding_metadata(callback_response) - return await _maybe_add_grounding_metadata() - - def _finalize_model_response_event( - self, - llm_request: LlmRequest, - llm_response: LlmResponse, - model_response_event: Event, - ) -> Event: - model_response_event = Event.model_validate({ - **model_response_event.model_dump(exclude_none=True), - **llm_response.model_dump(exclude_none=True), - }) - - if model_response_event.content: - function_calls = model_response_event.get_function_calls() - if function_calls: - functions.populate_client_function_call_id(model_response_event) - model_response_event.long_running_tool_ids = ( - functions.get_long_running_function_calls( - function_calls, llm_request.tools_dict + + # First run callbacks from the plugins. + callback_response = ( + await invocation_context.plugin_manager.run_after_model_callback( + callback_context=CallbackContext(invocation_context), + llm_response=llm_response, ) ) + if callback_response: + return await _maybe_add_grounding_metadata(callback_response) + + # If no overrides are provided from the plugins, further run the canonical + # callbacks. + if not agent.canonical_after_model_callbacks: + return await _maybe_add_grounding_metadata() + for callback in agent.canonical_after_model_callbacks: + callback_response = callback( + callback_context=callback_context, llm_response=llm_response + ) + if inspect.isawaitable(callback_response): + callback_response = await callback_response + if callback_response: + return await _maybe_add_grounding_metadata(callback_response) + return await _maybe_add_grounding_metadata() + + def _finalize_model_response_event( + self, + llm_request: LlmRequest, + llm_response: LlmResponse, + model_response_event: Event, + ) -> Event: + model_response_event = Event.model_validate( + { + **model_response_event.model_dump(exclude_none=True), + **llm_response.model_dump(exclude_none=True), + } + ) - return model_response_event + if model_response_event.content: + function_calls = model_response_event.get_function_calls() + if function_calls: + functions.populate_client_function_call_id(model_response_event) + model_response_event.long_running_tool_ids = ( + functions.get_long_running_function_calls( + function_calls, llm_request.tools_dict + ) + ) - async def _handle_control_event_flush( - self, invocation_context: InvocationContext, llm_response: LlmResponse - ) -> list[Event]: - """Handle audio cache flushing based on control events. + return model_response_event - Args: - invocation_context: The invocation context containing audio caches. - llm_response: The LLM response containing control event information. + async def _handle_control_event_flush( + self, invocation_context: InvocationContext, llm_response: LlmResponse + ) -> list[Event]: + """Handle audio cache flushing based on control events. - Returns: - A list of Event objects created from the flushed caches. - """ + Args: + invocation_context: The invocation context containing audio caches. + llm_response: The LLM response containing control event information. - # Log cache statistics if enabled - if DEFAULT_ENABLE_CACHE_STATISTICS: - stats = self.audio_cache_manager.get_cache_stats(invocation_context) - logger.debug('Audio cache stats: %s', stats) - - if llm_response.interrupted: - # user interrupts so the model will stop. we can flush model audio here - return await self.audio_cache_manager.flush_caches( - invocation_context, - flush_user_audio=False, - flush_model_audio=True, - ) - elif llm_response.turn_complete: - # turn completes so we can flush both user and model - return await self.audio_cache_manager.flush_caches( - invocation_context, - flush_user_audio=True, - flush_model_audio=True, - ) - elif getattr(llm_response, 'generation_complete', False): - # model generation complete so we can flush model audio - return await self.audio_cache_manager.flush_caches( - invocation_context, - flush_user_audio=False, - flush_model_audio=True, - ) - return [] - - async def _run_and_handle_error( - self, - response_generator: AsyncGenerator[LlmResponse, None], - invocation_context: InvocationContext, - llm_request: LlmRequest, - model_response_event: Event, - ) -> AsyncGenerator[LlmResponse, None]: - """Runs the response generator and processes the error with plugins. - - Args: - response_generator: The response generator to run. - invocation_context: The invocation context. - llm_request: The LLM request. - model_response_event: The model response event. - - Yields: - A generator of LlmResponse. - """ + Returns: + A list of Event objects created from the flushed caches. + """ - from ...agents.llm_agent import LlmAgent + # Log cache statistics if enabled + if DEFAULT_ENABLE_CACHE_STATISTICS: + stats = self.audio_cache_manager.get_cache_stats(invocation_context) + logger.debug("Audio cache stats: %s", stats) - agent = invocation_context.agent - if not isinstance(agent, LlmAgent): - raise TypeError( - f'Expected agent to be an LlmAgent, but got {type(agent)}' - ) + if llm_response.interrupted: + # user interrupts so the model will stop. we can flush model audio here + return await self.audio_cache_manager.flush_caches( + invocation_context, + flush_user_audio=False, + flush_model_audio=True, + ) + elif llm_response.turn_complete: + # turn completes so we can flush both user and model + return await self.audio_cache_manager.flush_caches( + invocation_context, + flush_user_audio=True, + flush_model_audio=True, + ) + elif getattr(llm_response, "generation_complete", False): + # model generation complete so we can flush model audio + return await self.audio_cache_manager.flush_caches( + invocation_context, + flush_user_audio=False, + flush_model_audio=True, + ) + return [] - async def _run_on_model_error_callbacks( - *, - callback_context: CallbackContext, + async def _run_and_handle_error( + self, + response_generator: AsyncGenerator[LlmResponse, None], + invocation_context: InvocationContext, llm_request: LlmRequest, - error: Exception, - ) -> Optional[LlmResponse]: - error_response = ( - await invocation_context.plugin_manager.run_on_model_error_callback( - callback_context=callback_context, - llm_request=llm_request, - error=error, - ) - ) - if error_response is not None: - return error_response - - for callback in agent.canonical_on_model_error_callbacks: - error_response = callback( - callback_context=callback_context, - llm_request=llm_request, - error=error, - ) - if inspect.isawaitable(error_response): - error_response = await error_response - if error_response is not None: - return error_response - - return None - - try: - async with Aclosing(response_generator) as agen: - async for response in agen: - yield response - except Exception as model_error: - callback_context = CallbackContext( - invocation_context, event_actions=model_response_event.actions - ) - error_response = await _run_on_model_error_callbacks( - callback_context=callback_context, - llm_request=llm_request, - error=model_error, - ) - if error_response is not None: - yield error_response - else: - raise model_error - - def __get_llm(self, invocation_context: InvocationContext) -> BaseLlm: - from ...agents.llm_agent import LlmAgent + model_response_event: Event, + ) -> AsyncGenerator[LlmResponse, None]: + """Runs the response generator and processes the error with plugins. + + Args: + response_generator: The response generator to run. + invocation_context: The invocation context. + llm_request: The LLM request. + model_response_event: The model response event. + + Yields: + A generator of LlmResponse. + """ + + from ...agents.llm_agent import LlmAgent + + agent = invocation_context.agent + if not isinstance(agent, LlmAgent): + raise TypeError(f"Expected agent to be an LlmAgent, but got {type(agent)}") + + async def _run_on_model_error_callbacks( + *, + callback_context: CallbackContext, + llm_request: LlmRequest, + error: Exception, + ) -> Optional[LlmResponse]: + error_response = ( + await invocation_context.plugin_manager.run_on_model_error_callback( + callback_context=callback_context, + llm_request=llm_request, + error=error, + ) + ) + if error_response is not None: + return error_response + + for callback in agent.canonical_on_model_error_callbacks: + error_response = callback( + callback_context=callback_context, + llm_request=llm_request, + error=error, + ) + if inspect.isawaitable(error_response): + error_response = await error_response + if error_response is not None: + return error_response + + return None + + try: + async with Aclosing(response_generator) as agen: + async for response in agen: + yield response + except Exception as model_error: + callback_context = CallbackContext( + invocation_context, event_actions=model_response_event.actions + ) + error_response = await _run_on_model_error_callbacks( + callback_context=callback_context, + llm_request=llm_request, + error=model_error, + ) + if error_response is not None: + yield error_response + else: + raise model_error + + def __get_llm(self, invocation_context: InvocationContext) -> BaseLlm: + from ...agents.llm_agent import LlmAgent - return cast(LlmAgent, invocation_context.agent).canonical_model + return cast(LlmAgent, invocation_context.agent).canonical_model diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index ffe1657be1..890be12d40 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -42,102 +42,101 @@ from ...tools.tool_confirmation import ToolConfirmation from ...tools.tool_context import ToolContext from ...utils.context_utils import Aclosing +from ...utils.telemetry_utils import is_telemetry_enabled if TYPE_CHECKING: - from ...agents.llm_agent import LlmAgent + from ...agents.llm_agent import LlmAgent -AF_FUNCTION_CALL_ID_PREFIX = 'adk-' -REQUEST_EUC_FUNCTION_CALL_NAME = 'adk_request_credential' -REQUEST_CONFIRMATION_FUNCTION_CALL_NAME = 'adk_request_confirmation' +AF_FUNCTION_CALL_ID_PREFIX = "adk-" +REQUEST_EUC_FUNCTION_CALL_NAME = "adk_request_credential" +REQUEST_CONFIRMATION_FUNCTION_CALL_NAME = "adk_request_confirmation" -logger = logging.getLogger('google_adk.' + __name__) +logger = logging.getLogger("google_adk." + __name__) def generate_client_function_call_id() -> str: - return f'{AF_FUNCTION_CALL_ID_PREFIX}{uuid.uuid4()}' + return f"{AF_FUNCTION_CALL_ID_PREFIX}{uuid.uuid4()}" def populate_client_function_call_id(model_response_event: Event) -> None: - if not model_response_event.get_function_calls(): - return - for function_call in model_response_event.get_function_calls(): - if not function_call.id: - function_call.id = generate_client_function_call_id() + if not model_response_event.get_function_calls(): + return + for function_call in model_response_event.get_function_calls(): + if not function_call.id: + function_call.id = generate_client_function_call_id() def remove_client_function_call_id(content: Optional[types.Content]) -> None: - """Removes ADK-generated function call IDs from content before sending to LLM. - - Strips client-side function call/response IDs that start with 'adk-' prefix - to avoid sending internal tracking IDs to the model. - - Args: - content: Content containing function calls/responses to clean. - """ - if content and content.parts: - for part in content.parts: - if ( - part.function_call - and part.function_call.id - and part.function_call.id.startswith(AF_FUNCTION_CALL_ID_PREFIX) - ): - part.function_call.id = None - if ( - part.function_response - and part.function_response.id - and part.function_response.id.startswith(AF_FUNCTION_CALL_ID_PREFIX) - ): - part.function_response.id = None + """Removes ADK-generated function call IDs from content before sending to LLM. + + Strips client-side function call/response IDs that start with 'adk-' prefix + to avoid sending internal tracking IDs to the model. + + Args: + content: Content containing function calls/responses to clean. + """ + if content and content.parts: + for part in content.parts: + if ( + part.function_call + and part.function_call.id + and part.function_call.id.startswith(AF_FUNCTION_CALL_ID_PREFIX) + ): + part.function_call.id = None + if ( + part.function_response + and part.function_response.id + and part.function_response.id.startswith(AF_FUNCTION_CALL_ID_PREFIX) + ): + part.function_response.id = None def get_long_running_function_calls( function_calls: list[types.FunctionCall], tools_dict: dict[str, BaseTool], ) -> set[str]: - long_running_tool_ids = set() - for function_call in function_calls: - if ( - function_call.name in tools_dict - and tools_dict[function_call.name].is_long_running - ): - long_running_tool_ids.add(function_call.id) + long_running_tool_ids = set() + for function_call in function_calls: + if ( + function_call.name in tools_dict + and tools_dict[function_call.name].is_long_running + ): + long_running_tool_ids.add(function_call.id) - return long_running_tool_ids + return long_running_tool_ids def generate_auth_event( invocation_context: InvocationContext, function_response_event: Event, ) -> Optional[Event]: - if not function_response_event.actions.requested_auth_configs: - return None - parts = [] - long_running_tool_ids = set() - for ( - function_call_id, - auth_config, - ) in function_response_event.actions.requested_auth_configs.items(): - - request_euc_function_call = types.FunctionCall( - name=REQUEST_EUC_FUNCTION_CALL_NAME, - args=AuthToolArguments( - function_call_id=function_call_id, - auth_config=auth_config, - ).model_dump(exclude_none=True, by_alias=True), + if not function_response_event.actions.requested_auth_configs: + return None + parts = [] + long_running_tool_ids = set() + for ( + function_call_id, + auth_config, + ) in function_response_event.actions.requested_auth_configs.items(): + + request_euc_function_call = types.FunctionCall( + name=REQUEST_EUC_FUNCTION_CALL_NAME, + args=AuthToolArguments( + function_call_id=function_call_id, + auth_config=auth_config, + ).model_dump(exclude_none=True, by_alias=True), + ) + request_euc_function_call.id = generate_client_function_call_id() + long_running_tool_ids.add(request_euc_function_call.id) + parts.append(types.Part(function_call=request_euc_function_call)) + + return Event( + invocation_id=invocation_context.invocation_id, + author=invocation_context.agent.name, + branch=invocation_context.branch, + content=types.Content(parts=parts, role=function_response_event.content.role), + long_running_tool_ids=long_running_tool_ids, ) - request_euc_function_call.id = generate_client_function_call_id() - long_running_tool_ids.add(request_euc_function_call.id) - parts.append(types.Part(function_call=request_euc_function_call)) - - return Event( - invocation_id=invocation_context.invocation_id, - author=invocation_context.agent.name, - branch=invocation_context.branch, - content=types.Content( - parts=parts, role=function_response_event.content.role - ), - long_running_tool_ids=long_running_tool_ids, - ) def generate_request_confirmation_event( @@ -145,45 +144,43 @@ def generate_request_confirmation_event( function_call_event: Event, function_response_event: Event, ) -> Optional[Event]: - """Generates a request confirmation event from a function response event.""" - if not function_response_event.actions.requested_tool_confirmations: - return None - parts = [] - long_running_tool_ids = set() - function_calls = function_call_event.get_function_calls() - for ( - function_call_id, - tool_confirmation, - ) in function_response_event.actions.requested_tool_confirmations.items(): - original_function_call = next( - (fc for fc in function_calls if fc.id == function_call_id), None - ) - if not original_function_call: - continue - request_confirmation_function_call = types.FunctionCall( - name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME, - args={ - 'originalFunctionCall': original_function_call.model_dump( - exclude_none=True, by_alias=True - ), - 'toolConfirmation': tool_confirmation.model_dump( - by_alias=True, exclude_none=True - ), - }, + """Generates a request confirmation event from a function response event.""" + if not function_response_event.actions.requested_tool_confirmations: + return None + parts = [] + long_running_tool_ids = set() + function_calls = function_call_event.get_function_calls() + for ( + function_call_id, + tool_confirmation, + ) in function_response_event.actions.requested_tool_confirmations.items(): + original_function_call = next( + (fc for fc in function_calls if fc.id == function_call_id), None + ) + if not original_function_call: + continue + request_confirmation_function_call = types.FunctionCall( + name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME, + args={ + "originalFunctionCall": original_function_call.model_dump( + exclude_none=True, by_alias=True + ), + "toolConfirmation": tool_confirmation.model_dump( + by_alias=True, exclude_none=True + ), + }, + ) + request_confirmation_function_call.id = generate_client_function_call_id() + long_running_tool_ids.add(request_confirmation_function_call.id) + parts.append(types.Part(function_call=request_confirmation_function_call)) + + return Event( + invocation_id=invocation_context.invocation_id, + author=invocation_context.agent.name, + branch=invocation_context.branch, + content=types.Content(parts=parts, role=function_response_event.content.role), + long_running_tool_ids=long_running_tool_ids, ) - request_confirmation_function_call.id = generate_client_function_call_id() - long_running_tool_ids.add(request_confirmation_function_call.id) - parts.append(types.Part(function_call=request_confirmation_function_call)) - - return Event( - invocation_id=invocation_context.invocation_id, - author=invocation_context.agent.name, - branch=invocation_context.branch, - content=types.Content( - parts=parts, role=function_response_event.content.role - ), - long_running_tool_ids=long_running_tool_ids, - ) async def handle_function_calls_async( @@ -193,15 +190,15 @@ async def handle_function_calls_async( filters: Optional[set[str]] = None, tool_confirmation_dict: Optional[dict[str, ToolConfirmation]] = None, ) -> Optional[Event]: - """Calls the functions and returns the function response event.""" - function_calls = function_call_event.get_function_calls() - return await handle_function_call_list_async( - invocation_context, - function_calls, - tools_dict, - filters, - tool_confirmation_dict, - ) + """Calls the functions and returns the function response event.""" + function_calls = function_call_event.get_function_calls() + return await handle_function_call_list_async( + invocation_context, + function_calls, + tools_dict, + filters, + tool_confirmation_dict, + ) async def handle_function_call_list_async( @@ -211,60 +208,58 @@ async def handle_function_call_list_async( filters: Optional[set[str]] = None, tool_confirmation_dict: Optional[dict[str, ToolConfirmation]] = None, ) -> Optional[Event]: - """Calls the functions and returns the function response event.""" - from ...agents.llm_agent import LlmAgent + """Calls the functions and returns the function response event.""" + from ...agents.llm_agent import LlmAgent - agent = invocation_context.agent + agent = invocation_context.agent - # Filter function calls - filtered_calls = [ - fc for fc in function_calls if not filters or fc.id in filters - ] + # Filter function calls + filtered_calls = [fc for fc in function_calls if not filters or fc.id in filters] - if not filtered_calls: - return None + if not filtered_calls: + return None - # Create tasks for parallel execution - tasks = [ - asyncio.create_task( - _execute_single_function_call_async( - invocation_context, - function_call, - tools_dict, - agent, - tool_confirmation_dict[function_call.id] - if tool_confirmation_dict - else None, - ) - ) - for function_call in filtered_calls - ] - - # Wait for all tasks to complete - function_response_events = await asyncio.gather(*tasks) - - # Filter out None results - function_response_events = [ - event for event in function_response_events if event is not None - ] - - if not function_response_events: - return None + # Create tasks for parallel execution + tasks = [ + asyncio.create_task( + _execute_single_function_call_async( + invocation_context, + function_call, + tools_dict, + agent, + ( + tool_confirmation_dict[function_call.id] + if tool_confirmation_dict + else None + ), + ) + ) + for function_call in filtered_calls + ] + + # Wait for all tasks to complete + function_response_events = await asyncio.gather(*tasks) - merged_event = merge_parallel_function_response_events( - function_response_events - ) + # Filter out None results + function_response_events = [ + event for event in function_response_events if event is not None + ] - if len(function_response_events) > 1: - # this is needed for debug traces of parallel calls - # individual response with tool.name is traced in __build_response_event - # (we drop tool.name from span name here as this is merged event) - with tracer.start_as_current_span('execute_tool (merged)'): - trace_merged_tool_calls( - response_event_id=merged_event.id, - function_response_event=merged_event, - ) - return merged_event + if not function_response_events: + return None + + merged_event = merge_parallel_function_response_events(function_response_events) + + if len(function_response_events) > 1 and is_telemetry_enabled(agent): + # this is needed for debug traces of parallel calls + # individual response with tool.name is traced in __build_response_event + # (we drop tool.name from span name here as this is merged event) + with tracer.start_as_current_span("execute_tool (merged)"): + trace_merged_tool_calls( + response_event_id=merged_event.id, + function_response_event=merged_event, + ) + return merged_event async def _execute_single_function_call_async( @@ -274,99 +269,54 @@ async def _execute_single_function_call_async( agent: LlmAgent, tool_confirmation: Optional[ToolConfirmation] = None, ) -> Optional[Event]: - """Execute a single function call with thread safety for state modifications.""" - - async def _run_on_tool_error_callbacks( - *, - tool: BaseTool, - tool_args: dict[str, Any], - tool_context: ToolContext, - error: Exception, - ) -> Optional[dict[str, Any]]: - """Runs the on_tool_error_callbacks for the given tool.""" - error_response = ( - await invocation_context.plugin_manager.run_on_tool_error_callback( - tool=tool, - tool_args=tool_args, - tool_context=tool_context, - error=error, + """Execute a single function call with thread safety for state modifications.""" + + async def _run_on_tool_error_callbacks( + *, + tool: BaseTool, + tool_args: dict[str, Any], + tool_context: ToolContext, + error: Exception, + ) -> Optional[dict[str, Any]]: + """Runs the on_tool_error_callbacks for the given tool.""" + error_response = ( + await invocation_context.plugin_manager.run_on_tool_error_callback( + tool=tool, + tool_args=tool_args, + tool_context=tool_context, + error=error, + ) ) - ) - if error_response is not None: - return error_response - - for callback in agent.canonical_on_tool_error_callbacks: - error_response = callback( - tool=tool, - args=tool_args, - tool_context=tool_context, - error=error, - ) - if inspect.isawaitable(error_response): - error_response = await error_response - if error_response is not None: - return error_response - - return None - - # Do not use "args" as the variable name, because it is a reserved keyword - # in python debugger. - # Make a deep copy to avoid being modified. - function_args = ( - copy.deepcopy(function_call.args) if function_call.args else {} - ) + if error_response is not None: + return error_response - tool_context = _create_tool_context( - invocation_context, function_call, tool_confirmation - ) + for callback in agent.canonical_on_tool_error_callbacks: + error_response = callback( + tool=tool, + args=tool_args, + tool_context=tool_context, + error=error, + ) + if inspect.isawaitable(error_response): + error_response = await error_response + if error_response is not None: + return error_response - try: - tool = _get_tool(function_call, tools_dict) - except ValueError as tool_error: - tool = BaseTool(name=function_call.name, description='Tool not found') - error_response = await _run_on_tool_error_callbacks( - tool=tool, - tool_args=function_args, - tool_context=tool_context, - error=tool_error, - ) - if error_response is not None: - return __build_response_event( - tool, error_response, tool_context, invocation_context - ) - else: - raise tool_error + return None - async def _run_with_trace(): - nonlocal function_args + # Do not use "args" as the variable name, because it is a reserved keyword + # in python debugger. + # Make a deep copy to avoid being modified. + function_args = copy.deepcopy(function_call.args) if function_call.args else {} - # Step 1: Check if plugin before_tool_callback overrides the function - # response. - function_response = ( - await invocation_context.plugin_manager.run_before_tool_callback( - tool=tool, tool_args=function_args, tool_context=tool_context - ) + tool_context = _create_tool_context( + invocation_context, function_call, tool_confirmation ) - # Step 2: If no overrides are provided from the plugins, further run the - # canonical callback. - if function_response is None: - for callback in agent.canonical_before_tool_callbacks: - function_response = callback( - tool=tool, args=function_args, tool_context=tool_context - ) - if inspect.isawaitable(function_response): - function_response = await function_response - if function_response: - break - - # Step 3: Otherwise, proceed calling the tool normally. - if function_response is None: - try: - function_response = await __call_tool_async( - tool, args=function_args, tool_context=tool_context - ) - except Exception as tool_error: + try: + tool = _get_tool(function_call, tools_dict) + except ValueError as tool_error: + tool = BaseTool(name=function_call.name, description="Tool not found") error_response = await _run_on_tool_error_callbacks( tool=tool, tool_args=function_args, @@ -374,71 +324,117 @@ async def _run_with_trace(): error=tool_error, ) if error_response is not None: - function_response = error_response + return __build_response_event( + tool, error_response, tool_context, invocation_context + ) else: - raise tool_error + raise tool_error - # Step 4: Check if plugin after_tool_callback overrides the function - # response. - altered_function_response = ( - await invocation_context.plugin_manager.run_after_tool_callback( - tool=tool, - tool_args=function_args, - tool_context=tool_context, - result=function_response, - ) - ) + async def _run_with_trace(): + nonlocal function_args - # Step 5: If no overrides are provided from the plugins, further run the - # canonical after_tool_callbacks. - if altered_function_response is None: - for callback in agent.canonical_after_tool_callbacks: - altered_function_response = callback( - tool=tool, - args=function_args, - tool_context=tool_context, - tool_response=function_response, + # Step 1: Check if plugin before_tool_callback overrides the function + # response. + function_response = ( + await invocation_context.plugin_manager.run_before_tool_callback( + tool=tool, tool_args=function_args, tool_context=tool_context + ) ) - if inspect.isawaitable(altered_function_response): - altered_function_response = await altered_function_response - if altered_function_response: - break - - # Step 6: If alternative response exists from after_tool_callback, use it - # instead of the original function response. - if altered_function_response is not None: - function_response = altered_function_response - - if tool.is_long_running: - # Allow long running function to return None to not provide function - # response. - if not function_response: - return None - - # Note: State deltas are not applied here - they are collected in - # tool_context.actions.state_delta and applied later when the session - # service processes the events - # Builds the function response event. - function_response_event = __build_response_event( - tool, function_response, tool_context, invocation_context - ) - return function_response_event + # Step 2: If no overrides are provided from the plugins, further run the + # canonical callback. + if function_response is None: + for callback in agent.canonical_before_tool_callbacks: + function_response = callback( + tool=tool, args=function_args, tool_context=tool_context + ) + if inspect.isawaitable(function_response): + function_response = await function_response + if function_response: + break + + # Step 3: Otherwise, proceed calling the tool normally. + if function_response is None: + try: + function_response = await __call_tool_async( + tool, args=function_args, tool_context=tool_context + ) + except Exception as tool_error: + error_response = await _run_on_tool_error_callbacks( + tool=tool, + tool_args=function_args, + tool_context=tool_context, + error=tool_error, + ) + if error_response is not None: + function_response = error_response + else: + raise tool_error + + # Step 4: Check if plugin after_tool_callback overrides the function + # response. + altered_function_response = ( + await invocation_context.plugin_manager.run_after_tool_callback( + tool=tool, + tool_args=function_args, + tool_context=tool_context, + result=function_response, + ) + ) - with tracer.start_as_current_span(f'execute_tool {tool.name}'): - try: - function_response_event = await _run_with_trace() - trace_tool_call( - tool=tool, - args=function_args, - function_response_event=function_response_event, - ) - return function_response_event - except: - trace_tool_call( - tool=tool, args=function_args, function_response_event=None - ) - raise + # Step 5: If no overrides are provided from the plugins, further run the + # canonical after_tool_callbacks. + if altered_function_response is None: + for callback in agent.canonical_after_tool_callbacks: + altered_function_response = callback( + tool=tool, + args=function_args, + tool_context=tool_context, + tool_response=function_response, + ) + if inspect.isawaitable(altered_function_response): + altered_function_response = await altered_function_response + if altered_function_response: + break + + # Step 6: If alternative response exists from after_tool_callback, use it + # instead of the original function response. + if altered_function_response is not None: + function_response = altered_function_response + + if tool.is_long_running: + # Allow long running function to return None to not provide function + # response. + if not function_response: + return None + + # Note: State deltas are not applied here - they are collected in + # tool_context.actions.state_delta and applied later when the session + # service processes the events + + # Builds the function response event. + function_response_event = __build_response_event( + tool, function_response, tool_context, invocation_context + ) + return function_response_event + + if is_telemetry_enabled(agent): + with tracer.start_as_current_span(f"execute_tool {tool.name}"): + try: + function_response_event = await _run_with_trace() + trace_tool_call( + tool=tool, + args=function_args, + function_response_event=function_response_event, + ) + return function_response_event + except: + trace_tool_call( + tool=tool, args=function_args, function_response_event=None + ) + raise + else: + return await _run_with_trace() async def handle_function_calls_live( @@ -446,56 +442,55 @@ async def handle_function_calls_live( function_call_event: Event, tools_dict: dict[str, BaseTool], ) -> Event: - """Calls the functions and returns the function response event.""" - from ...agents.llm_agent import LlmAgent + """Calls the functions and returns the function response event.""" + from ...agents.llm_agent import LlmAgent - agent = cast(LlmAgent, invocation_context.agent) - function_calls = function_call_event.get_function_calls() + agent = cast(LlmAgent, invocation_context.agent) + function_calls = function_call_event.get_function_calls() - if not function_calls: - return None + if not function_calls: + return None - # Create async lock for active_streaming_tools modifications - streaming_lock = asyncio.Lock() - - # Create tasks for parallel execution - tasks = [ - asyncio.create_task( - _execute_single_function_call_live( - invocation_context, - function_call, - tools_dict, - agent, - streaming_lock, - ) - ) - for function_call in function_calls - ] - - # Wait for all tasks to complete - function_response_events = await asyncio.gather(*tasks) - - # Filter out None results - function_response_events = [ - event for event in function_response_events if event is not None - ] - - if not function_response_events: - return None + # Create async lock for active_streaming_tools modifications + streaming_lock = asyncio.Lock() + + # Create tasks for parallel execution + tasks = [ + asyncio.create_task( + _execute_single_function_call_live( + invocation_context, + function_call, + tools_dict, + agent, + streaming_lock, + ) + ) + for function_call in function_calls + ] - merged_event = merge_parallel_function_response_events( - function_response_events - ) - if len(function_response_events) > 1: - # this is needed for debug traces of parallel calls - # individual response with tool.name is traced in __build_response_event - # (we drop tool.name from span name here as this is merged event) - with tracer.start_as_current_span('execute_tool (merged)'): - trace_merged_tool_calls( - response_event_id=merged_event.id, - function_response_event=merged_event, - ) - return merged_event + # Wait for all tasks to complete + function_response_events = await asyncio.gather(*tasks) + + # Filter out None results + function_response_events = [ + event for event in function_response_events if event is not None + ] + + if not function_response_events: + return None + + merged_event = merge_parallel_function_response_events(function_response_events) + + if len(function_response_events) > 1 and is_telemetry_enabled(agent): + # this is needed for debug traces of parallel calls + # individual response with tool.name is traced in __build_response_event + # (we drop tool.name from span name here as this is merged event) + with tracer.start_as_current_span("execute_tool (merged)"): + trace_merged_tool_calls( + response_event_id=merged_event.id, + function_response_event=merged_event, + ) + return merged_event async def _execute_single_function_call_live( @@ -505,90 +500,91 @@ async def _execute_single_function_call_live( agent: LlmAgent, streaming_lock: asyncio.Lock, ) -> Optional[Event]: - """Execute a single function call for live mode with thread safety.""" - tool, tool_context = _get_tool_and_context( - invocation_context, function_call, tools_dict - ) + """Execute a single function call for live mode with thread safety.""" + tool, tool_context = _get_tool_and_context( + invocation_context, function_call, tools_dict + ) - function_args = ( - copy.deepcopy(function_call.args) if function_call.args else {} - ) + function_args = copy.deepcopy(function_call.args) if function_call.args else {} - async def _run_with_trace(): - nonlocal function_args + async def _run_with_trace(): + nonlocal function_args - # Do not use "args" as the variable name, because it is a reserved keyword - # in python debugger. - # Make a deep copy to avoid being modified. - function_response = None + # Do not use "args" as the variable name, because it is a reserved keyword + # in python debugger. + # Make a deep copy to avoid being modified. + function_response = None - # Handle before_tool_callbacks - iterate through the canonical callback - # list - for callback in agent.canonical_before_tool_callbacks: - function_response = callback( - tool=tool, args=function_args, tool_context=tool_context - ) - if inspect.isawaitable(function_response): - function_response = await function_response - if function_response: - break - - if function_response is None: - function_response = await _process_function_live_helper( - tool, - tool_context, - function_call, - function_args, - invocation_context, - streaming_lock, - ) - - # Calls after_tool_callback if it exists. - altered_function_response = None - for callback in agent.canonical_after_tool_callbacks: - altered_function_response = callback( - tool=tool, - args=function_args, - tool_context=tool_context, - tool_response=function_response, - ) - if inspect.isawaitable(altered_function_response): - altered_function_response = await altered_function_response - if altered_function_response: - break - - if altered_function_response is not None: - function_response = altered_function_response - - if tool.is_long_running: - # Allow async function to return None to not provide function response. - if not function_response: - return None - - # Note: State deltas are not applied here - they are collected in - # tool_context.actions.state_delta and applied later when the session - # service processes the events - - # Builds the function response event. - function_response_event = __build_response_event( - tool, function_response, tool_context, invocation_context - ) - return function_response_event + # Handle before_tool_callbacks - iterate through the canonical callback + # list + for callback in agent.canonical_before_tool_callbacks: + function_response = callback( + tool=tool, args=function_args, tool_context=tool_context + ) + if inspect.isawaitable(function_response): + function_response = await function_response + if function_response: + break + + if function_response is None: + function_response = await _process_function_live_helper( + tool, + tool_context, + function_call, + function_args, + invocation_context, + streaming_lock, + ) - with tracer.start_as_current_span(f'execute_tool {tool.name}'): - try: - function_response_event = await _run_with_trace() - trace_tool_call( - tool=tool, - args=function_args, - function_response_event=function_response_event, - ) - return function_response_event - except: - trace_tool_call( - tool=tool, args=function_args, function_response_event=None - ) - raise + # Calls after_tool_callback if it exists. + altered_function_response = None + for callback in agent.canonical_after_tool_callbacks: + altered_function_response = callback( + tool=tool, + args=function_args, + tool_context=tool_context, + tool_response=function_response, + ) + if inspect.isawaitable(altered_function_response): + altered_function_response = await altered_function_response + if altered_function_response: + break + + if altered_function_response is not None: + function_response = altered_function_response + + if tool.is_long_running: + # Allow async function to return None to not provide function response. + if not function_response: + return None + + # Note: State deltas are not applied here - they are collected in + # tool_context.actions.state_delta and applied later when the session + # service processes the events + + # Builds the function response event. + function_response_event = __build_response_event( + tool, function_response, tool_context, invocation_context + ) + return function_response_event + + if is_telemetry_enabled(agent): + with tracer.start_as_current_span(f"execute_tool {tool.name}"): + try: + function_response_event = await _run_with_trace() + trace_tool_call( + tool=tool, + args=function_args, + function_response_event=function_response_event, + ) + return function_response_event + except: + trace_tool_call( + tool=tool, args=function_args, function_response_event=None + ) + raise + else: + return await _run_with_trace() async def _process_function_live_helper( @@ -599,136 +595,134 @@ async def _process_function_live_helper( invocation_context, streaming_lock: asyncio.Lock, ): - function_response = None - # Check if this is a stop_streaming function call - if ( - function_call.name == 'stop_streaming' - and 'function_name' in function_args - ): - function_name = function_args['function_name'] - # Thread-safe access to active_streaming_tools - async with streaming_lock: - active_tasks = invocation_context.active_streaming_tools - if ( - active_tasks - and function_name in active_tasks - and active_tasks[function_name].task - and not active_tasks[function_name].task.done() - ): - task = active_tasks[function_name].task - else: - task = None - - if task: - task.cancel() - try: - # Wait for the task to be cancelled - await asyncio.wait_for(task, timeout=1.0) - except (asyncio.CancelledError, asyncio.TimeoutError): - # Log the specific condition - if task.cancelled(): - logging.info('Task %s was cancelled successfully', function_name) - elif task.done(): - logging.info('Task %s completed during cancellation', function_name) - else: - logging.warning( - 'Task %s might still be running after cancellation timeout', - function_name, - ) - function_response = { - 'status': f'The task is not cancelled yet for {function_name}.' - } - if not function_response: - # Clean up the reference under lock + function_response = None + # Check if this is a stop_streaming function call + if function_call.name == "stop_streaming" and "function_name" in function_args: + function_name = function_args["function_name"] + # Thread-safe access to active_streaming_tools async with streaming_lock: - if ( - invocation_context.active_streaming_tools - and function_name in invocation_context.active_streaming_tools - ): - invocation_context.active_streaming_tools[function_name].task = None + active_tasks = invocation_context.active_streaming_tools + if ( + active_tasks + and function_name in active_tasks + and active_tasks[function_name].task + and not active_tasks[function_name].task.done() + ): + task = active_tasks[function_name].task + else: + task = None + + if task: + task.cancel() + try: + # Wait for the task to be cancelled + await asyncio.wait_for(task, timeout=1.0) + except (asyncio.CancelledError, asyncio.TimeoutError): + # Log the specific condition + if task.cancelled(): + logging.info("Task %s was cancelled successfully", function_name) + elif task.done(): + logging.info("Task %s completed during cancellation", function_name) + else: + logging.warning( + "Task %s might still be running after cancellation timeout", + function_name, + ) + function_response = { + "status": f"The task is not cancelled yet for {function_name}." + } + if not function_response: + # Clean up the reference under lock + async with streaming_lock: + if ( + invocation_context.active_streaming_tools + and function_name in invocation_context.active_streaming_tools + ): + invocation_context.active_streaming_tools[ + function_name + ].task = None + + function_response = { + "status": f"Successfully stopped streaming function {function_name}" + } + else: + function_response = { + "status": f"No active streaming function named {function_name} found" + } + elif hasattr(tool, "func") and inspect.isasyncgenfunction(tool.func): + # for streaming tool use case + # we require the function to be an async generator function + async def run_tool_and_update_queue(tool, function_args, tool_context): + try: + async with Aclosing( + __call_tool_live( + tool=tool, + args=function_args, + tool_context=tool_context, + invocation_context=invocation_context, + ) + ) as agen: + async for result in agen: + updated_content = types.Content( + role="user", + parts=[ + types.Part.from_text( + text=f"Function {tool.name} returned: {result}" + ) + ], + ) + invocation_context.live_request_queue.send_content( + updated_content + ) + except asyncio.CancelledError: + raise # Re-raise to properly propagate the cancellation + + task = asyncio.create_task( + run_tool_and_update_queue(tool, function_args, tool_context) + ) + # Register streaming tool using original logic + async with streaming_lock: + if invocation_context.active_streaming_tools is None: + invocation_context.active_streaming_tools = {} + + if tool.name in invocation_context.active_streaming_tools: + invocation_context.active_streaming_tools[tool.name].task = task + else: + invocation_context.active_streaming_tools[tool.name] = ( + ActiveStreamingTool(task=task) + ) + + # Immediately return a pending response. + # This is required by current live model. function_response = { - 'status': f'Successfully stopped streaming function {function_name}' + "status": ( + "The function is running asynchronously and the results are" " pending." + ) } else: - function_response = { - 'status': f'No active streaming function named {function_name} found' - } - elif hasattr(tool, 'func') and inspect.isasyncgenfunction(tool.func): - # for streaming tool use case - # we require the function to be an async generator function - async def run_tool_and_update_queue(tool, function_args, tool_context): - try: - async with Aclosing( - __call_tool_live( - tool=tool, - args=function_args, - tool_context=tool_context, - invocation_context=invocation_context, - ) - ) as agen: - async for result in agen: - updated_content = types.Content( - role='user', - parts=[ - types.Part.from_text( - text=f'Function {tool.name} returned: {result}' - ) - ], - ) - invocation_context.live_request_queue.send_content(updated_content) - except asyncio.CancelledError: - raise # Re-raise to properly propagate the cancellation - - task = asyncio.create_task( - run_tool_and_update_queue(tool, function_args, tool_context) - ) - - # Register streaming tool using original logic - async with streaming_lock: - if invocation_context.active_streaming_tools is None: - invocation_context.active_streaming_tools = {} - - if tool.name in invocation_context.active_streaming_tools: - invocation_context.active_streaming_tools[tool.name].task = task - else: - invocation_context.active_streaming_tools[tool.name] = ( - ActiveStreamingTool(task=task) + function_response = await __call_tool_async( + tool, args=function_args, tool_context=tool_context ) - - # Immediately return a pending response. - # This is required by current live model. - function_response = { - 'status': ( - 'The function is running asynchronously and the results are' - ' pending.' + return function_response + + +def _get_tool(function_call: types.FunctionCall, tools_dict: dict[str, BaseTool]): + """Returns the tool corresponding to the function call.""" + if function_call.name not in tools_dict: + available = list(tools_dict.keys()) + error_msg = ( + f"Tool '{function_call.name}' not found.\nAvailable tools:" + f" {', '.join(available)}\n\nPossible causes:\n 1. LLM hallucinated" + " the function name - review agent instruction clarity\n 2. Tool not" + " registered - verify agent.tools list\n 3. Name mismatch - check for" + " typos\n\nSuggested fixes:\n - Review agent instruction to ensure" + " tool usage is clear\n - Verify tool is included in agent.tools" + " list\n - Check for typos in function name" ) - } - else: - function_response = await __call_tool_async( - tool, args=function_args, tool_context=tool_context - ) - return function_response + raise ValueError(error_msg) - -def _get_tool( - function_call: types.FunctionCall, tools_dict: dict[str, BaseTool] -): - """Returns the tool corresponding to the function call.""" - if function_call.name not in tools_dict: - available = list(tools_dict.keys()) - error_msg = ( - f"Tool '{function_call.name}' not found.\nAvailable tools:" - f" {', '.join(available)}\n\nPossible causes:\n 1. LLM hallucinated" - ' the function name - review agent instruction clarity\n 2. Tool not' - ' registered - verify agent.tools list\n 3. Name mismatch - check for' - ' typos\n\nSuggested fixes:\n - Review agent instruction to ensure' - ' tool usage is clear\n - Verify tool is included in agent.tools' - ' list\n - Check for typos in function name' - ) - raise ValueError(error_msg) - - return tools_dict[function_call.name] + return tools_dict[function_call.name] def _create_tool_context( @@ -736,12 +730,12 @@ def _create_tool_context( function_call: types.FunctionCall, tool_confirmation: Optional[ToolConfirmation] = None, ): - """Creates a ToolContext object.""" - return ToolContext( - invocation_context=invocation_context, - function_call_id=function_call.id, - tool_confirmation=tool_confirmation, - ) + """Creates a ToolContext object.""" + return ToolContext( + invocation_context=invocation_context, + function_call_id=function_call.id, + tool_confirmation=tool_confirmation, + ) def _get_tool_and_context( @@ -750,15 +744,15 @@ def _get_tool_and_context( tools_dict: dict[str, BaseTool], tool_confirmation: Optional[ToolConfirmation] = None, ): - """Returns the tool and tool context corresponding to the function call.""" - tool = _get_tool(function_call, tools_dict) - tool_context = _create_tool_context( - invocation_context, - function_call, - tool_confirmation, - ) + """Returns the tool and tool context corresponding to the function call.""" + tool = _get_tool(function_call, tools_dict) + tool_context = _create_tool_context( + invocation_context, + function_call, + tool_confirmation, + ) - return (tool, tool_context) + return (tool, tool_context) async def __call_tool_live( @@ -767,16 +761,16 @@ async def __call_tool_live( tool_context: ToolContext, invocation_context: InvocationContext, ) -> AsyncGenerator[Event, None]: - """Calls the tool asynchronously (awaiting the coroutine).""" - async with Aclosing( - tool._call_live( - args=args, - tool_context=tool_context, - invocation_context=invocation_context, - ) - ) as agen: - async for item in agen: - yield item + """Calls the tool asynchronously (awaiting the coroutine).""" + async with Aclosing( + tool._call_live( + args=args, + tool_context=tool_context, + invocation_context=invocation_context, + ) + ) as agen: + async for item in agen: + yield item async def __call_tool_async( @@ -784,8 +778,8 @@ async def __call_tool_async( args: dict[str, Any], tool_context: ToolContext, ) -> Any: - """Calls the tool.""" - return await tool.run_async(args=args, tool_context=tool_context) + """Calls the tool.""" + return await tool.run_async(args=args, tool_context=tool_context) def __build_response_event( @@ -794,111 +788,111 @@ def __build_response_event( tool_context: ToolContext, invocation_context: InvocationContext, ) -> Event: - # Specs requires the result to be a dict. - if not isinstance(function_result, dict): - function_result = {'result': function_result} + # Specs requires the result to be a dict. + if not isinstance(function_result, dict): + function_result = {"result": function_result} - part_function_response = types.Part.from_function_response( - name=tool.name, response=function_result - ) - part_function_response.function_response.id = tool_context.function_call_id + part_function_response = types.Part.from_function_response( + name=tool.name, response=function_result + ) + part_function_response.function_response.id = tool_context.function_call_id - content = types.Content( - role='user', - parts=[part_function_response], - ) + content = types.Content( + role="user", + parts=[part_function_response], + ) - function_response_event = Event( - invocation_id=invocation_context.invocation_id, - author=invocation_context.agent.name, - content=content, - actions=tool_context.actions, - branch=invocation_context.branch, - ) + function_response_event = Event( + invocation_id=invocation_context.invocation_id, + author=invocation_context.agent.name, + content=content, + actions=tool_context.actions, + branch=invocation_context.branch, + ) - return function_response_event + return function_response_event def deep_merge_dicts(d1: dict, d2: dict) -> dict: - """Recursively merges d2 into d1.""" - for key, value in d2.items(): - if key in d1 and isinstance(d1[key], dict) and isinstance(value, dict): - d1[key] = deep_merge_dicts(d1[key], value) - else: - d1[key] = value - return d1 + """Recursively merges d2 into d1.""" + for key, value in d2.items(): + if key in d1 and isinstance(d1[key], dict) and isinstance(value, dict): + d1[key] = deep_merge_dicts(d1[key], value) + else: + d1[key] = value + return d1 def merge_parallel_function_response_events( - function_response_events: list['Event'], -) -> 'Event': - if not function_response_events: - raise ValueError('No function response events provided.') - - if len(function_response_events) == 1: - return function_response_events[0] - merged_parts = [] - for event in function_response_events: - if event.content: - for part in event.content.parts or []: - merged_parts.append(part) - - # Use the first event as the "base" for common attributes - base_event = function_response_events[0] - - # Merge actions from all events - merged_actions_data: dict[str, Any] = {} - for event in function_response_events: - if event.actions: - # Use `by_alias=True` because it converts the model to a dictionary while respecting field aliases, ensuring that the enum fields are correctly handled without creating a duplicate. - merged_actions_data = deep_merge_dicts( - merged_actions_data, - event.actions.model_dump(exclude_none=True, by_alias=True), - ) - - merged_actions = EventActions.model_validate(merged_actions_data) - - # Create the new merged event - merged_event = Event( - invocation_id=base_event.invocation_id, - author=base_event.author, - branch=base_event.branch, - content=types.Content(role='user', parts=merged_parts), - actions=merged_actions, # Optionally merge actions if required - ) - - # Use the base_event as the timestamp - merged_event.timestamp = base_event.timestamp - return merged_event + function_response_events: list["Event"], +) -> "Event": + if not function_response_events: + raise ValueError("No function response events provided.") + + if len(function_response_events) == 1: + return function_response_events[0] + merged_parts = [] + for event in function_response_events: + if event.content: + for part in event.content.parts or []: + merged_parts.append(part) + + # Use the first event as the "base" for common attributes + base_event = function_response_events[0] + + # Merge actions from all events + merged_actions_data: dict[str, Any] = {} + for event in function_response_events: + if event.actions: + # Use `by_alias=True` because it converts the model to a dictionary while respecting field aliases, ensuring that the enum fields are correctly handled without creating a duplicate. + merged_actions_data = deep_merge_dicts( + merged_actions_data, + event.actions.model_dump(exclude_none=True, by_alias=True), + ) + + merged_actions = EventActions.model_validate(merged_actions_data) + + # Create the new merged event + merged_event = Event( + invocation_id=base_event.invocation_id, + author=base_event.author, + branch=base_event.branch, + content=types.Content(role="user", parts=merged_parts), + actions=merged_actions, # Optionally merge actions if required + ) + + # Use the base_event as the timestamp + merged_event.timestamp = base_event.timestamp + return merged_event def find_matching_function_call( events: list[Event], ) -> Optional[Event]: - """Finds the function call event that matches the function response id of the last event.""" - if not events: - return None + """Finds the function call event that matches the function response id of the last event.""" + if not events: + return None - last_event = events[-1] - if ( - last_event.content - and last_event.content.parts - and any(part.function_response for part in last_event.content.parts) - ): - - function_call_id = next( - part.function_response.id - for part in last_event.content.parts - if part.function_response - ) - for i in range(len(events) - 2, -1, -1): - event = events[i] - # looking for the system long running request euc function call - function_calls = event.get_function_calls() - if not function_calls: - continue - - for function_call in function_calls: - if function_call.id == function_call_id: - return event - return None + last_event = events[-1] + if ( + last_event.content + and last_event.content.parts + and any(part.function_response for part in last_event.content.parts) + ): + + function_call_id = next( + part.function_response.id + for part in last_event.content.parts + if part.function_response + ) + for i in range(len(events) - 2, -1, -1): + event = events[i] + # looking for the system long running request euc function call + function_calls = event.get_function_calls() + if not function_calls: + continue + + for function_call in function_calls: + if function_call.id == function_call_id: + return event + return None diff --git a/src/google/adk/models/gemini_context_cache_manager.py b/src/google/adk/models/gemini_context_cache_manager.py index cd842cf494..1606ea11f4 100644 --- a/src/google/adk/models/gemini_context_cache_manager.py +++ b/src/google/adk/models/gemini_context_cache_manager.py @@ -19,6 +19,7 @@ import hashlib import json import logging +from opentelemetry.trace import Span import time from typing import Optional from typing import TYPE_CHECKING @@ -33,436 +34,461 @@ logger = logging.getLogger("google_adk." + __name__) if TYPE_CHECKING: - from google.genai import Client + from google.genai import Client @experimental class GeminiContextCacheManager: - """Manages context cache lifecycle for Gemini models. + """Manages context cache lifecycle for Gemini models. - This manager handles cache creation, validation, cleanup, and metadata - population for Gemini context caching. It uses content hashing to determine - cache compatibility and implements efficient caching strategies. - """ - - def __init__(self, genai_client: Client): - """Initialize cache manager with shared client. - - Args: - genai_client: The GenAI client to use for cache operations. - """ - self.genai_client = genai_client - - async def handle_context_caching( - self, llm_request: LlmRequest - ) -> Optional[CacheMetadata]: - """Handle context caching for Gemini models. - - Validates existing cache or creates a new one if needed. Applies - the cache to the request by setting cached_content and removing cached - contents from the request. - - Args: - llm_request: Request that may contain cache config and metadata. - Modified in-place to use the cache. - - Returns: - Cache metadata to be included in response, or None if caching failed + This manager handles cache creation, validation, cleanup, and metadata + population for Gemini context caching. It uses content hashing to determine + cache compatibility and implements efficient caching strategies. """ - # Check if we have existing cache metadata and if it's valid - if llm_request.cache_metadata: - logger.debug( - "Found existing cache metadata: %s", - llm_request.cache_metadata, - ) - if await self._is_cache_valid(llm_request): - # Valid cache found - use it - logger.debug( - "Cache is valid, reusing cache: %s", - llm_request.cache_metadata.cache_name, - ) - cache_name = llm_request.cache_metadata.cache_name - cache_contents_count = llm_request.cache_metadata.contents_count - self._apply_cache_to_request( - llm_request, cache_name, cache_contents_count - ) - return llm_request.cache_metadata.model_copy() - else: - # Invalid cache - clean it up and check if we should create new one - old_cache_metadata = llm_request.cache_metadata - - # Only cleanup if there's an active cache - if old_cache_metadata.cache_name is not None: - logger.debug( - "Cache is invalid, cleaning up: %s", - old_cache_metadata.cache_name, - ) - await self.cleanup_cache(old_cache_metadata.cache_name) - - # Calculate current fingerprint using contents count from old metadata - cache_contents_count = old_cache_metadata.contents_count - current_fingerprint = self._generate_cache_fingerprint( - llm_request, cache_contents_count - ) - # If fingerprints match, create new cache (expired but same content) - if current_fingerprint == old_cache_metadata.fingerprint: - logger.debug( - "Fingerprints match after invalidation, creating new cache" - ) - cache_metadata = await self._create_new_cache_with_contents( - llm_request, cache_contents_count - ) - if cache_metadata: - self._apply_cache_to_request( - llm_request, cache_metadata.cache_name, cache_contents_count + def __init__(self, genai_client: Client, disable_telemetry: bool = False): + """Initialize cache manager with shared client. + + Args: + genai_client: The GenAI client to use for cache operations. + disable_telemetry: A bool to flag whether or not telemetry should be disabled. + """ + self.genai_client = genai_client + self.disable_telemetry = disable_telemetry + + async def handle_context_caching( + self, llm_request: LlmRequest + ) -> Optional[CacheMetadata]: + """Handle context caching for Gemini models. + + Validates existing cache or creates a new one if needed. Applies + the cache to the request by setting cached_content and removing cached + contents from the request. + + Args: + llm_request: Request that may contain cache config and metadata. + Modified in-place to use the cache. + + Returns: + Cache metadata to be included in response, or None if caching failed + """ + # Check if we have existing cache metadata and if it's valid + if llm_request.cache_metadata: + logger.debug( + "Found existing cache metadata: %s", + llm_request.cache_metadata, ) - return cache_metadata - - # Fingerprints don't match - recalculate with total contents - logger.debug( - "Fingerprints don't match, returning fingerprint-only metadata" - ) + if await self._is_cache_valid(llm_request): + # Valid cache found - use it + logger.debug( + "Cache is valid, reusing cache: %s", + llm_request.cache_metadata.cache_name, + ) + cache_name = llm_request.cache_metadata.cache_name + cache_contents_count = llm_request.cache_metadata.contents_count + self._apply_cache_to_request( + llm_request, cache_name, cache_contents_count + ) + return llm_request.cache_metadata.model_copy() + else: + # Invalid cache - clean it up and check if we should create new one + old_cache_metadata = llm_request.cache_metadata + + # Only cleanup if there's an active cache + if old_cache_metadata.cache_name is not None: + logger.debug( + "Cache is invalid, cleaning up: %s", + old_cache_metadata.cache_name, + ) + await self.cleanup_cache(old_cache_metadata.cache_name) + + # Calculate current fingerprint using contents count from old metadata + cache_contents_count = old_cache_metadata.contents_count + current_fingerprint = self._generate_cache_fingerprint( + llm_request, cache_contents_count + ) + + # If fingerprints match, create new cache (expired but same content) + if current_fingerprint == old_cache_metadata.fingerprint: + logger.debug( + "Fingerprints match after invalidation, creating new cache" + ) + cache_metadata = await self._create_new_cache_with_contents( + llm_request, cache_contents_count + ) + if cache_metadata: + self._apply_cache_to_request( + llm_request, cache_metadata.cache_name, cache_contents_count + ) + return cache_metadata + + # Fingerprints don't match - recalculate with total contents + logger.debug( + "Fingerprints don't match, returning fingerprint-only metadata" + ) + total_contents_count = len(llm_request.contents) + fingerprint_for_all = self._generate_cache_fingerprint( + llm_request, total_contents_count + ) + return CacheMetadata( + fingerprint=fingerprint_for_all, + contents_count=total_contents_count, + ) + + # No existing cache metadata - return fingerprint-only metadata + # We don't create cache without previous fingerprint to match + logger.debug("No existing cache metadata, creating fingerprint-only metadata") total_contents_count = len(llm_request.contents) - fingerprint_for_all = self._generate_cache_fingerprint( + fingerprint = self._generate_cache_fingerprint( llm_request, total_contents_count ) return CacheMetadata( - fingerprint=fingerprint_for_all, + fingerprint=fingerprint, contents_count=total_contents_count, ) - # No existing cache metadata - return fingerprint-only metadata - # We don't create cache without previous fingerprint to match - logger.debug( - "No existing cache metadata, creating fingerprint-only metadata" - ) - total_contents_count = len(llm_request.contents) - fingerprint = self._generate_cache_fingerprint( - llm_request, total_contents_count - ) - return CacheMetadata( - fingerprint=fingerprint, - contents_count=total_contents_count, - ) - - def _find_count_of_contents_to_cache( - self, contents: list[types.Content] - ) -> int: - """Find the number of contents to cache based on user content strategy. - - Strategy: Find the last continuous batch of user contents and cache - all contents before them. - - Args: - contents: List of contents from the LLM request - - Returns: - Number of contents to cache (can be 0 if all contents are user contents) - """ - if not contents: - return 0 + def _find_count_of_contents_to_cache(self, contents: list[types.Content]) -> int: + """Find the number of contents to cache based on user content strategy. + + Strategy: Find the last continuous batch of user contents and cache + all contents before them. + + Args: + contents: List of contents from the LLM request + + Returns: + Number of contents to cache (can be 0 if all contents are user contents) + """ + if not contents: + return 0 + + # Find the last continuous batch of user contents + last_user_batch_start = len(contents) + + # Scan backwards to find the start of the last user content batch + for i in range(len(contents) - 1, -1, -1): + if contents[i].role == "user": + last_user_batch_start = i + else: + # Found non-user content, stop the batch + break + + # Cache all contents before the last user batch + # This ensures we always have some user content to send to the API + return last_user_batch_start + + async def _is_cache_valid(self, llm_request: LlmRequest) -> bool: + """Check if the cache from request metadata is still valid. + + Validates that it's an active cache (not fingerprint-only), checks expiry, + cache intervals, and fingerprint compatibility. + + Args: + llm_request: Request containing cache metadata to validate + + Returns: + True if cache is valid, False otherwise + """ + cache_metadata = llm_request.cache_metadata + if not cache_metadata: + return False + + # Fingerprint-only metadata is not a valid active cache + if cache_metadata.cache_name is None: + return False + + # Check if cache has expired + if time.time() >= cache_metadata.expire_time: + logger.info("Cache expired: %s", cache_metadata.cache_name) + return False + + # Check if cache has been used for too many invocations + if cache_metadata.invocations_used > llm_request.cache_config.cache_intervals: + logger.info( + "Cache exceeded cache intervals: %s (%d > %d intervals)", + cache_metadata.cache_name, + cache_metadata.invocations_used, + llm_request.cache_config.cache_intervals, + ) + return False - # Find the last continuous batch of user contents - last_user_batch_start = len(contents) + # Check if fingerprint matches using cached contents count + current_fingerprint = self._generate_cache_fingerprint( + llm_request, cache_metadata.contents_count + ) + if current_fingerprint != cache_metadata.fingerprint: + logger.debug("Cache content fingerprint mismatch") + return False - # Scan backwards to find the start of the last user content batch - for i in range(len(contents) - 1, -1, -1): - if contents[i].role == "user": - last_user_batch_start = i - else: - # Found non-user content, stop the batch - break + return True - # Cache all contents before the last user batch - # This ensures we always have some user content to send to the API - return last_user_batch_start + def _generate_cache_fingerprint( + self, llm_request: LlmRequest, cache_contents_count: int + ) -> str: + """Generate a fingerprint for cache validation. - async def _is_cache_valid(self, llm_request: LlmRequest) -> bool: - """Check if the cache from request metadata is still valid. + Includes system instruction, tools, tool_config, and first N contents. - Validates that it's an active cache (not fingerprint-only), checks expiry, - cache intervals, and fingerprint compatibility. + Args: + llm_request: Request to generate fingerprint for + cache_contents_count: Number of contents to include in fingerprint - Args: - llm_request: Request containing cache metadata to validate + Returns: + 16-character hexadecimal fingerprint representing the cached state + """ + # Create fingerprint from system instruction, tools, tool_config, and first N contents + fingerprint_data = {} + + if llm_request.config and llm_request.config.system_instruction: + fingerprint_data["system_instruction"] = ( + llm_request.config.system_instruction + ) + + if llm_request.config and llm_request.config.tools: + # Simplified: just dump types.Tool instances to JSON + tools_data = [] + for tool in llm_request.config.tools: + if isinstance(tool, types.Tool): + tools_data.append(tool.model_dump()) + fingerprint_data["tools"] = tools_data + + if llm_request.config and llm_request.config.tool_config: + fingerprint_data["tool_config"] = ( + llm_request.config.tool_config.model_dump() + ) + + # Include first N contents in fingerprint + if cache_contents_count > 0 and llm_request.contents: + contents_data = [] + for i in range(min(cache_contents_count, len(llm_request.contents))): + content = llm_request.contents[i] + contents_data.append(content.model_dump()) + fingerprint_data["cached_contents"] = contents_data + + # Generate hash using str() instead of json.dumps() to handle bytes + fingerprint_str = str(fingerprint_data) + return hashlib.sha256(fingerprint_str.encode()).hexdigest()[:16] + + async def _create_new_cache_with_contents( + self, llm_request: LlmRequest, cache_contents_count: int + ) -> Optional[CacheMetadata]: + """Create a new cache with specified number of contents. + + Args: + llm_request: Request to create cache for + cache_contents_count: Number of contents to include in cache + + Returns: + Cache metadata if successful, None otherwise + """ + # Check if we have token count from previous response for cache size validation + if llm_request.cacheable_contents_token_count is None: + logger.info( + "No previous token count available, skipping cache creation for" + " initial request" + ) + return None + + if ( + llm_request.cacheable_contents_token_count + < llm_request.cache_config.min_tokens + ): + logger.info( + "Previous request too small for caching (%d < %d tokens)", + llm_request.cacheable_contents_token_count, + llm_request.cache_config.min_tokens, + ) + return None + + try: + # Create cache using Gemini API directly + return await self._create_gemini_cache(llm_request, cache_contents_count) + except Exception as e: + logger.warning("Failed to create cache: %s", e) + return None + + def _estimate_request_tokens(self, llm_request: LlmRequest) -> int: + """Estimate token count for the request. + + This is a rough estimation based on content text length. + + Args: + llm_request: Request to estimate tokens for + + Returns: + Estimated token count + """ + total_chars = 0 + + # System instruction + if llm_request.config and llm_request.config.system_instruction: + total_chars += len(llm_request.config.system_instruction) + + # Tools + if llm_request.config and llm_request.config.tools: + for tool in llm_request.config.tools: + if isinstance(tool, types.Tool): + tool_str = json.dumps(tool.model_dump()) + total_chars += len(tool_str) + + # Contents + for content in llm_request.contents: + for part in content.parts: + if part.text: + total_chars += len(part.text) + + # Rough estimate: 4 characters per token + return total_chars // 4 + + async def _create_gemini_cache_with_optional_tracing( + self, llm_request: LlmRequest, cache_contents_count: int + ) -> CacheMetadata: + """Create cache using Gemini API. + + Args: + llm_request: Request to create cache for + cache_contents_count: Number of contents to cache + + Returns: + Cache metadata with precise creation timestamp + """ + + if not self.disable_telemetry: + from ..telemetry.tracing import tracer + + with tracer.start_as_current_span("create_cache") as span: + return await self._create_gemini_cache_body( + llm_request=llm_request, + cache_contents_count=cache_contents_count, + span=span, + ) + else: + return await self._create_gemini_cache_body( + llm_request=llm_request, cache_contents_count=cache_contents_count + ) + + async def _create_gemini_cache_body( + self, + llm_request: LlmRequest, + cache_contents_count: int, + span: Optional[Span] = None, + ) -> CacheMetadata: + """Create cache using Gemini API. + + Args: + llm_request: Request to create cache for + cache_contents_count: Number of contents to cache + + Returns: + Cache metadata with precise creation timestamp + """ + + # Prepare cache contents (first N contents + system instruction + tools) + cache_contents = llm_request.contents[:cache_contents_count] + + cache_config = types.CreateCachedContentConfig( + contents=cache_contents, + ttl=llm_request.cache_config.ttl_string, + display_name=( + f"adk-cache-{int(time.time())}-{cache_contents_count}contents" + ), + ) + + # Add system instruction if present + if llm_request.config and llm_request.config.system_instruction: + cache_config.system_instruction = llm_request.config.system_instruction + logger.debug( + "Added system instruction to cache config (length=%d)", + len(llm_request.config.system_instruction), + ) + + # Add tools if present + if llm_request.config and llm_request.config.tools: + cache_config.tools = llm_request.config.tools + + # Add tool config if present + if llm_request.config and llm_request.config.tool_config: + cache_config.tool_config = llm_request.config.tool_config + + if span is not None: + span.set_attribute("cache_contents_count", cache_contents_count) + span.set_attribute("model", llm_request.model) + span.set_attribute("ttl_seconds", llm_request.cache_config.ttl_seconds) - Returns: - True if cache is valid, False otherwise - """ - cache_metadata = llm_request.cache_metadata - if not cache_metadata: - return False - - # Fingerprint-only metadata is not a valid active cache - if cache_metadata.cache_name is None: - return False - - # Check if cache has expired - if time.time() >= cache_metadata.expire_time: - logger.info("Cache expired: %s", cache_metadata.cache_name) - return False - - # Check if cache has been used for too many invocations - if ( - cache_metadata.invocations_used - > llm_request.cache_config.cache_intervals - ): - logger.info( - "Cache exceeded cache intervals: %s (%d > %d intervals)", - cache_metadata.cache_name, - cache_metadata.invocations_used, - llm_request.cache_config.cache_intervals, - ) - return False - - # Check if fingerprint matches using cached contents count - current_fingerprint = self._generate_cache_fingerprint( - llm_request, cache_metadata.contents_count - ) - if current_fingerprint != cache_metadata.fingerprint: - logger.debug("Cache content fingerprint mismatch") - return False - - return True - - def _generate_cache_fingerprint( - self, llm_request: LlmRequest, cache_contents_count: int - ) -> str: - """Generate a fingerprint for cache validation. - - Includes system instruction, tools, tool_config, and first N contents. - - Args: - llm_request: Request to generate fingerprint for - cache_contents_count: Number of contents to include in fingerprint - - Returns: - 16-character hexadecimal fingerprint representing the cached state - """ - # Create fingerprint from system instruction, tools, tool_config, and first N contents - fingerprint_data = {} - - if llm_request.config and llm_request.config.system_instruction: - fingerprint_data["system_instruction"] = ( - llm_request.config.system_instruction - ) - - if llm_request.config and llm_request.config.tools: - # Simplified: just dump types.Tool instances to JSON - tools_data = [] - for tool in llm_request.config.tools: - if isinstance(tool, types.Tool): - tools_data.append(tool.model_dump()) - fingerprint_data["tools"] = tools_data - - if llm_request.config and llm_request.config.tool_config: - fingerprint_data["tool_config"] = ( - llm_request.config.tool_config.model_dump() - ) - - # Include first N contents in fingerprint - if cache_contents_count > 0 and llm_request.contents: - contents_data = [] - for i in range(min(cache_contents_count, len(llm_request.contents))): - content = llm_request.contents[i] - contents_data.append(content.model_dump()) - fingerprint_data["cached_contents"] = contents_data - - # Generate hash using str() instead of json.dumps() to handle bytes - fingerprint_str = str(fingerprint_data) - return hashlib.sha256(fingerprint_str.encode()).hexdigest()[:16] - - async def _create_new_cache_with_contents( - self, llm_request: LlmRequest, cache_contents_count: int - ) -> Optional[CacheMetadata]: - """Create a new cache with specified number of contents. - - Args: - llm_request: Request to create cache for - cache_contents_count: Number of contents to include in cache - - Returns: - Cache metadata if successful, None otherwise - """ - # Check if we have token count from previous response for cache size validation - if llm_request.cacheable_contents_token_count is None: - logger.info( - "No previous token count available, skipping cache creation for" - " initial request" - ) - return None - - if ( - llm_request.cacheable_contents_token_count - < llm_request.cache_config.min_tokens - ): - logger.info( - "Previous request too small for caching (%d < %d tokens)", - llm_request.cacheable_contents_token_count, - llm_request.cache_config.min_tokens, - ) - return None - - try: - # Create cache using Gemini API directly - return await self._create_gemini_cache(llm_request, cache_contents_count) - except Exception as e: - logger.warning("Failed to create cache: %s", e) - return None - - def _estimate_request_tokens(self, llm_request: LlmRequest) -> int: - """Estimate token count for the request. - - This is a rough estimation based on content text length. - - Args: - llm_request: Request to estimate tokens for - - Returns: - Estimated token count - """ - total_chars = 0 - - # System instruction - if llm_request.config and llm_request.config.system_instruction: - total_chars += len(llm_request.config.system_instruction) - - # Tools - if llm_request.config and llm_request.config.tools: - for tool in llm_request.config.tools: - if isinstance(tool, types.Tool): - tool_str = json.dumps(tool.model_dump()) - total_chars += len(tool_str) - - # Contents - for content in llm_request.contents: - for part in content.parts: - if part.text: - total_chars += len(part.text) - - # Rough estimate: 4 characters per token - return total_chars // 4 - - async def _create_gemini_cache( - self, llm_request: LlmRequest, cache_contents_count: int - ) -> CacheMetadata: - """Create cache using Gemini API. - - Args: - llm_request: Request to create cache for - cache_contents_count: Number of contents to cache - - Returns: - Cache metadata with precise creation timestamp - """ - from ..telemetry.tracing import tracer - - with tracer.start_as_current_span("create_cache") as span: - # Prepare cache contents (first N contents + system instruction + tools) - cache_contents = llm_request.contents[:cache_contents_count] - - cache_config = types.CreateCachedContentConfig( - contents=cache_contents, - ttl=llm_request.cache_config.ttl_string, - display_name=( - f"adk-cache-{int(time.time())}-{cache_contents_count}contents" - ), - ) - - # Add system instruction if present - if llm_request.config and llm_request.config.system_instruction: - cache_config.system_instruction = llm_request.config.system_instruction logger.debug( - "Added system instruction to cache config (length=%d)", - len(llm_request.config.system_instruction), + "Creating cache with model %s and config: %s", + llm_request.model, + cache_config, ) + cached_content = await self.genai_client.aio.caches.create( + model=llm_request.model, + config=cache_config, + ) + # Set precise creation timestamp right after cache creation + created_at = time.time() + logger.info("Cache created successfully: %s", cached_content.name) - # Add tools if present - if llm_request.config and llm_request.config.tools: - cache_config.tools = llm_request.config.tools - - # Add tool config if present - if llm_request.config and llm_request.config.tool_config: - cache_config.tool_config = llm_request.config.tool_config - - span.set_attribute("cache_contents_count", cache_contents_count) - span.set_attribute("model", llm_request.model) - span.set_attribute("ttl_seconds", llm_request.cache_config.ttl_seconds) - - logger.debug( - "Creating cache with model %s and config: %s", - llm_request.model, - cache_config, - ) - cached_content = await self.genai_client.aio.caches.create( - model=llm_request.model, - config=cache_config, - ) - # Set precise creation timestamp right after cache creation - created_at = time.time() - logger.info("Cache created successfully: %s", cached_content.name) - - span.set_attribute("cache_name", cached_content.name) - - # Return complete cache metadata with precise timing - return CacheMetadata( - cache_name=cached_content.name, - expire_time=created_at + llm_request.cache_config.ttl_seconds, - fingerprint=self._generate_cache_fingerprint( - llm_request, cache_contents_count - ), - invocations_used=1, - contents_count=cache_contents_count, - created_at=created_at, - ) - - async def cleanup_cache(self, cache_name: str) -> None: - """Clean up cache by deleting it. - - Args: - cache_name: Name of cache to delete - """ - logger.debug("Attempting to delete cache: %s", cache_name) - try: - await self.genai_client.aio.caches.delete(name=cache_name) - logger.info("Cache cleaned up: %s", cache_name) - except Exception as e: - logger.warning("Failed to cleanup cache %s: %s", cache_name, e) - - def _apply_cache_to_request( - self, - llm_request: LlmRequest, - cache_name: str, - cache_contents_count: int, - ) -> None: - """Apply cache to the request by modifying it to use cached content. - - Args: - llm_request: Request to modify - cache_name: Name of cache to use - cache_contents_count: Number of contents that are cached - """ - # Remove system instruction, tools, and tool config from request config since they're in cache - if llm_request.config: - llm_request.config.system_instruction = None - llm_request.config.tools = None - llm_request.config.tool_config = None - - # Set cached content reference - llm_request.config.cached_content = cache_name - - # Remove cached contents from the request (keep only uncached contents) - llm_request.contents = llm_request.contents[cache_contents_count:] - - def populate_cache_metadata_in_response( - self, llm_response: LlmResponse, cache_metadata: CacheMetadata - ) -> None: - """Populate cache metadata in LLM response. - - Args: - llm_response: Response to populate metadata in - cache_metadata: Cache metadata to copy into response - """ - # Create a copy of cache metadata for the response - llm_response.cache_metadata = cache_metadata.model_copy() + if span is not None: + span.set_attribute("cache_name", cached_content.name) + + # Return complete cache metadata with precise timing + return CacheMetadata( + cache_name=cached_content.name, + expire_time=created_at + llm_request.cache_config.ttl_seconds, + fingerprint=self._generate_cache_fingerprint( + llm_request, cache_contents_count + ), + invocations_used=1, + contents_count=cache_contents_count, + created_at=created_at, + ) + + async def cleanup_cache(self, cache_name: str) -> None: + """Clean up cache by deleting it. + + Args: + cache_name: Name of cache to delete + """ + logger.debug("Attempting to delete cache: %s", cache_name) + try: + await self.genai_client.aio.caches.delete(name=cache_name) + logger.info("Cache cleaned up: %s", cache_name) + except Exception as e: + logger.warning("Failed to cleanup cache %s: %s", cache_name, e) + + def _apply_cache_to_request( + self, + llm_request: LlmRequest, + cache_name: str, + cache_contents_count: int, + ) -> None: + """Apply cache to the request by modifying it to use cached content. + + Args: + llm_request: Request to modify + cache_name: Name of cache to use + cache_contents_count: Number of contents that are cached + """ + # Remove system instruction, tools, and tool config from request config since they're in cache + if llm_request.config: + llm_request.config.system_instruction = None + llm_request.config.tools = None + llm_request.config.tool_config = None + + # Set cached content reference + llm_request.config.cached_content = cache_name + + # Remove cached contents from the request (keep only uncached contents) + llm_request.contents = llm_request.contents[cache_contents_count:] + + def populate_cache_metadata_in_response( + self, llm_response: LlmResponse, cache_metadata: CacheMetadata + ) -> None: + """Populate cache metadata in LLM response. + + Args: + llm_response: Response to populate metadata in + cache_metadata: Cache metadata to copy into response + """ + # Create a copy of cache metadata for the response + llm_response.cache_metadata = cache_metadata.model_copy() diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index 9261fada39..c0f5725e0b 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -40,14 +40,14 @@ from .llm_response import LlmResponse if TYPE_CHECKING: - from google.genai import Client + from google.genai import Client - from .llm_request import LlmRequest + from .llm_request import LlmRequest -logger = logging.getLogger('google_adk.' + __name__) +logger = logging.getLogger("google_adk." + __name__) -_NEW_LINE = '\n' -_EXCLUDED_PART_FIELD = {'inline_data': {'data'}} +_NEW_LINE = "\n" +_EXCLUDED_PART_FIELD = {"inline_data": {"data"}} _RESOURCE_EXHAUSTED_POSSIBLE_FIX_MESSAGE = """ @@ -58,489 +58,499 @@ class _ResourceExhaustedError(ClientError): - """Represents an resources exhausted error received from the Model.""" - - def __init__( - self, - client_error: ClientError, - ): - super().__init__( - code=client_error.code, - response_json=client_error.details, - response=client_error.response, - ) - - def __str__(self): - # We don't get override the actual message on ClientError, so we override - # this method instead. This will ensure that when the exception is - # stringified (for either publishing the exception on console or to logs) - # we put in the required details for the developer. - base_message = super().__str__() - return f'{_RESOURCE_EXHAUSTED_POSSIBLE_FIX_MESSAGE}\n\n{base_message}' - - -class Gemini(BaseLlm): - """Integration for Gemini models. + """Represents an resources exhausted error received from the Model.""" - Attributes: - model: The name of the Gemini model. - use_interactions_api: Whether to use the interactions API for model - invocation. - """ + def __init__( + self, + client_error: ClientError, + ): + super().__init__( + code=client_error.code, + response_json=client_error.details, + response=client_error.response, + ) - model: str = 'gemini-2.5-flash' + def __str__(self): + # We don't get override the actual message on ClientError, so we override + # this method instead. This will ensure that when the exception is + # stringified (for either publishing the exception on console or to logs) + # we put in the required details for the developer. + base_message = super().__str__() + return f"{_RESOURCE_EXHAUSTED_POSSIBLE_FIX_MESSAGE}\n\n{base_message}" - speech_config: Optional[types.SpeechConfig] = None - use_interactions_api: bool = False - """Whether to use the interactions API for model invocation. +class Gemini(BaseLlm): + """Integration for Gemini models. - When enabled, uses the interactions API (client.aio.interactions.create()) - instead of the traditional generate_content API. The interactions API - provides stateful conversation capabilities, allowing you to chain - interactions using previous_interaction_id instead of sending full history. - The response format will be converted to match the existing LlmResponse - structure for compatibility. + Attributes: + model: The name of the Gemini model. + use_interactions_api: Whether to use the interactions API for model + invocation. + """ - Sample: - ```python - agent = Agent( - model=Gemini(use_interactions_api=True) - ) - ``` - """ + model: str = "gemini-2.5-flash" - retry_options: Optional[types.HttpRetryOptions] = None - """Allow Gemini to retry failed responses. + speech_config: Optional[types.SpeechConfig] = None - Sample: - ```python - from google.genai import types + use_interactions_api: bool = False + """Whether to use the interactions API for model invocation. - # ... + When enabled, uses the interactions API (client.aio.interactions.create()) + instead of the traditional generate_content API. The interactions API + provides stateful conversation capabilities, allowing you to chain + interactions using previous_interaction_id instead of sending full history. + The response format will be converted to match the existing LlmResponse + structure for compatibility. - agent = Agent( - model=Gemini( - retry_options=types.HttpRetryOptions(initial_delay=1, attempts=2), + Sample: + ```python + agent = Agent( + model=Gemini(use_interactions_api=True) ) - ) - ``` - """ - - @classmethod - @override - def supported_models(cls) -> list[str]: - """Provides the list of supported models. - - Returns: - A list of supported models. + ``` """ - return [ - r'gemini-.*', - # model optimizer pattern - r'model-optimizer-.*', - # fine-tuned vertex endpoint pattern - r'projects\/.+\/locations\/.+\/endpoints\/.+', - # vertex gemini long name - r'projects\/.+\/locations\/.+\/publishers\/google\/models\/gemini.+', - ] + retry_options: Optional[types.HttpRetryOptions] = None + """Allow Gemini to retry failed responses. - async def generate_content_async( - self, llm_request: LlmRequest, stream: bool = False - ) -> AsyncGenerator[LlmResponse, None]: - """Sends a request to the Gemini model. + Sample: + ```python + from google.genai import types - Args: - llm_request: LlmRequest, the request to send to the Gemini model. - stream: bool = False, whether to do streaming call. + # ... - Yields: - LlmResponse: The model response. - """ - await self._preprocess_request(llm_request) - self._maybe_append_user_content(llm_request) - - # Handle context caching if configured - cache_metadata = None - cache_manager = None - if llm_request.cache_config: - from ..telemetry.tracing import tracer - from .gemini_context_cache_manager import GeminiContextCacheManager - - with tracer.start_as_current_span('handle_context_caching') as span: - cache_manager = GeminiContextCacheManager(self.api_client) - cache_metadata = await cache_manager.handle_context_caching(llm_request) - if cache_metadata: - if cache_metadata.cache_name: - span.set_attribute('cache_action', 'active_cache') - span.set_attribute('cache_name', cache_metadata.cache_name) - else: - span.set_attribute('cache_action', 'fingerprint_only') - - logger.info( - 'Sending out request, model: %s, backend: %s, stream: %s', - llm_request.model, - self._api_backend, - stream, - ) - - # Always add tracking headers to custom headers given it will override - # the headers set in the api client constructor to avoid tracking headers - # being dropped if user provides custom headers or overrides the api client. - if llm_request.config: - if not llm_request.config.http_options: - llm_request.config.http_options = types.HttpOptions() - llm_request.config.http_options.headers = self._merge_tracking_headers( - llm_request.config.http_options.headers + agent = Agent( + model=Gemini( + retry_options=types.HttpRetryOptions(initial_delay=1, attempts=2), ) + ) + ``` + """ - try: - # Use interactions API if enabled - if self.use_interactions_api: - async for llm_response in self._generate_content_via_interactions( - llm_request, stream - ): - yield llm_response - return - - logger.debug(_build_request_log(llm_request)) + disable_telemetry: bool = True + """A bool to flag whether or not telemetry should be being disabled for Gemini LLM interactions. + """ - if stream: - responses = await self.api_client.aio.models.generate_content_stream( - model=llm_request.model, - contents=llm_request.contents, - config=llm_request.config, + @classmethod + @override + def supported_models(cls) -> list[str]: + """Provides the list of supported models. + + Returns: + A list of supported models. + """ + + return [ + r"gemini-.*", + # model optimizer pattern + r"model-optimizer-.*", + # fine-tuned vertex endpoint pattern + r"projects\/.+\/locations\/.+\/endpoints\/.+", + # vertex gemini long name + r"projects\/.+\/locations\/.+\/publishers\/google\/models\/gemini.+", + ] + + async def generate_content_async( + self, llm_request: LlmRequest, stream: bool = False + ) -> AsyncGenerator[LlmResponse, None]: + """Sends a request to the Gemini model. + + Args: + llm_request: LlmRequest, the request to send to the Gemini model. + stream: bool = False, whether to do streaming call. + + Yields: + LlmResponse: The model response. + """ + await self._preprocess_request(llm_request) + self._maybe_append_user_content(llm_request) + + # Handle context caching if configured + cache_metadata = None + cache_manager = None + if llm_request.cache_config: + from .gemini_context_cache_manager import GeminiContextCacheManager + + if not self.disable_telemetry: + from ..telemetry.tracing import tracer + + with tracer.start_as_current_span("handle_context_caching") as span: + cache_manager = GeminiContextCacheManager( + self.api_client, disable_telemetry=self.disable_telemetry + ) + cache_metadata = await cache_manager.handle_context_caching( + llm_request + ) + if cache_metadata: + if cache_metadata.cache_name: + span.set_attribute("cache_action", "active_cache") + span.set_attribute("cache_name", cache_metadata.cache_name) + else: + span.set_attribute("cache_action", "fingerprint_only") + else: + cache_manager = GeminiContextCacheManager( + self.api_client, disable_telemetry=self.disable_telemetry + ) + cache_metadata = await cache_manager.handle_context_caching(llm_request) + + logger.info( + "Sending out request, model: %s, backend: %s, stream: %s", + llm_request.model, + self._api_backend, + stream, ) - # for sse, similar as bidi (see receive method in - # gemini_llm_connection.py), we need to mark those text content as - # partial and after all partial contents are sent, we send an - # accumulated event which contains all the previous partial content. The - # only difference is bidi rely on complete_turn flag to detect end while - # sse depends on finish_reason. - aggregator = StreamingResponseAggregator() - async with Aclosing(responses) as agen: - async for response in agen: - logger.debug(_build_response_log(response)) - async with Aclosing( - aggregator.process_response(response) - ) as aggregator_gen: - async for llm_response in aggregator_gen: - yield llm_response - if (close_result := aggregator.close()) is not None: - # Populate cache metadata in the final aggregated response for - # streaming - if cache_metadata: - cache_manager.populate_cache_metadata_in_response( - close_result, cache_metadata + # Always add tracking headers to custom headers given it will override + # the headers set in the api client constructor to avoid tracking headers + # being dropped if user provides custom headers or overrides the api client. + if llm_request.config: + if not llm_request.config.http_options: + llm_request.config.http_options = types.HttpOptions() + llm_request.config.http_options.headers = self._merge_tracking_headers( + llm_request.config.http_options.headers ) - yield close_result - else: - response = await self.api_client.aio.models.generate_content( - model=llm_request.model, - contents=llm_request.contents, - config=llm_request.config, - ) - logger.info('Response received from the model.') - logger.debug(_build_response_log(response)) - - llm_response = LlmResponse.create(response) - if cache_metadata: - cache_manager.populate_cache_metadata_in_response( - llm_response, cache_metadata - ) - yield llm_response - except ClientError as ce: - if ce.code == 429: - # We expect running into a Resource Exhausted error to be a common - # client error that developers would run into. We enhance the messaging - # with possible fixes to this issue. - raise _ResourceExhaustedError(ce) from ce - - raise ce - - async def _generate_content_via_interactions( - self, - llm_request: LlmRequest, - stream: bool, - ) -> AsyncGenerator[LlmResponse, None]: - """Generate content using the interactions API. - - The interactions API provides stateful conversation capabilities. When - previous_interaction_id is set in the request, the API chains interactions - instead of requiring full conversation history. - - Note: Context caching is not used with the Interactions API since it - maintains conversation state via previous_interaction_id. - - Args: - llm_request: The LLM request to send. - stream: Whether to stream the response. - - Yields: - LlmResponse objects converted from interaction responses. - """ - from .interactions_utils import generate_content_via_interactions + try: + # Use interactions API if enabled + if self.use_interactions_api: + async for llm_response in self._generate_content_via_interactions( + llm_request, stream + ): + yield llm_response + return + + logger.debug(_build_request_log(llm_request)) + + if stream: + responses = await self.api_client.aio.models.generate_content_stream( + model=llm_request.model, + contents=llm_request.contents, + config=llm_request.config, + ) + + # for sse, similar as bidi (see receive method in + # gemini_llm_connection.py), we need to mark those text content as + # partial and after all partial contents are sent, we send an + # accumulated event which contains all the previous partial content. The + # only difference is bidi rely on complete_turn flag to detect end while + # sse depends on finish_reason. + aggregator = StreamingResponseAggregator() + async with Aclosing(responses) as agen: + async for response in agen: + logger.debug(_build_response_log(response)) + async with Aclosing( + aggregator.process_response(response) + ) as aggregator_gen: + async for llm_response in aggregator_gen: + yield llm_response + if (close_result := aggregator.close()) is not None: + # Populate cache metadata in the final aggregated response for + # streaming + if cache_metadata: + cache_manager.populate_cache_metadata_in_response( + close_result, cache_metadata + ) + yield close_result + + else: + response = await self.api_client.aio.models.generate_content( + model=llm_request.model, + contents=llm_request.contents, + config=llm_request.config, + ) + logger.info("Response received from the model.") + logger.debug(_build_response_log(response)) + + llm_response = LlmResponse.create(response) + if cache_metadata: + cache_manager.populate_cache_metadata_in_response( + llm_response, cache_metadata + ) + yield llm_response + except ClientError as ce: + if ce.code == 429: + # We expect running into a Resource Exhausted error to be a common + # client error that developers would run into. We enhance the messaging + # with possible fixes to this issue. + raise _ResourceExhaustedError(ce) from ce + + raise ce + + async def _generate_content_via_interactions( + self, + llm_request: LlmRequest, + stream: bool, + ) -> AsyncGenerator[LlmResponse, None]: + """Generate content using the interactions API. + + The interactions API provides stateful conversation capabilities. When + previous_interaction_id is set in the request, the API chains interactions + instead of requiring full conversation history. + + Note: Context caching is not used with the Interactions API since it + maintains conversation state via previous_interaction_id. + + Args: + llm_request: The LLM request to send. + stream: Whether to stream the response. + + Yields: + LlmResponse objects converted from interaction responses. + """ + from .interactions_utils import generate_content_via_interactions + + async for llm_response in generate_content_via_interactions( + api_client=self.api_client, + llm_request=llm_request, + stream=stream, + ): + yield llm_response - async for llm_response in generate_content_via_interactions( - api_client=self.api_client, - llm_request=llm_request, - stream=stream, - ): - yield llm_response + @cached_property + def api_client(self) -> Client: + """Provides the api client. - @cached_property - def api_client(self) -> Client: - """Provides the api client. + Returns: + The api client. + """ + from google.genai import Client - Returns: - The api client. - """ - from google.genai import Client - - return Client( - http_options=types.HttpOptions( - headers=self._tracking_headers(), - retry_options=self.retry_options, + return Client( + http_options=types.HttpOptions( + headers=self._tracking_headers(), + retry_options=self.retry_options, + ) ) - ) - @cached_property - def _api_backend(self) -> GoogleLLMVariant: - return ( - GoogleLLMVariant.VERTEX_AI - if self.api_client.vertexai - else GoogleLLMVariant.GEMINI_API - ) - - def _tracking_headers(self) -> dict[str, str]: - labels = get_client_labels() - header_value = ' '.join(labels) - tracking_headers = { - 'x-goog-api-client': header_value, - 'user-agent': header_value, - } - return tracking_headers - - @cached_property - def _live_api_version(self) -> str: - if self._api_backend == GoogleLLMVariant.VERTEX_AI: - # use beta version for vertex api - return 'v1beta1' - else: - # use v1alpha for using API KEY from Google AI Studio - return 'v1alpha' - - @cached_property - def _live_api_client(self) -> Client: - from google.genai import Client - - return Client( - http_options=types.HttpOptions( - headers=self._tracking_headers(), api_version=self._live_api_version + @cached_property + def _api_backend(self) -> GoogleLLMVariant: + return ( + GoogleLLMVariant.VERTEX_AI + if self.api_client.vertexai + else GoogleLLMVariant.GEMINI_API ) - ) - @contextlib.asynccontextmanager - async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection: - """Connects to the Gemini model and returns an llm connection. - - Args: - llm_request: LlmRequest, the request to send to the Gemini model. - - Yields: - BaseLlmConnection, the connection to the Gemini model. - """ - # add tracking headers to custom headers and set api_version given - # the customized http options will override the one set in the api client - # constructor - if ( - llm_request.live_connect_config - and llm_request.live_connect_config.http_options - ): - if not llm_request.live_connect_config.http_options.headers: - llm_request.live_connect_config.http_options.headers = {} - llm_request.live_connect_config.http_options.headers.update( - self._tracking_headers() - ) - llm_request.live_connect_config.http_options.api_version = ( - self._live_api_version - ) - - if self.speech_config is not None: - llm_request.live_connect_config.speech_config = self.speech_config - - llm_request.live_connect_config.system_instruction = types.Content( - role='system', - parts=[ - types.Part.from_text(text=llm_request.config.system_instruction) - ], - ) - if ( - llm_request.live_connect_config.session_resumption - and llm_request.live_connect_config.session_resumption.transparent - ): - logger.debug( - 'session resumption config: %s', - llm_request.live_connect_config.session_resumption, - ) - logger.debug( - 'self._api_backend: %s', - self._api_backend, - ) - if self._api_backend == GoogleLLMVariant.GEMINI_API: - raise ValueError( - 'Transparent session resumption is only supported for Vertex AI' - ' backend. Please use Vertex AI backend.' + def _tracking_headers(self) -> dict[str, str]: + labels = get_client_labels() + header_value = " ".join(labels) + tracking_headers = { + "x-goog-api-client": header_value, + "user-agent": header_value, + } + return tracking_headers + + @cached_property + def _live_api_version(self) -> str: + if self._api_backend == GoogleLLMVariant.VERTEX_AI: + # use beta version for vertex api + return "v1beta1" + else: + # use v1alpha for using API KEY from Google AI Studio + return "v1alpha" + + @cached_property + def _live_api_client(self) -> Client: + from google.genai import Client + + return Client( + http_options=types.HttpOptions( + headers=self._tracking_headers(), api_version=self._live_api_version + ) ) - llm_request.live_connect_config.tools = llm_request.config.tools - logger.info('Connecting to live for model: %s', llm_request.model) - logger.debug('Connecting to live with llm_request:%s', llm_request) - logger.debug('Live connect config: %s', llm_request.live_connect_config) - async with self._live_api_client.aio.live.connect( - model=llm_request.model, config=llm_request.live_connect_config - ) as live_session: - yield GeminiLlmConnection(live_session, api_backend=self._api_backend) - - async def _adapt_computer_use_tool(self, llm_request: LlmRequest) -> None: - """Adapt the google computer use predefined functions to the adk computer use toolset.""" - - from ..tools.computer_use.computer_use_toolset import ComputerUseToolset - async def convert_wait_to_wait_5_seconds(wait_func): - async def wait_5_seconds(): - return await wait_func(5) + @contextlib.asynccontextmanager + async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection: + """Connects to the Gemini model and returns an llm connection. + + Args: + llm_request: LlmRequest, the request to send to the Gemini model. + + Yields: + BaseLlmConnection, the connection to the Gemini model. + """ + # add tracking headers to custom headers and set api_version given + # the customized http options will override the one set in the api client + # constructor + if ( + llm_request.live_connect_config + and llm_request.live_connect_config.http_options + ): + if not llm_request.live_connect_config.http_options.headers: + llm_request.live_connect_config.http_options.headers = {} + llm_request.live_connect_config.http_options.headers.update( + self._tracking_headers() + ) + llm_request.live_connect_config.http_options.api_version = ( + self._live_api_version + ) - return wait_5_seconds + if self.speech_config is not None: + llm_request.live_connect_config.speech_config = self.speech_config - await ComputerUseToolset.adapt_computer_use_tool( - 'wait', convert_wait_to_wait_5_seconds, llm_request - ) + llm_request.live_connect_config.system_instruction = types.Content( + role="system", + parts=[types.Part.from_text(text=llm_request.config.system_instruction)], + ) + if ( + llm_request.live_connect_config.session_resumption + and llm_request.live_connect_config.session_resumption.transparent + ): + logger.debug( + "session resumption config: %s", + llm_request.live_connect_config.session_resumption, + ) + logger.debug( + "self._api_backend: %s", + self._api_backend, + ) + if self._api_backend == GoogleLLMVariant.GEMINI_API: + raise ValueError( + "Transparent session resumption is only supported for Vertex AI" + " backend. Please use Vertex AI backend." + ) + llm_request.live_connect_config.tools = llm_request.config.tools + logger.info("Connecting to live for model: %s", llm_request.model) + logger.debug("Connecting to live with llm_request:%s", llm_request) + logger.debug("Live connect config: %s", llm_request.live_connect_config) + async with self._live_api_client.aio.live.connect( + model=llm_request.model, config=llm_request.live_connect_config + ) as live_session: + yield GeminiLlmConnection(live_session, api_backend=self._api_backend) + + async def _adapt_computer_use_tool(self, llm_request: LlmRequest) -> None: + """Adapt the google computer use predefined functions to the adk computer use toolset.""" + + from ..tools.computer_use.computer_use_toolset import ComputerUseToolset + + async def convert_wait_to_wait_5_seconds(wait_func): + async def wait_5_seconds(): + return await wait_func(5) + + return wait_5_seconds + + await ComputerUseToolset.adapt_computer_use_tool( + "wait", convert_wait_to_wait_5_seconds, llm_request + ) - async def _preprocess_request(self, llm_request: LlmRequest) -> None: - - if self._api_backend == GoogleLLMVariant.GEMINI_API: - # Using API key from Google AI Studio to call model doesn't support labels. - if llm_request.config: - llm_request.config.labels = None - - if llm_request.contents: - for content in llm_request.contents: - if not content.parts: - continue - for part in content.parts: - # Create copies to avoid mutating the original objects - if part.inline_data: - part.inline_data = copy.copy(part.inline_data) - _remove_display_name_if_present(part.inline_data) - if part.file_data: - part.file_data = copy.copy(part.file_data) - _remove_display_name_if_present(part.file_data) - - # Initialize config if needed - if llm_request.config and llm_request.config.tools: - # Check if computer use is configured - for tool in llm_request.config.tools: - if isinstance(tool, types.Tool) and tool.computer_use: - llm_request.config.system_instruction = None - await self._adapt_computer_use_tool(llm_request) - - def _merge_tracking_headers(self, headers: dict[str, str]) -> dict[str, str]: - """Merge tracking headers to the given headers.""" - headers = headers or {} - for key, tracking_header_value in self._tracking_headers().items(): - custom_value = headers.get(key, None) - if not custom_value: - headers[key] = tracking_header_value - continue - - # Merge tracking headers with existing headers and avoid duplicates. - value_parts = tracking_header_value.split(' ') - for custom_value_part in custom_value.split(' '): - if custom_value_part not in value_parts: - value_parts.append(custom_value_part) - headers[key] = ' '.join(value_parts) - return headers + async def _preprocess_request(self, llm_request: LlmRequest) -> None: + + if self._api_backend == GoogleLLMVariant.GEMINI_API: + # Using API key from Google AI Studio to call model doesn't support labels. + if llm_request.config: + llm_request.config.labels = None + + if llm_request.contents: + for content in llm_request.contents: + if not content.parts: + continue + for part in content.parts: + # Create copies to avoid mutating the original objects + if part.inline_data: + part.inline_data = copy.copy(part.inline_data) + _remove_display_name_if_present(part.inline_data) + if part.file_data: + part.file_data = copy.copy(part.file_data) + _remove_display_name_if_present(part.file_data) + + # Initialize config if needed + if llm_request.config and llm_request.config.tools: + # Check if computer use is configured + for tool in llm_request.config.tools: + if isinstance(tool, types.Tool) and tool.computer_use: + llm_request.config.system_instruction = None + await self._adapt_computer_use_tool(llm_request) + + def _merge_tracking_headers(self, headers: dict[str, str]) -> dict[str, str]: + """Merge tracking headers to the given headers.""" + headers = headers or {} + for key, tracking_header_value in self._tracking_headers().items(): + custom_value = headers.get(key, None) + if not custom_value: + headers[key] = tracking_header_value + continue + + # Merge tracking headers with existing headers and avoid duplicates. + value_parts = tracking_header_value.split(" ") + for custom_value_part in custom_value.split(" "): + if custom_value_part not in value_parts: + value_parts.append(custom_value_part) + headers[key] = " ".join(value_parts) + return headers def _build_function_declaration_log( func_decl: types.FunctionDeclaration, ) -> str: - param_str = '{}' - if func_decl.parameters and func_decl.parameters.properties: - param_str = str({ - k: v.model_dump(exclude_none=True) - for k, v in func_decl.parameters.properties.items() - }) - elif func_decl.parameters_json_schema: - param_str = str(func_decl.parameters_json_schema) + param_str = "{}" + if func_decl.parameters and func_decl.parameters.properties: + param_str = str( + { + k: v.model_dump(exclude_none=True) + for k, v in func_decl.parameters.properties.items() + } + ) + elif func_decl.parameters_json_schema: + param_str = str(func_decl.parameters_json_schema) - return_str = '' - if func_decl.response: - return_str = '-> ' + str(func_decl.response.model_dump(exclude_none=True)) - elif func_decl.response_json_schema: - return_str = '-> ' + str(func_decl.response_json_schema) + return_str = "" + if func_decl.response: + return_str = "-> " + str(func_decl.response.model_dump(exclude_none=True)) + elif func_decl.response_json_schema: + return_str = "-> " + str(func_decl.response_json_schema) - return f'{func_decl.name}: {param_str} {return_str}' + return f"{func_decl.name}: {param_str} {return_str}" def _build_request_log(req: LlmRequest) -> str: - # Find which tool contains function_declarations - function_decls: list[types.FunctionDeclaration] = [] - function_decl_tool_index: Optional[int] = None - - if req.config.tools: - for idx, tool in enumerate(req.config.tools): - if tool.function_declarations: - function_decls = cast( - list[types.FunctionDeclaration], tool.function_declarations - ) - function_decl_tool_index = idx - break - - function_logs = ( - [ - _build_function_declaration_log(func_decl) - for func_decl in function_decls - ] - if function_decls - else [] - ) - contents_logs = [ - content.model_dump_json( - exclude_none=True, - exclude={ - 'parts': { - i: _EXCLUDED_PART_FIELD for i in range(len(content.parts)) - } - }, - ) - for content in req.contents - ] - - # Build exclusion dict for config logging - tools_exclusion = ( - {function_decl_tool_index: {'function_declarations'}} - if function_decl_tool_index is not None - else True - ) - - try: - config_log = str( - req.config.model_dump( + # Find which tool contains function_declarations + function_decls: list[types.FunctionDeclaration] = [] + function_decl_tool_index: Optional[int] = None + + if req.config.tools: + for idx, tool in enumerate(req.config.tools): + if tool.function_declarations: + function_decls = cast( + list[types.FunctionDeclaration], tool.function_declarations + ) + function_decl_tool_index = idx + break + + function_logs = ( + [_build_function_declaration_log(func_decl) for func_decl in function_decls] + if function_decls + else [] + ) + contents_logs = [ + content.model_dump_json( exclude_none=True, exclude={ - 'system_instruction': True, - 'tools': tools_exclusion if req.config.tools else True, + "parts": {i: _EXCLUDED_PART_FIELD for i in range(len(content.parts))} }, ) + for content in req.contents + ] + + # Build exclusion dict for config logging + tools_exclusion = ( + {function_decl_tool_index: {"function_declarations"}} + if function_decl_tool_index is not None + else True ) - except Exception: - config_log = repr(req.config) - return f""" + try: + config_log = str( + req.config.model_dump( + exclude_none=True, + exclude={ + "system_instruction": True, + "tools": tools_exclusion if req.config.tools else True, + }, + ) + ) + except Exception: + config_log = repr(req.config) + + return f""" LLM Request: ----------------------------------------------------------- System Instruction: @@ -559,13 +569,13 @@ def _build_request_log(req: LlmRequest) -> str: def _build_response_log(resp: types.GenerateContentResponse) -> str: - function_calls_text = [] - if function_calls := resp.function_calls: - for func_call in function_calls: - function_calls_text.append( - f'name: {func_call.name}, args: {func_call.args}' - ) - return f""" + function_calls_text = [] + if function_calls := resp.function_calls: + for func_call in function_calls: + function_calls_text.append( + f"name: {func_call.name}, args: {func_call.args}" + ) + return f""" LLM Response: ----------------------------------------------------------- Text: @@ -583,10 +593,10 @@ def _build_response_log(resp: types.GenerateContentResponse) -> str: def _remove_display_name_if_present( data_obj: Union[types.Blob, types.FileData, None], ): - """Sets display_name to None for the Gemini API (non-Vertex) backend. + """Sets display_name to None for the Gemini API (non-Vertex) backend. - This backend does not support the display_name parameter for file uploads, - so it must be removed to prevent request failures. - """ - if data_obj and data_obj.display_name: - data_obj.display_name = None + This backend does not support the display_name parameter for file uploads, + so it must be removed to prevent request failures. + """ + if data_obj and data_obj.display_name: + data_obj.display_name = None diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 1773729719..2ad9594ebd 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -63,1467 +63,1481 @@ from .tools.base_toolset import BaseToolset from .utils._debug_output import print_event from .utils.context_utils import Aclosing +from .utils.telemetry_utils import is_telemetry_enabled -logger = logging.getLogger('google_adk.' + __name__) +logger = logging.getLogger("google_adk." + __name__) def _is_tool_call_or_response(event: Event) -> bool: - return bool(event.get_function_calls() or event.get_function_responses()) + return bool(event.get_function_calls() or event.get_function_responses()) def _is_transcription(event: Event) -> bool: - return ( - event.input_transcription is not None - or event.output_transcription is not None - ) + return ( + event.input_transcription is not None or event.output_transcription is not None + ) def _has_non_empty_transcription_text(transcription) -> bool: - return bool( - transcription and transcription.text and transcription.text.strip() - ) + return bool(transcription and transcription.text and transcription.text.strip()) class Runner: - """The Runner class is used to run agents. - - It manages the execution of an agent within a session, handling message - processing, event generation, and interaction with various services like - artifact storage, session management, and memory. - - Attributes: - app_name: The application name of the runner. - agent: The root agent to run. - artifact_service: The artifact service for the runner. - plugin_manager: The plugin manager for the runner. - session_service: The session service for the runner. - memory_service: The memory service for the runner. - credential_service: The credential service for the runner. - context_cache_config: The context cache config for the runner. - resumability_config: The resumability config for the application. - """ - - app_name: str - """The app name of the runner.""" - agent: BaseAgent - """The root agent to run.""" - artifact_service: Optional[BaseArtifactService] = None - """The artifact service for the runner.""" - plugin_manager: PluginManager - """The plugin manager for the runner.""" - session_service: BaseSessionService - """The session service for the runner.""" - memory_service: Optional[BaseMemoryService] = None - """The memory service for the runner.""" - credential_service: Optional[BaseCredentialService] = None - """The credential service for the runner.""" - context_cache_config: Optional[ContextCacheConfig] = None - """The context cache config for the runner.""" - resumability_config: Optional[ResumabilityConfig] = None - """The resumability config for the application.""" - - def __init__( - self, - *, - app: Optional[App] = None, - app_name: Optional[str] = None, - agent: Optional[BaseAgent] = None, - plugins: Optional[List[BasePlugin]] = None, - artifact_service: Optional[BaseArtifactService] = None, - session_service: BaseSessionService, - memory_service: Optional[BaseMemoryService] = None, - credential_service: Optional[BaseCredentialService] = None, - plugin_close_timeout: float = 5.0, - ): - """Initializes the Runner. - - Developers should provide either an `app` instance or both `app_name` and - `agent`. Providing a mix of `app` and `app_name`/`agent` will result in a - `ValueError`. Providing `app` is the recommended way to create a runner. - - Args: - app: An optional `App` instance. If provided, `app_name` and `agent` - should not be specified. - app_name: The application name of the runner. Required if `app` is not - provided. - agent: The root agent to run. Required if `app` is not provided. - plugins: Deprecated. A list of plugins for the runner. Please use the - `app` argument to provide plugins instead. + """The Runner class is used to run agents. + + It manages the execution of an agent within a session, handling message + processing, event generation, and interaction with various services like + artifact storage, session management, and memory. + + Attributes: + app_name: The application name of the runner. + agent: The root agent to run. artifact_service: The artifact service for the runner. + plugin_manager: The plugin manager for the runner. session_service: The session service for the runner. memory_service: The memory service for the runner. credential_service: The credential service for the runner. - plugin_close_timeout: The timeout in seconds for plugin close methods. - - Raises: - ValueError: If `app` is provided along with `app_name` or `plugins`, or - if `app` is not provided but either `app_name` or `agent` is missing. + context_cache_config: The context cache config for the runner. + resumability_config: The resumability config for the application. """ - self.app = app - ( - self.app_name, - self.agent, - self.context_cache_config, - self.resumability_config, - plugins, - ) = self._validate_runner_params(app, app_name, agent, plugins) - self.artifact_service = artifact_service - self.session_service = session_service - self.memory_service = memory_service - self.credential_service = credential_service - self.plugin_manager = PluginManager( - plugins=plugins, close_timeout=plugin_close_timeout - ) - ( - self._agent_origin_app_name, - self._agent_origin_dir, - ) = self._infer_agent_origin(self.agent) - self._app_name_alignment_hint: Optional[str] = None - self._enforce_app_name_alignment() - - def _validate_runner_params( - self, - app: Optional[App], - app_name: Optional[str], - agent: Optional[BaseAgent], - plugins: Optional[List[BasePlugin]], - ) -> tuple[ - str, - BaseAgent, - Optional[ContextCacheConfig], - Optional[ResumabilityConfig], - Optional[List[BasePlugin]], - ]: - """Validates and extracts runner parameters. - - Args: - app: An optional `App` instance. - app_name: The application name of the runner. - agent: The root agent to run. - plugins: A list of plugins for the runner. - Returns: - A tuple containing (app_name, agent, context_cache_config, - resumability_config, plugins). + app_name: str + """The app name of the runner.""" + agent: BaseAgent + """The root agent to run.""" + artifact_service: Optional[BaseArtifactService] = None + """The artifact service for the runner.""" + plugin_manager: PluginManager + """The plugin manager for the runner.""" + session_service: BaseSessionService + """The session service for the runner.""" + memory_service: Optional[BaseMemoryService] = None + """The memory service for the runner.""" + credential_service: Optional[BaseCredentialService] = None + """The credential service for the runner.""" + context_cache_config: Optional[ContextCacheConfig] = None + """The context cache config for the runner.""" + resumability_config: Optional[ResumabilityConfig] = None + """The resumability config for the application.""" + + def __init__( + self, + *, + app: Optional[App] = None, + app_name: Optional[str] = None, + agent: Optional[BaseAgent] = None, + plugins: Optional[List[BasePlugin]] = None, + artifact_service: Optional[BaseArtifactService] = None, + session_service: BaseSessionService, + memory_service: Optional[BaseMemoryService] = None, + credential_service: Optional[BaseCredentialService] = None, + plugin_close_timeout: float = 5.0, + ): + """Initializes the Runner. + + Developers should provide either an `app` instance or both `app_name` and + `agent`. Providing a mix of `app` and `app_name`/`agent` will result in a + `ValueError`. Providing `app` is the recommended way to create a runner. + + Args: + app: An optional `App` instance. If provided, `app_name` and `agent` + should not be specified. + app_name: The application name of the runner. Required if `app` is not + provided. + agent: The root agent to run. Required if `app` is not provided. + plugins: Deprecated. A list of plugins for the runner. Please use the + `app` argument to provide plugins instead. + artifact_service: The artifact service for the runner. + session_service: The session service for the runner. + memory_service: The memory service for the runner. + credential_service: The credential service for the runner. + plugin_close_timeout: The timeout in seconds for plugin close methods. + + Raises: + ValueError: If `app` is provided along with `app_name` or `plugins`, or + if `app` is not provided but either `app_name` or `agent` is missing. + """ + self.app = app + ( + self.app_name, + self.agent, + self.context_cache_config, + self.resumability_config, + plugins, + ) = self._validate_runner_params(app, app_name, agent, plugins) + self.artifact_service = artifact_service + self.session_service = session_service + self.memory_service = memory_service + self.credential_service = credential_service + self.plugin_manager = PluginManager( + plugins=plugins, close_timeout=plugin_close_timeout + ) + ( + self._agent_origin_app_name, + self._agent_origin_dir, + ) = self._infer_agent_origin(self.agent) + self._app_name_alignment_hint: Optional[str] = None + self._enforce_app_name_alignment() + + def _validate_runner_params( + self, + app: Optional[App], + app_name: Optional[str], + agent: Optional[BaseAgent], + plugins: Optional[List[BasePlugin]], + ) -> tuple[ + str, + BaseAgent, + Optional[ContextCacheConfig], + Optional[ResumabilityConfig], + Optional[List[BasePlugin]], + ]: + """Validates and extracts runner parameters. + + Args: + app: An optional `App` instance. + app_name: The application name of the runner. + agent: The root agent to run. + plugins: A list of plugins for the runner. + + Returns: + A tuple containing (app_name, agent, context_cache_config, + resumability_config, plugins). + + Raises: + ValueError: If parameters are invalid. + """ + if plugins is not None: + warnings.warn( + "The `plugins` argument is deprecated. Please use the `app` argument" + " to provide plugins instead.", + DeprecationWarning, + ) - Raises: - ValueError: If parameters are invalid. - """ - if plugins is not None: - warnings.warn( - 'The `plugins` argument is deprecated. Please use the `app` argument' - ' to provide plugins instead.', - DeprecationWarning, - ) - - if app: - if app_name: - raise ValueError( - 'When app is provided, app_name should not be provided.' + if app: + if app_name: + raise ValueError( + "When app is provided, app_name should not be provided." + ) + if agent: + raise ValueError("When app is provided, agent should not be provided.") + if plugins: + raise ValueError( + "When app is provided, plugins should not be provided and should be" + " provided in the app instead." + ) + app_name = app.name + agent = app.root_agent + plugins = app.plugins + context_cache_config = app.context_cache_config + resumability_config = app.resumability_config + elif not app_name or not agent: + raise ValueError("Either app or both app_name and agent must be provided.") + else: + context_cache_config = None + resumability_config = None + + return app_name, agent, context_cache_config, resumability_config, plugins + + def _infer_agent_origin( + self, agent: BaseAgent + ) -> tuple[Optional[str], Optional[Path]]: + """Infer the origin app name and directory from an agent's module location. + + Returns: + A tuple of (origin_app_name, origin_path): + - origin_app_name: The inferred app name (directory name containing the + agent), or None if inference is not possible/applicable. + - origin_path: The directory path where the agent is defined, or None + if the path cannot be determined. + + Both values are None when: + - The agent has no associated module + - The agent is defined in google.adk.* (ADK internal modules) + - The module has no __file__ attribute + """ + # First, check for metadata set by AgentLoader (most reliable source). + # AgentLoader sets these attributes when loading agents. + origin_app_name = getattr(agent, "_adk_origin_app_name", None) + origin_path = getattr(agent, "_adk_origin_path", None) + if origin_app_name is not None and origin_path is not None: + return origin_app_name, origin_path + + # Fall back to heuristic inference for programmatic usage. + module = inspect.getmodule(agent.__class__) + if not module: + return None, None + + # Skip ADK internal modules. When users instantiate LlmAgent directly + # (not subclassed), inspect.getmodule() returns the ADK module. This + # could falsely match 'agents' in 'google/adk/agents/' path. + if module.__name__.startswith("google.adk."): + return None, None + + module_file = getattr(module, "__file__", None) + if not module_file: + return None, None + module_path = Path(module_file).resolve() + project_root = Path.cwd() + try: + relative_path = module_path.relative_to(project_root) + except ValueError: + return None, module_path.parent + origin_dir = module_path.parent + if "agents" not in relative_path.parts: + return None, origin_dir + origin_name = origin_dir.name + if origin_name.startswith("."): + return None, origin_dir + return origin_name, origin_dir + + def _enforce_app_name_alignment(self) -> None: + origin_name = self._agent_origin_app_name + origin_dir = self._agent_origin_dir + if not origin_name or origin_name.startswith("__"): + self._app_name_alignment_hint = None + return + if origin_name == self.app_name: + self._app_name_alignment_hint = None + return + origin_location = str(origin_dir) if origin_dir else origin_name + mismatch_details = ( + "The runner is configured with app name " + f'"{self.app_name}", but the root agent was loaded from ' + f'"{origin_location}", which implies app name "{origin_name}".' ) - if agent: - raise ValueError('When app is provided, agent should not be provided.') - if plugins: - raise ValueError( - 'When app is provided, plugins should not be provided and should be' - ' provided in the app instead.' + resolution = ( + "Ensure the runner app_name matches that directory or pass app_name " + "explicitly when constructing the runner." + ) + self._app_name_alignment_hint = f"{mismatch_details} {resolution}" + logger.warning("App name mismatch detected. %s", mismatch_details) + + def _format_session_not_found_message(self, session_id: str) -> str: + message = f"Session not found: {session_id}" + if not self._app_name_alignment_hint: + return message + return ( + f"{message}. {self._app_name_alignment_hint} " + "The mismatch prevents the runner from locating the session." ) - app_name = app.name - agent = app.root_agent - plugins = app.plugins - context_cache_config = app.context_cache_config - resumability_config = app.resumability_config - elif not app_name or not agent: - raise ValueError( - 'Either app or both app_name and agent must be provided.' - ) - else: - context_cache_config = None - resumability_config = None - - return app_name, agent, context_cache_config, resumability_config, plugins - - def _infer_agent_origin( - self, agent: BaseAgent - ) -> tuple[Optional[str], Optional[Path]]: - """Infer the origin app name and directory from an agent's module location. - - Returns: - A tuple of (origin_app_name, origin_path): - - origin_app_name: The inferred app name (directory name containing the - agent), or None if inference is not possible/applicable. - - origin_path: The directory path where the agent is defined, or None - if the path cannot be determined. - - Both values are None when: - - The agent has no associated module - - The agent is defined in google.adk.* (ADK internal modules) - - The module has no __file__ attribute - """ - # First, check for metadata set by AgentLoader (most reliable source). - # AgentLoader sets these attributes when loading agents. - origin_app_name = getattr(agent, '_adk_origin_app_name', None) - origin_path = getattr(agent, '_adk_origin_path', None) - if origin_app_name is not None and origin_path is not None: - return origin_app_name, origin_path - - # Fall back to heuristic inference for programmatic usage. - module = inspect.getmodule(agent.__class__) - if not module: - return None, None - - # Skip ADK internal modules. When users instantiate LlmAgent directly - # (not subclassed), inspect.getmodule() returns the ADK module. This - # could falsely match 'agents' in 'google/adk/agents/' path. - if module.__name__.startswith('google.adk.'): - return None, None - - module_file = getattr(module, '__file__', None) - if not module_file: - return None, None - module_path = Path(module_file).resolve() - project_root = Path.cwd() - try: - relative_path = module_path.relative_to(project_root) - except ValueError: - return None, module_path.parent - origin_dir = module_path.parent - if 'agents' not in relative_path.parts: - return None, origin_dir - origin_name = origin_dir.name - if origin_name.startswith('.'): - return None, origin_dir - return origin_name, origin_dir - - def _enforce_app_name_alignment(self) -> None: - origin_name = self._agent_origin_app_name - origin_dir = self._agent_origin_dir - if not origin_name or origin_name.startswith('__'): - self._app_name_alignment_hint = None - return - if origin_name == self.app_name: - self._app_name_alignment_hint = None - return - origin_location = str(origin_dir) if origin_dir else origin_name - mismatch_details = ( - 'The runner is configured with app name ' - f'"{self.app_name}", but the root agent was loaded from ' - f'"{origin_location}", which implies app name "{origin_name}".' - ) - resolution = ( - 'Ensure the runner app_name matches that directory or pass app_name ' - 'explicitly when constructing the runner.' - ) - self._app_name_alignment_hint = f'{mismatch_details} {resolution}' - logger.warning('App name mismatch detected. %s', mismatch_details) - def _format_session_not_found_message(self, session_id: str) -> str: - message = f'Session not found: {session_id}' - if not self._app_name_alignment_hint: - return message - return ( - f'{message}. {self._app_name_alignment_hint} ' - 'The mismatch prevents the runner from locating the session.' - ) + def run( + self, + *, + user_id: str, + session_id: str, + new_message: types.Content, + run_config: Optional[RunConfig] = None, + ) -> Generator[Event, None, None]: + """Runs the agent. + + NOTE: + This sync interface is only for local testing and convenience purpose. + Consider using `run_async` for production usage. + + If event compaction is enabled in the App configuration, it will be + performed after all agent events for the current invocation have been + yielded. The generator will only finish iterating after event + compaction is complete. + + Args: + user_id: The user ID of the session. + session_id: The session ID of the session. + new_message: A new message to append to the session. + run_config: The run config for the agent. + + Yields: + The events generated by the agent. + """ + run_config = run_config or RunConfig() + event_queue = queue.Queue() + + async def _invoke_run_async(): + try: + async with Aclosing( + self.run_async( + user_id=user_id, + session_id=session_id, + new_message=new_message, + run_config=run_config, + ) + ) as agen: + async for event in agen: + event_queue.put(event) + finally: + event_queue.put(None) + + def _asyncio_thread_main(): + try: + asyncio.run(_invoke_run_async()) + finally: + event_queue.put(None) + + thread = create_thread(target=_asyncio_thread_main) + thread.start() + + # consumes and re-yield the events from background thread. + while True: + event = event_queue.get() + if event is None: + break + else: + yield event + + thread.join() + + async def run_async( + self, + *, + user_id: str, + session_id: str, + invocation_id: Optional[str] = None, + new_message: Optional[types.Content] = None, + state_delta: Optional[dict[str, Any]] = None, + run_config: Optional[RunConfig] = None, + ) -> AsyncGenerator[Event, None]: + """Main entry method to run the agent in this runner. + + If event compaction is enabled in the App configuration, it will be + performed after all agent events for the current invocation have been + yielded. The async generator will only finish iterating after event + compaction is complete. However, this does not block new `run_async` + calls for subsequent user queries, which can be started concurrently. + + Args: + user_id: The user ID of the session. + session_id: The session ID of the session. + invocation_id: The invocation ID of the session, set this to resume an + interrupted invocation. + new_message: A new message to append to the session. + state_delta: Optional state changes to apply to the session. + run_config: The run config for the agent. + + Yields: + The events generated by the agent. + + Raises: + ValueError: If the session is not found; If both invocation_id and + new_message are None. + """ + run_config = run_config or RunConfig() + + if new_message and not new_message.role: + new_message.role = "user" + + async def _run_body( + new_message: Optional[types.Content] = None, + invocation_id: Optional[str] = None, + ) -> AsyncGenerator[Event, None]: + session = await self.session_service.get_session( + app_name=self.app_name, user_id=user_id, session_id=session_id + ) + if not session: + message = self._format_session_not_found_message(session_id) + raise ValueError(message) + if not invocation_id and not new_message: + raise ValueError( + "Running an agent requires either a new_message or an " + "invocation_id to resume a previous invocation. " + f"Session: {session_id}, User: {user_id}" + ) - def run( - self, - *, - user_id: str, - session_id: str, - new_message: types.Content, - run_config: Optional[RunConfig] = None, - ) -> Generator[Event, None, None]: - """Runs the agent. - - NOTE: - This sync interface is only for local testing and convenience purpose. - Consider using `run_async` for production usage. - - If event compaction is enabled in the App configuration, it will be - performed after all agent events for the current invocation have been - yielded. The generator will only finish iterating after event - compaction is complete. - - Args: - user_id: The user ID of the session. - session_id: The session ID of the session. - new_message: A new message to append to the session. - run_config: The run config for the agent. - - Yields: - The events generated by the agent. - """ - run_config = run_config or RunConfig() - event_queue = queue.Queue() + if invocation_id: + if ( + not self.resumability_config + or not self.resumability_config.is_resumable + ): + raise ValueError( + f"invocation_id: {invocation_id} is provided but the app is not" + " resumable." + ) + invocation_context = await self._setup_context_for_resumed_invocation( + session=session, + new_message=new_message, + invocation_id=invocation_id, + run_config=run_config, + state_delta=state_delta, + ) + if invocation_context.end_of_agents.get(invocation_context.agent.name): + # Directly return if the current agent in invocation context is + # already final. + return + else: + invocation_context = await self._setup_context_for_new_invocation( + session=session, + new_message=new_message, # new_message is not None. + run_config=run_config, + state_delta=state_delta, + ) - async def _invoke_run_async(): - try: - async with Aclosing( - self.run_async( - user_id=user_id, - session_id=session_id, - new_message=new_message, - run_config=run_config, - ) - ) as agen: - async for event in agen: - event_queue.put(event) - finally: - event_queue.put(None) - - def _asyncio_thread_main(): - try: - asyncio.run(_invoke_run_async()) - finally: - event_queue.put(None) - - thread = create_thread(target=_asyncio_thread_main) - thread.start() - - # consumes and re-yield the events from background thread. - while True: - event = event_queue.get() - if event is None: - break - else: - yield event - - thread.join() - - async def run_async( - self, - *, - user_id: str, - session_id: str, - invocation_id: Optional[str] = None, - new_message: Optional[types.Content] = None, - state_delta: Optional[dict[str, Any]] = None, - run_config: Optional[RunConfig] = None, - ) -> AsyncGenerator[Event, None]: - """Main entry method to run the agent in this runner. - - If event compaction is enabled in the App configuration, it will be - performed after all agent events for the current invocation have been - yielded. The async generator will only finish iterating after event - compaction is complete. However, this does not block new `run_async` - calls for subsequent user queries, which can be started concurrently. - - Args: - user_id: The user ID of the session. - session_id: The session ID of the session. - invocation_id: The invocation ID of the session, set this to resume an - interrupted invocation. - new_message: A new message to append to the session. - state_delta: Optional state changes to apply to the session. - run_config: The run config for the agent. - - Yields: - The events generated by the agent. - - Raises: - ValueError: If the session is not found; If both invocation_id and - new_message are None. - """ - run_config = run_config or RunConfig() + async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]: + async with Aclosing(ctx.agent.run_async(ctx)) as agen: + async for event in agen: + yield event + + async with Aclosing( + self._exec_with_plugin( + invocation_context=invocation_context, + session=session, + execute_fn=execute, + is_live_call=False, + ) + ) as agen: + async for event in agen: + yield event + # Run compaction after all events are yielded from the agent. + # (We don't compact in the middle of an invocation, we only compact at + # the end of an invocation.) + if self.app and self.app.events_compaction_config: + logger.debug("Running event compactor.") + await _run_compaction_for_sliding_window( + self.app, session, self.session_service + ) - if new_message and not new_message.role: - new_message.role = 'user' + async def _run_with_optional_trace( + agent: BaseAgent, + new_message: Optional[types.Content] = None, + invocation_id: Optional[str] = None, + ) -> AsyncGenerator[Event, None]: + if is_telemetry_enabled(agent): + with tracer.start_as_current_span("invocation"): + async with Aclosing( + _run_body(new_message=new_message, invocation_id=invocation_id) + ) as agen: + async for e in agen: + yield e + else: + async with Aclosing( + _run_body(new_message=new_message, invocation_id=invocation_id) + ) as agen: + async for e in agen: + yield e - async def _run_with_trace( - new_message: Optional[types.Content] = None, - invocation_id: Optional[str] = None, - ) -> AsyncGenerator[Event, None]: - with tracer.start_as_current_span('invocation'): + async with Aclosing( + _run_with_optional_trace(self.agent, new_message, invocation_id) + ) as agen: + async for event in agen: + yield event + + async def rewind_async( + self, + *, + user_id: str, + session_id: str, + rewind_before_invocation_id: str, + ) -> None: + """Rewinds the session to before the specified invocation.""" session = await self.session_service.get_session( app_name=self.app_name, user_id=user_id, session_id=session_id ) if not session: - message = self._format_session_not_found_message(session_id) - raise ValueError(message) - if not invocation_id and not new_message: - raise ValueError( - 'Running an agent requires either a new_message or an ' - 'invocation_id to resume a previous invocation. ' - f'Session: {session_id}, User: {user_id}' - ) - - if invocation_id: - if ( - not self.resumability_config - or not self.resumability_config.is_resumable - ): - raise ValueError( - f'invocation_id: {invocation_id} is provided but the app is not' - ' resumable.' + raise ValueError(f"Session not found: {session_id}") + + rewind_event_index = -1 + for i, event in enumerate(session.events): + if event.invocation_id == rewind_before_invocation_id: + rewind_event_index = i + break + + if rewind_event_index == -1: + raise ValueError(f"Invocation ID not found: {rewind_before_invocation_id}") + + # Compute state delta to reverse changes + state_delta = await self._compute_state_delta_for_rewind( + session, rewind_event_index + ) + + # Compute artifact delta to reverse changes + artifact_delta = await self._compute_artifact_delta_for_rewind( + session, rewind_event_index + ) + + # Create rewind event + rewind_event = Event( + invocation_id=new_invocation_context_id(), + author="user", + actions=EventActions( + rewind_before_invocation_id=rewind_before_invocation_id, + state_delta=state_delta, + artifact_delta=artifact_delta, + ), + ) + + logger.info("Rewinding session to invocation: %s", rewind_event) + + await self.session_service.append_event(session=session, event=rewind_event) + + async def _compute_state_delta_for_rewind( + self, session: Session, rewind_event_index: int + ) -> dict[str, Any]: + """Computes the state delta to reverse changes.""" + state_at_rewind_point: dict[str, Any] = {} + for i in range(rewind_event_index): + if session.events[i].actions.state_delta: + for k, v in session.events[i].actions.state_delta.items(): + if k.startswith("app:") or k.startswith("user:"): + continue + if v is None: + state_at_rewind_point.pop(k, None) + else: + state_at_rewind_point[k] = v + + current_state = session.state + rewind_state_delta = {} + + # 1. Add/update keys in rewind_state_delta to match state_at_rewind_point. + for key, value_at_rewind in state_at_rewind_point.items(): + if key not in current_state or current_state[key] != value_at_rewind: + rewind_state_delta[key] = value_at_rewind + + # 2. Set keys to None in rewind_state_delta if they are in current_state + # but not in state_at_rewind_point. These keys were added after the + # rewind point and need to be removed. + for key in current_state: + if key.startswith("app:") or key.startswith("user:"): + continue + if key not in state_at_rewind_point: + rewind_state_delta[key] = None + + return rewind_state_delta + + async def _compute_artifact_delta_for_rewind( + self, session: Session, rewind_event_index: int + ) -> dict[str, int]: + """Computes the artifact delta to reverse changes.""" + if not self.artifact_service: + return {} + + versions_at_rewind_point: dict[str, int] = {} + for i in range(rewind_event_index): + event = session.events[i] + if event.actions.artifact_delta: + versions_at_rewind_point.update(event.actions.artifact_delta) + + current_versions: dict[str, int] = {} + for event in session.events: + if event.actions.artifact_delta: + current_versions.update(event.actions.artifact_delta) + + rewind_artifact_delta = {} + for filename, vn in current_versions.items(): + if filename.startswith("user:"): + # User artifacts are not restored on rewind. + continue + vt = versions_at_rewind_point.get(filename) + if vt == vn: + continue + + rewind_artifact_delta[filename] = vn + 1 + if vt is None: + # Artifact did not exist at rewind point. Mark it as inaccessible. + artifact = types.Part( + inline_data=types.Blob( + mime_type="application/octet-stream", data=b"" + ) + ) + else: + # Artifact version changed after rewind point. Restore to version at + # rewind point. + artifact_uri = artifact_util.get_artifact_uri( + app_name=self.app_name, + user_id=session.user_id, + session_id=session.id, + filename=filename, + version=vt, + ) + artifact = types.Part(file_data=types.FileData(file_uri=artifact_uri)) + await self.artifact_service.save_artifact( + app_name=self.app_name, + user_id=session.user_id, + session_id=session.id, + filename=filename, + artifact=artifact, + ) + + return rewind_artifact_delta + + def _should_append_event(self, event: Event, is_live_call: bool) -> bool: + """Checks if an event should be appended to the session.""" + # Don't append audio response from model in live mode to session. + # The data is appended to artifacts with a reference in file_data in the + # event. + # We should append non-partial events only.For example, non-finished(partial) + # transcription events should not be appended. + # Function call and function response events should be appended. + # Other control events should be appended. + if is_live_call and contents._is_live_model_audio_event_with_inline_data(event): + # We don't append live model audio events with inline data to avoid + # storing large blobs in the session. However, events with file_data + # (references to artifacts) should be appended. + return False + return True + + async def _exec_with_plugin( + self, + invocation_context: InvocationContext, + session: Session, + execute_fn: Callable[[InvocationContext], AsyncGenerator[Event, None]], + is_live_call: bool = False, + ) -> AsyncGenerator[Event, None]: + """Wraps execution with plugin callbacks. + + Args: + invocation_context: The invocation context + session: The current session + execute_fn: A callable that returns an AsyncGenerator of Events + is_live_call: Whether this is a live call + + Yields: + Events from the execution, including any generated by plugins + """ + + plugin_manager = invocation_context.plugin_manager + + # Step 1: Run the before_run callbacks to see if we should early exit. + early_exit_result = await plugin_manager.run_before_run_callback( + invocation_context=invocation_context + ) + if isinstance(early_exit_result, types.Content): + early_exit_event = Event( + invocation_id=invocation_context.invocation_id, + author="model", + content=early_exit_result, + ) + if self._should_append_event(early_exit_event, is_live_call): + await self.session_service.append_event( + session=session, + event=early_exit_event, + ) + yield early_exit_event + else: + # Step 2: Otherwise continue with normal execution + # Note for live/bidi: + # the transcription may arrive later then the action(function call + # event and thus function response event). In this case, the order of + # transcription and function call event will be wrong if we just + # append as it arrives. To address this, we should check if there is + # transcription going on. If there is transcription going on, we + # should hold on appending the function call event until the + # transcription is finished. The transcription in progress can be + # identified by checking if the transcription event is partial. When + # the next transcription event is not partial, it means the previous + # transcription is finished. Then if there is any buffered function + # call event, we should append them after this finished(non-parital) + # transcription event. + buffered_events: list[Event] = [] + is_transcribing: bool = False + + async with Aclosing(execute_fn(invocation_context)) as agen: + async for event in agen: + if is_live_call: + if event.partial and _is_transcription(event): + is_transcribing = True + if is_transcribing and _is_tool_call_or_response(event): + # only buffer function call and function response event which is + # non-partial + buffered_events.append(event) + continue + # Note for live/bidi: for audio response, it's considered as + # non-paritla event(event.partial=None) + # event.partial=False and event.partial=None are considered as + # non-partial event; event.partial=True is considered as partial + # event. + if event.partial is not True: + if _is_transcription(event) and ( + _has_non_empty_transcription_text( + event.input_transcription + ) + or _has_non_empty_transcription_text( + event.output_transcription + ) + ): + # transcription end signal, append buffered events + is_transcribing = False + logger.debug( + "Appending transcription finished event: %s", event + ) + if self._should_append_event(event, is_live_call): + await self.session_service.append_event( + session=session, event=event + ) + + for buffered_event in buffered_events: + logger.debug( + "Appending buffered event: %s", buffered_event + ) + await self.session_service.append_event( + session=session, event=buffered_event + ) + buffered_events = [] + else: + # non-transcription event or empty transcription event, for + # example, event that stores blob reference, should be appended. + if self._should_append_event(event, is_live_call): + logger.debug( + "Appending non-buffered event: %s", event + ) + await self.session_service.append_event( + session=session, event=event + ) + else: + if event.partial is not True: + await self.session_service.append_event( + session=session, event=event + ) + + # Step 3: Run the on_event callbacks to optionally modify the event. + modified_event = await plugin_manager.run_on_event_callback( + invocation_context=invocation_context, event=event + ) + yield (modified_event if modified_event else event) + + # Step 4: Run the after_run callbacks to perform global cleanup tasks or + # finalizing logs and metrics data. + # This does NOT emit any event. + await plugin_manager.run_after_run_callback( + invocation_context=invocation_context + ) + + async def _append_new_message_to_session( + self, + *, + session: Session, + new_message: types.Content, + invocation_context: InvocationContext, + save_input_blobs_as_artifacts: bool = False, + state_delta: Optional[dict[str, Any]] = None, + ): + """Appends a new message to the session. + + Args: + session: The session to append the message to. + new_message: The new message to append. + invocation_context: The invocation context for the message. + save_input_blobs_as_artifacts: Whether to save input blobs as artifacts. + state_delta: Optional state changes to apply to the session. + """ + if not new_message.parts: + raise ValueError("No parts in the new_message.") + + if self.artifact_service and save_input_blobs_as_artifacts: + # Issue deprecation warning + warnings.warn( + "The 'save_input_blobs_as_artifacts' parameter is deprecated. Use" + " SaveFilesAsArtifactsPlugin instead for better control and" + " flexibility. See google.adk.plugins.SaveFilesAsArtifactsPlugin for" + " migration guidance.", + DeprecationWarning, + stacklevel=3, + ) + # The runner directly saves the artifacts (if applicable) in the + # user message and replaces the artifact data with a file name + # placeholder. + for i, part in enumerate(new_message.parts): + if part.inline_data is None: + continue + file_name = f"artifact_{invocation_context.invocation_id}_{i}" + await self.artifact_service.save_artifact( + app_name=self.app_name, + user_id=session.user_id, + session_id=session.id, + filename=file_name, + artifact=part, + ) + new_message.parts[i] = types.Part( + text=f"Uploaded file: {file_name}. It is saved into artifacts" + ) + # Appends only. We do not yield the event because it's not from the model. + if state_delta: + event = Event( + invocation_id=invocation_context.invocation_id, + author="user", + actions=EventActions(state_delta=state_delta), + content=new_message, ) - invocation_context = await self._setup_context_for_resumed_invocation( - session=session, - new_message=new_message, - invocation_id=invocation_id, - run_config=run_config, - state_delta=state_delta, - ) - if invocation_context.end_of_agents.get( - invocation_context.agent.name - ): - # Directly return if the current agent in invocation context is - # already final. - return else: - invocation_context = await self._setup_context_for_new_invocation( - session=session, - new_message=new_message, # new_message is not None. - run_config=run_config, - state_delta=state_delta, - ) + event = Event( + invocation_id=invocation_context.invocation_id, + author="user", + content=new_message, + ) + # If new_message is a function response, find the matching function call + # and use its branch as the new event's branch. + if function_call := invocation_context._find_matching_function_call(event): + event.branch = function_call.branch + + await self.session_service.append_event(session=session, event=event) + + async def run_live( + self, + *, + user_id: Optional[str] = None, + session_id: Optional[str] = None, + live_request_queue: LiveRequestQueue, + run_config: Optional[RunConfig] = None, + session: Optional[Session] = None, + ) -> AsyncGenerator[Event, None]: + """Runs the agent in live mode (experimental feature). + + The `run_live` method yields a stream of `Event` objects, but not all + yielded events are saved to the session. Here's a breakdown: + + **Events Yielded to Callers:** + * **Live Model Audio Events with Inline Data:** Events containing raw + audio `Blob` data(`inline_data`). + * **Live Model Audio Events with File Data:** Both input and ouput audio + data are aggregated into a audio file saved into artifacts. The + reference to the file is saved in the event as `file_data`. + * **Usage Metadata:** Events containing token usage. + * **Transcription Events:** Both partial and non-partial transcription + events are yielded. + * **Function Call and Response Events:** Always saved. + * **Other Control Events:** Most control events are saved. + + **Events Saved to the Session:** + * **Live Model Audio Events with File Data:** Both input and ouput audio + data are aggregated into a audio file saved into artifacts. The + reference to the file is saved as event in the `file_data` to session + if RunConfig.save_live_model_audio_to_session is True. + * **Usage Metadata Events:** Saved to the session. + * **Non-Partial Transcription Events:** Non-partial transcription events + are saved. + * **Function Call and Response Events:** Always saved. + * **Other Control Events:** Most control events are saved. + + **Events Not Saved to the Session:** + * **Live Model Audio Events with Inline Data:** Events containing raw + audio `Blob` data are **not** saved to the session. + + Args: + user_id: The user ID for the session. Required if `session` is None. + session_id: The session ID for the session. Required if `session` is + None. + live_request_queue: The queue for live requests. + run_config: The run config for the agent. + session: The session to use. This parameter is deprecated, please use + `user_id` and `session_id` instead. + + Yields: + AsyncGenerator[Event, None]: An asynchronous generator that yields + `Event` + objects as they are produced by the agent during its live execution. + + .. warning:: + This feature is **experimental** and its API or behavior may change + in future releases. + + .. NOTE:: + Either `session` or both `user_id` and `session_id` must be provided. + """ + run_config = run_config or RunConfig() + # Some native audio models requires the modality to be set. So we set it to + # AUDIO by default. + if run_config.response_modalities is None: + run_config.response_modalities = ["AUDIO"] + if session is None and (user_id is None or session_id is None): + raise ValueError( + "Either session or user_id and session_id must be provided." + ) + if session is not None: + warnings.warn( + "The `session` parameter is deprecated. Please use `user_id` and" + " `session_id` instead.", + DeprecationWarning, + stacklevel=2, + ) + if not session: + session = await self.session_service.get_session( + app_name=self.app_name, user_id=user_id, session_id=session_id + ) + if not session: + raise ValueError(f"Session not found: {session_id}") + invocation_context = self._new_invocation_context_for_live( + session, + live_request_queue=live_request_queue, + run_config=run_config, + ) + + root_agent = self.agent + invocation_context.agent = self._find_agent_to_run(session, root_agent) + + # Pre-processing for live streaming tools + # Inspect the tool's parameters to find if it uses LiveRequestQueue + invocation_context.active_streaming_tools = {} + # TODO(hangfei): switch to use canonical_tools. + # for shell agents, there is no tools associated with it so we should skip. + if hasattr(invocation_context.agent, "tools"): + import inspect + + for tool in invocation_context.agent.tools: + # We use `inspect.signature()` to examine the tool's underlying function (`tool.func`). + # This approach is deliberately chosen over `typing.get_type_hints()` for robustness. + # + # The Problem with `get_type_hints()`: + # `get_type_hints()` attempts to resolve forward-referenced (string-based) type + # annotations. This resolution can easily fail with a `NameError` (e.g., "Union not found") + # if the type isn't available in the scope where `get_type_hints()` is called. + # This is a common and brittle issue in framework code that inspects functions + # defined in separate user modules. + # + # Why `inspect.signature()` is Better Here: + # `inspect.signature()` does NOT resolve the annotations; it retrieves the raw + # annotation object as it was defined on the function. This allows us to + # perform a direct and reliable identity check (`param.annotation is LiveRequestQueue`) + # without risking a `NameError`. + callable_to_inspect = tool.func if hasattr(tool, "func") else tool + # Ensure the target is actually callable before inspecting to avoid errors. + if not callable(callable_to_inspect): + continue + for param in inspect.signature(callable_to_inspect).parameters.values(): + if param.annotation is LiveRequestQueue: + if not invocation_context.active_streaming_tools: + invocation_context.active_streaming_tools = {} + active_streaming_tool = ActiveStreamingTool( + stream=LiveRequestQueue() + ) + invocation_context.active_streaming_tools[tool.__name__] = ( + active_streaming_tool + ) async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]: - async with Aclosing(ctx.agent.run_async(ctx)) as agen: - async for event in agen: - yield event + async with Aclosing(ctx.agent.run_live(ctx)) as agen: + async for event in agen: + yield event async with Aclosing( self._exec_with_plugin( invocation_context=invocation_context, session=session, execute_fn=execute, - is_live_call=False, + is_live_call=True, ) ) as agen: - async for event in agen: - yield event - # Run compaction after all events are yielded from the agent. - # (We don't compact in the middle of an invocation, we only compact at - # the end of an invocation.) - if self.app and self.app.events_compaction_config: - logger.debug('Running event compactor.') - await _run_compaction_for_sliding_window( - self.app, session, self.session_service - ) - - async with Aclosing(_run_with_trace(new_message, invocation_id)) as agen: - async for event in agen: - yield event - - async def rewind_async( - self, - *, - user_id: str, - session_id: str, - rewind_before_invocation_id: str, - ) -> None: - """Rewinds the session to before the specified invocation.""" - session = await self.session_service.get_session( - app_name=self.app_name, user_id=user_id, session_id=session_id - ) - if not session: - raise ValueError(f'Session not found: {session_id}') - - rewind_event_index = -1 - for i, event in enumerate(session.events): - if event.invocation_id == rewind_before_invocation_id: - rewind_event_index = i - break - - if rewind_event_index == -1: - raise ValueError( - f'Invocation ID not found: {rewind_before_invocation_id}' - ) - - # Compute state delta to reverse changes - state_delta = await self._compute_state_delta_for_rewind( - session, rewind_event_index - ) + async for event in agen: + yield event + + def _find_agent_to_run(self, session: Session, root_agent: BaseAgent) -> BaseAgent: + """Finds the agent to run to continue the session. + + A qualified agent must be either of: + + - The agent that returned a function call and the last user message is a + function response to this function call. + - The root agent. + - An LlmAgent who replied last and is capable to transfer to any other agent + in the agent hierarchy. + + Args: + session: The session to find the agent for. + root_agent: The root agent of the runner. + + Returns: + The agent to run. (the active agent that should reply to the latest user + message) + """ + # If the last event is a function response, should send this response to + # the agent that returned the corresponding function call regardless the + # type of the agent. e.g. a remote a2a agent may surface a credential + # request as a special long running function tool call. + event = find_matching_function_call(session.events) + if event and event.author: + return root_agent.find_agent(event.author) + + def _event_filter(event: Event) -> bool: + """Filters out user-authored events and agent state change events.""" + if event.author == "user": + return False + if event.actions.agent_state is not None or event.actions.end_of_agent: + return False + return True + + for event in filter(_event_filter, reversed(session.events)): + if event.author == root_agent.name: + # Found root agent. + return root_agent + if not (agent := root_agent.find_sub_agent(event.author)): + # Agent not found, continue looking. + logger.warning( + "Event from an unknown agent: %s, event id: %s", + event.author, + event.id, + ) + continue + if self._is_transferable_across_agent_tree(agent): + return agent + # Falls back to root agent if no suitable agents are found in the session. + return root_agent - # Compute artifact delta to reverse changes - artifact_delta = await self._compute_artifact_delta_for_rewind( - session, rewind_event_index - ) + def _is_transferable_across_agent_tree(self, agent_to_run: BaseAgent) -> bool: + """Whether the agent to run can transfer to any other agent in the agent tree. + + This typically means all agent_to_run's ancestor can transfer to their + parent_agent all the way to the root_agent. + + Args: + agent_to_run: The agent to check for transferability. + + Returns: + True if the agent can transfer, False otherwise. + """ + agent = agent_to_run + while agent: + if not isinstance(agent, LlmAgent): + # Only LLM-based Agent can provide agent transfer capability. + return False + if agent.disallow_transfer_to_parent: + return False + agent = agent.parent_agent + return True + + async def run_debug( + self, + user_messages: str | list[str], + *, + user_id: str = "debug_user_id", + session_id: str = "debug_session_id", + run_config: RunConfig | None = None, + quiet: bool = False, + verbose: bool = False, + ) -> list[Event]: + """Debug helper for quick agent experimentation and testing. + + This convenience method is designed for developers getting started with ADK + who want to quickly test agents without dealing with session management, + content formatting, or event streaming. It automatically handles common + boilerplate while hiding complexity. + + IMPORTANT: This is for debugging and experimentation only. For production + use, please use the standard run_async() method which provides full control + over session management, event streaming, and error handling. + + Args: + user_messages: Message(s) to send to the agent. Can be: - Single string: + "What is 2+2?" - List of strings: ["Hello!", "What's my name?"] + user_id: User identifier. Defaults to "debug_user_id". + session_id: Session identifier for conversation persistence. Defaults to + "debug_session_id". Reuse the same ID to continue a conversation. + run_config: Optional configuration for the agent execution. + quiet: If True, suppresses console output. Defaults to False (output + shown). + verbose: If True, shows detailed tool calls and responses. Defaults to + False for cleaner output showing only final agent responses. + + Returns: + list[Event]: All events from all messages. + + Raises: + ValueError: If session creation/retrieval fails. + + Examples: + Quick debugging: + >>> runner = InMemoryRunner(agent=my_agent) + >>> await runner.run_debug("What is 2+2?") + + Multiple queries in conversation: + >>> await runner.run_debug(["Hello!", "What's my name?"]) + + Continue a debug session: + >>> await runner.run_debug("What did we discuss?") # Continues default + session + + Separate debug sessions: + >>> await runner.run_debug("Hi", user_id="alice", session_id="debug1") + >>> await runner.run_debug("Hi", user_id="bob", session_id="debug2") + + Capture events for inspection: + >>> events = await runner.run_debug("Analyze this") + >>> for event in events: + ... inspect_event(event) + + Note: + For production applications requiring: + - Custom session/memory services (Spanner, Cloud SQL, etc.) + - Fine-grained event processing and streaming + - Error recovery and resumability + - Performance optimization + Please use run_async() with proper configuration. + """ + session = await self.session_service.get_session( + app_name=self.app_name, user_id=user_id, session_id=session_id + ) + if not session: + session = await self.session_service.create_session( + app_name=self.app_name, user_id=user_id, session_id=session_id + ) + if not quiet: + print(f"\n ### Created new session: {session_id}") + elif not quiet: + print(f"\n ### Continue session: {session_id}") - # Create rewind event - rewind_event = Event( - invocation_id=new_invocation_context_id(), - author='user', - actions=EventActions( - rewind_before_invocation_id=rewind_before_invocation_id, - state_delta=state_delta, - artifact_delta=artifact_delta, - ), - ) + collected_events: list[Event] = [] - logger.info('Rewinding session to invocation: %s', rewind_event) - - await self.session_service.append_event(session=session, event=rewind_event) - - async def _compute_state_delta_for_rewind( - self, session: Session, rewind_event_index: int - ) -> dict[str, Any]: - """Computes the state delta to reverse changes.""" - state_at_rewind_point: dict[str, Any] = {} - for i in range(rewind_event_index): - if session.events[i].actions.state_delta: - for k, v in session.events[i].actions.state_delta.items(): - if k.startswith('app:') or k.startswith('user:'): - continue - if v is None: - state_at_rewind_point.pop(k, None) - else: - state_at_rewind_point[k] = v - - current_state = session.state - rewind_state_delta = {} - - # 1. Add/update keys in rewind_state_delta to match state_at_rewind_point. - for key, value_at_rewind in state_at_rewind_point.items(): - if key not in current_state or current_state[key] != value_at_rewind: - rewind_state_delta[key] = value_at_rewind - - # 2. Set keys to None in rewind_state_delta if they are in current_state - # but not in state_at_rewind_point. These keys were added after the - # rewind point and need to be removed. - for key in current_state: - if key.startswith('app:') or key.startswith('user:'): - continue - if key not in state_at_rewind_point: - rewind_state_delta[key] = None - - return rewind_state_delta - - async def _compute_artifact_delta_for_rewind( - self, session: Session, rewind_event_index: int - ) -> dict[str, int]: - """Computes the artifact delta to reverse changes.""" - if not self.artifact_service: - return {} - - versions_at_rewind_point: dict[str, int] = {} - for i in range(rewind_event_index): - event = session.events[i] - if event.actions.artifact_delta: - versions_at_rewind_point.update(event.actions.artifact_delta) - - current_versions: dict[str, int] = {} - for event in session.events: - if event.actions.artifact_delta: - current_versions.update(event.actions.artifact_delta) - - rewind_artifact_delta = {} - for filename, vn in current_versions.items(): - if filename.startswith('user:'): - # User artifacts are not restored on rewind. - continue - vt = versions_at_rewind_point.get(filename) - if vt == vn: - continue - - rewind_artifact_delta[filename] = vn + 1 - if vt is None: - # Artifact did not exist at rewind point. Mark it as inaccessible. - artifact = types.Part( - inline_data=types.Blob( - mime_type='application/octet-stream', data=b'' - ) - ) - else: - # Artifact version changed after rewind point. Restore to version at - # rewind point. - artifact_uri = artifact_util.get_artifact_uri( - app_name=self.app_name, - user_id=session.user_id, - session_id=session.id, - filename=filename, - version=vt, - ) - artifact = types.Part(file_data=types.FileData(file_uri=artifact_uri)) - await self.artifact_service.save_artifact( - app_name=self.app_name, - user_id=session.user_id, - session_id=session.id, - filename=filename, - artifact=artifact, - ) - - return rewind_artifact_delta - - def _should_append_event(self, event: Event, is_live_call: bool) -> bool: - """Checks if an event should be appended to the session.""" - # Don't append audio response from model in live mode to session. - # The data is appended to artifacts with a reference in file_data in the - # event. - # We should append non-partial events only.For example, non-finished(partial) - # transcription events should not be appended. - # Function call and function response events should be appended. - # Other control events should be appended. - if is_live_call and contents._is_live_model_audio_event_with_inline_data( - event - ): - # We don't append live model audio events with inline data to avoid - # storing large blobs in the session. However, events with file_data - # (references to artifacts) should be appended. - return False - return True - - async def _exec_with_plugin( - self, - invocation_context: InvocationContext, - session: Session, - execute_fn: Callable[[InvocationContext], AsyncGenerator[Event, None]], - is_live_call: bool = False, - ) -> AsyncGenerator[Event, None]: - """Wraps execution with plugin callbacks. - - Args: - invocation_context: The invocation context - session: The current session - execute_fn: A callable that returns an AsyncGenerator of Events - is_live_call: Whether this is a live call - - Yields: - Events from the execution, including any generated by plugins - """ + if isinstance(user_messages, str): + user_messages = [user_messages] - plugin_manager = invocation_context.plugin_manager + for message in user_messages: + if not quiet: + print(f"\nUser > {message}") - # Step 1: Run the before_run callbacks to see if we should early exit. - early_exit_result = await plugin_manager.run_before_run_callback( - invocation_context=invocation_context - ) - if isinstance(early_exit_result, types.Content): - early_exit_event = Event( - invocation_id=invocation_context.invocation_id, - author='model', - content=early_exit_result, - ) - if self._should_append_event(early_exit_event, is_live_call): - await self.session_service.append_event( - session=session, - event=early_exit_event, + async for event in self.run_async( + user_id=user_id, + session_id=session.id, + new_message=types.UserContent(parts=[types.Part(text=message)]), + run_config=run_config, + ): + if not quiet: + print_event(event, verbose=verbose) + + collected_events.append(event) + + return collected_events + + async def _setup_context_for_new_invocation( + self, + *, + session: Session, + new_message: types.Content, + run_config: RunConfig, + state_delta: Optional[dict[str, Any]], + ) -> InvocationContext: + """Sets up the context for a new invocation. + + Args: + session: The session to set up the invocation context for. + new_message: The new message to process and append to the session. + run_config: The run config of the agent. + state_delta: Optional state changes to apply to the session. + + Returns: + The invocation context for the new invocation. + """ + # Step 1: Create invocation context in memory. + invocation_context = self._new_invocation_context( + session, + new_message=new_message, + run_config=run_config, ) - yield early_exit_event - else: - # Step 2: Otherwise continue with normal execution - # Note for live/bidi: - # the transcription may arrive later then the action(function call - # event and thus function response event). In this case, the order of - # transcription and function call event will be wrong if we just - # append as it arrives. To address this, we should check if there is - # transcription going on. If there is transcription going on, we - # should hold on appending the function call event until the - # transcription is finished. The transcription in progress can be - # identified by checking if the transcription event is partial. When - # the next transcription event is not partial, it means the previous - # transcription is finished. Then if there is any buffered function - # call event, we should append them after this finished(non-parital) - # transcription event. - buffered_events: list[Event] = [] - is_transcribing: bool = False - - async with Aclosing(execute_fn(invocation_context)) as agen: - async for event in agen: - if is_live_call: - if event.partial and _is_transcription(event): - is_transcribing = True - if is_transcribing and _is_tool_call_or_response(event): - # only buffer function call and function response event which is - # non-partial - buffered_events.append(event) - continue - # Note for live/bidi: for audio response, it's considered as - # non-paritla event(event.partial=None) - # event.partial=False and event.partial=None are considered as - # non-partial event; event.partial=True is considered as partial - # event. - if event.partial is not True: - if _is_transcription(event) and ( - _has_non_empty_transcription_text(event.input_transcription) - or _has_non_empty_transcription_text( - event.output_transcription - ) - ): - # transcription end signal, append buffered events - is_transcribing = False - logger.debug( - 'Appending transcription finished event: %s', event - ) - if self._should_append_event(event, is_live_call): - await self.session_service.append_event( - session=session, event=event - ) - - for buffered_event in buffered_events: - logger.debug('Appending buffered event: %s', buffered_event) - await self.session_service.append_event( - session=session, event=buffered_event - ) - buffered_events = [] - else: - # non-transcription event or empty transcription event, for - # example, event that stores blob reference, should be appended. - if self._should_append_event(event, is_live_call): - logger.debug('Appending non-buffered event: %s', event) - await self.session_service.append_event( - session=session, event=event - ) - else: - if event.partial is not True: - await self.session_service.append_event( - session=session, event=event - ) - - # Step 3: Run the on_event callbacks to optionally modify the event. - modified_event = await plugin_manager.run_on_event_callback( - invocation_context=invocation_context, event=event - ) - yield (modified_event if modified_event else event) - - # Step 4: Run the after_run callbacks to perform global cleanup tasks or - # finalizing logs and metrics data. - # This does NOT emit any event. - await plugin_manager.run_after_run_callback( - invocation_context=invocation_context - ) - - async def _append_new_message_to_session( - self, - *, - session: Session, - new_message: types.Content, - invocation_context: InvocationContext, - save_input_blobs_as_artifacts: bool = False, - state_delta: Optional[dict[str, Any]] = None, - ): - """Appends a new message to the session. - - Args: - session: The session to append the message to. - new_message: The new message to append. - invocation_context: The invocation context for the message. - save_input_blobs_as_artifacts: Whether to save input blobs as artifacts. - state_delta: Optional state changes to apply to the session. - """ - if not new_message.parts: - raise ValueError('No parts in the new_message.') - - if self.artifact_service and save_input_blobs_as_artifacts: - # Issue deprecation warning - warnings.warn( - "The 'save_input_blobs_as_artifacts' parameter is deprecated. Use" - ' SaveFilesAsArtifactsPlugin instead for better control and' - ' flexibility. See google.adk.plugins.SaveFilesAsArtifactsPlugin for' - ' migration guidance.', - DeprecationWarning, - stacklevel=3, - ) - # The runner directly saves the artifacts (if applicable) in the - # user message and replaces the artifact data with a file name - # placeholder. - for i, part in enumerate(new_message.parts): - if part.inline_data is None: - continue - file_name = f'artifact_{invocation_context.invocation_id}_{i}' - await self.artifact_service.save_artifact( - app_name=self.app_name, - user_id=session.user_id, - session_id=session.id, - filename=file_name, - artifact=part, + # Step 2: Handle new message, by running callbacks and appending to + # session. + await self._handle_new_message( + session=session, + new_message=new_message, + invocation_context=invocation_context, + run_config=run_config, + state_delta=state_delta, ) - new_message.parts[i] = types.Part( - text=f'Uploaded file: {file_name}. It is saved into artifacts' + # Step 3: Set agent to run for the invocation. + invocation_context.agent = self._find_agent_to_run(session, self.agent) + return invocation_context + + async def _setup_context_for_resumed_invocation( + self, + *, + session: Session, + new_message: Optional[types.Content], + invocation_id: Optional[str], + run_config: RunConfig, + state_delta: Optional[dict[str, Any]], + ) -> InvocationContext: + """Sets up the context for a resumed invocation. + + Args: + session: The session to set up the invocation context for. + new_message: The new message to process and append to the session. + invocation_id: The invocation id to resume. + run_config: The run config of the agent. + state_delta: Optional state changes to apply to the session. + + Returns: + The invocation context for the resumed invocation. + + Raises: + ValueError: If the session has no events to resume; If no user message is + available for resuming the invocation; Or if the app is not resumable. + """ + if not session.events: + raise ValueError(f"Session {session.id} has no events to resume.") + + # Step 1: Maybe retrieve a previous user message for the invocation. + user_message = new_message or self._find_user_message_for_invocation( + session.events, invocation_id ) - # Appends only. We do not yield the event because it's not from the model. - if state_delta: - event = Event( - invocation_id=invocation_context.invocation_id, - author='user', - actions=EventActions(state_delta=state_delta), - content=new_message, - ) - else: - event = Event( - invocation_id=invocation_context.invocation_id, - author='user', - content=new_message, - ) - # If new_message is a function response, find the matching function call - # and use its branch as the new event's branch. - if function_call := invocation_context._find_matching_function_call(event): - event.branch = function_call.branch - - await self.session_service.append_event(session=session, event=event) - - async def run_live( - self, - *, - user_id: Optional[str] = None, - session_id: Optional[str] = None, - live_request_queue: LiveRequestQueue, - run_config: Optional[RunConfig] = None, - session: Optional[Session] = None, - ) -> AsyncGenerator[Event, None]: - """Runs the agent in live mode (experimental feature). - - The `run_live` method yields a stream of `Event` objects, but not all - yielded events are saved to the session. Here's a breakdown: - - **Events Yielded to Callers:** - * **Live Model Audio Events with Inline Data:** Events containing raw - audio `Blob` data(`inline_data`). - * **Live Model Audio Events with File Data:** Both input and ouput audio - data are aggregated into a audio file saved into artifacts. The - reference to the file is saved in the event as `file_data`. - * **Usage Metadata:** Events containing token usage. - * **Transcription Events:** Both partial and non-partial transcription - events are yielded. - * **Function Call and Response Events:** Always saved. - * **Other Control Events:** Most control events are saved. - - **Events Saved to the Session:** - * **Live Model Audio Events with File Data:** Both input and ouput audio - data are aggregated into a audio file saved into artifacts. The - reference to the file is saved as event in the `file_data` to session - if RunConfig.save_live_model_audio_to_session is True. - * **Usage Metadata Events:** Saved to the session. - * **Non-Partial Transcription Events:** Non-partial transcription events - are saved. - * **Function Call and Response Events:** Always saved. - * **Other Control Events:** Most control events are saved. - - **Events Not Saved to the Session:** - * **Live Model Audio Events with Inline Data:** Events containing raw - audio `Blob` data are **not** saved to the session. - - Args: - user_id: The user ID for the session. Required if `session` is None. - session_id: The session ID for the session. Required if `session` is - None. - live_request_queue: The queue for live requests. - run_config: The run config for the agent. - session: The session to use. This parameter is deprecated, please use - `user_id` and `session_id` instead. - - Yields: - AsyncGenerator[Event, None]: An asynchronous generator that yields - `Event` - objects as they are produced by the agent during its live execution. - - .. warning:: - This feature is **experimental** and its API or behavior may change - in future releases. - - .. NOTE:: - Either `session` or both `user_id` and `session_id` must be provided. - """ - run_config = run_config or RunConfig() - # Some native audio models requires the modality to be set. So we set it to - # AUDIO by default. - if run_config.response_modalities is None: - run_config.response_modalities = ['AUDIO'] - if session is None and (user_id is None or session_id is None): - raise ValueError( - 'Either session or user_id and session_id must be provided.' - ) - if session is not None: - warnings.warn( - 'The `session` parameter is deprecated. Please use `user_id` and' - ' `session_id` instead.', - DeprecationWarning, - stacklevel=2, - ) - if not session: - session = await self.session_service.get_session( - app_name=self.app_name, user_id=user_id, session_id=session_id - ) - if not session: - raise ValueError(f'Session not found: {session_id}') - invocation_context = self._new_invocation_context_for_live( - session, - live_request_queue=live_request_queue, - run_config=run_config, - ) - - root_agent = self.agent - invocation_context.agent = self._find_agent_to_run(session, root_agent) - - # Pre-processing for live streaming tools - # Inspect the tool's parameters to find if it uses LiveRequestQueue - invocation_context.active_streaming_tools = {} - # TODO(hangfei): switch to use canonical_tools. - # for shell agents, there is no tools associated with it so we should skip. - if hasattr(invocation_context.agent, 'tools'): - import inspect - - for tool in invocation_context.agent.tools: - # We use `inspect.signature()` to examine the tool's underlying function (`tool.func`). - # This approach is deliberately chosen over `typing.get_type_hints()` for robustness. - # - # The Problem with `get_type_hints()`: - # `get_type_hints()` attempts to resolve forward-referenced (string-based) type - # annotations. This resolution can easily fail with a `NameError` (e.g., "Union not found") - # if the type isn't available in the scope where `get_type_hints()` is called. - # This is a common and brittle issue in framework code that inspects functions - # defined in separate user modules. - # - # Why `inspect.signature()` is Better Here: - # `inspect.signature()` does NOT resolve the annotations; it retrieves the raw - # annotation object as it was defined on the function. This allows us to - # perform a direct and reliable identity check (`param.annotation is LiveRequestQueue`) - # without risking a `NameError`. - callable_to_inspect = tool.func if hasattr(tool, 'func') else tool - # Ensure the target is actually callable before inspecting to avoid errors. - if not callable(callable_to_inspect): - continue - for param in inspect.signature(callable_to_inspect).parameters.values(): - if param.annotation is LiveRequestQueue: - if not invocation_context.active_streaming_tools: - invocation_context.active_streaming_tools = {} - active_streaming_tool = ActiveStreamingTool( - stream=LiveRequestQueue() + if not user_message: + raise ValueError( + f"No user message available for resuming invocation: {invocation_id}" ) - invocation_context.active_streaming_tools[tool.__name__] = ( - active_streaming_tool + # Step 2: Create invocation context. + invocation_context = self._new_invocation_context( + session, + new_message=user_message, + run_config=run_config, + invocation_id=invocation_id, + ) + # Step 3: Maybe handle new message. + if new_message: + await self._handle_new_message( + session=session, + new_message=user_message, + invocation_context=invocation_context, + run_config=run_config, + state_delta=state_delta, ) - - async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]: - async with Aclosing(ctx.agent.run_live(ctx)) as agen: - async for event in agen: - yield event - - async with Aclosing( - self._exec_with_plugin( - invocation_context=invocation_context, + # Step 4: Populate agent states for the current invocation. + invocation_context.populate_invocation_agent_states() + # Step 5: Set agent to run for the invocation. + # + # If the root agent is not found in end_of_agents, it means the invocation + # started from a sub-agent and paused on a sub-agent. + # We should find the appropriate agent to run to continue the invocation. + if self.agent.name not in invocation_context.end_of_agents: + invocation_context.agent = self._find_agent_to_run(session, self.agent) + return invocation_context + + def _find_user_message_for_invocation( + self, events: list[Event], invocation_id: str + ) -> Optional[types.Content]: + """Finds the user message that started a specific invocation.""" + for event in events: + if ( + event.invocation_id == invocation_id + and event.author == "user" + and event.content + and event.content.parts + and event.content.parts[0].text + ): + return event.content + return None + + def _new_invocation_context( + self, + session: Session, + *, + invocation_id: Optional[str] = None, + new_message: Optional[types.Content] = None, + live_request_queue: Optional[LiveRequestQueue] = None, + run_config: Optional[RunConfig] = None, + ) -> InvocationContext: + """Creates a new invocation context. + + Args: + session: The session for the context. + invocation_id: The invocation id for the context. + new_message: The new message for the context. + live_request_queue: The live request queue for the context. + run_config: The run config for the context. + + Returns: + The new invocation context. + """ + run_config = run_config or RunConfig() + invocation_id = invocation_id or new_invocation_context_id() + + if run_config.support_cfc and isinstance(self.agent, LlmAgent): + model_name = self.agent.canonical_model.model + if not model_name.startswith("gemini-2"): + raise ValueError( + f"CFC is not supported for model: {model_name} in agent:" + f" {self.agent.name}" + ) + if not isinstance(self.agent.code_executor, BuiltInCodeExecutor): + self.agent.code_executor = BuiltInCodeExecutor() + + return InvocationContext( + artifact_service=self.artifact_service, + session_service=self.session_service, + memory_service=self.memory_service, + credential_service=self.credential_service, + plugin_manager=self.plugin_manager, + context_cache_config=self.context_cache_config, + invocation_id=invocation_id, + agent=self.agent, session=session, - execute_fn=execute, - is_live_call=True, + user_content=new_message, + live_request_queue=live_request_queue, + run_config=run_config, + resumability_config=self.resumability_config, ) - ) as agen: - async for event in agen: - yield event - - def _find_agent_to_run( - self, session: Session, root_agent: BaseAgent - ) -> BaseAgent: - """Finds the agent to run to continue the session. - - A qualified agent must be either of: - - - The agent that returned a function call and the last user message is a - function response to this function call. - - The root agent. - - An LlmAgent who replied last and is capable to transfer to any other agent - in the agent hierarchy. - - Args: - session: The session to find the agent for. - root_agent: The root agent of the runner. - - Returns: - The agent to run. (the active agent that should reply to the latest user - message) - """ - # If the last event is a function response, should send this response to - # the agent that returned the corresponding function call regardless the - # type of the agent. e.g. a remote a2a agent may surface a credential - # request as a special long running function tool call. - event = find_matching_function_call(session.events) - if event and event.author: - return root_agent.find_agent(event.author) - - def _event_filter(event: Event) -> bool: - """Filters out user-authored events and agent state change events.""" - if event.author == 'user': - return False - if event.actions.agent_state is not None or event.actions.end_of_agent: - return False - return True - - for event in filter(_event_filter, reversed(session.events)): - if event.author == root_agent.name: - # Found root agent. - return root_agent - if not (agent := root_agent.find_sub_agent(event.author)): - # Agent not found, continue looking. - logger.warning( - 'Event from an unknown agent: %s, event id: %s', - event.author, - event.id, - ) - continue - if self._is_transferable_across_agent_tree(agent): - return agent - # Falls back to root agent if no suitable agents are found in the session. - return root_agent - def _is_transferable_across_agent_tree(self, agent_to_run: BaseAgent) -> bool: - """Whether the agent to run can transfer to any other agent in the agent tree. + def _new_invocation_context_for_live( + self, + session: Session, + *, + live_request_queue: Optional[LiveRequestQueue] = None, + run_config: Optional[RunConfig] = None, + ) -> InvocationContext: + """Creates a new invocation context for live multi-agent.""" + run_config = run_config or RunConfig() + + # For live multi-agent, we need model's text transcription as context for + # next agent. + if self.agent.sub_agents and live_request_queue: + if not run_config.response_modalities: + # default + run_config.response_modalities = ["AUDIO"] + if not run_config.output_audio_transcription: + run_config.output_audio_transcription = ( + types.AudioTranscriptionConfig() + ) + elif "TEXT" not in run_config.response_modalities: + if not run_config.output_audio_transcription: + run_config.output_audio_transcription = ( + types.AudioTranscriptionConfig() + ) + if not run_config.input_audio_transcription: + # need this input transcription for agent transferring in live mode. + run_config.input_audio_transcription = types.AudioTranscriptionConfig() + return self._new_invocation_context( + session, + live_request_queue=live_request_queue, + run_config=run_config, + ) - This typically means all agent_to_run's ancestor can transfer to their - parent_agent all the way to the root_agent. + async def _handle_new_message( + self, + *, + session: Session, + new_message: types.Content, + invocation_context: InvocationContext, + run_config: RunConfig, + state_delta: Optional[dict[str, Any]], + ) -> None: + """Handles a new message by running callbacks and appending to session. + + Args: + session: The session of the new message. + new_message: The new message to process and append to the session. + invocation_context: The invocation context to use for the message + handling. + run_config: The run config of the agent. + state_delta: Optional state changes to apply to the session. + """ + modified_user_message = ( + await invocation_context.plugin_manager.run_on_user_message_callback( + invocation_context=invocation_context, user_message=new_message + ) + ) + if modified_user_message is not None: + new_message = modified_user_message + invocation_context.user_content = new_message - Args: - agent_to_run: The agent to check for transferability. + if new_message: + await self._append_new_message_to_session( + session=session, + new_message=new_message, + invocation_context=invocation_context, + save_input_blobs_as_artifacts=run_config.save_input_blobs_as_artifacts, + state_delta=state_delta, + ) - Returns: - True if the agent can transfer, False otherwise. - """ - agent = agent_to_run - while agent: - if not isinstance(agent, LlmAgent): - # Only LLM-based Agent can provide agent transfer capability. - return False - if agent.disallow_transfer_to_parent: - return False - agent = agent.parent_agent - return True - - async def run_debug( - self, - user_messages: str | list[str], - *, - user_id: str = 'debug_user_id', - session_id: str = 'debug_session_id', - run_config: RunConfig | None = None, - quiet: bool = False, - verbose: bool = False, - ) -> list[Event]: - """Debug helper for quick agent experimentation and testing. - - This convenience method is designed for developers getting started with ADK - who want to quickly test agents without dealing with session management, - content formatting, or event streaming. It automatically handles common - boilerplate while hiding complexity. - - IMPORTANT: This is for debugging and experimentation only. For production - use, please use the standard run_async() method which provides full control - over session management, event streaming, and error handling. - - Args: - user_messages: Message(s) to send to the agent. Can be: - Single string: - "What is 2+2?" - List of strings: ["Hello!", "What's my name?"] - user_id: User identifier. Defaults to "debug_user_id". - session_id: Session identifier for conversation persistence. Defaults to - "debug_session_id". Reuse the same ID to continue a conversation. - run_config: Optional configuration for the agent execution. - quiet: If True, suppresses console output. Defaults to False (output - shown). - verbose: If True, shows detailed tool calls and responses. Defaults to - False for cleaner output showing only final agent responses. - - Returns: - list[Event]: All events from all messages. - - Raises: - ValueError: If session creation/retrieval fails. - - Examples: - Quick debugging: - >>> runner = InMemoryRunner(agent=my_agent) - >>> await runner.run_debug("What is 2+2?") - - Multiple queries in conversation: - >>> await runner.run_debug(["Hello!", "What's my name?"]) - - Continue a debug session: - >>> await runner.run_debug("What did we discuss?") # Continues default - session - - Separate debug sessions: - >>> await runner.run_debug("Hi", user_id="alice", session_id="debug1") - >>> await runner.run_debug("Hi", user_id="bob", session_id="debug2") - - Capture events for inspection: - >>> events = await runner.run_debug("Analyze this") - >>> for event in events: - ... inspect_event(event) - - Note: - For production applications requiring: - - Custom session/memory services (Spanner, Cloud SQL, etc.) - - Fine-grained event processing and streaming - - Error recovery and resumability - - Performance optimization - Please use run_async() with proper configuration. - """ - session = await self.session_service.get_session( - app_name=self.app_name, user_id=user_id, session_id=session_id - ) - if not session: - session = await self.session_service.create_session( - app_name=self.app_name, user_id=user_id, session_id=session_id - ) - if not quiet: - print(f'\n ### Created new session: {session_id}') - elif not quiet: - print(f'\n ### Continue session: {session_id}') - - collected_events: list[Event] = [] - - if isinstance(user_messages, str): - user_messages = [user_messages] - - for message in user_messages: - if not quiet: - print(f'\nUser > {message}') - - async for event in self.run_async( - user_id=user_id, - session_id=session.id, - new_message=types.UserContent(parts=[types.Part(text=message)]), - run_config=run_config, - ): - if not quiet: - print_event(event, verbose=verbose) - - collected_events.append(event) - - return collected_events - - async def _setup_context_for_new_invocation( - self, - *, - session: Session, - new_message: types.Content, - run_config: RunConfig, - state_delta: Optional[dict[str, Any]], - ) -> InvocationContext: - """Sets up the context for a new invocation. - - Args: - session: The session to set up the invocation context for. - new_message: The new message to process and append to the session. - run_config: The run config of the agent. - state_delta: Optional state changes to apply to the session. - - Returns: - The invocation context for the new invocation. - """ - # Step 1: Create invocation context in memory. - invocation_context = self._new_invocation_context( - session, - new_message=new_message, - run_config=run_config, - ) - # Step 2: Handle new message, by running callbacks and appending to - # session. - await self._handle_new_message( - session=session, - new_message=new_message, - invocation_context=invocation_context, - run_config=run_config, - state_delta=state_delta, - ) - # Step 3: Set agent to run for the invocation. - invocation_context.agent = self._find_agent_to_run(session, self.agent) - return invocation_context - - async def _setup_context_for_resumed_invocation( - self, - *, - session: Session, - new_message: Optional[types.Content], - invocation_id: Optional[str], - run_config: RunConfig, - state_delta: Optional[dict[str, Any]], - ) -> InvocationContext: - """Sets up the context for a resumed invocation. - - Args: - session: The session to set up the invocation context for. - new_message: The new message to process and append to the session. - invocation_id: The invocation id to resume. - run_config: The run config of the agent. - state_delta: Optional state changes to apply to the session. - - Returns: - The invocation context for the resumed invocation. - - Raises: - ValueError: If the session has no events to resume; If no user message is - available for resuming the invocation; Or if the app is not resumable. - """ - if not session.events: - raise ValueError(f'Session {session.id} has no events to resume.') + def _collect_toolset(self, agent: BaseAgent) -> set[BaseToolset]: + toolsets = set() + if isinstance(agent, LlmAgent): + for tool_union in agent.tools: + if isinstance(tool_union, BaseToolset): + toolsets.add(tool_union) + for sub_agent in agent.sub_agents: + toolsets.update(self._collect_toolset(sub_agent)) + return toolsets + + async def _cleanup_toolsets(self, toolsets_to_close: set[BaseToolset]): + """Clean up toolsets with proper task context management.""" + if not toolsets_to_close: + return - # Step 1: Maybe retrieve a previous user message for the invocation. - user_message = new_message or self._find_user_message_for_invocation( - session.events, invocation_id - ) - if not user_message: - raise ValueError( - f'No user message available for resuming invocation: {invocation_id}' - ) - # Step 2: Create invocation context. - invocation_context = self._new_invocation_context( - session, - new_message=user_message, - run_config=run_config, - invocation_id=invocation_id, - ) - # Step 3: Maybe handle new message. - if new_message: - await self._handle_new_message( - session=session, - new_message=user_message, - invocation_context=invocation_context, - run_config=run_config, - state_delta=state_delta, - ) - # Step 4: Populate agent states for the current invocation. - invocation_context.populate_invocation_agent_states() - # Step 5: Set agent to run for the invocation. - # - # If the root agent is not found in end_of_agents, it means the invocation - # started from a sub-agent and paused on a sub-agent. - # We should find the appropriate agent to run to continue the invocation. - if self.agent.name not in invocation_context.end_of_agents: - invocation_context.agent = self._find_agent_to_run(session, self.agent) - return invocation_context - - def _find_user_message_for_invocation( - self, events: list[Event], invocation_id: str - ) -> Optional[types.Content]: - """Finds the user message that started a specific invocation.""" - for event in events: - if ( - event.invocation_id == invocation_id - and event.author == 'user' - and event.content - and event.content.parts - and event.content.parts[0].text - ): - return event.content - return None - - def _new_invocation_context( - self, - session: Session, - *, - invocation_id: Optional[str] = None, - new_message: Optional[types.Content] = None, - live_request_queue: Optional[LiveRequestQueue] = None, - run_config: Optional[RunConfig] = None, - ) -> InvocationContext: - """Creates a new invocation context. - - Args: - session: The session for the context. - invocation_id: The invocation id for the context. - new_message: The new message for the context. - live_request_queue: The live request queue for the context. - run_config: The run config for the context. - - Returns: - The new invocation context. - """ - run_config = run_config or RunConfig() - invocation_id = invocation_id or new_invocation_context_id() - - if run_config.support_cfc and isinstance(self.agent, LlmAgent): - model_name = self.agent.canonical_model.model - if not model_name.startswith('gemini-2'): - raise ValueError( - f'CFC is not supported for model: {model_name} in agent:' - f' {self.agent.name}' - ) - if not isinstance(self.agent.code_executor, BuiltInCodeExecutor): - self.agent.code_executor = BuiltInCodeExecutor() - - return InvocationContext( - artifact_service=self.artifact_service, - session_service=self.session_service, - memory_service=self.memory_service, - credential_service=self.credential_service, - plugin_manager=self.plugin_manager, - context_cache_config=self.context_cache_config, - invocation_id=invocation_id, - agent=self.agent, - session=session, - user_content=new_message, - live_request_queue=live_request_queue, - run_config=run_config, - resumability_config=self.resumability_config, - ) + # This maintains the same task context throughout cleanup + for toolset in toolsets_to_close: + try: + logger.info("Closing toolset: %s", type(toolset).__name__) + # Use asyncio.wait_for to add timeout protection + await asyncio.wait_for(toolset.close(), timeout=10.0) + logger.info("Successfully closed toolset: %s", type(toolset).__name__) + except asyncio.TimeoutError: + logger.warning("Toolset %s cleanup timed out", type(toolset).__name__) + except asyncio.CancelledError as e: + # Handle cancel scope issues in Python 3.10 and 3.11 with anyio + # + # Root cause: MCP library uses anyio.CancelScope() in RequestResponder.__enter__() + # and __exit__() methods. When asyncio.wait_for() creates a new task for cleanup, + # the cancel scope is entered in one task context but exited in another. + # + # Python 3.12+ fixes: Enhanced task context management (Task.get_context()), + # improved context propagation across task boundaries, and better cancellation + # handling prevent the cross-task cancel scope violation. + logger.warning( + "Toolset %s cleanup cancelled: %s", type(toolset).__name__, e + ) + except Exception as e: + logger.error("Error closing toolset %s: %s", type(toolset).__name__, e) - def _new_invocation_context_for_live( - self, - session: Session, - *, - live_request_queue: Optional[LiveRequestQueue] = None, - run_config: Optional[RunConfig] = None, - ) -> InvocationContext: - """Creates a new invocation context for live multi-agent.""" - run_config = run_config or RunConfig() - - # For live multi-agent, we need model's text transcription as context for - # next agent. - if self.agent.sub_agents and live_request_queue: - if not run_config.response_modalities: - # default - run_config.response_modalities = ['AUDIO'] - if not run_config.output_audio_transcription: - run_config.output_audio_transcription = ( - types.AudioTranscriptionConfig() - ) - elif 'TEXT' not in run_config.response_modalities: - if not run_config.output_audio_transcription: - run_config.output_audio_transcription = ( - types.AudioTranscriptionConfig() - ) - if not run_config.input_audio_transcription: - # need this input transcription for agent transferring in live mode. - run_config.input_audio_transcription = types.AudioTranscriptionConfig() - return self._new_invocation_context( - session, - live_request_queue=live_request_queue, - run_config=run_config, - ) + async def close(self): + """Closes the runner.""" + logger.info("Closing runner...") + # Close Toolsets + await self._cleanup_toolsets(self._collect_toolset(self.agent)) - async def _handle_new_message( - self, - *, - session: Session, - new_message: types.Content, - invocation_context: InvocationContext, - run_config: RunConfig, - state_delta: Optional[dict[str, Any]], - ) -> None: - """Handles a new message by running callbacks and appending to session. - - Args: - session: The session of the new message. - new_message: The new message to process and append to the session. - invocation_context: The invocation context to use for the message - handling. - run_config: The run config of the agent. - state_delta: Optional state changes to apply to the session. - """ - modified_user_message = ( - await invocation_context.plugin_manager.run_on_user_message_callback( - invocation_context=invocation_context, user_message=new_message - ) - ) - if modified_user_message is not None: - new_message = modified_user_message - invocation_context.user_content = new_message - - if new_message: - await self._append_new_message_to_session( - session=session, - new_message=new_message, - invocation_context=invocation_context, - save_input_blobs_as_artifacts=run_config.save_input_blobs_as_artifacts, - state_delta=state_delta, - ) - - def _collect_toolset(self, agent: BaseAgent) -> set[BaseToolset]: - toolsets = set() - if isinstance(agent, LlmAgent): - for tool_union in agent.tools: - if isinstance(tool_union, BaseToolset): - toolsets.add(tool_union) - for sub_agent in agent.sub_agents: - toolsets.update(self._collect_toolset(sub_agent)) - return toolsets - - async def _cleanup_toolsets(self, toolsets_to_close: set[BaseToolset]): - """Clean up toolsets with proper task context management.""" - if not toolsets_to_close: - return - - # This maintains the same task context throughout cleanup - for toolset in toolsets_to_close: - try: - logger.info('Closing toolset: %s', type(toolset).__name__) - # Use asyncio.wait_for to add timeout protection - await asyncio.wait_for(toolset.close(), timeout=10.0) - logger.info('Successfully closed toolset: %s', type(toolset).__name__) - except asyncio.TimeoutError: - logger.warning('Toolset %s cleanup timed out', type(toolset).__name__) - except asyncio.CancelledError as e: - # Handle cancel scope issues in Python 3.10 and 3.11 with anyio - # - # Root cause: MCP library uses anyio.CancelScope() in RequestResponder.__enter__() - # and __exit__() methods. When asyncio.wait_for() creates a new task for cleanup, - # the cancel scope is entered in one task context but exited in another. - # - # Python 3.12+ fixes: Enhanced task context management (Task.get_context()), - # improved context propagation across task boundaries, and better cancellation - # handling prevent the cross-task cancel scope violation. - logger.warning( - 'Toolset %s cleanup cancelled: %s', type(toolset).__name__, e - ) - except Exception as e: - logger.error('Error closing toolset %s: %s', type(toolset).__name__, e) + # Close Plugins + if self.plugin_manager: + await self.plugin_manager.close() - async def close(self): - """Closes the runner.""" - logger.info('Closing runner...') - # Close Toolsets - await self._cleanup_toolsets(self._collect_toolset(self.agent)) + logger.info("Runner closed.") - # Close Plugins - if self.plugin_manager: - await self.plugin_manager.close() + if sys.version_info < (3, 11): + Self = "Runner" # pylint: disable=invalid-name + else: + from typing import Self # pylint: disable=g-import-not-at-top - logger.info('Runner closed.') + async def __aenter__(self) -> Self: + """Async context manager entry.""" + return self - if sys.version_info < (3, 11): - Self = 'Runner' # pylint: disable=invalid-name - else: - from typing import Self # pylint: disable=g-import-not-at-top + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + await self.close() + return False # Don't suppress exceptions from the async with block - async def __aenter__(self) -> Self: - """Async context manager entry.""" - return self - async def __aexit__(self, exc_type, exc_val, exc_tb): - """Async context manager exit.""" - await self.close() - return False # Don't suppress exceptions from the async with block +class InMemoryRunner(Runner): + """An in-memory Runner for testing and development. + This runner uses in-memory implementations for artifact, session, and memory + services, providing a lightweight and self-contained environment for agent + execution. -class InMemoryRunner(Runner): - """An in-memory Runner for testing and development. - - This runner uses in-memory implementations for artifact, session, and memory - services, providing a lightweight and self-contained environment for agent - execution. - - Attributes: - agent: The root agent to run. - app_name: The application name of the runner. Defaults to - 'InMemoryRunner'. - """ - - def __init__( - self, - agent: Optional[BaseAgent] = None, - *, - app_name: Optional[str] = None, - plugins: Optional[list[BasePlugin]] = None, - app: Optional[App] = None, - plugin_close_timeout: float = 5.0, - ): - """Initializes the InMemoryRunner. - - Args: + Attributes: agent: The root agent to run. app_name: The application name of the runner. Defaults to 'InMemoryRunner'. - plugins: Optional list of plugins for the runner. - app: Optional App instance. - plugin_close_timeout: The timeout in seconds for plugin close methods. """ - if app is None and app_name is None: - app_name = 'InMemoryRunner' - super().__init__( - app_name=app_name, - agent=agent, - artifact_service=InMemoryArtifactService(), - plugins=plugins, - app=app, - session_service=InMemorySessionService(), - memory_service=InMemoryMemoryService(), - plugin_close_timeout=plugin_close_timeout, - ) + + def __init__( + self, + agent: Optional[BaseAgent] = None, + *, + app_name: Optional[str] = None, + plugins: Optional[list[BasePlugin]] = None, + app: Optional[App] = None, + plugin_close_timeout: float = 5.0, + ): + """Initializes the InMemoryRunner. + + Args: + agent: The root agent to run. + app_name: The application name of the runner. Defaults to + 'InMemoryRunner'. + plugins: Optional list of plugins for the runner. + app: Optional App instance. + plugin_close_timeout: The timeout in seconds for plugin close methods. + """ + if app is None and app_name is None: + app_name = "InMemoryRunner" + super().__init__( + app_name=app_name, + agent=agent, + artifact_service=InMemoryArtifactService(), + plugins=plugins, + app=app, + session_service=InMemorySessionService(), + memory_service=InMemoryMemoryService(), + plugin_close_timeout=plugin_close_timeout, + ) diff --git a/src/google/adk/telemetry/tracing.py b/src/google/adk/telemetry/tracing.py index f03cdc8010..3b57c4328b 100644 --- a/src/google/adk/telemetry/tracing.py +++ b/src/google/adk/telemetry/tracing.py @@ -37,83 +37,83 @@ # By default some ADK spans include attributes with potential PII data. # This env, when set to false, allows to disable populating those attributes. -ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS = 'ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS' +ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS = "ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS" # TODO: Replace with constant from opentelemetry.semconv when it reaches version 1.37 in g3. -GEN_AI_AGENT_DESCRIPTION = 'gen_ai.agent.description' -GEN_AI_AGENT_NAME = 'gen_ai.agent.name' -GEN_AI_CONVERSATION_ID = 'gen_ai.conversation.id' -GEN_AI_OPERATION_NAME = 'gen_ai.operation.name' -GEN_AI_TOOL_CALL_ID = 'gen_ai.tool.call.id' -GEN_AI_TOOL_DESCRIPTION = 'gen_ai.tool.description' -GEN_AI_TOOL_NAME = 'gen_ai.tool.name' -GEN_AI_TOOL_TYPE = 'gen_ai.tool.type' +GEN_AI_AGENT_DESCRIPTION = "gen_ai.agent.description" +GEN_AI_AGENT_NAME = "gen_ai.agent.name" +GEN_AI_CONVERSATION_ID = "gen_ai.conversation.id" +GEN_AI_OPERATION_NAME = "gen_ai.operation.name" +GEN_AI_TOOL_CALL_ID = "gen_ai.tool.call.id" +GEN_AI_TOOL_DESCRIPTION = "gen_ai.tool.description" +GEN_AI_TOOL_NAME = "gen_ai.tool.name" +GEN_AI_TOOL_TYPE = "gen_ai.tool.type" # Needed to avoid circular imports if TYPE_CHECKING: - from ..agents.base_agent import BaseAgent - from ..agents.invocation_context import InvocationContext - from ..models.llm_request import LlmRequest - from ..models.llm_response import LlmResponse - from ..tools.base_tool import BaseTool + from ..agents.base_agent import BaseAgent + from ..agents.invocation_context import InvocationContext + from ..models.llm_request import LlmRequest + from ..models.llm_response import LlmResponse + from ..tools.base_tool import BaseTool tracer = trace.get_tracer( - instrumenting_module_name='gcp.vertex.agent', + instrumenting_module_name="gcp.vertex.agent", instrumenting_library_version=version.__version__, # TODO: Replace with constant from opentelemetry.semconv when it reaches version 1.37 in g3. - schema_url='https://opentelemetry.io/schemas/1.37.0', + schema_url="https://opentelemetry.io/schemas/1.37.0", ) def _safe_json_serialize(obj) -> str: - """Convert any Python object to a JSON-serializable type or string. + """Convert any Python object to a JSON-serializable type or string. - Args: - obj: The object to serialize. + Args: + obj: The object to serialize. - Returns: - The JSON-serialized object string or if the object cannot be serialized. - """ + Returns: + The JSON-serialized object string or if the object cannot be serialized. + """ - try: - # Try direct JSON serialization first - return json.dumps( - obj, ensure_ascii=False, default=lambda o: '' - ) - except (TypeError, OverflowError): - return '' + try: + # Try direct JSON serialization first + return json.dumps( + obj, ensure_ascii=False, default=lambda o: "" + ) + except (TypeError, OverflowError): + return "" def trace_agent_invocation( span: trace.Span, agent: BaseAgent, ctx: InvocationContext ) -> None: - """Sets span attributes immediately available on agent invocation according to OTEL semconv version 1.37. + """Sets span attributes immediately available on agent invocation according to OTEL semconv version 1.37. - Args: - span: Span on which attributes are set. - agent: Agent from which attributes are gathered. - ctx: InvocationContext from which attributes are gathered. + Args: + span: Span on which attributes are set. + agent: Agent from which attributes are gathered. + ctx: InvocationContext from which attributes are gathered. - Inference related fields are not set, due to their planned removal from invoke_agent span: - https://github.com/open-telemetry/semantic-conventions/issues/2632 + Inference related fields are not set, due to their planned removal from invoke_agent span: + https://github.com/open-telemetry/semantic-conventions/issues/2632 - `gen_ai.agent.id` is not set because currently it's unclear what attributes this field should have, specifically: - - In which scope should it be unique (globally, given project, given agentic flow, given deployment). - - Should it be unchanging between deployments, and how this should this be achieved. + `gen_ai.agent.id` is not set because currently it's unclear what attributes this field should have, specifically: + - In which scope should it be unique (globally, given project, given agentic flow, given deployment). + - Should it be unchanging between deployments, and how this should this be achieved. - `gen_ai.data_source.id` is not set because it's not available. - Closest type which could contain this information is types.GroundingMetadata, which does not have an ID. + `gen_ai.data_source.id` is not set because it's not available. + Closest type which could contain this information is types.GroundingMetadata, which does not have an ID. - `server.*` attributes are not set pending confirmation from aabmass. - """ + `server.*` attributes are not set pending confirmation from aabmass. + """ - # Required - span.set_attribute(GEN_AI_OPERATION_NAME, 'invoke_agent') + # Required + span.set_attribute(GEN_AI_OPERATION_NAME, "invoke_agent") - # Conditionally Required - span.set_attribute(GEN_AI_AGENT_DESCRIPTION, agent.description) + # Conditionally Required + span.set_attribute(GEN_AI_AGENT_DESCRIPTION, agent.description) - span.set_attribute(GEN_AI_AGENT_NAME, agent.name) - span.set_attribute(GEN_AI_CONVERSATION_ID, ctx.session.id) + span.set_attribute(GEN_AI_AGENT_NAME, agent.name) + span.set_attribute(GEN_AI_CONVERSATION_ID, ctx.session.id) def trace_tool_call( @@ -121,112 +121,112 @@ def trace_tool_call( args: dict[str, Any], function_response_event: Optional[Event], ): - """Traces tool call. + """Traces tool call. - Args: - tool: The tool that was called. - args: The arguments to the tool call. - function_response_event: The event with the function response details. - """ - span = trace.get_current_span() + Args: + tool: The tool that was called. + args: The arguments to the tool call. + function_response_event: The event with the function response details. + """ + span = trace.get_current_span() - span.set_attribute(GEN_AI_OPERATION_NAME, 'execute_tool') + span.set_attribute(GEN_AI_OPERATION_NAME, "execute_tool") - span.set_attribute(GEN_AI_TOOL_DESCRIPTION, tool.description) - span.set_attribute(GEN_AI_TOOL_NAME, tool.name) + span.set_attribute(GEN_AI_TOOL_DESCRIPTION, tool.description) + span.set_attribute(GEN_AI_TOOL_NAME, tool.name) - # e.g. FunctionTool - span.set_attribute(GEN_AI_TOOL_TYPE, tool.__class__.__name__) + # e.g. FunctionTool + span.set_attribute(GEN_AI_TOOL_TYPE, tool.__class__.__name__) - # Setting empty llm request and response (as UI expect these) while not - # applicable for tool_response. - span.set_attribute('gcp.vertex.agent.llm_request', '{}') - span.set_attribute('gcp.vertex.agent.llm_response', '{}') + # Setting empty llm request and response (as UI expect these) while not + # applicable for tool_response. + span.set_attribute("gcp.vertex.agent.llm_request", "{}") + span.set_attribute("gcp.vertex.agent.llm_response", "{}") - if _should_add_request_response_to_spans(): - span.set_attribute( - 'gcp.vertex.agent.tool_call_args', - _safe_json_serialize(args), - ) - else: - span.set_attribute('gcp.vertex.agent.tool_call_args', {}) - - # Tracing tool response - tool_call_id = '' - tool_response = '' - if ( - function_response_event is not None - and function_response_event.content is not None - and function_response_event.content.parts - ): - response_parts = function_response_event.content.parts - function_response = response_parts[0].function_response - if function_response is not None: - if function_response.id is not None: - tool_call_id = function_response.id - if function_response.response is not None: - tool_response = function_response.response - - span.set_attribute(GEN_AI_TOOL_CALL_ID, tool_call_id) - - if not isinstance(tool_response, dict): - tool_response = {'result': tool_response} - if function_response_event is not None: - span.set_attribute('gcp.vertex.agent.event_id', function_response_event.id) - if _should_add_request_response_to_spans(): - span.set_attribute( - 'gcp.vertex.agent.tool_response', - _safe_json_serialize(tool_response), - ) - else: - span.set_attribute('gcp.vertex.agent.tool_response', {}) + if _should_add_request_response_to_spans(): + span.set_attribute( + "gcp.vertex.agent.tool_call_args", + _safe_json_serialize(args), + ) + else: + span.set_attribute("gcp.vertex.agent.tool_call_args", {}) + + # Tracing tool response + tool_call_id = "" + tool_response = "" + if ( + function_response_event is not None + and function_response_event.content is not None + and function_response_event.content.parts + ): + response_parts = function_response_event.content.parts + function_response = response_parts[0].function_response + if function_response is not None: + if function_response.id is not None: + tool_call_id = function_response.id + if function_response.response is not None: + tool_response = function_response.response + + span.set_attribute(GEN_AI_TOOL_CALL_ID, tool_call_id) + + if not isinstance(tool_response, dict): + tool_response = {"result": tool_response} + if function_response_event is not None: + span.set_attribute("gcp.vertex.agent.event_id", function_response_event.id) + if _should_add_request_response_to_spans(): + span.set_attribute( + "gcp.vertex.agent.tool_response", + _safe_json_serialize(tool_response), + ) + else: + span.set_attribute("gcp.vertex.agent.tool_response", {}) def trace_merged_tool_calls( response_event_id: str, function_response_event: Event, ): - """Traces merged tool call events. + """Traces merged tool call events. - Calling this function is not needed for telemetry purposes. This is provided - for preventing /debug/trace requests (typically sent by web UI). + Calling this function is not needed for telemetry purposes. This is provided + for preventing /debug/trace requests (typically sent by web UI). - Args: - response_event_id: The ID of the response event. - function_response_event: The merged response event. - """ + Args: + response_event_id: The ID of the response event. + function_response_event: The merged response event. + """ - span = trace.get_current_span() + span = trace.get_current_span() - span.set_attribute(GEN_AI_OPERATION_NAME, 'execute_tool') - span.set_attribute(GEN_AI_TOOL_NAME, '(merged tools)') - span.set_attribute(GEN_AI_TOOL_DESCRIPTION, '(merged tools)') - span.set_attribute(GEN_AI_TOOL_CALL_ID, response_event_id) + span.set_attribute(GEN_AI_OPERATION_NAME, "execute_tool") + span.set_attribute(GEN_AI_TOOL_NAME, "(merged tools)") + span.set_attribute(GEN_AI_TOOL_DESCRIPTION, "(merged tools)") + span.set_attribute(GEN_AI_TOOL_CALL_ID, response_event_id) - # TODO(b/441461932): See if these are still necessary - span.set_attribute('gcp.vertex.agent.tool_call_args', 'N/A') - span.set_attribute('gcp.vertex.agent.event_id', response_event_id) - try: - function_response_event_json = function_response_event.model_dumps_json( - exclude_none=True - ) - except Exception: # pylint: disable=broad-exception-caught - function_response_event_json = '' + # TODO(b/441461932): See if these are still necessary + span.set_attribute("gcp.vertex.agent.tool_call_args", "N/A") + span.set_attribute("gcp.vertex.agent.event_id", response_event_id) + try: + function_response_event_json = function_response_event.model_dumps_json( + exclude_none=True + ) + except Exception: # pylint: disable=broad-exception-caught + function_response_event_json = "" - if _should_add_request_response_to_spans(): + if _should_add_request_response_to_spans(): + span.set_attribute( + "gcp.vertex.agent.tool_response", + function_response_event_json, + ) + else: + span.set_attribute("gcp.vertex.agent.tool_response", {}) + # Setting empty llm request and response (as UI expect these) while not + # applicable for tool_response. + span.set_attribute("gcp.vertex.agent.llm_request", "{}") span.set_attribute( - 'gcp.vertex.agent.tool_response', - function_response_event_json, + "gcp.vertex.agent.llm_response", + "{}", ) - else: - span.set_attribute('gcp.vertex.agent.tool_response', {}) - # Setting empty llm request and response (as UI expect these) while not - # applicable for tool_response. - span.set_attribute('gcp.vertex.agent.llm_request', '{}') - span.set_attribute( - 'gcp.vertex.agent.llm_response', - '{}', - ) def trace_call_llm( @@ -235,82 +235,80 @@ def trace_call_llm( llm_request: LlmRequest, llm_response: LlmResponse, ): - """Traces a call to the LLM. - - This function records details about the LLM request and response as - attributes on the current OpenTelemetry span. - - Args: - invocation_context: The invocation context for the current agent run. - event_id: The ID of the event. - llm_request: The LLM request object. - llm_response: The LLM response object. - """ - span = trace.get_current_span() - # Special standard Open Telemetry GenaI attributes that indicate - # that this is a span related to a Generative AI system. - span.set_attribute('gen_ai.system', 'gcp.vertex.agent') - span.set_attribute('gen_ai.request.model', llm_request.model) - span.set_attribute( - 'gcp.vertex.agent.invocation_id', invocation_context.invocation_id - ) - span.set_attribute( - 'gcp.vertex.agent.session_id', invocation_context.session.id - ) - span.set_attribute('gcp.vertex.agent.event_id', event_id) - # Consider removing once GenAI SDK provides a way to record this info. - if _should_add_request_response_to_spans(): - span.set_attribute( - 'gcp.vertex.agent.llm_request', - _safe_json_serialize(_build_llm_request_for_trace(llm_request)), - ) - else: - span.set_attribute('gcp.vertex.agent.llm_request', {}) - # Consider removing once GenAI SDK provides a way to record this info. - if llm_request.config: - if llm_request.config.top_p: - span.set_attribute( - 'gen_ai.request.top_p', - llm_request.config.top_p, - ) - if llm_request.config.max_output_tokens: - span.set_attribute( - 'gen_ai.request.max_tokens', - llm_request.config.max_output_tokens, - ) - - try: - llm_response_json = llm_response.model_dump_json(exclude_none=True) - except Exception: # pylint: disable=broad-exception-caught - llm_response_json = '' - - if _should_add_request_response_to_spans(): + """Traces a call to the LLM. + + This function records details about the LLM request and response as + attributes on the current OpenTelemetry span. + + Args: + invocation_context: The invocation context for the current agent run. + event_id: The ID of the event. + llm_request: The LLM request object. + llm_response: The LLM response object. + """ + span = trace.get_current_span() + # Special standard Open Telemetry GenaI attributes that indicate + # that this is a span related to a Generative AI system. + span.set_attribute("gen_ai.system", "gcp.vertex.agent") + span.set_attribute("gen_ai.request.model", llm_request.model) span.set_attribute( - 'gcp.vertex.agent.llm_response', - llm_response_json, + "gcp.vertex.agent.invocation_id", invocation_context.invocation_id ) - else: - span.set_attribute('gcp.vertex.agent.llm_response', {}) + span.set_attribute("gcp.vertex.agent.session_id", invocation_context.session.id) + span.set_attribute("gcp.vertex.agent.event_id", event_id) + # Consider removing once GenAI SDK provides a way to record this info. + if _should_add_request_response_to_spans(): + span.set_attribute( + "gcp.vertex.agent.llm_request", + _safe_json_serialize(_build_llm_request_for_trace(llm_request)), + ) + else: + span.set_attribute("gcp.vertex.agent.llm_request", {}) + # Consider removing once GenAI SDK provides a way to record this info. + if llm_request.config: + if llm_request.config.top_p: + span.set_attribute( + "gen_ai.request.top_p", + llm_request.config.top_p, + ) + if llm_request.config.max_output_tokens: + span.set_attribute( + "gen_ai.request.max_tokens", + llm_request.config.max_output_tokens, + ) - if llm_response.usage_metadata is not None: - span.set_attribute( - 'gen_ai.usage.input_tokens', - llm_response.usage_metadata.prompt_token_count, - ) - if llm_response.usage_metadata.candidates_token_count is not None: - span.set_attribute( - 'gen_ai.usage.output_tokens', - llm_response.usage_metadata.candidates_token_count, - ) - if llm_response.finish_reason: try: - finish_reason_str = llm_response.finish_reason.value.lower() - except AttributeError: - finish_reason_str = str(llm_response.finish_reason).lower() - span.set_attribute( - 'gen_ai.response.finish_reasons', - [finish_reason_str], - ) + llm_response_json = llm_response.model_dump_json(exclude_none=True) + except Exception: # pylint: disable=broad-exception-caught + llm_response_json = "" + + if _should_add_request_response_to_spans(): + span.set_attribute( + "gcp.vertex.agent.llm_response", + llm_response_json, + ) + else: + span.set_attribute("gcp.vertex.agent.llm_response", {}) + + if llm_response.usage_metadata is not None: + span.set_attribute( + "gen_ai.usage.input_tokens", + llm_response.usage_metadata.prompt_token_count, + ) + if llm_response.usage_metadata.candidates_token_count is not None: + span.set_attribute( + "gen_ai.usage.output_tokens", + llm_response.usage_metadata.candidates_token_count, + ) + if llm_response.finish_reason: + try: + finish_reason_str = llm_response.finish_reason.value.lower() + except AttributeError: + finish_reason_str = str(llm_response.finish_reason).lower() + span.set_attribute( + "gen_ai.response.finish_reasons", + [finish_reason_str], + ) def trace_send_data( @@ -318,67 +316,67 @@ def trace_send_data( event_id: str, data: list[types.Content], ): - """Traces the sending of data to the agent. - - This function records details about the data sent to the agent as - attributes on the current OpenTelemetry span. - - Args: - invocation_context: The invocation context for the current agent run. - event_id: The ID of the event. - data: A list of content objects. - """ - span = trace.get_current_span() - span.set_attribute( - 'gcp.vertex.agent.invocation_id', invocation_context.invocation_id - ) - span.set_attribute('gcp.vertex.agent.event_id', event_id) - # Once instrumentation is added to the GenAI SDK, consider whether this - # information still needs to be recorded by the Agent Development Kit. - if _should_add_request_response_to_spans(): + """Traces the sending of data to the agent. + + This function records details about the data sent to the agent as + attributes on the current OpenTelemetry span. + + Args: + invocation_context: The invocation context for the current agent run. + event_id: The ID of the event. + data: A list of content objects. + """ + span = trace.get_current_span() span.set_attribute( - 'gcp.vertex.agent.data', - _safe_json_serialize([ - types.Content(role=content.role, parts=content.parts).model_dump( - exclude_none=True - ) - for content in data - ]), + "gcp.vertex.agent.invocation_id", invocation_context.invocation_id ) - else: - span.set_attribute('gcp.vertex.agent.data', {}) + span.set_attribute("gcp.vertex.agent.event_id", event_id) + # Once instrumentation is added to the GenAI SDK, consider whether this + # information still needs to be recorded by the Agent Development Kit. + if _should_add_request_response_to_spans(): + span.set_attribute( + "gcp.vertex.agent.data", + _safe_json_serialize( + [ + types.Content(role=content.role, parts=content.parts).model_dump( + exclude_none=True + ) + for content in data + ] + ), + ) + else: + span.set_attribute("gcp.vertex.agent.data", {}) def _build_llm_request_for_trace(llm_request: LlmRequest) -> dict[str, Any]: - """Builds a dictionary representation of the LLM request for tracing. - - This function prepares a dictionary representation of the LlmRequest - object, suitable for inclusion in a trace. It excludes fields that cannot - be serialized (e.g., function pointers) and avoids sending bytes data. - - Args: - llm_request: The LlmRequest object. - - Returns: - A dictionary representation of the LLM request. - """ - # Some fields in LlmRequest are function pointers and cannot be serialized. - result = { - 'model': llm_request.model, - 'config': llm_request.config.model_dump( - exclude_none=True, exclude='response_schema' - ), - 'contents': [], - } - # We do not want to send bytes data to the trace. - for content in llm_request.contents: - parts = [part for part in content.parts if not part.inline_data] - result['contents'].append( - types.Content(role=content.role, parts=parts).model_dump( - exclude_none=True + """Builds a dictionary representation of the LLM request for tracing. + + This function prepares a dictionary representation of the LlmRequest + object, suitable for inclusion in a trace. It excludes fields that cannot + be serialized (e.g., function pointers) and avoids sending bytes data. + + Args: + llm_request: The LlmRequest object. + + Returns: + A dictionary representation of the LLM request. + """ + # Some fields in LlmRequest are function pointers and cannot be serialized. + result = { + "model": llm_request.model, + "config": llm_request.config.model_dump( + exclude_none=True, exclude="response_schema" + ), + "contents": [], + } + # We do not want to send bytes data to the trace. + for content in llm_request.contents: + parts = [part for part in content.parts if not part.inline_data] + result["contents"].append( + types.Content(role=content.role, parts=parts).model_dump(exclude_none=True) ) - ) - return result + return result # Defaults to true for now to preserve backward compatibility. @@ -386,7 +384,7 @@ def _build_llm_request_for_trace(llm_request: LlmRequest) -> dict[str, Any]: # a deprecation of request/response content in spans by switching the default # to false. def _should_add_request_response_to_spans() -> bool: - disabled_via_env_var = os.getenv( - ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS, 'true' - ).lower() in ('false', '0') - return not disabled_via_env_var + disabled_via_env_var = os.getenv( + ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS, "true" + ).lower() in ("false", "0") + return not disabled_via_env_var diff --git a/src/google/adk/utils/telemetry_utils.py b/src/google/adk/utils/telemetry_utils.py new file mode 100644 index 0000000000..0388fc8cfd --- /dev/null +++ b/src/google/adk/utils/telemetry_utils.py @@ -0,0 +1,63 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for telemetry. + +This module is for ADK internal use only. +Please do not rely on the implementation details. +""" + +from .env_utils import is_env_enabled +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ..agents.base_agent import BaseAgent + + +def is_telemetry_enabled(agent: "BaseAgent") -> bool: + """Check if telemetry is enabled for the given agent. + + By default telemetry is enabled for an agent unless any of the variables to disable telemetry are set to true. + + Args: + agent: The agent to check if telemetry is enabled for. + + Returns: + False if any of the environment variables or attributes to disable telemetryare set to True, 'true' or 1, False otherwise. + + Examples: + >>> os.environ['OTEL_SDK_DISABLED'] = 'true' + >>> is_telemetry_disabled(my_agent) + False + + >>> os.environ['ADK_TELEMETRY_DISABLED'] = 1 + >>> is_telemetry_disabled(my_agent) + False + + >>> my_agent.disable_telemetry = True + >>> is_telemetry_disabled(my_agent) + False + + >>> os.environ['OTEL_SDK_DISABLED'] = 0 + >>> os.environ['ADK_TELEMETRY_DISABLED'] = 'false' + >>> my_agent.disable_telemetry = False + >>> is_telemetry_disabled(my_agent) + True + """ + telemetry_disabled = ( + is_env_enabled("OTEL_SDK_DISABLED") + or is_env_enabled("ADK_TELEMETRY_DISABLED") + or getattr(agent, "disable_telemetry", False) + ) + return not telemetry_disabled From 8aa526cda863c45b7b27da6d7eb3ca896f6423f1 Mon Sep 17 00:00:00 2001 From: mportdata Date: Sat, 27 Dec 2025 23:30:54 +0000 Subject: [PATCH 02/24] test(telemetry): cover disable flags Add unit tests for agent and Gemini telemetry disable flags including env var gating. Add integration test asserting spans are emitted when enabled and omitted when ADK_TELEMETRY_DISABLED is set. --- tests/integration/telemetry/__init__.py | 13 ++ .../telemetry/test_telemetry_disable.py | 83 +++++++++++++ .../telemetry/test_telemetry_disable_agent.py | 117 ++++++++++++++++++ .../test_telemetry_disable_google_llm.py | 114 +++++++++++++++++ 4 files changed, 327 insertions(+) create mode 100644 tests/integration/telemetry/__init__.py create mode 100644 tests/integration/telemetry/test_telemetry_disable.py create mode 100644 tests/unittests/telemetry/test_telemetry_disable_agent.py create mode 100644 tests/unittests/telemetry/test_telemetry_disable_google_llm.py diff --git a/tests/integration/telemetry/__init__.py b/tests/integration/telemetry/__init__.py new file mode 100644 index 0000000000..0a2669d7a2 --- /dev/null +++ b/tests/integration/telemetry/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/integration/telemetry/test_telemetry_disable.py b/tests/integration/telemetry/test_telemetry_disable.py new file mode 100644 index 0000000000..a5d70d1af1 --- /dev/null +++ b/tests/integration/telemetry/test_telemetry_disable.py @@ -0,0 +1,83 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import ( + InMemorySpanExporter, +) +import pytest + +from google.adk.agents.llm_agent import Agent +from google.adk.telemetry import tracing +from google.adk.utils.context_utils import Aclosing +from google.genai.types import Part + +from tests.unittests.testing_utils import MockModel +from tests.unittests.testing_utils import TestInMemoryRunner + + +@pytest.fixture +def span_exporter(monkeypatch: pytest.MonkeyPatch) -> InMemorySpanExporter: + tracer_provider = TracerProvider() + exporter = InMemorySpanExporter() + tracer_provider.add_span_processor(SimpleSpanProcessor(exporter)) + real_tracer = tracer_provider.get_tracer(__name__) + + monkeypatch.setattr( + tracing.tracer, + "start_as_current_span", + real_tracer.start_as_current_span, + ) + return exporter + + +@pytest.mark.asyncio +async def test_telemetry_enabled_records_spans(monkeypatch, span_exporter): + monkeypatch.delenv("OTEL_SDK_DISABLED", raising=False) + monkeypatch.delenv("ADK_TELEMETRY_DISABLED", raising=False) + + agent = Agent( + name="test_agent", + model=MockModel.create(responses=[Part.from_text(text="ok")]), + disable_telemetry=False, + ) + runner = TestInMemoryRunner(agent) + + async with Aclosing(runner.run_async_with_new_session_agen("")) as agen: + async for _ in agen: + pass + + spans = span_exporter.get_finished_spans() + assert spans + + +@pytest.mark.asyncio +async def test_env_var_disables_telemetry(monkeypatch, span_exporter): + monkeypatch.setenv("ADK_TELEMETRY_DISABLED", "true") + monkeypatch.delenv("OTEL_SDK_DISABLED", raising=False) + + agent = Agent( + name="test_agent", + model=MockModel.create(responses=[Part.from_text(text="ok")]), + disable_telemetry=False, + ) + runner = TestInMemoryRunner(agent) + + async with Aclosing(runner.run_async_with_new_session_agen("")) as agen: + async for _ in agen: + pass + + spans = span_exporter.get_finished_spans() + assert not spans diff --git a/tests/unittests/telemetry/test_telemetry_disable_agent.py b/tests/unittests/telemetry/test_telemetry_disable_agent.py new file mode 100644 index 0000000000..4a91778511 --- /dev/null +++ b/tests/unittests/telemetry/test_telemetry_disable_agent.py @@ -0,0 +1,117 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from unittest import mock + +from google.adk.agents.llm_agent import Agent +from google.adk.telemetry import tracing +from google.adk.utils.context_utils import Aclosing +from google.genai.types import Part + + +from ..testing_utils import MockModel, TestInMemoryRunner + + +@pytest.mark.asyncio +async def test_disable_telemetry_prevents_span_creation(monkeypatch): + monkeypatch.delenv("OTEL_SDK_DISABLED", raising=False) + monkeypatch.delenv("ADK_TELEMETRY_DISABLED", raising=False) + span = mock.MagicMock() + context_manager = mock.MagicMock() + context_manager.__enter__.return_value = span + context_manager.__exit__.return_value = False + + mock_start = mock.Mock(return_value=context_manager) + monkeypatch.setattr(tracing.tracer, "start_as_current_span", mock_start) + + agent = Agent( + name="agent", + model=MockModel.create(responses=[Part.from_text(text="ok")]), + disable_telemetry=True, + ) + + runner = TestInMemoryRunner(agent) + + async with Aclosing(runner.run_async_with_new_session_agen("")) as agen: + async for _ in agen: + pass + + assert mock_start.call_count == 0 + + +@pytest.mark.asyncio +async def test_enabled_telemetry_causes_span_creation(monkeypatch): + monkeypatch.setenv("OTEL_SDK_DISABLED", "false") + monkeypatch.setenv("ADK_TELEMETRY_DISABLED", "false") + span = mock.MagicMock() + context_manager = mock.MagicMock() + context_manager.__enter__.return_value = span + context_manager.__exit__.return_value = False + + mock_start = mock.Mock(return_value=context_manager) + monkeypatch.setattr(tracing.tracer, "start_as_current_span", mock_start) + + agent = Agent( + name="agent", + model=MockModel.create(responses=[Part.from_text(text="ok")]), + disable_telemetry=False, + ) + + runner = TestInMemoryRunner(agent) + + async with Aclosing(runner.run_async_with_new_session_agen("")) as agen: + async for _ in agen: + pass + + assert mock_start.call_count > 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "env_var,env_value", + [ + ("OTEL_SDK_DISABLED", "true"), + ("OTEL_SDK_DISABLED", "1"), + ("ADK_TELEMETRY_DISABLED", "true"), + ("ADK_TELEMETRY_DISABLED", "1"), + ], +) +async def test_env_flag_disables_telemetry(monkeypatch, env_var, env_value): + monkeypatch.setenv(env_var, env_value) + monkeypatch.delenv( + "ADK_TELEMETRY_DISABLED" if env_var == "OTEL_SDK_DISABLED" else "OTEL_SDK_DISABLED", + raising=False, + ) + span = mock.MagicMock() + context_manager = mock.MagicMock() + context_manager.__enter__.return_value = span + context_manager.__exit__.return_value = False + + mock_start = mock.Mock(return_value=context_manager) + monkeypatch.setattr(tracing.tracer, "start_as_current_span", mock_start) + + agent = Agent( + name="agent", + model=MockModel.create(responses=[Part.from_text(text="ok")]), + disable_telemetry=False, + ) + + runner = TestInMemoryRunner(agent) + + async with Aclosing(runner.run_async_with_new_session_agen("")) as agen: + async for _ in agen: + pass + + assert mock_start.call_count == 0 diff --git a/tests/unittests/telemetry/test_telemetry_disable_google_llm.py b/tests/unittests/telemetry/test_telemetry_disable_google_llm.py new file mode 100644 index 0000000000..df4cb16752 --- /dev/null +++ b/tests/unittests/telemetry/test_telemetry_disable_google_llm.py @@ -0,0 +1,114 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock +import pytest + +from google.adk.models.google_llm import Gemini +from google.adk.models import llm_response as llm_response_mod +from google.adk.models import gemini_context_cache_manager as cache_mod + + +@pytest.mark.asyncio +async def test_disable_google_llm_telemetry(monkeypatch, context_manager_with_span): + monkeypatch.setenv("OTEL_SDK_DISABLED", "false") + monkeypatch.setenv("ADK_TELEMETRY_DISABLED", "false") + start_span = mock.Mock(return_value=context_manager_with_span) + monkeypatch.setattr( + "google.adk.telemetry.tracing.tracer.start_as_current_span", + start_span, + ) + + gemini = Gemini(disable_telemetry=True) + + # Avoid real Client construction + fake_client = mock.MagicMock() + fake_client.vertexai = False + fake_client.aio.models.generate_content = mock.AsyncMock( + return_value=mock.MagicMock() + ) + gemini.__dict__["api_client"] = fake_client + + # Prevent cache validation code running (the bit that touches expire_time) + monkeypatch.setattr( + cache_mod.GeminiContextCacheManager, + "handle_context_caching", + mock.AsyncMock(return_value=None), + ) + + req = mock.MagicMock() + req.cache_config = object() # force the cache path + req.model = "gemini-2.5-flash" + req.contents = [] + req.config = mock.MagicMock() + req.config.tools = None + req.config.system_instruction = "" + req.config.model_dump = mock.Mock(return_value={}) + req.config.http_options = None + + monkeypatch.setattr( + llm_response_mod.LlmResponse, "create", mock.Mock(return_value=mock.MagicMock()) + ) + + async for _ in gemini.generate_content_async(req, stream=False): + break + + assert start_span.call_count == 0 + + +@pytest.mark.asyncio +async def test_enable_google_llm_telemetry(monkeypatch, context_manager_with_span): + monkeypatch.setenv("OTEL_SDK_DISABLED", "false") + monkeypatch.setenv("ADK_TELEMETRY_DISABLED", "false") + start_span = mock.Mock(return_value=context_manager_with_span) + monkeypatch.setattr( + "google.adk.telemetry.tracing.tracer.start_as_current_span", + start_span, + ) + + gemini = Gemini(disable_telemetry=False) + + # Avoid real Client construction + fake_client = mock.MagicMock() + fake_client.vertexai = False + fake_client.aio.models.generate_content = mock.AsyncMock( + return_value=mock.MagicMock() + ) + gemini.__dict__["api_client"] = fake_client + + # Prevent cache validation code running (the bit that touches expire_time) + monkeypatch.setattr( + cache_mod.GeminiContextCacheManager, + "handle_context_caching", + mock.AsyncMock(return_value=None), + ) + + req = mock.MagicMock() + req.cache_config = object() # force the cache path + req.model = "gemini-2.5-flash" + req.contents = [] + req.config = mock.MagicMock() + req.config.tools = None + req.config.system_instruction = "" + req.config.model_dump = mock.Mock(return_value={}) + req.config.http_options = None + + monkeypatch.setattr( + llm_response_mod.LlmResponse, "create", mock.Mock(return_value=mock.MagicMock()) + ) + + async for _ in gemini.generate_content_async(req, stream=False): + break + + assert start_span.call_count > 0 From 541be5f1b97cf30f5232ee30a9f233b44f050879 Mon Sep 17 00:00:00 2001 From: mportdata Date: Sat, 27 Dec 2025 23:30:54 +0000 Subject: [PATCH 03/24] test(telemetry): cover disable paths Add integration coverage for OTEL_SDK_DISABLED, ADK_TELEMETRY_DISABLED, and agent.disable_telemetry. --- .../telemetry/test_telemetry_disable.py | 42 ++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/tests/integration/telemetry/test_telemetry_disable.py b/tests/integration/telemetry/test_telemetry_disable.py index a5d70d1af1..5f953362d0 100644 --- a/tests/integration/telemetry/test_telemetry_disable.py +++ b/tests/integration/telemetry/test_telemetry_disable.py @@ -64,7 +64,7 @@ async def test_telemetry_enabled_records_spans(monkeypatch, span_exporter): @pytest.mark.asyncio -async def test_env_var_disables_telemetry(monkeypatch, span_exporter): +async def test_adk_telemetry_disabled_env_var_disables(monkeypatch, span_exporter): monkeypatch.setenv("ADK_TELEMETRY_DISABLED", "true") monkeypatch.delenv("OTEL_SDK_DISABLED", raising=False) @@ -81,3 +81,43 @@ async def test_env_var_disables_telemetry(monkeypatch, span_exporter): spans = span_exporter.get_finished_spans() assert not spans + + +@pytest.mark.asyncio +async def test_otel_sdk_env_var_disables_telemetry(monkeypatch, span_exporter): + monkeypatch.setenv("OTEL_SDK_DISABLED", "true") + monkeypatch.delenv("ADK_TELEMETRY_DISABLED", raising=False) + + agent = Agent( + name="test_agent", + model=MockModel.create(responses=[Part.from_text(text="ok")]), + disable_telemetry=False, + ) + runner = TestInMemoryRunner(agent) + + async with Aclosing(runner.run_async_with_new_session_agen("")) as agen: + async for _ in agen: + pass + + spans = span_exporter.get_finished_spans() + assert not spans + + +@pytest.mark.asyncio +async def test_agent_flag_disables_telemetry(monkeypatch, span_exporter): + monkeypatch.delenv("OTEL_SDK_DISABLED", raising=False) + monkeypatch.delenv("ADK_TELEMETRY_DISABLED", raising=False) + + agent = Agent( + name="test_agent", + model=MockModel.create(responses=[Part.from_text(text="ok")]), + disable_telemetry=True, + ) + runner = TestInMemoryRunner(agent) + + async with Aclosing(runner.run_async_with_new_session_agen("")) as agen: + async for _ in agen: + pass + + spans = span_exporter.get_finished_spans() + assert not spans From 696588023b9eee0c23706f23ea5347a900464eed Mon Sep 17 00:00:00 2001 From: mportdata Date: Sat, 27 Dec 2025 23:30:54 +0000 Subject: [PATCH 04/24] fix(telemetry): restore cache helper Add back _create_gemini_cache wrapper for compatibility after telemetry refactor. Apply autoformat across touched files. --- contributing/samples/gepa/experiment.py | 1 - contributing/samples/gepa/run_experiment.py | 1 - src/google/adk/agents/base_agent.py | 1136 +++---- src/google/adk/agents/llm_agent.py | 1204 +++---- .../adk/flows/llm_flows/base_llm_flow.py | 1908 ++++++----- src/google/adk/flows/llm_flows/functions.py | 1328 ++++---- .../models/gemini_context_cache_manager.py | 875 ++--- src/google/adk/models/google_llm.py | 867 ++--- src/google/adk/runners.py | 2813 +++++++++-------- src/google/adk/telemetry/tracing.py | 508 +-- src/google/adk/utils/telemetry_utils.py | 75 +- .../telemetry/test_telemetry_disable.py | 145 +- .../test_gemini_context_cache_manager.py | 4 +- tests/unittests/conftest.py | 33 +- .../telemetry/test_telemetry_disable_agent.py | 132 +- .../test_telemetry_disable_google_llm.py | 190 +- 16 files changed, 5651 insertions(+), 5569 deletions(-) diff --git a/contributing/samples/gepa/experiment.py b/contributing/samples/gepa/experiment.py index 2f5d03a772..f68b349d9c 100644 --- a/contributing/samples/gepa/experiment.py +++ b/contributing/samples/gepa/experiment.py @@ -43,7 +43,6 @@ from tau_bench.types import EnvRunResult from tau_bench.types import RunConfig import tau_bench_agent as tau_bench_agent_lib - import utils diff --git a/contributing/samples/gepa/run_experiment.py b/contributing/samples/gepa/run_experiment.py index cfd850b3a3..1bc4ee58c8 100644 --- a/contributing/samples/gepa/run_experiment.py +++ b/contributing/samples/gepa/run_experiment.py @@ -25,7 +25,6 @@ from absl import flags import experiment from google.genai import types - import utils _OUTPUT_DIR = flags.DEFINE_string( diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index 11cab7ea04..5207897c54 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -49,7 +49,7 @@ from .callback_context import CallbackContext if TYPE_CHECKING: - from .invocation_context import InvocationContext + from .invocation_context import InvocationContext logger = logging.getLogger("google_adk." + __name__) @@ -73,27 +73,27 @@ @experimental class BaseAgentState(BaseModel): - """Base class for all agent states.""" + """Base class for all agent states.""" - model_config = ConfigDict( - extra="forbid", - ) + model_config = ConfigDict( + extra="forbid", + ) AgentState = TypeVar("AgentState", bound=BaseAgentState) class BaseAgent(BaseModel): - """Base class for all agents in Agent Development Kit.""" + """Base class for all agents in Agent Development Kit.""" - model_config = ConfigDict( - arbitrary_types_allowed=True, - extra="forbid", - ) - """The pydantic model config.""" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) + """The pydantic model config.""" - config_type: ClassVar[type[BaseAgentConfig]] = BaseAgentConfig - """The config type for this agent. + config_type: ClassVar[type[BaseAgentConfig]] = BaseAgentConfig + """The config type for this agent. Sub-classes should override this to specify their own config type. @@ -108,22 +108,22 @@ class MyAgent(BaseAgent): ``` """ - name: str - """The agent's name. + name: str + """The agent's name. Agent name must be a Python identifier and unique within the agent tree. Agent name cannot be "user", since it's reserved for end-user's input. """ - description: str = "" - """Description about the agent's capability. + description: str = "" + """Description about the agent's capability. The model uses this to determine whether to delegate control to the agent. One-line description is enough and preferred. """ - parent_agent: Optional[BaseAgent] = Field(default=None, init=False) - """The parent agent of this agent. + parent_agent: Optional[BaseAgent] = Field(default=None, init=False) + """The parent agent of this agent. Note that an agent can ONLY be added as sub-agent once. @@ -131,13 +131,13 @@ class MyAgent(BaseAgent): instances with identical config, but with different name and add them to the agent tree. """ - sub_agents: list[BaseAgent] = Field(default_factory=list) - """The sub-agents of this agent.""" - disable_telemetry: bool = False - """Whether to disable telemetry for this agent. + sub_agents: list[BaseAgent] = Field(default_factory=list) + """The sub-agents of this agent.""" + disable_telemetry: bool = False + """Whether to disable telemetry for this agent. """ - before_agent_callback: Optional[BeforeAgentCallback] = None - """Callback or list of callbacks to be invoked before the agent run. + before_agent_callback: Optional[BeforeAgentCallback] = None + """Callback or list of callbacks to be invoked before the agent run. When a list of callbacks is provided, the callbacks will be called in the order they are listed until a callback does not return None. @@ -150,8 +150,8 @@ class MyAgent(BaseAgent): When the content is present, the agent run will be skipped and the provided content will be returned to user. """ - after_agent_callback: Optional[AfterAgentCallback] = None - """Callback or list of callbacks to be invoked after the agent run. + after_agent_callback: Optional[AfterAgentCallback] = None + """Callback or list of callbacks to be invoked after the agent run. When a list of callbacks is provided, the callbacks will be called in the order they are listed until a callback does not return None. @@ -165,550 +165,562 @@ class MyAgent(BaseAgent): will be appended to event history as an additional agent response. """ - def _load_agent_state( - self, - ctx: InvocationContext, - state_type: Type[AgentState], - ) -> Optional[AgentState]: - """Loads the agent state from the invocation context. - - Args: - ctx: The invocation context. - state_type: The type of the agent state. - - Returns: - The current state if exists; otherwise, None. - """ - if ctx.agent_states is None or self.name not in ctx.agent_states: - return None - else: - return state_type.model_validate(ctx.agent_states.get(self.name)) - - def _create_agent_state_event( - self, - ctx: InvocationContext, - ) -> Event: - """Returns an event with current agent state set in the invocation context. - - Args: - ctx: The invocation context. - - Returns: - An event with the current agent state set in the invocation context. - """ - event_actions = EventActions() - if (agent_state := ctx.agent_states.get(self.name)) is not None: - event_actions.agent_state = agent_state - if ctx.end_of_agents.get(self.name): - event_actions.end_of_agent = True - return Event( - invocation_id=ctx.invocation_id, - author=self.name, - branch=ctx.branch, - actions=event_actions, - ) + def _load_agent_state( + self, + ctx: InvocationContext, + state_type: Type[AgentState], + ) -> Optional[AgentState]: + """Loads the agent state from the invocation context. + + Args: + ctx: The invocation context. + state_type: The type of the agent state. + + Returns: + The current state if exists; otherwise, None. + """ + if ctx.agent_states is None or self.name not in ctx.agent_states: + return None + else: + return state_type.model_validate(ctx.agent_states.get(self.name)) + + def _create_agent_state_event( + self, + ctx: InvocationContext, + ) -> Event: + """Returns an event with current agent state set in the invocation context. + + Args: + ctx: The invocation context. + + Returns: + An event with the current agent state set in the invocation context. + """ + event_actions = EventActions() + if (agent_state := ctx.agent_states.get(self.name)) is not None: + event_actions.agent_state = agent_state + if ctx.end_of_agents.get(self.name): + event_actions.end_of_agent = True + return Event( + invocation_id=ctx.invocation_id, + author=self.name, + branch=ctx.branch, + actions=event_actions, + ) - def clone(self: SelfAgent, update: Mapping[str, Any] | None = None) -> SelfAgent: - """Creates a copy of this agent instance. - - Args: - update: Optional mapping of new values for the fields of the cloned agent. - The keys of the mapping are the names of the fields to be updated, and - the values are the new values for those fields. - For example: {"name": "cloned_agent"} - - Returns: - A new agent instance with identical configuration as the original - agent except for the fields specified in the update. - """ - if update is not None and "parent_agent" in update: - raise ValueError( - "Cannot update `parent_agent` field in clone. Parent agent is set" - " only when the parent agent is instantiated with the sub-agents." - ) - - # Only allow updating fields that are defined in the agent class. - allowed_fields = set(self.__class__.model_fields) - if update is not None: - invalid_fields = set(update) - allowed_fields - if invalid_fields: - raise ValueError( - f"Cannot update nonexistent fields in {self.__class__.__name__}:" - f" {invalid_fields}" - ) - - cloned_agent = self.model_copy(update=update) - - # If any field is stored as list and not provided in the update, need to - # shallow copy it for the cloned agent to avoid sharing the same list object - # with the original agent. - for field_name in cloned_agent.__class__.model_fields: - if field_name == "sub_agents": - continue - if update is not None and field_name in update: - continue - field = getattr(cloned_agent, field_name) - if isinstance(field, list): - setattr(cloned_agent, field_name, field.copy()) - - if update is None or "sub_agents" not in update: - # If `sub_agents` is not provided in the update, need to recursively clone - # the sub-agents to avoid sharing the sub-agents with the original agent. - cloned_agent.sub_agents = [] - for sub_agent in self.sub_agents: - cloned_sub_agent = sub_agent.clone() - cloned_sub_agent.parent_agent = cloned_agent - cloned_agent.sub_agents.append(cloned_sub_agent) - else: - for sub_agent in cloned_agent.sub_agents: - sub_agent.parent_agent = cloned_agent - - # Remove the parent agent from the cloned agent to avoid sharing the parent - # agent with the cloned agent. - cloned_agent.parent_agent = None - return cloned_agent - - @final - async def run_async( - self, - parent_context: InvocationContext, - ) -> AsyncGenerator[Event, None]: - """Entry method to run an agent via text-based conversation. - - Args: - parent_context: InvocationContext, the invocation context of the parent - agent. - - Yields: - Event: the events generated by the agent. - """ - - ctx = self._create_invocation_context(parent_context) - if is_telemetry_enabled(self): - with tracer.start_as_current_span(f"invoke_agent {self.name}") as span: - tracing.trace_agent_invocation(span, self, ctx) - async with Aclosing( - self._run_callbacks_and_impl(ctx, mode="async") - ) as agen: - async for event in agen: - yield event - else: - async with Aclosing( - self._run_callbacks_and_impl(ctx, mode="async") - ) as agen: - async for event in agen: - yield event - - @final - async def run_live( - self, - parent_context: InvocationContext, - ) -> AsyncGenerator[Event, None]: - """Entry method to run an agent via video/audio-based conversation. - - Args: - parent_context: InvocationContext, the invocation context of the parent - agent. - - Yields: - Event: the events generated by the agent. - """ - - ctx = self._create_invocation_context(parent_context) - if is_telemetry_enabled(self): - with tracer.start_as_current_span(f"invoke_agent {self.name}") as span: - tracing.trace_agent_invocation(span, self, ctx) - async for event in self._run_callbacks_and_impl(ctx, mode="live"): - yield event - else: - async for event in self._run_callbacks_and_impl(ctx, mode="live"): - yield event - - async def _run_async_impl( - self, ctx: InvocationContext - ) -> AsyncGenerator[Event, None]: - """Core logic to run this agent via text-based conversation. - - Args: - ctx: InvocationContext, the invocation context for this agent. - - Yields: - Event: the events generated by the agent. - """ - raise NotImplementedError( - f"_run_async_impl for {type(self)} is not implemented." + def clone( + self: SelfAgent, update: Mapping[str, Any] | None = None + ) -> SelfAgent: + """Creates a copy of this agent instance. + + Args: + update: Optional mapping of new values for the fields of the cloned agent. + The keys of the mapping are the names of the fields to be updated, and + the values are the new values for those fields. + For example: {"name": "cloned_agent"} + + Returns: + A new agent instance with identical configuration as the original + agent except for the fields specified in the update. + """ + if update is not None and "parent_agent" in update: + raise ValueError( + "Cannot update `parent_agent` field in clone. Parent agent is set" + " only when the parent agent is instantiated with the sub-agents." + ) + + # Only allow updating fields that are defined in the agent class. + allowed_fields = set(self.__class__.model_fields) + if update is not None: + invalid_fields = set(update) - allowed_fields + if invalid_fields: + raise ValueError( + f"Cannot update nonexistent fields in {self.__class__.__name__}:" + f" {invalid_fields}" ) - yield # AsyncGenerator requires having at least one yield statement - async def _run_live_impl( - self, ctx: InvocationContext - ) -> AsyncGenerator[Event, None]: - """Core logic to run this agent via video/audio-based conversation. + cloned_agent = self.model_copy(update=update) + + # If any field is stored as list and not provided in the update, need to + # shallow copy it for the cloned agent to avoid sharing the same list object + # with the original agent. + for field_name in cloned_agent.__class__.model_fields: + if field_name == "sub_agents": + continue + if update is not None and field_name in update: + continue + field = getattr(cloned_agent, field_name) + if isinstance(field, list): + setattr(cloned_agent, field_name, field.copy()) + + if update is None or "sub_agents" not in update: + # If `sub_agents` is not provided in the update, need to recursively clone + # the sub-agents to avoid sharing the sub-agents with the original agent. + cloned_agent.sub_agents = [] + for sub_agent in self.sub_agents: + cloned_sub_agent = sub_agent.clone() + cloned_sub_agent.parent_agent = cloned_agent + cloned_agent.sub_agents.append(cloned_sub_agent) + else: + for sub_agent in cloned_agent.sub_agents: + sub_agent.parent_agent = cloned_agent + + # Remove the parent agent from the cloned agent to avoid sharing the parent + # agent with the cloned agent. + cloned_agent.parent_agent = None + return cloned_agent + + @final + async def run_async( + self, + parent_context: InvocationContext, + ) -> AsyncGenerator[Event, None]: + """Entry method to run an agent via text-based conversation. + + Args: + parent_context: InvocationContext, the invocation context of the parent + agent. + + Yields: + Event: the events generated by the agent. + """ + + ctx = self._create_invocation_context(parent_context) + if is_telemetry_enabled(self): + with tracer.start_as_current_span(f"invoke_agent {self.name}") as span: + tracing.trace_agent_invocation(span, self, ctx) + async with Aclosing( + self._run_callbacks_and_impl(ctx, mode="async") + ) as agen: + async for event in agen: + yield event + else: + async with Aclosing( + self._run_callbacks_and_impl(ctx, mode="async") + ) as agen: + async for event in agen: + yield event + + @final + async def run_live( + self, + parent_context: InvocationContext, + ) -> AsyncGenerator[Event, None]: + """Entry method to run an agent via video/audio-based conversation. + + Args: + parent_context: InvocationContext, the invocation context of the parent + agent. + + Yields: + Event: the events generated by the agent. + """ + + ctx = self._create_invocation_context(parent_context) + if is_telemetry_enabled(self): + with tracer.start_as_current_span(f"invoke_agent {self.name}") as span: + tracing.trace_agent_invocation(span, self, ctx) + async for event in self._run_callbacks_and_impl(ctx, mode="live"): + yield event + else: + async for event in self._run_callbacks_and_impl(ctx, mode="live"): + yield event + + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + """Core logic to run this agent via text-based conversation. + + Args: + ctx: InvocationContext, the invocation context for this agent. + + Yields: + Event: the events generated by the agent. + """ + raise NotImplementedError( + f"_run_async_impl for {type(self)} is not implemented." + ) + yield # AsyncGenerator requires having at least one yield statement - Args: - ctx: InvocationContext, the invocation context for this agent. + async def _run_live_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + """Core logic to run this agent via video/audio-based conversation. - Yields: - Event: the events generated by the agent. - """ - raise NotImplementedError( - f"_run_live_impl for {type(self)} is not implemented." - ) - yield # AsyncGenerator requires having at least one yield statement - - async def _run_callbacks_and_impl( - self, ctx: InvocationContext, mode: str = "async" - ) -> AsyncGenerator[Event, None]: - """Runs the before and after agent callbacks around the core agent logic. - Args: - ctx: InvocationContext, the invocation context for this agent. - mode: str, either 'async' or 'live', indicating which core agent logic to run. - Yields: - Event: the events generated by the agent. - """ - if event := await self._handle_before_agent_callback(ctx): - yield event - if ctx.end_invocation: - return - if mode.lower() == "async": - async with Aclosing(self._run_async_impl(ctx)) as agen: - async for event in agen: - yield event - elif mode.lower() == "live": - async with Aclosing(self._run_live_impl(ctx)) as agen: - async for event in agen: - yield event - else: - raise ValueError(f"Invalid mode: {mode}. Must be either 'async' or 'live'.") - - if ctx.end_invocation: - return - - if event := await self._handle_after_agent_callback(ctx): - yield event + Args: + ctx: InvocationContext, the invocation context for this agent. - @property - def root_agent(self) -> BaseAgent: - """Gets the root agent of this agent.""" - root_agent = self - while root_agent.parent_agent is not None: - root_agent = root_agent.parent_agent - return root_agent - - def find_agent(self, name: str) -> Optional[BaseAgent]: - """Finds the agent with the given name in this agent and its descendants. - - Args: - name: The name of the agent to find. - - Returns: - The agent with the matching name, or None if no such agent is found. - """ - if self.name == name: - return self - return self.find_sub_agent(name) - - def find_sub_agent(self, name: str) -> Optional[BaseAgent]: - """Finds the agent with the given name in this agent's descendants. - - Args: - name: The name of the agent to find. - - Returns: - The agent with the matching name, or None if no such agent is found. - """ - for sub_agent in self.sub_agents: - if result := sub_agent.find_agent(name): - return result - return None - - def _create_invocation_context( - self, parent_context: InvocationContext - ) -> InvocationContext: - """Creates a new invocation context for this agent.""" - invocation_context = parent_context.model_copy(update={"agent": self}) - return invocation_context - - @property - def canonical_before_agent_callbacks(self) -> list[_SingleAgentCallback]: - """The resolved self.before_agent_callback field as a list of _SingleAgentCallback. - - This method is only for use by Agent Development Kit. - """ - if not self.before_agent_callback: - return [] - if isinstance(self.before_agent_callback, list): - return self.before_agent_callback - return [self.before_agent_callback] - - @property - def canonical_after_agent_callbacks(self) -> list[_SingleAgentCallback]: - """The resolved self.after_agent_callback field as a list of _SingleAgentCallback. - - This method is only for use by Agent Development Kit. - """ - if not self.after_agent_callback: - return [] - if isinstance(self.after_agent_callback, list): - return self.after_agent_callback - return [self.after_agent_callback] - - async def _handle_before_agent_callback( - self, ctx: InvocationContext - ) -> Optional[Event]: - """Runs the before_agent_callback if it exists. - - Args: - ctx: InvocationContext, the invocation context for this agent. - - Returns: - Optional[Event]: an event if callback provides content or changed state. - """ - callback_context = CallbackContext(ctx) - - # Run callbacks from the plugins. - before_agent_callback_content = ( - await ctx.plugin_manager.run_before_agent_callback( - agent=self, callback_context=callback_context - ) + Yields: + Event: the events generated by the agent. + """ + raise NotImplementedError( + f"_run_live_impl for {type(self)} is not implemented." + ) + yield # AsyncGenerator requires having at least one yield statement + + async def _run_callbacks_and_impl( + self, ctx: InvocationContext, mode: str = "async" + ) -> AsyncGenerator[Event, None]: + """Runs the before and after agent callbacks around the core agent logic. + Args: + ctx: InvocationContext, the invocation context for this agent. + mode: str, either 'async' or 'live', indicating which core agent logic to run. + Yields: + Event: the events generated by the agent. + """ + if event := await self._handle_before_agent_callback(ctx): + yield event + if ctx.end_invocation: + return + if mode.lower() == "async": + async with Aclosing(self._run_async_impl(ctx)) as agen: + async for event in agen: + yield event + elif mode.lower() == "live": + async with Aclosing(self._run_live_impl(ctx)) as agen: + async for event in agen: + yield event + else: + raise ValueError( + f"Invalid mode: {mode}. Must be either 'async' or 'live'." + ) + + if ctx.end_invocation: + return + + if event := await self._handle_after_agent_callback(ctx): + yield event + + @property + def root_agent(self) -> BaseAgent: + """Gets the root agent of this agent.""" + root_agent = self + while root_agent.parent_agent is not None: + root_agent = root_agent.parent_agent + return root_agent + + def find_agent(self, name: str) -> Optional[BaseAgent]: + """Finds the agent with the given name in this agent and its descendants. + + Args: + name: The name of the agent to find. + + Returns: + The agent with the matching name, or None if no such agent is found. + """ + if self.name == name: + return self + return self.find_sub_agent(name) + + def find_sub_agent(self, name: str) -> Optional[BaseAgent]: + """Finds the agent with the given name in this agent's descendants. + + Args: + name: The name of the agent to find. + + Returns: + The agent with the matching name, or None if no such agent is found. + """ + for sub_agent in self.sub_agents: + if result := sub_agent.find_agent(name): + return result + return None + + def _create_invocation_context( + self, parent_context: InvocationContext + ) -> InvocationContext: + """Creates a new invocation context for this agent.""" + invocation_context = parent_context.model_copy(update={"agent": self}) + return invocation_context + + @property + def canonical_before_agent_callbacks(self) -> list[_SingleAgentCallback]: + """The resolved self.before_agent_callback field as a list of _SingleAgentCallback. + + This method is only for use by Agent Development Kit. + """ + if not self.before_agent_callback: + return [] + if isinstance(self.before_agent_callback, list): + return self.before_agent_callback + return [self.before_agent_callback] + + @property + def canonical_after_agent_callbacks(self) -> list[_SingleAgentCallback]: + """The resolved self.after_agent_callback field as a list of _SingleAgentCallback. + + This method is only for use by Agent Development Kit. + """ + if not self.after_agent_callback: + return [] + if isinstance(self.after_agent_callback, list): + return self.after_agent_callback + return [self.after_agent_callback] + + async def _handle_before_agent_callback( + self, ctx: InvocationContext + ) -> Optional[Event]: + """Runs the before_agent_callback if it exists. + + Args: + ctx: InvocationContext, the invocation context for this agent. + + Returns: + Optional[Event]: an event if callback provides content or changed state. + """ + callback_context = CallbackContext(ctx) + + # Run callbacks from the plugins. + before_agent_callback_content = ( + await ctx.plugin_manager.run_before_agent_callback( + agent=self, callback_context=callback_context ) + ) - # If no overrides are provided from the plugins, further run the canonical - # callbacks. - if not before_agent_callback_content and self.canonical_before_agent_callbacks: - for callback in self.canonical_before_agent_callbacks: - before_agent_callback_content = callback( - callback_context=callback_context - ) - if inspect.isawaitable(before_agent_callback_content): - before_agent_callback_content = await before_agent_callback_content - if before_agent_callback_content: - break - - # Process the override content if exists, and further process the state - # change if exists. + # If no overrides are provided from the plugins, further run the canonical + # callbacks. + if ( + not before_agent_callback_content + and self.canonical_before_agent_callbacks + ): + for callback in self.canonical_before_agent_callbacks: + before_agent_callback_content = callback( + callback_context=callback_context + ) + if inspect.isawaitable(before_agent_callback_content): + before_agent_callback_content = await before_agent_callback_content if before_agent_callback_content: - ret_event = Event( - invocation_id=ctx.invocation_id, - author=self.name, - branch=ctx.branch, - content=before_agent_callback_content, - actions=callback_context._event_actions, - ) - ctx.end_invocation = True - return ret_event - - if callback_context.state.has_delta(): - return Event( - invocation_id=ctx.invocation_id, - author=self.name, - branch=ctx.branch, - actions=callback_context._event_actions, - ) - - return None - - async def _handle_after_agent_callback( - self, invocation_context: InvocationContext - ) -> Optional[Event]: - """Runs the after_agent_callback if it exists. - - Args: - invocation_context: InvocationContext, the invocation context for this - agent. - - Returns: - Optional[Event]: an event if callback provides content or changed state. - """ - - callback_context = CallbackContext(invocation_context) - - # Run callbacks from the plugins. - after_agent_callback_content = ( - await invocation_context.plugin_manager.run_after_agent_callback( - agent=self, callback_context=callback_context - ) + break + + # Process the override content if exists, and further process the state + # change if exists. + if before_agent_callback_content: + ret_event = Event( + invocation_id=ctx.invocation_id, + author=self.name, + branch=ctx.branch, + content=before_agent_callback_content, + actions=callback_context._event_actions, + ) + ctx.end_invocation = True + return ret_event + + if callback_context.state.has_delta(): + return Event( + invocation_id=ctx.invocation_id, + author=self.name, + branch=ctx.branch, + actions=callback_context._event_actions, + ) + + return None + + async def _handle_after_agent_callback( + self, invocation_context: InvocationContext + ) -> Optional[Event]: + """Runs the after_agent_callback if it exists. + + Args: + invocation_context: InvocationContext, the invocation context for this + agent. + + Returns: + Optional[Event]: an event if callback provides content or changed state. + """ + + callback_context = CallbackContext(invocation_context) + + # Run callbacks from the plugins. + after_agent_callback_content = ( + await invocation_context.plugin_manager.run_after_agent_callback( + agent=self, callback_context=callback_context ) + ) - # If no overrides are provided from the plugins, further run the canonical - # callbacks. - if not after_agent_callback_content and self.canonical_after_agent_callbacks: - for callback in self.canonical_after_agent_callbacks: - after_agent_callback_content = callback( - callback_context=callback_context - ) - if inspect.isawaitable(after_agent_callback_content): - after_agent_callback_content = await after_agent_callback_content - if after_agent_callback_content: - break - - # Process the override content if exists, and further process the state - # change if exists. + # If no overrides are provided from the plugins, further run the canonical + # callbacks. + if ( + not after_agent_callback_content + and self.canonical_after_agent_callbacks + ): + for callback in self.canonical_after_agent_callbacks: + after_agent_callback_content = callback( + callback_context=callback_context + ) + if inspect.isawaitable(after_agent_callback_content): + after_agent_callback_content = await after_agent_callback_content if after_agent_callback_content: - ret_event = Event( - invocation_id=invocation_context.invocation_id, - author=self.name, - branch=invocation_context.branch, - content=after_agent_callback_content, - actions=callback_context._event_actions, - ) - return ret_event - - if callback_context.state.has_delta(): - return Event( - invocation_id=invocation_context.invocation_id, - author=self.name, - branch=invocation_context.branch, - content=after_agent_callback_content, - actions=callback_context._event_actions, - ) - return None - - @override - def model_post_init(self, __context: Any) -> None: - self.__set_parent_agent_for_sub_agents() - - @field_validator("name", mode="after") - @classmethod - def validate_name(cls, value: str): - if not value.isidentifier(): - raise ValueError( - f"Found invalid agent name: `{value}`." - " Agent name must be a valid identifier. It should start with a" - " letter (a-z, A-Z) or an underscore (_), and can only contain" - " letters, digits (0-9), and underscores." - ) - if value == "user": - raise ValueError( - "Agent name cannot be `user`. `user` is reserved for end-user's" - " input." - ) - return value - - @field_validator("sub_agents", mode="after") - @classmethod - def validate_sub_agents_unique_names( - cls, value: list[BaseAgent] - ) -> list[BaseAgent]: - """Validates that all sub-agents have unique names. - - Args: - value: The list of sub-agents to validate. - - Returns: - The validated list of sub-agents. - - """ - if not value: - return value - - seen_names: set[str] = set() - duplicates: set[str] = set() - - for sub_agent in value: - name = sub_agent.name - if name in seen_names: - duplicates.add(name) - else: - seen_names.add(name) - - if duplicates: - duplicate_names_str = ", ".join(f"`{name}`" for name in sorted(duplicates)) - logger.warning( - "Found duplicate sub-agent names: %s. " - "All sub-agents must have unique names.", - duplicate_names_str, - ) - - return value - - def __set_parent_agent_for_sub_agents(self) -> BaseAgent: - for sub_agent in self.sub_agents: - if sub_agent.parent_agent is not None: - raise ValueError( - f"Agent `{sub_agent.name}` already has a parent agent, current" - f" parent: `{sub_agent.parent_agent.name}`, trying to add:" - f" `{self.name}`" - ) - sub_agent.parent_agent = self - return self - - @final - @classmethod - @experimental - def from_config( - cls: Type[SelfAgent], - config: BaseAgentConfig, - config_abs_path: str, - ) -> SelfAgent: - """Creates an agent from a config. - - If sub-classes uses a custom agent config, override `_from_config_kwargs` - method to return an updated kwargs for agent constructor. - - Args: - config: The config to create the agent from. - config_abs_path: The absolute path to the config file that contains the - agent config. - - Returns: - The created agent. - """ - kwargs = cls.__create_kwargs(config, config_abs_path) - kwargs = cls._parse_config(config, config_abs_path, kwargs) - return cls(**kwargs) - - @classmethod - @experimental - def _parse_config( - cls: Type[SelfAgent], - config: BaseAgentConfig, - config_abs_path: str, - kwargs: Dict[str, Any], - ) -> Dict[str, Any]: - """Parses the config and returns updated kwargs to construct the agent. - - Sub-classes should override this method to use a custom agent config class. - - Args: - config: The config to parse. - config_abs_path: The absolute path to the config file that contains the - agent config. - kwargs: The keyword arguments used for agent constructor. - - Returns: - The updated keyword arguments used for agent constructor. - """ - return kwargs - - @classmethod - def __create_kwargs( - cls, - config: BaseAgentConfig, - config_abs_path: str, - ) -> Dict[str, Any]: - """Creates kwargs for the fields of BaseAgent.""" - - from .config_agent_utils import resolve_agent_reference - from .config_agent_utils import resolve_callbacks - - kwargs: Dict[str, Any] = { - "name": config.name, - "description": config.description, - } - if config.sub_agents: - sub_agents = [] - for sub_agent_config in config.sub_agents: - sub_agent = resolve_agent_reference(sub_agent_config, config_abs_path) - sub_agents.append(sub_agent) - kwargs["sub_agents"] = sub_agents - - if config.before_agent_callbacks: - kwargs["before_agent_callback"] = resolve_callbacks( - config.before_agent_callbacks - ) - if config.after_agent_callbacks: - kwargs["after_agent_callback"] = resolve_callbacks( - config.after_agent_callbacks - ) - return kwargs + break + + # Process the override content if exists, and further process the state + # change if exists. + if after_agent_callback_content: + ret_event = Event( + invocation_id=invocation_context.invocation_id, + author=self.name, + branch=invocation_context.branch, + content=after_agent_callback_content, + actions=callback_context._event_actions, + ) + return ret_event + + if callback_context.state.has_delta(): + return Event( + invocation_id=invocation_context.invocation_id, + author=self.name, + branch=invocation_context.branch, + content=after_agent_callback_content, + actions=callback_context._event_actions, + ) + return None + + @override + def model_post_init(self, __context: Any) -> None: + self.__set_parent_agent_for_sub_agents() + + @field_validator("name", mode="after") + @classmethod + def validate_name(cls, value: str): + if not value.isidentifier(): + raise ValueError( + f"Found invalid agent name: `{value}`." + " Agent name must be a valid identifier. It should start with a" + " letter (a-z, A-Z) or an underscore (_), and can only contain" + " letters, digits (0-9), and underscores." + ) + if value == "user": + raise ValueError( + "Agent name cannot be `user`. `user` is reserved for end-user's" + " input." + ) + return value + + @field_validator("sub_agents", mode="after") + @classmethod + def validate_sub_agents_unique_names( + cls, value: list[BaseAgent] + ) -> list[BaseAgent]: + """Validates that all sub-agents have unique names. + + Args: + value: The list of sub-agents to validate. + + Returns: + The validated list of sub-agents. + + """ + if not value: + return value + + seen_names: set[str] = set() + duplicates: set[str] = set() + + for sub_agent in value: + name = sub_agent.name + if name in seen_names: + duplicates.add(name) + else: + seen_names.add(name) + + if duplicates: + duplicate_names_str = ", ".join( + f"`{name}`" for name in sorted(duplicates) + ) + logger.warning( + "Found duplicate sub-agent names: %s. " + "All sub-agents must have unique names.", + duplicate_names_str, + ) + + return value + + def __set_parent_agent_for_sub_agents(self) -> BaseAgent: + for sub_agent in self.sub_agents: + if sub_agent.parent_agent is not None: + raise ValueError( + f"Agent `{sub_agent.name}` already has a parent agent, current" + f" parent: `{sub_agent.parent_agent.name}`, trying to add:" + f" `{self.name}`" + ) + sub_agent.parent_agent = self + return self + + @final + @classmethod + @experimental + def from_config( + cls: Type[SelfAgent], + config: BaseAgentConfig, + config_abs_path: str, + ) -> SelfAgent: + """Creates an agent from a config. + + If sub-classes uses a custom agent config, override `_from_config_kwargs` + method to return an updated kwargs for agent constructor. + + Args: + config: The config to create the agent from. + config_abs_path: The absolute path to the config file that contains the + agent config. + + Returns: + The created agent. + """ + kwargs = cls.__create_kwargs(config, config_abs_path) + kwargs = cls._parse_config(config, config_abs_path, kwargs) + return cls(**kwargs) + + @classmethod + @experimental + def _parse_config( + cls: Type[SelfAgent], + config: BaseAgentConfig, + config_abs_path: str, + kwargs: Dict[str, Any], + ) -> Dict[str, Any]: + """Parses the config and returns updated kwargs to construct the agent. + + Sub-classes should override this method to use a custom agent config class. + + Args: + config: The config to parse. + config_abs_path: The absolute path to the config file that contains the + agent config. + kwargs: The keyword arguments used for agent constructor. + + Returns: + The updated keyword arguments used for agent constructor. + """ + return kwargs + + @classmethod + def __create_kwargs( + cls, + config: BaseAgentConfig, + config_abs_path: str, + ) -> Dict[str, Any]: + """Creates kwargs for the fields of BaseAgent.""" + + from .config_agent_utils import resolve_agent_reference + from .config_agent_utils import resolve_callbacks + + kwargs: Dict[str, Any] = { + "name": config.name, + "description": config.description, + } + if config.sub_agents: + sub_agents = [] + for sub_agent_config in config.sub_agents: + sub_agent = resolve_agent_reference(sub_agent_config, config_abs_path) + sub_agents.append(sub_agent) + kwargs["sub_agents"] = sub_agents + + if config.before_agent_callbacks: + kwargs["before_agent_callback"] = resolve_callbacks( + config.before_agent_callbacks + ) + if config.after_agent_callbacks: + kwargs["after_agent_callback"] = resolve_callbacks( + config.after_agent_callbacks + ) + return kwargs diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index 82542b0a9d..5862397f58 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -126,7 +126,9 @@ list[_SingleOnToolErrorCallback], ] -InstructionProvider: TypeAlias = Callable[[ReadonlyContext], Union[str, Awaitable[str]]] +InstructionProvider: TypeAlias = Callable[ + [ReadonlyContext], Union[str, Awaitable[str]] +] ToolUnion: TypeAlias = Union[Callable, BaseTool, BaseToolset] @@ -137,62 +139,62 @@ async def _convert_tool_union_to_tools( model: Union[str, BaseLlm], multiple_tools: bool = False, ) -> list[BaseTool]: - from ..tools.google_search_tool import GoogleSearchTool - from ..tools.vertex_ai_search_tool import VertexAiSearchTool + from ..tools.google_search_tool import GoogleSearchTool + from ..tools.vertex_ai_search_tool import VertexAiSearchTool + + # Wrap google_search tool with AgentTool if there are multiple tools because + # the built-in tools cannot be used together with other tools. + # TODO(b/448114567): Remove once the workaround is no longer needed. + if multiple_tools and isinstance(tool_union, GoogleSearchTool): + from ..tools.google_search_agent_tool import create_google_search_agent + from ..tools.google_search_agent_tool import GoogleSearchAgentTool + + search_tool = cast(GoogleSearchTool, tool_union) + if search_tool.bypass_multi_tools_limit: + return [GoogleSearchAgentTool(create_google_search_agent(model))] + + # Replace VertexAiSearchTool with DiscoveryEngineSearchTool if there are + # multiple tools because the built-in tools cannot be used together with + # other tools. + # TODO(b/448114567): Remove once the workaround is no longer needed. + if multiple_tools and isinstance(tool_union, VertexAiSearchTool): + from ..tools.discovery_engine_search_tool import DiscoveryEngineSearchTool + + vais_tool = cast(VertexAiSearchTool, tool_union) + if vais_tool.bypass_multi_tools_limit: + return [ + DiscoveryEngineSearchTool( + data_store_id=vais_tool.data_store_id, + data_store_specs=vais_tool.data_store_specs, + search_engine_id=vais_tool.search_engine_id, + filter=vais_tool.filter, + max_results=vais_tool.max_results, + ) + ] - # Wrap google_search tool with AgentTool if there are multiple tools because - # the built-in tools cannot be used together with other tools. - # TODO(b/448114567): Remove once the workaround is no longer needed. - if multiple_tools and isinstance(tool_union, GoogleSearchTool): - from ..tools.google_search_agent_tool import create_google_search_agent - from ..tools.google_search_agent_tool import GoogleSearchAgentTool + if isinstance(tool_union, BaseTool): + return [tool_union] + if callable(tool_union): + return [FunctionTool(func=tool_union)] - search_tool = cast(GoogleSearchTool, tool_union) - if search_tool.bypass_multi_tools_limit: - return [GoogleSearchAgentTool(create_google_search_agent(model))] - - # Replace VertexAiSearchTool with DiscoveryEngineSearchTool if there are - # multiple tools because the built-in tools cannot be used together with - # other tools. - # TODO(b/448114567): Remove once the workaround is no longer needed. - if multiple_tools and isinstance(tool_union, VertexAiSearchTool): - from ..tools.discovery_engine_search_tool import DiscoveryEngineSearchTool - - vais_tool = cast(VertexAiSearchTool, tool_union) - if vais_tool.bypass_multi_tools_limit: - return [ - DiscoveryEngineSearchTool( - data_store_id=vais_tool.data_store_id, - data_store_specs=vais_tool.data_store_specs, - search_engine_id=vais_tool.search_engine_id, - filter=vais_tool.filter, - max_results=vais_tool.max_results, - ) - ] - - if isinstance(tool_union, BaseTool): - return [tool_union] - if callable(tool_union): - return [FunctionTool(func=tool_union)] - - # At this point, tool_union must be a BaseToolset - return await tool_union.get_tools_with_prefix(ctx) + # At this point, tool_union must be a BaseToolset + return await tool_union.get_tools_with_prefix(ctx) class LlmAgent(BaseAgent): - """LLM-based Agent.""" + """LLM-based Agent.""" - model: Union[str, BaseLlm] = "" - """The model to use for the agent. + model: Union[str, BaseLlm] = "" + """The model to use for the agent. When not set, the agent will inherit the model from its ancestor. """ - config_type: ClassVar[Type[BaseAgentConfig]] = LlmAgentConfig - """The config type for this agent.""" + config_type: ClassVar[Type[BaseAgentConfig]] = LlmAgentConfig + """The config type for this agent.""" - instruction: Union[str, InstructionProvider] = "" - """Dynamic instructions for the LLM model, guiding the agent's behavior. + instruction: Union[str, InstructionProvider] = "" + """Dynamic instructions for the LLM model, guiding the agent's behavior. These instructions can contain placeholders like {variable_name} that will be resolved at runtime using session state and context. @@ -205,8 +207,8 @@ class LlmAgent(BaseAgent): comes first in the prompt, followed by dynamic content (instruction). """ - global_instruction: Union[str, InstructionProvider] = "" - """Instructions for all the agents in the entire agent tree. + global_instruction: Union[str, InstructionProvider] = "" + """Instructions for all the agents in the entire agent tree. DEPRECATED: This field is deprecated and will be removed in a future version. Use GlobalInstructionPlugin instead, which provides the same functionality @@ -218,8 +220,8 @@ class LlmAgent(BaseAgent): or personality. """ - static_instruction: Optional[types.ContentUnion] = None - """Static instruction content sent literally as system instruction at the beginning. + static_instruction: Optional[types.ContentUnion] = None + """Static instruction content sent literally as system instruction at the beginning. This field is for content that never changes and doesn't contain placeholders. It's sent directly to the model without any processing or variable substitution. @@ -269,11 +271,11 @@ class LlmAgent(BaseAgent): ``` """ - tools: list[ToolUnion] = Field(default_factory=list) - """Tools available to this agent.""" + tools: list[ToolUnion] = Field(default_factory=list) + """Tools available to this agent.""" - generate_content_config: Optional[types.GenerateContentConfig] = None - """The additional content generation configurations. + generate_content_config: Optional[types.GenerateContentConfig] = None + """The additional content generation configurations. NOTE: not all fields are usable, e.g. tools must be configured via `tools`, thinking_config must be configured via `planner` in LlmAgent. @@ -282,21 +284,21 @@ class LlmAgent(BaseAgent): settings, etc. """ - # LLM-based agent transfer configs - Start - disallow_transfer_to_parent: bool = False - """Disallows LLM-controlled transferring to the parent agent. + # LLM-based agent transfer configs - Start + disallow_transfer_to_parent: bool = False + """Disallows LLM-controlled transferring to the parent agent. NOTE: Setting this as True also prevents this agent from continuing to reply to the end-user, and will transfer control back to the parent agent in the next turn. This behavior prevents one-way transfer, in which end-user may be stuck with one agent that cannot transfer to other agents in the agent tree. """ - disallow_transfer_to_peers: bool = False - """Disallows LLM-controlled transferring to the peer agents.""" - # LLM-based agent transfer configs - End + disallow_transfer_to_peers: bool = False + """Disallows LLM-controlled transferring to the peer agents.""" + # LLM-based agent transfer configs - End - include_contents: Literal["default", "none"] = "default" - """Controls content inclusion in model requests. + include_contents: Literal["default", "none"] = "default" + """Controls content inclusion in model requests. Options: default: Model receives relevant conversation history @@ -304,36 +306,36 @@ class LlmAgent(BaseAgent): instruction and input """ - # Controlled input/output configurations - Start - input_schema: Optional[type[BaseModel]] = None - """The input schema when agent is used as a tool.""" - output_schema: Optional[type[BaseModel]] = None - """The output schema when agent replies. + # Controlled input/output configurations - Start + input_schema: Optional[type[BaseModel]] = None + """The input schema when agent is used as a tool.""" + output_schema: Optional[type[BaseModel]] = None + """The output schema when agent replies. NOTE: When this is set, agent can ONLY reply and CANNOT use any tools, such as function tools, RAGs, agent transfer, etc. """ - output_key: Optional[str] = None - """The key in session state to store the output of the agent. + output_key: Optional[str] = None + """The key in session state to store the output of the agent. Typically use cases: - Extracts agent reply for later use, such as in tools, callbacks, etc. - Connects agents to coordinate with each other. """ - # Controlled input/output configurations - End + # Controlled input/output configurations - End - # Advance features - Start - planner: Optional[BasePlanner] = None - """Instructs the agent to make a plan and execute it step by step. + # Advance features - Start + planner: Optional[BasePlanner] = None + """Instructs the agent to make a plan and execute it step by step. NOTE: To use model's built-in thinking features, set the `thinking_config` field in `google.adk.planners.built_in_planner`. """ - code_executor: Optional[BaseCodeExecutor] = None - """Allow agent to execute code blocks from model responses using the provided + code_executor: Optional[BaseCodeExecutor] = None + """Allow agent to execute code blocks from model responses using the provided CodeExecutor. Check out available code executions in `google.adk.code_executor` package. @@ -341,11 +343,11 @@ class LlmAgent(BaseAgent): NOTE: To use model's built-in code executor, use the `BuiltInCodeExecutor`. """ - # Advance features - End + # Advance features - End - # Callbacks - Start - before_model_callback: Optional[BeforeModelCallback] = None - """Callback or list of callbacks to be called before calling the LLM. + # Callbacks - Start + before_model_callback: Optional[BeforeModelCallback] = None + """Callback or list of callbacks to be called before calling the LLM. When a list of callbacks is provided, the callbacks will be called in the order they are listed until a callback does not return None. @@ -359,8 +361,8 @@ class LlmAgent(BaseAgent): The content to return to the user. When present, the model call will be skipped and the provided content will be returned to user. """ - after_model_callback: Optional[AfterModelCallback] = None - """Callback or list of callbacks to be called after calling the LLM. + after_model_callback: Optional[AfterModelCallback] = None + """Callback or list of callbacks to be called after calling the LLM. When a list of callbacks is provided, the callbacks will be called in the order they are listed until a callback does not return None. @@ -373,8 +375,8 @@ class LlmAgent(BaseAgent): The content to return to the user. When present, the actual model response will be ignored and the provided content will be returned to user. """ - on_model_error_callback: Optional[OnModelErrorCallback] = None - """Callback or list of callbacks to be called when a model call encounters an error. + on_model_error_callback: Optional[OnModelErrorCallback] = None + """Callback or list of callbacks to be called when a model call encounters an error. When a list of callbacks is provided, the callbacks will be called in the order they are listed until a callback does not return None. @@ -388,8 +390,8 @@ class LlmAgent(BaseAgent): The content to return to the user. When present, the error will be ignored and the provided content will be returned to user. """ - before_tool_callback: Optional[BeforeToolCallback] = None - """Callback or list of callbacks to be called before calling the tool. + before_tool_callback: Optional[BeforeToolCallback] = None + """Callback or list of callbacks to be called before calling the tool. When a list of callbacks is provided, the callbacks will be called in the order they are listed until a callback does not return None. @@ -403,8 +405,8 @@ class LlmAgent(BaseAgent): The tool response. When present, the returned tool response will be used and the framework will skip calling the actual tool. """ - after_tool_callback: Optional[AfterToolCallback] = None - """Callback or list of callbacks to be called after calling the tool. + after_tool_callback: Optional[AfterToolCallback] = None + """Callback or list of callbacks to be called after calling the tool. When a list of callbacks is provided, the callbacks will be called in the order they are listed until a callback does not return None. @@ -418,8 +420,8 @@ class LlmAgent(BaseAgent): Returns: When present, the returned dict will be used as tool result. """ - on_tool_error_callback: Optional[OnToolErrorCallback] = None - """Callback or list of callbacks to be called when a tool call encounters an error. + on_tool_error_callback: Optional[OnToolErrorCallback] = None + """Callback or list of callbacks to be called when a tool call encounters an error. When a list of callbacks is provided, the callbacks will be called in the order they are listed until a callback does not return None. @@ -433,515 +435,525 @@ class LlmAgent(BaseAgent): Returns: When present, the returned dict will be used as tool result. """ - # Callbacks - End - - @override - async def _run_async_impl( - self, ctx: InvocationContext - ) -> AsyncGenerator[Event, None]: - agent_state = self._load_agent_state(ctx, BaseAgentState) - - # If there is a sub-agent to resume, run it and then end the current - # agent. - if agent_state is not None and ( - agent_to_transfer := self._get_subagent_to_resume(ctx) - ): - async with Aclosing(agent_to_transfer.run_async(ctx)) as agen: - async for event in agen: - yield event - - ctx.set_agent_state(self.name, end_of_agent=True) - yield self._create_agent_state_event(ctx) - return - - should_pause = False - async with Aclosing(self._llm_flow.run_async(ctx)) as agen: - async for event in agen: - self.__maybe_save_output_to_state(event) - yield event - if ctx.should_pause_invocation(event): - # Do not pause immediately, wait until the long running tool call is - # executed. - should_pause = True - if should_pause: - return - - if ctx.is_resumable: - events = ctx._get_events(current_invocation=True, current_branch=True) - if events and any(ctx.should_pause_invocation(e) for e in events[-2:]): - return - # Only yield an end state if the last event is no longer a long running - # tool call. - ctx.set_agent_state(self.name, end_of_agent=True) - yield self._create_agent_state_event(ctx) - - @override - async def _run_live_impl( - self, ctx: InvocationContext - ) -> AsyncGenerator[Event, None]: - async with Aclosing(self._llm_flow.run_live(ctx)) as agen: - async for event in agen: - self.__maybe_save_output_to_state(event) - yield event - if ctx.end_invocation: - return - - @property - def canonical_model(self) -> BaseLlm: - """The resolved self.model field as BaseLlm. - - This method is only for use by Agent Development Kit. - """ - if isinstance(self.model, BaseLlm): - return self.model - elif self.model: # model is non-empty str - return LLMRegistry.new_llm(self.model) - else: # find model from ancestors. - ancestor_agent = self.parent_agent - while ancestor_agent is not None: - if isinstance(ancestor_agent, LlmAgent): - return ancestor_agent.canonical_model - ancestor_agent = ancestor_agent.parent_agent - raise ValueError(f"No model found for {self.name}.") - - async def canonical_instruction(self, ctx: ReadonlyContext) -> tuple[str, bool]: - """The resolved self.instruction field to construct instruction for this agent. - - This method is only for use by Agent Development Kit. - - Args: - ctx: The context to retrieve the session state. - - Returns: - A tuple of (instruction, bypass_state_injection). - instruction: The resolved self.instruction field. - bypass_state_injection: Whether the instruction is based on - InstructionProvider. - """ - if isinstance(self.instruction, str): - return self.instruction, False - else: - instruction = self.instruction(ctx) - if inspect.isawaitable(instruction): - instruction = await instruction - return instruction, True - - async def canonical_global_instruction( - self, ctx: ReadonlyContext - ) -> tuple[str, bool]: - """The resolved self.instruction field to construct global instruction. - - This method is only for use by Agent Development Kit. - - Args: - ctx: The context to retrieve the session state. - - Returns: - A tuple of (instruction, bypass_state_injection). - instruction: The resolved self.global_instruction field. - bypass_state_injection: Whether the instruction is based on - InstructionProvider. - """ - # Issue deprecation warning if global_instruction is being used - if self.global_instruction: - warnings.warn( - "global_instruction field is deprecated and will be removed in a" - " future version. Use GlobalInstructionPlugin instead for the same" - " functionality at the App level. See migration guide for details.", - DeprecationWarning, - stacklevel=2, - ) - - if isinstance(self.global_instruction, str): - return self.global_instruction, False - else: - global_instruction = self.global_instruction(ctx) - if inspect.isawaitable(global_instruction): - global_instruction = await global_instruction - return global_instruction, True - - async def canonical_tools(self, ctx: ReadonlyContext = None) -> list[BaseTool]: - """The resolved self.tools field as a list of BaseTool based on the context. - - This method is only for use by Agent Development Kit. - """ - resolved_tools = [] - # We may need to wrap some built-in tools if there are other tools - # because the built-in tools cannot be used together with other tools. - # TODO(b/448114567): Remove once the workaround is no longer needed. - multiple_tools = len(self.tools) > 1 - for tool_union in self.tools: - resolved_tools.extend( - await _convert_tool_union_to_tools( - tool_union, ctx, self.model, multiple_tools - ) - ) - return resolved_tools - - @property - def canonical_before_model_callbacks( - self, - ) -> list[_SingleBeforeModelCallback]: - """The resolved self.before_model_callback field as a list of _SingleBeforeModelCallback. - - This method is only for use by Agent Development Kit. - """ - if not self.before_model_callback: - return [] - if isinstance(self.before_model_callback, list): - return self.before_model_callback - return [self.before_model_callback] - - @property - def canonical_after_model_callbacks(self) -> list[_SingleAfterModelCallback]: - """The resolved self.after_model_callback field as a list of _SingleAfterModelCallback. - - This method is only for use by Agent Development Kit. - """ - if not self.after_model_callback: - return [] - if isinstance(self.after_model_callback, list): - return self.after_model_callback - return [self.after_model_callback] - - @property - def canonical_on_model_error_callbacks( - self, - ) -> list[_SingleOnModelErrorCallback]: - """The resolved self.on_model_error_callback field as a list of _SingleOnModelErrorCallback. - - This method is only for use by Agent Development Kit. - """ - if not self.on_model_error_callback: - return [] - if isinstance(self.on_model_error_callback, list): - return self.on_model_error_callback - return [self.on_model_error_callback] - - @property - def canonical_before_tool_callbacks( - self, - ) -> list[BeforeToolCallback]: - """The resolved self.before_tool_callback field as a list of BeforeToolCallback. - - This method is only for use by Agent Development Kit. - """ - if not self.before_tool_callback: - return [] - if isinstance(self.before_tool_callback, list): - return self.before_tool_callback - return [self.before_tool_callback] - - @property - def canonical_after_tool_callbacks( - self, - ) -> list[AfterToolCallback]: - """The resolved self.after_tool_callback field as a list of AfterToolCallback. - - This method is only for use by Agent Development Kit. - """ - if not self.after_tool_callback: - return [] - if isinstance(self.after_tool_callback, list): - return self.after_tool_callback - return [self.after_tool_callback] - - @property - def canonical_on_tool_error_callbacks( - self, - ) -> list[OnToolErrorCallback]: - """The resolved self.on_tool_error_callback field as a list of OnToolErrorCallback. - - This method is only for use by Agent Development Kit. - """ - if not self.on_tool_error_callback: - return [] - if isinstance(self.on_tool_error_callback, list): - return self.on_tool_error_callback - return [self.on_tool_error_callback] - - @property - def _llm_flow(self) -> BaseLlmFlow: - if ( - self.disallow_transfer_to_parent - and self.disallow_transfer_to_peers - and not self.sub_agents - ): - return SingleFlow() - else: - return AutoFlow() - - def _get_subagent_to_resume(self, ctx: InvocationContext) -> Optional[BaseAgent]: - """Returns the sub-agent in the llm tree to resume if it exists. - - There are 2 cases where we need to transfer to and resume a sub-agent: - 1. The last event is a transfer to agent response from the current agent. - In this case, we need to return the agent specified in the response. - - 2. The last event's author isn't the current agent, or the user is - responding to another agent's tool call. - In this case, we need to return the LAST agent being transferred to - from the current agent. - """ - events = ctx._get_events(current_invocation=True, current_branch=True) - if not events: - return None - - last_event = events[-1] - if last_event.author == self.name: - # Last event is from current agent. Return transfer_to_agent in the event - # if it exists, or None. - return self.__get_transfer_to_agent_or_none(last_event, self.name) - - # Last event is from user or another agent. - if last_event.author == "user": - function_call_event = ctx._find_matching_function_call(last_event) - if not function_call_event: - raise ValueError( - "No agent to transfer to for resuming agent from function response" - f" {self.name}" - ) - if function_call_event.author == self.name: - # User is responding to a tool call from the current agent. - # Current agent should continue, so no sub-agent to resume. - return None - - # Last event is from another agent, or from user for another agent's tool - # call. We need to find the last agent we transferred to. - for event in reversed(events): - if agent := self.__get_transfer_to_agent_or_none(event, self.name): - return agent - - return None - - def __get_agent_to_run(self, agent_name: str) -> BaseAgent: - """Find the agent to run under the root agent by name.""" - agent_to_run = self.root_agent.find_agent(agent_name) - if not agent_to_run: - available = self._get_available_agent_names() - error_msg = ( - f"Agent '{agent_name}' not found.\n" - f"Available agents: {', '.join(available)}\n\n" - "Possible causes:\n" - " 1. Agent not registered before being referenced\n" - " 2. Agent name mismatch (typo or case sensitivity)\n" - " 3. Timing issue (agent referenced before creation)\n\n" - "Suggested fixes:\n" - " - Verify agent is registered with root agent\n" - " - Check agent name spelling and case\n" - " - Ensure agents are created before being referenced" - ) - raise ValueError(error_msg) - return agent_to_run - - def _get_available_agent_names(self) -> list[str]: - """Helper to get all agent names in the tree for error reporting. - - This is a private helper method used only for error message formatting. - Traverses the agent tree starting from root_agent and collects all - agent names for display in error messages. - - Returns: - List of all agent names in the agent tree. - """ - agents = [] - - def collect_agents(agent): - agents.append(agent.name) - if hasattr(agent, "sub_agents") and agent.sub_agents: - for sub_agent in agent.sub_agents: - collect_agents(sub_agent) - - collect_agents(self.root_agent) - return agents - - def __get_transfer_to_agent_or_none( - self, event: Event, from_agent: str - ) -> Optional[BaseAgent]: - """Returns the agent to run if the event is a transfer to agent response.""" - function_responses = event.get_function_responses() - if not function_responses: - return None - for function_response in function_responses: - if ( - function_response.name == "transfer_to_agent" - and event.author == from_agent - and event.actions.transfer_to_agent != from_agent - ): - return self.__get_agent_to_run(event.actions.transfer_to_agent) + # Callbacks - End + + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + agent_state = self._load_agent_state(ctx, BaseAgentState) + + # If there is a sub-agent to resume, run it and then end the current + # agent. + if agent_state is not None and ( + agent_to_transfer := self._get_subagent_to_resume(ctx) + ): + async with Aclosing(agent_to_transfer.run_async(ctx)) as agen: + async for event in agen: + yield event + + ctx.set_agent_state(self.name, end_of_agent=True) + yield self._create_agent_state_event(ctx) + return + + should_pause = False + async with Aclosing(self._llm_flow.run_async(ctx)) as agen: + async for event in agen: + self.__maybe_save_output_to_state(event) + yield event + if ctx.should_pause_invocation(event): + # Do not pause immediately, wait until the long running tool call is + # executed. + should_pause = True + if should_pause: + return + + if ctx.is_resumable: + events = ctx._get_events(current_invocation=True, current_branch=True) + if events and any(ctx.should_pause_invocation(e) for e in events[-2:]): + return + # Only yield an end state if the last event is no longer a long running + # tool call. + ctx.set_agent_state(self.name, end_of_agent=True) + yield self._create_agent_state_event(ctx) + + @override + async def _run_live_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + async with Aclosing(self._llm_flow.run_live(ctx)) as agen: + async for event in agen: + self.__maybe_save_output_to_state(event) + yield event + if ctx.end_invocation: + return + + @property + def canonical_model(self) -> BaseLlm: + """The resolved self.model field as BaseLlm. + + This method is only for use by Agent Development Kit. + """ + if isinstance(self.model, BaseLlm): + return self.model + elif self.model: # model is non-empty str + return LLMRegistry.new_llm(self.model) + else: # find model from ancestors. + ancestor_agent = self.parent_agent + while ancestor_agent is not None: + if isinstance(ancestor_agent, LlmAgent): + return ancestor_agent.canonical_model + ancestor_agent = ancestor_agent.parent_agent + raise ValueError(f"No model found for {self.name}.") + + async def canonical_instruction( + self, ctx: ReadonlyContext + ) -> tuple[str, bool]: + """The resolved self.instruction field to construct instruction for this agent. + + This method is only for use by Agent Development Kit. + + Args: + ctx: The context to retrieve the session state. + + Returns: + A tuple of (instruction, bypass_state_injection). + instruction: The resolved self.instruction field. + bypass_state_injection: Whether the instruction is based on + InstructionProvider. + """ + if isinstance(self.instruction, str): + return self.instruction, False + else: + instruction = self.instruction(ctx) + if inspect.isawaitable(instruction): + instruction = await instruction + return instruction, True + + async def canonical_global_instruction( + self, ctx: ReadonlyContext + ) -> tuple[str, bool]: + """The resolved self.instruction field to construct global instruction. + + This method is only for use by Agent Development Kit. + + Args: + ctx: The context to retrieve the session state. + + Returns: + A tuple of (instruction, bypass_state_injection). + instruction: The resolved self.global_instruction field. + bypass_state_injection: Whether the instruction is based on + InstructionProvider. + """ + # Issue deprecation warning if global_instruction is being used + if self.global_instruction: + warnings.warn( + "global_instruction field is deprecated and will be removed in a" + " future version. Use GlobalInstructionPlugin instead for the same" + " functionality at the App level. See migration guide for details.", + DeprecationWarning, + stacklevel=2, + ) + + if isinstance(self.global_instruction, str): + return self.global_instruction, False + else: + global_instruction = self.global_instruction(ctx) + if inspect.isawaitable(global_instruction): + global_instruction = await global_instruction + return global_instruction, True + + async def canonical_tools( + self, ctx: ReadonlyContext = None + ) -> list[BaseTool]: + """The resolved self.tools field as a list of BaseTool based on the context. + + This method is only for use by Agent Development Kit. + """ + resolved_tools = [] + # We may need to wrap some built-in tools if there are other tools + # because the built-in tools cannot be used together with other tools. + # TODO(b/448114567): Remove once the workaround is no longer needed. + multiple_tools = len(self.tools) > 1 + for tool_union in self.tools: + resolved_tools.extend( + await _convert_tool_union_to_tools( + tool_union, ctx, self.model, multiple_tools + ) + ) + return resolved_tools + + @property + def canonical_before_model_callbacks( + self, + ) -> list[_SingleBeforeModelCallback]: + """The resolved self.before_model_callback field as a list of _SingleBeforeModelCallback. + + This method is only for use by Agent Development Kit. + """ + if not self.before_model_callback: + return [] + if isinstance(self.before_model_callback, list): + return self.before_model_callback + return [self.before_model_callback] + + @property + def canonical_after_model_callbacks(self) -> list[_SingleAfterModelCallback]: + """The resolved self.after_model_callback field as a list of _SingleAfterModelCallback. + + This method is only for use by Agent Development Kit. + """ + if not self.after_model_callback: + return [] + if isinstance(self.after_model_callback, list): + return self.after_model_callback + return [self.after_model_callback] + + @property + def canonical_on_model_error_callbacks( + self, + ) -> list[_SingleOnModelErrorCallback]: + """The resolved self.on_model_error_callback field as a list of _SingleOnModelErrorCallback. + + This method is only for use by Agent Development Kit. + """ + if not self.on_model_error_callback: + return [] + if isinstance(self.on_model_error_callback, list): + return self.on_model_error_callback + return [self.on_model_error_callback] + + @property + def canonical_before_tool_callbacks( + self, + ) -> list[BeforeToolCallback]: + """The resolved self.before_tool_callback field as a list of BeforeToolCallback. + + This method is only for use by Agent Development Kit. + """ + if not self.before_tool_callback: + return [] + if isinstance(self.before_tool_callback, list): + return self.before_tool_callback + return [self.before_tool_callback] + + @property + def canonical_after_tool_callbacks( + self, + ) -> list[AfterToolCallback]: + """The resolved self.after_tool_callback field as a list of AfterToolCallback. + + This method is only for use by Agent Development Kit. + """ + if not self.after_tool_callback: + return [] + if isinstance(self.after_tool_callback, list): + return self.after_tool_callback + return [self.after_tool_callback] + + @property + def canonical_on_tool_error_callbacks( + self, + ) -> list[OnToolErrorCallback]: + """The resolved self.on_tool_error_callback field as a list of OnToolErrorCallback. + + This method is only for use by Agent Development Kit. + """ + if not self.on_tool_error_callback: + return [] + if isinstance(self.on_tool_error_callback, list): + return self.on_tool_error_callback + return [self.on_tool_error_callback] + + @property + def _llm_flow(self) -> BaseLlmFlow: + if ( + self.disallow_transfer_to_parent + and self.disallow_transfer_to_peers + and not self.sub_agents + ): + return SingleFlow() + else: + return AutoFlow() + + def _get_subagent_to_resume( + self, ctx: InvocationContext + ) -> Optional[BaseAgent]: + """Returns the sub-agent in the llm tree to resume if it exists. + + There are 2 cases where we need to transfer to and resume a sub-agent: + 1. The last event is a transfer to agent response from the current agent. + In this case, we need to return the agent specified in the response. + + 2. The last event's author isn't the current agent, or the user is + responding to another agent's tool call. + In this case, we need to return the LAST agent being transferred to + from the current agent. + """ + events = ctx._get_events(current_invocation=True, current_branch=True) + if not events: + return None + + last_event = events[-1] + if last_event.author == self.name: + # Last event is from current agent. Return transfer_to_agent in the event + # if it exists, or None. + return self.__get_transfer_to_agent_or_none(last_event, self.name) + + # Last event is from user or another agent. + if last_event.author == "user": + function_call_event = ctx._find_matching_function_call(last_event) + if not function_call_event: + raise ValueError( + "No agent to transfer to for resuming agent from function response" + f" {self.name}" + ) + if function_call_event.author == self.name: + # User is responding to a tool call from the current agent. + # Current agent should continue, so no sub-agent to resume. return None - def __maybe_save_output_to_state(self, event: Event): - """Saves the model output to state if needed.""" - # skip if the event was authored by some other agent (e.g. current agent - # transferred to another agent) - if event.author != self.name: - logger.debug( - "Skipping output save for agent %s: event authored by %s", - self.name, - event.author, - ) - return - if ( - self.output_key - and event.is_final_response() - and event.content - and event.content.parts - ): - - result = "".join( - part.text - for part in event.content.parts - if part.text and not part.thought - ) - if self.output_schema: - # If the result from the final chunk is just whitespace or empty, - # it means this is an empty final chunk of a stream. - # Do not attempt to parse it as JSON. - if not result.strip(): - return - result = self.output_schema.model_validate_json(result).model_dump( - exclude_none=True - ) - event.actions.state_delta[self.output_key] = result - - @model_validator(mode="after") - def __model_validator_after(self) -> LlmAgent: - root_agent = getattr(self, "root_agent", None) or self - disable_telemetry: bool = not is_telemetry_enabled(root_agent) - if hasattr(self.model, "disable_telemetry"): - self.model.disable_telemetry = disable_telemetry - return self - - @field_validator("generate_content_config", mode="after") - @classmethod - def validate_generate_content_config( - cls, generate_content_config: Optional[types.GenerateContentConfig] - ) -> types.GenerateContentConfig: - if not generate_content_config: - return types.GenerateContentConfig() - if generate_content_config.thinking_config: - raise ValueError("Thinking config should be set via LlmAgent.planner.") - if generate_content_config.tools: - raise ValueError("All tools must be set via LlmAgent.tools.") - if generate_content_config.system_instruction: - raise ValueError("System instruction must be set via LlmAgent.instruction.") - if generate_content_config.response_schema: - raise ValueError("Response schema must be set via LlmAgent.output_schema.") - return generate_content_config - - @classmethod - @experimental - def _resolve_tools( - cls, tool_configs: list[ToolConfig], config_abs_path: str - ) -> list[Any]: - """Resolve tools from configuration. - - Args: - tool_configs: List of tool configurations (ToolConfig objects). - config_abs_path: The absolute path to the agent config file. - - Returns: - List of resolved tool objects. - """ - - resolved_tools = [] - for tool_config in tool_configs: - if "." not in tool_config.name: - # ADK built-in tools - module = importlib.import_module("google.adk.tools") - obj = getattr(module, tool_config.name) - else: - # User-defined tools - module_path, obj_name = tool_config.name.rsplit(".", 1) - module = importlib.import_module(module_path) - obj = getattr(module, obj_name) - - if isinstance(obj, BaseTool) or isinstance(obj, BaseToolset): - logger.debug( - "Tool %s is an instance of BaseTool/BaseToolset.", tool_config.name - ) - resolved_tools.append(obj) - elif inspect.isclass(obj) and ( - issubclass(obj, BaseTool) or issubclass(obj, BaseToolset) - ): - logger.debug( - "Tool %s is a sub-class of BaseTool/BaseToolset.", tool_config.name - ) - resolved_tools.append( - obj.from_config(tool_config.args, config_abs_path) - ) - elif callable(obj): - if tool_config.args: - logger.debug( - "Tool %s is a user-defined tool-generating function.", - tool_config.name, - ) - resolved_tools.append(obj(tool_config.args)) - else: - logger.debug( - "Tool %s is a user-defined function tool.", tool_config.name - ) - resolved_tools.append(obj) - else: - raise ValueError(f"Invalid tool YAML config: {tool_config}.") - - return resolved_tools - - @override - @classmethod - @experimental - def _parse_config( - cls: Type[LlmAgent], - config: LlmAgentConfig, - config_abs_path: str, - kwargs: Dict[str, Any], - ) -> Dict[str, Any]: - from .config_agent_utils import resolve_callbacks - from .config_agent_utils import resolve_code_reference - - if config.model_code: - kwargs["model"] = resolve_code_reference(config.model_code) - elif config.model: - kwargs["model"] = config.model - if config.instruction: - kwargs["instruction"] = config.instruction - if config.static_instruction: - kwargs["static_instruction"] = config.static_instruction - if config.disallow_transfer_to_parent: - kwargs["disallow_transfer_to_parent"] = config.disallow_transfer_to_parent - if config.disallow_transfer_to_peers: - kwargs["disallow_transfer_to_peers"] = config.disallow_transfer_to_peers - if config.include_contents != "default": - kwargs["include_contents"] = config.include_contents - if config.input_schema: - kwargs["input_schema"] = resolve_code_reference(config.input_schema) - if config.output_schema: - kwargs["output_schema"] = resolve_code_reference(config.output_schema) - if config.output_key: - kwargs["output_key"] = config.output_key - if config.tools: - kwargs["tools"] = cls._resolve_tools(config.tools, config_abs_path) - if config.before_model_callbacks: - kwargs["before_model_callback"] = resolve_callbacks( - config.before_model_callbacks - ) - if config.after_model_callbacks: - kwargs["after_model_callback"] = resolve_callbacks( - config.after_model_callbacks - ) - if config.before_tool_callbacks: - kwargs["before_tool_callback"] = resolve_callbacks( - config.before_tool_callbacks - ) - if config.after_tool_callbacks: - kwargs["after_tool_callback"] = resolve_callbacks( - config.after_tool_callbacks - ) - if config.generate_content_config: - kwargs["generate_content_config"] = config.generate_content_config - - return kwargs + # Last event is from another agent, or from user for another agent's tool + # call. We need to find the last agent we transferred to. + for event in reversed(events): + if agent := self.__get_transfer_to_agent_or_none(event, self.name): + return agent + + return None + + def __get_agent_to_run(self, agent_name: str) -> BaseAgent: + """Find the agent to run under the root agent by name.""" + agent_to_run = self.root_agent.find_agent(agent_name) + if not agent_to_run: + available = self._get_available_agent_names() + error_msg = ( + f"Agent '{agent_name}' not found.\n" + f"Available agents: {', '.join(available)}\n\n" + "Possible causes:\n" + " 1. Agent not registered before being referenced\n" + " 2. Agent name mismatch (typo or case sensitivity)\n" + " 3. Timing issue (agent referenced before creation)\n\n" + "Suggested fixes:\n" + " - Verify agent is registered with root agent\n" + " - Check agent name spelling and case\n" + " - Ensure agents are created before being referenced" + ) + raise ValueError(error_msg) + return agent_to_run + + def _get_available_agent_names(self) -> list[str]: + """Helper to get all agent names in the tree for error reporting. + + This is a private helper method used only for error message formatting. + Traverses the agent tree starting from root_agent and collects all + agent names for display in error messages. + + Returns: + List of all agent names in the agent tree. + """ + agents = [] + + def collect_agents(agent): + agents.append(agent.name) + if hasattr(agent, "sub_agents") and agent.sub_agents: + for sub_agent in agent.sub_agents: + collect_agents(sub_agent) + + collect_agents(self.root_agent) + return agents + + def __get_transfer_to_agent_or_none( + self, event: Event, from_agent: str + ) -> Optional[BaseAgent]: + """Returns the agent to run if the event is a transfer to agent response.""" + function_responses = event.get_function_responses() + if not function_responses: + return None + for function_response in function_responses: + if ( + function_response.name == "transfer_to_agent" + and event.author == from_agent + and event.actions.transfer_to_agent != from_agent + ): + return self.__get_agent_to_run(event.actions.transfer_to_agent) + return None + + def __maybe_save_output_to_state(self, event: Event): + """Saves the model output to state if needed.""" + # skip if the event was authored by some other agent (e.g. current agent + # transferred to another agent) + if event.author != self.name: + logger.debug( + "Skipping output save for agent %s: event authored by %s", + self.name, + event.author, + ) + return + if ( + self.output_key + and event.is_final_response() + and event.content + and event.content.parts + ): + + result = "".join( + part.text + for part in event.content.parts + if part.text and not part.thought + ) + if self.output_schema: + # If the result from the final chunk is just whitespace or empty, + # it means this is an empty final chunk of a stream. + # Do not attempt to parse it as JSON. + if not result.strip(): + return + result = self.output_schema.model_validate_json(result).model_dump( + exclude_none=True + ) + event.actions.state_delta[self.output_key] = result + + @model_validator(mode="after") + def __model_validator_after(self) -> LlmAgent: + root_agent = getattr(self, "root_agent", None) or self + disable_telemetry: bool = not is_telemetry_enabled(root_agent) + if hasattr(self.model, "disable_telemetry"): + self.model.disable_telemetry = disable_telemetry + return self + + @field_validator("generate_content_config", mode="after") + @classmethod + def validate_generate_content_config( + cls, generate_content_config: Optional[types.GenerateContentConfig] + ) -> types.GenerateContentConfig: + if not generate_content_config: + return types.GenerateContentConfig() + if generate_content_config.thinking_config: + raise ValueError("Thinking config should be set via LlmAgent.planner.") + if generate_content_config.tools: + raise ValueError("All tools must be set via LlmAgent.tools.") + if generate_content_config.system_instruction: + raise ValueError( + "System instruction must be set via LlmAgent.instruction." + ) + if generate_content_config.response_schema: + raise ValueError( + "Response schema must be set via LlmAgent.output_schema." + ) + return generate_content_config + + @classmethod + @experimental + def _resolve_tools( + cls, tool_configs: list[ToolConfig], config_abs_path: str + ) -> list[Any]: + """Resolve tools from configuration. + + Args: + tool_configs: List of tool configurations (ToolConfig objects). + config_abs_path: The absolute path to the agent config file. + + Returns: + List of resolved tool objects. + """ + + resolved_tools = [] + for tool_config in tool_configs: + if "." not in tool_config.name: + # ADK built-in tools + module = importlib.import_module("google.adk.tools") + obj = getattr(module, tool_config.name) + else: + # User-defined tools + module_path, obj_name = tool_config.name.rsplit(".", 1) + module = importlib.import_module(module_path) + obj = getattr(module, obj_name) + + if isinstance(obj, BaseTool) or isinstance(obj, BaseToolset): + logger.debug( + "Tool %s is an instance of BaseTool/BaseToolset.", tool_config.name + ) + resolved_tools.append(obj) + elif inspect.isclass(obj) and ( + issubclass(obj, BaseTool) or issubclass(obj, BaseToolset) + ): + logger.debug( + "Tool %s is a sub-class of BaseTool/BaseToolset.", tool_config.name + ) + resolved_tools.append( + obj.from_config(tool_config.args, config_abs_path) + ) + elif callable(obj): + if tool_config.args: + logger.debug( + "Tool %s is a user-defined tool-generating function.", + tool_config.name, + ) + resolved_tools.append(obj(tool_config.args)) + else: + logger.debug( + "Tool %s is a user-defined function tool.", tool_config.name + ) + resolved_tools.append(obj) + else: + raise ValueError(f"Invalid tool YAML config: {tool_config}.") + + return resolved_tools + + @override + @classmethod + @experimental + def _parse_config( + cls: Type[LlmAgent], + config: LlmAgentConfig, + config_abs_path: str, + kwargs: Dict[str, Any], + ) -> Dict[str, Any]: + from .config_agent_utils import resolve_callbacks + from .config_agent_utils import resolve_code_reference + + if config.model_code: + kwargs["model"] = resolve_code_reference(config.model_code) + elif config.model: + kwargs["model"] = config.model + if config.instruction: + kwargs["instruction"] = config.instruction + if config.static_instruction: + kwargs["static_instruction"] = config.static_instruction + if config.disallow_transfer_to_parent: + kwargs["disallow_transfer_to_parent"] = config.disallow_transfer_to_parent + if config.disallow_transfer_to_peers: + kwargs["disallow_transfer_to_peers"] = config.disallow_transfer_to_peers + if config.include_contents != "default": + kwargs["include_contents"] = config.include_contents + if config.input_schema: + kwargs["input_schema"] = resolve_code_reference(config.input_schema) + if config.output_schema: + kwargs["output_schema"] = resolve_code_reference(config.output_schema) + if config.output_key: + kwargs["output_key"] = config.output_key + if config.tools: + kwargs["tools"] = cls._resolve_tools(config.tools, config_abs_path) + if config.before_model_callbacks: + kwargs["before_model_callback"] = resolve_callbacks( + config.before_model_callbacks + ) + if config.after_model_callbacks: + kwargs["after_model_callback"] = resolve_callbacks( + config.after_model_callbacks + ) + if config.before_tool_callbacks: + kwargs["before_tool_callback"] = resolve_callbacks( + config.before_tool_callbacks + ) + if config.after_tool_callbacks: + kwargs["after_tool_callback"] = resolve_callbacks( + config.after_tool_callbacks + ) + if config.generate_content_config: + kwargs["generate_content_config"] = config.generate_content_config + + return kwargs Agent: TypeAlias = LlmAgent diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 90f5e1423d..4d37e66db5 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -54,10 +54,10 @@ from .audio_cache_manager import AudioCacheManager if TYPE_CHECKING: - from ...agents.llm_agent import LlmAgent - from ...models.base_llm import BaseLlm - from ._base_llm_processor import BaseLlmRequestProcessor - from ._base_llm_processor import BaseLlmResponseProcessor + from ...agents.llm_agent import LlmAgent + from ...models.base_llm import BaseLlm + from ._base_llm_processor import BaseLlmRequestProcessor + from ._base_llm_processor import BaseLlmResponseProcessor logger = logging.getLogger("google_adk." + __name__) @@ -73,1013 +73,1011 @@ class BaseLlmFlow(ABC): - """A basic flow that calls the LLM in a loop until a final response is generated. - - This flow ends when it transfer to another agent. - """ - - def __init__(self): - self.request_processors: list[BaseLlmRequestProcessor] = [] - self.response_processors: list[BaseLlmResponseProcessor] = [] - - # Initialize configuration and managers - self.audio_cache_manager = AudioCacheManager() - - async def run_live( - self, - invocation_context: InvocationContext, - ) -> AsyncGenerator[Event, None]: - """Runs the flow using live api.""" - llm_request = LlmRequest() - event_id = Event.new_id() - - # Preprocess before calling the LLM. - async with Aclosing( - self._preprocess_async(invocation_context, llm_request) - ) as agen: - async for event in agen: - yield event - if invocation_context.end_invocation: - return - - llm = self.__get_llm(invocation_context) - logger.debug( - "Establishing live connection for agent: %s with llm request: %s", + """A basic flow that calls the LLM in a loop until a final response is generated. + + This flow ends when it transfer to another agent. + """ + + def __init__(self): + self.request_processors: list[BaseLlmRequestProcessor] = [] + self.response_processors: list[BaseLlmResponseProcessor] = [] + + # Initialize configuration and managers + self.audio_cache_manager = AudioCacheManager() + + async def run_live( + self, + invocation_context: InvocationContext, + ) -> AsyncGenerator[Event, None]: + """Runs the flow using live api.""" + llm_request = LlmRequest() + event_id = Event.new_id() + + # Preprocess before calling the LLM. + async with Aclosing( + self._preprocess_async(invocation_context, llm_request) + ) as agen: + async for event in agen: + yield event + if invocation_context.end_invocation: + return + + llm = self.__get_llm(invocation_context) + logger.debug( + "Establishing live connection for agent: %s with llm request: %s", + invocation_context.agent.name, + llm_request, + ) + + attempt = 1 + while True: + try: + # On subsequent attempts, use the saved token to reconnect + if invocation_context.live_session_resumption_handle: + logger.info("Attempting to reconnect (Attempt %s)...", attempt) + attempt += 1 + if not llm_request.live_connect_config: + llm_request.live_connect_config = types.LiveConnectConfig() + llm_request.live_connect_config.session_resumption.handle = ( + invocation_context.live_session_resumption_handle + ) + llm_request.live_connect_config.session_resumption.transparent = True + + logger.info( + "Establishing live connection for agent: %s", invocation_context.agent.name, - llm_request, ) - - attempt = 1 - while True: - try: - # On subsequent attempts, use the saved token to reconnect - if invocation_context.live_session_resumption_handle: - logger.info("Attempting to reconnect (Attempt %s)...", attempt) - attempt += 1 - if not llm_request.live_connect_config: - llm_request.live_connect_config = types.LiveConnectConfig() - llm_request.live_connect_config.session_resumption.handle = ( - invocation_context.live_session_resumption_handle - ) - llm_request.live_connect_config.session_resumption.transparent = ( - True - ) - - logger.info( - "Establishing live connection for agent: %s", - invocation_context.agent.name, - ) - async with llm.connect(llm_request) as llm_connection: - if llm_request.contents: - # Sends the conversation history to the model. - if is_telemetry_enabled(invocation_context.agent): - with tracer.start_as_current_span("send_data"): - # Combine regular contents with audio/transcription from session - logger.debug( - "Sending history to model: %s", llm_request.contents - ) - await llm_connection.send_history(llm_request.contents) - trace_send_data( - invocation_context, event_id, llm_request.contents - ) - else: - logger.debug( - "Sending history to model: %s", llm_request.contents - ) - await llm_connection.send_history(llm_request.contents) - - send_task = asyncio.create_task( - self._send_to_model(llm_connection, invocation_context) - ) - - try: - async with Aclosing( - self._receive_from_model( - llm_connection, - event_id, - invocation_context, - llm_request, - ) - ) as agen: - async for event in agen: - # Empty event means the queue is closed. - if not event: - break - logger.debug("Receive new event: %s", event) - yield event - # send back the function response to models - if event.get_function_responses(): - logger.debug( - "Sending back last function response event: %s", - event, - ) - invocation_context.live_request_queue.send_content( - event.content - ) - # We handle agent transfer here in `run_live` rather than - # in `_postprocess_live` to prevent duplication of function - # response processing. If agent transfer were handled in - # `_postprocess_live`, events yielded from child agent's - # `run_live` would bubble up to parent agent's `run_live`, - # causing `event.get_function_responses()` to be true in both - # child and parent, and `send_content()` to be called twice for - # the same function response. By handling agent transfer here, - # we ensure that only child agent processes its own function - # responses after the transfer. - if ( - event.content - and event.content.parts - and event.content.parts[0].function_response - and event.content.parts[0].function_response.name - == "transfer_to_agent" - ): - await asyncio.sleep(DEFAULT_TRANSFER_AGENT_DELAY) - # cancel the tasks that belongs to the closed connection. - send_task.cancel() - logger.debug("Closing live connection") - await llm_connection.close() - logger.debug("Live connection closed.") - # transfer to the sub agent. - transfer_to_agent = event.actions.transfer_to_agent - if transfer_to_agent: - logger.debug( - "Transferring to agent: %s", - transfer_to_agent, - ) - agent_to_run = self._get_agent_to_run( - invocation_context, transfer_to_agent - ) - async with Aclosing( - agent_to_run.run_live(invocation_context) - ) as agen: - async for item in agen: - yield item - if ( - event.content - and event.content.parts - and event.content.parts[0].function_response - and event.content.parts[0].function_response.name - == "task_completed" - ): - # this is used for sequential agent to signal the end of the agent. - await asyncio.sleep(DEFAULT_TASK_COMPLETION_DELAY) - # cancel the tasks that belongs to the closed connection. - send_task.cancel() - return - finally: - # Clean up - if not send_task.done(): - send_task.cancel() - try: - await send_task - except asyncio.CancelledError: - pass - except (ConnectionClosed, ConnectionClosedOK) as e: - # when the session timeout, it will just close and not throw exception. - # so this is for bad cases - logger.error("Connection closed: %s.", e) - raise - except Exception as e: - logger.error( - "An unexpected error occurred in live flow: %s", e, exc_info=True - ) - raise - - async def _send_to_model( - self, - llm_connection: BaseLlmConnection, - invocation_context: InvocationContext, - ): - """Sends data to model.""" - while True: - live_request_queue = invocation_context.live_request_queue - try: - # Streamlit's execution model doesn't preemptively yield to the event - # loop. Therefore, we must explicitly introduce timeouts to allow the - # event loop to process events. - # TODO: revert back(remove timeout) once we move off streamlit. - live_request = await asyncio.wait_for( - live_request_queue.get(), timeout=DEFAULT_REQUEST_QUEUE_TIMEOUT - ) - # duplicate the live_request to all the active streams + async with llm.connect(llm_request) as llm_connection: + if llm_request.contents: + # Sends the conversation history to the model. + if is_telemetry_enabled(invocation_context.agent): + with tracer.start_as_current_span("send_data"): + # Combine regular contents with audio/transcription from session logger.debug( - "Sending live request %s to active streams: %s", - live_request, - invocation_context.active_streaming_tools, + "Sending history to model: %s", llm_request.contents ) - if invocation_context.active_streaming_tools: - for active_streaming_tool in ( - invocation_context.active_streaming_tools - ).values(): - if active_streaming_tool.stream: - active_streaming_tool.stream.send(live_request) - await asyncio.sleep(0) - except asyncio.TimeoutError: - continue - if live_request.close: - await llm_connection.close() - return - - if live_request.activity_start: - await llm_connection.send_realtime(types.ActivityStart()) - elif live_request.activity_end: - await llm_connection.send_realtime(types.ActivityEnd()) - elif live_request.blob: - # Cache input audio chunks before flushing - self.audio_cache_manager.cache_audio( - invocation_context, live_request.blob, cache_type="input" + await llm_connection.send_history(llm_request.contents) + trace_send_data( + invocation_context, event_id, llm_request.contents ) + else: + logger.debug("Sending history to model: %s", llm_request.contents) + await llm_connection.send_history(llm_request.contents) - await llm_connection.send_realtime(live_request.blob) - - if live_request.content: - await llm_connection.send_content(live_request.content) - - async def _receive_from_model( - self, - llm_connection: BaseLlmConnection, - event_id: str, - invocation_context: InvocationContext, - llm_request: LlmRequest, - ) -> AsyncGenerator[Event, None]: - """Receive data from model and process events using BaseLlmConnection.""" - - def get_author_for_event(llm_response): - """Get the author of the event. - - When the model returns transcription, the author is "user". Otherwise, the - author is the agent name(not 'model'). + send_task = asyncio.create_task( + self._send_to_model(llm_connection, invocation_context) + ) - Args: - llm_response: The LLM response from the LLM call. - """ - if ( - llm_response - and llm_response.content - and llm_response.content.role == "user" - ): - return "user" - else: - return invocation_context.agent.name - - assert invocation_context.live_request_queue - try: - while True: - async with Aclosing(llm_connection.receive()) as agen: - async for llm_response in agen: - if llm_response.live_session_resumption_update: - logger.info( - "Update session resumption handle:" - f" {llm_response.live_session_resumption_update}." - ) - invocation_context.live_session_resumption_handle = ( - llm_response.live_session_resumption_update.new_handle - ) - model_response_event = Event( - id=Event.new_id(), - invocation_id=invocation_context.invocation_id, - author=get_author_for_event(llm_response), - ) - - async with Aclosing( - self._postprocess_live( - invocation_context, - llm_request, - llm_response, - model_response_event, - ) - ) as agen: - async for event in agen: - # Cache output audio chunks from model responses - # TODO: support video data - if ( - invocation_context.run_config.save_live_blob - and event.content - and event.content.parts - and event.content.parts[0].inline_data - and event.content.parts[ - 0 - ].inline_data.mime_type.startswith("audio/") - ): - audio_blob = types.Blob( - data=event.content.parts[0].inline_data.data, - mime_type=event.content.parts[ - 0 - ].inline_data.mime_type, - ) - self.audio_cache_manager.cache_audio( - invocation_context, - audio_blob, - cache_type="output", - ) - - yield event - # Give opportunity for other tasks to run. - await asyncio.sleep(0) - except ConnectionClosedOK: - pass - - async def run_async( - self, invocation_context: InvocationContext - ) -> AsyncGenerator[Event, None]: - """Runs the flow.""" - while True: - last_event = None - async with Aclosing(self._run_one_step_async(invocation_context)) as agen: - async for event in agen: - last_event = event - yield event - if not last_event or last_event.is_final_response() or last_event.partial: - if last_event and last_event.partial: - logger.warning("The last event is partial, which is not expected.") - break - - async def _run_one_step_async( - self, - invocation_context: InvocationContext, - ) -> AsyncGenerator[Event, None]: - """One step means one LLM call.""" - llm_request = LlmRequest() - - # Preprocess before calling the LLM. - async with Aclosing( - self._preprocess_async(invocation_context, llm_request) - ) as agen: - async for event in agen: + try: + async with Aclosing( + self._receive_from_model( + llm_connection, + event_id, + invocation_context, + llm_request, + ) + ) as agen: + async for event in agen: + # Empty event means the queue is closed. + if not event: + break + logger.debug("Receive new event: %s", event) yield event - if invocation_context.end_invocation: - return - - # Resume the LLM agent based on the last event from the current branch. - # 1. User content: continue the normal flow - # 2. Function call: call the tool and get the response event. - events = invocation_context._get_events( - current_invocation=True, current_branch=True + # send back the function response to models + if event.get_function_responses(): + logger.debug( + "Sending back last function response event: %s", + event, + ) + invocation_context.live_request_queue.send_content( + event.content + ) + # We handle agent transfer here in `run_live` rather than + # in `_postprocess_live` to prevent duplication of function + # response processing. If agent transfer were handled in + # `_postprocess_live`, events yielded from child agent's + # `run_live` would bubble up to parent agent's `run_live`, + # causing `event.get_function_responses()` to be true in both + # child and parent, and `send_content()` to be called twice for + # the same function response. By handling agent transfer here, + # we ensure that only child agent processes its own function + # responses after the transfer. + if ( + event.content + and event.content.parts + and event.content.parts[0].function_response + and event.content.parts[0].function_response.name + == "transfer_to_agent" + ): + await asyncio.sleep(DEFAULT_TRANSFER_AGENT_DELAY) + # cancel the tasks that belongs to the closed connection. + send_task.cancel() + logger.debug("Closing live connection") + await llm_connection.close() + logger.debug("Live connection closed.") + # transfer to the sub agent. + transfer_to_agent = event.actions.transfer_to_agent + if transfer_to_agent: + logger.debug( + "Transferring to agent: %s", + transfer_to_agent, + ) + agent_to_run = self._get_agent_to_run( + invocation_context, transfer_to_agent + ) + async with Aclosing( + agent_to_run.run_live(invocation_context) + ) as agen: + async for item in agen: + yield item + if ( + event.content + and event.content.parts + and event.content.parts[0].function_response + and event.content.parts[0].function_response.name + == "task_completed" + ): + # this is used for sequential agent to signal the end of the agent. + await asyncio.sleep(DEFAULT_TASK_COMPLETION_DELAY) + # cancel the tasks that belongs to the closed connection. + send_task.cancel() + return + finally: + # Clean up + if not send_task.done(): + send_task.cancel() + try: + await send_task + except asyncio.CancelledError: + pass + except (ConnectionClosed, ConnectionClosedOK) as e: + # when the session timeout, it will just close and not throw exception. + # so this is for bad cases + logger.error("Connection closed: %s.", e) + raise + except Exception as e: + logger.error( + "An unexpected error occurred in live flow: %s", e, exc_info=True + ) + raise + + async def _send_to_model( + self, + llm_connection: BaseLlmConnection, + invocation_context: InvocationContext, + ): + """Sends data to model.""" + while True: + live_request_queue = invocation_context.live_request_queue + try: + # Streamlit's execution model doesn't preemptively yield to the event + # loop. Therefore, we must explicitly introduce timeouts to allow the + # event loop to process events. + # TODO: revert back(remove timeout) once we move off streamlit. + live_request = await asyncio.wait_for( + live_request_queue.get(), timeout=DEFAULT_REQUEST_QUEUE_TIMEOUT + ) + # duplicate the live_request to all the active streams + logger.debug( + "Sending live request %s to active streams: %s", + live_request, + invocation_context.active_streaming_tools, + ) + if invocation_context.active_streaming_tools: + for active_streaming_tool in ( + invocation_context.active_streaming_tools + ).values(): + if active_streaming_tool.stream: + active_streaming_tool.stream.send(live_request) + await asyncio.sleep(0) + except asyncio.TimeoutError: + continue + if live_request.close: + await llm_connection.close() + return + + if live_request.activity_start: + await llm_connection.send_realtime(types.ActivityStart()) + elif live_request.activity_end: + await llm_connection.send_realtime(types.ActivityEnd()) + elif live_request.blob: + # Cache input audio chunks before flushing + self.audio_cache_manager.cache_audio( + invocation_context, live_request.blob, cache_type="input" ) - # Long running tool calls should have been handled before this point. - # If there are still long running tool calls, it means the agent is paused - # before, and its branch hasn't been resumed yet. - if ( - invocation_context.is_resumable - and events - and len(events) > 1 - # TODO: here we are using the last 2 events to decide whether to pause - # the invocation. But this is just being optimistic, we should find a - # way to pause when the long running tool call is followed by more than - # one text responses. - and ( - invocation_context.should_pause_invocation(events[-1]) - or invocation_context.should_pause_invocation(events[-2]) + await llm_connection.send_realtime(live_request.blob) + + if live_request.content: + await llm_connection.send_content(live_request.content) + + async def _receive_from_model( + self, + llm_connection: BaseLlmConnection, + event_id: str, + invocation_context: InvocationContext, + llm_request: LlmRequest, + ) -> AsyncGenerator[Event, None]: + """Receive data from model and process events using BaseLlmConnection.""" + + def get_author_for_event(llm_response): + """Get the author of the event. + + When the model returns transcription, the author is "user". Otherwise, the + author is the agent name(not 'model'). + + Args: + llm_response: The LLM response from the LLM call. + """ + if ( + llm_response + and llm_response.content + and llm_response.content.role == "user" + ): + return "user" + else: + return invocation_context.agent.name + + assert invocation_context.live_request_queue + try: + while True: + async with Aclosing(llm_connection.receive()) as agen: + async for llm_response in agen: + if llm_response.live_session_resumption_update: + logger.info( + "Update session resumption handle:" + f" {llm_response.live_session_resumption_update}." + ) + invocation_context.live_session_resumption_handle = ( + llm_response.live_session_resumption_update.new_handle + ) + model_response_event = Event( + id=Event.new_id(), + invocation_id=invocation_context.invocation_id, + author=get_author_for_event(llm_response), ) - ): - return - if ( - invocation_context.is_resumable - and events - and events[-1].get_function_calls() - ): - model_response_event = events[-1] async with Aclosing( - self._postprocess_handle_function_calls_async( - invocation_context, model_response_event, llm_request + self._postprocess_live( + invocation_context, + llm_request, + llm_response, + model_response_event, ) ) as agen: - async for event in agen: - event.id = Event.new_id() - yield event - return - - # Calls the LLM. - model_response_event = Event( - id=Event.new_id(), - invocation_id=invocation_context.invocation_id, - author=invocation_context.agent.name, - branch=invocation_context.branch, - ) - async with Aclosing( - self._call_llm_async(invocation_context, llm_request, model_response_event) - ) as agen: - async for llm_response in agen: - # Postprocess after calling the LLM. - async with Aclosing( - self._postprocess_async( - invocation_context, - llm_request, - llm_response, - model_response_event, + async for event in agen: + # Cache output audio chunks from model responses + # TODO: support video data + if ( + invocation_context.run_config.save_live_blob + and event.content + and event.content.parts + and event.content.parts[0].inline_data + and event.content.parts[0].inline_data.mime_type.startswith( + "audio/" ) - ) as agen: - async for event in agen: - # Update the mutable event id to avoid conflict - model_response_event.id = Event.new_id() - model_response_event.timestamp = ( - datetime.datetime.now().timestamp() - ) - yield event - - async def _preprocess_async( - self, invocation_context: InvocationContext, llm_request: LlmRequest - ) -> AsyncGenerator[Event, None]: - from ...agents.llm_agent import LlmAgent - - agent = invocation_context.agent - if not isinstance(agent, LlmAgent): - raise TypeError(f"Expected agent to be an LlmAgent, but got {type(agent)}") - - # Runs processors. - for processor in self.request_processors: - async with Aclosing( - processor.run_async(invocation_context, llm_request) - ) as agen: - async for event in agen: - yield event - - # Run processors for tools. - - # We may need to wrap some built-in tools if there are other tools - # because the built-in tools cannot be used together with other tools. - # TODO(b/448114567): Remove once the workaround is no longer needed. - multiple_tools = len(agent.tools) > 1 - for tool_union in agent.tools: - tool_context = ToolContext(invocation_context) - - # If it's a toolset, process it first - if isinstance(tool_union, BaseToolset): - await tool_union.process_llm_request( - tool_context=tool_context, llm_request=llm_request - ) - - from ...agents.llm_agent import _convert_tool_union_to_tools + ): + audio_blob = types.Blob( + data=event.content.parts[0].inline_data.data, + mime_type=event.content.parts[0].inline_data.mime_type, + ) + self.audio_cache_manager.cache_audio( + invocation_context, + audio_blob, + cache_type="output", + ) - # Then process all tools from this tool union - tools = await _convert_tool_union_to_tools( - tool_union, - ReadonlyContext(invocation_context), - agent.model, - multiple_tools, - ) - for tool in tools: - await tool.process_llm_request( - tool_context=tool_context, llm_request=llm_request - ) + yield event + # Give opportunity for other tasks to run. + await asyncio.sleep(0) + except ConnectionClosedOK: + pass + + async def run_async( + self, invocation_context: InvocationContext + ) -> AsyncGenerator[Event, None]: + """Runs the flow.""" + while True: + last_event = None + async with Aclosing(self._run_one_step_async(invocation_context)) as agen: + async for event in agen: + last_event = event + yield event + if not last_event or last_event.is_final_response() or last_event.partial: + if last_event and last_event.partial: + logger.warning("The last event is partial, which is not expected.") + break + + async def _run_one_step_async( + self, + invocation_context: InvocationContext, + ) -> AsyncGenerator[Event, None]: + """One step means one LLM call.""" + llm_request = LlmRequest() + + # Preprocess before calling the LLM. + async with Aclosing( + self._preprocess_async(invocation_context, llm_request) + ) as agen: + async for event in agen: + yield event + if invocation_context.end_invocation: + return + + # Resume the LLM agent based on the last event from the current branch. + # 1. User content: continue the normal flow + # 2. Function call: call the tool and get the response event. + events = invocation_context._get_events( + current_invocation=True, current_branch=True + ) + + # Long running tool calls should have been handled before this point. + # If there are still long running tool calls, it means the agent is paused + # before, and its branch hasn't been resumed yet. + if ( + invocation_context.is_resumable + and events + and len(events) > 1 + # TODO: here we are using the last 2 events to decide whether to pause + # the invocation. But this is just being optimistic, we should find a + # way to pause when the long running tool call is followed by more than + # one text responses. + and ( + invocation_context.should_pause_invocation(events[-1]) + or invocation_context.should_pause_invocation(events[-2]) + ) + ): + return - async def _postprocess_async( - self, - invocation_context: InvocationContext, - llm_request: LlmRequest, - llm_response: LlmResponse, - model_response_event: Event, - ) -> AsyncGenerator[Event, None]: - """Postprocess after calling the LLM. - - Args: - invocation_context: The invocation context. - llm_request: The original LLM request. - llm_response: The LLM response from the LLM call. - model_response_event: A mutable event for the LLM response. - - Yields: - A generator of events. - """ - - # Runs processors. + if ( + invocation_context.is_resumable + and events + and events[-1].get_function_calls() + ): + model_response_event = events[-1] + async with Aclosing( + self._postprocess_handle_function_calls_async( + invocation_context, model_response_event, llm_request + ) + ) as agen: + async for event in agen: + event.id = Event.new_id() + yield event + return + + # Calls the LLM. + model_response_event = Event( + id=Event.new_id(), + invocation_id=invocation_context.invocation_id, + author=invocation_context.agent.name, + branch=invocation_context.branch, + ) + async with Aclosing( + self._call_llm_async( + invocation_context, llm_request, model_response_event + ) + ) as agen: + async for llm_response in agen: + # Postprocess after calling the LLM. async with Aclosing( - self._postprocess_run_processors_async(invocation_context, llm_response) + self._postprocess_async( + invocation_context, + llm_request, + llm_response, + model_response_event, + ) ) as agen: - async for event in agen: - yield event + async for event in agen: + # Update the mutable event id to avoid conflict + model_response_event.id = Event.new_id() + model_response_event.timestamp = datetime.datetime.now().timestamp() + yield event + + async def _preprocess_async( + self, invocation_context: InvocationContext, llm_request: LlmRequest + ) -> AsyncGenerator[Event, None]: + from ...agents.llm_agent import LlmAgent - # Skip the model response event if there is no content and no error code. - # This is needed for the code executor to trigger another loop. - if ( - not llm_response.content - and not llm_response.error_code - and not llm_response.interrupted - ): - return + agent = invocation_context.agent + if not isinstance(agent, LlmAgent): + raise TypeError( + f"Expected agent to be an LlmAgent, but got {type(agent)}" + ) + + # Runs processors. + for processor in self.request_processors: + async with Aclosing( + processor.run_async(invocation_context, llm_request) + ) as agen: + async for event in agen: + yield event + + # Run processors for tools. + + # We may need to wrap some built-in tools if there are other tools + # because the built-in tools cannot be used together with other tools. + # TODO(b/448114567): Remove once the workaround is no longer needed. + multiple_tools = len(agent.tools) > 1 + for tool_union in agent.tools: + tool_context = ToolContext(invocation_context) + + # If it's a toolset, process it first + if isinstance(tool_union, BaseToolset): + await tool_union.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) - # Builds the event. - model_response_event = self._finalize_model_response_event( - llm_request, llm_response, model_response_event + from ...agents.llm_agent import _convert_tool_union_to_tools + + # Then process all tools from this tool union + tools = await _convert_tool_union_to_tools( + tool_union, + ReadonlyContext(invocation_context), + agent.model, + multiple_tools, + ) + for tool in tools: + await tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request ) - yield model_response_event - # Handles function calls. - if model_response_event.get_function_calls(): + async def _postprocess_async( + self, + invocation_context: InvocationContext, + llm_request: LlmRequest, + llm_response: LlmResponse, + model_response_event: Event, + ) -> AsyncGenerator[Event, None]: + """Postprocess after calling the LLM. + + Args: + invocation_context: The invocation context. + llm_request: The original LLM request. + llm_response: The LLM response from the LLM call. + model_response_event: A mutable event for the LLM response. + + Yields: + A generator of events. + """ - if is_feature_enabled(FeatureName.PROGRESSIVE_SSE_STREAMING): - # In progressive SSE streaming mode stage 1, we skip partial FC events - # Only execute FCs in the final aggregated event (partial=False) - if ( - invocation_context.run_config.streaming_mode == StreamingMode.SSE - and model_response_event.partial - ): - return + # Runs processors. + async with Aclosing( + self._postprocess_run_processors_async(invocation_context, llm_response) + ) as agen: + async for event in agen: + yield event + + # Skip the model response event if there is no content and no error code. + # This is needed for the code executor to trigger another loop. + if ( + not llm_response.content + and not llm_response.error_code + and not llm_response.interrupted + ): + return - async with Aclosing( - self._postprocess_handle_function_calls_async( - invocation_context, model_response_event, llm_request - ) - ) as agen: - async for event in agen: - yield event + # Builds the event. + model_response_event = self._finalize_model_response_event( + llm_request, llm_response, model_response_event + ) + yield model_response_event - async def _postprocess_live( - self, - invocation_context: InvocationContext, - llm_request: LlmRequest, - llm_response: LlmResponse, - model_response_event: Event, - ) -> AsyncGenerator[Event, None]: - """Postprocess after calling the LLM asynchronously. - - Args: - invocation_context: The invocation context. - llm_request: The original LLM request. - llm_response: The LLM response from the LLM call. - model_response_event: A mutable event for the LLM response. - - Yields: - A generator of events. - """ - - # Runs processors. - async with Aclosing( - self._postprocess_run_processors_async(invocation_context, llm_response) - ) as agen: - async for event in agen: - yield event + # Handles function calls. + if model_response_event.get_function_calls(): - # Skip the model response event if there is no content and no error code. - # This is needed for the code executor to trigger another loop. - # But don't skip control events like turn_complete or transcription events. + if is_feature_enabled(FeatureName.PROGRESSIVE_SSE_STREAMING): + # In progressive SSE streaming mode stage 1, we skip partial FC events + # Only execute FCs in the final aggregated event (partial=False) if ( - not llm_response.content - and not llm_response.error_code - and not llm_response.interrupted - and not llm_response.turn_complete - and not llm_response.input_transcription - and not llm_response.output_transcription - and not llm_response.usage_metadata + invocation_context.run_config.streaming_mode == StreamingMode.SSE + and model_response_event.partial ): - return - - # Handle transcription events ONCE per llm_response, outside the event loop - if llm_response.input_transcription: - model_response_event.input_transcription = llm_response.input_transcription - model_response_event.partial = llm_response.partial - yield model_response_event - return - - if llm_response.output_transcription: - model_response_event.output_transcription = ( - llm_response.output_transcription + return + + async with Aclosing( + self._postprocess_handle_function_calls_async( + invocation_context, model_response_event, llm_request + ) + ) as agen: + async for event in agen: + yield event + + async def _postprocess_live( + self, + invocation_context: InvocationContext, + llm_request: LlmRequest, + llm_response: LlmResponse, + model_response_event: Event, + ) -> AsyncGenerator[Event, None]: + """Postprocess after calling the LLM asynchronously. + + Args: + invocation_context: The invocation context. + llm_request: The original LLM request. + llm_response: The LLM response from the LLM call. + model_response_event: A mutable event for the LLM response. + + Yields: + A generator of events. + """ + + # Runs processors. + async with Aclosing( + self._postprocess_run_processors_async(invocation_context, llm_response) + ) as agen: + async for event in agen: + yield event + + # Skip the model response event if there is no content and no error code. + # This is needed for the code executor to trigger another loop. + # But don't skip control events like turn_complete or transcription events. + if ( + not llm_response.content + and not llm_response.error_code + and not llm_response.interrupted + and not llm_response.turn_complete + and not llm_response.input_transcription + and not llm_response.output_transcription + and not llm_response.usage_metadata + ): + return + + # Handle transcription events ONCE per llm_response, outside the event loop + if llm_response.input_transcription: + model_response_event.input_transcription = ( + llm_response.input_transcription + ) + model_response_event.partial = llm_response.partial + yield model_response_event + return + + if llm_response.output_transcription: + model_response_event.output_transcription = ( + llm_response.output_transcription + ) + model_response_event.partial = llm_response.partial + yield model_response_event + return + + # Flush audio caches based on control events using configurable settings + if invocation_context.run_config.save_live_blob: + flushed_events = await self._handle_control_event_flush( + invocation_context, llm_response + ) + for event in flushed_events: + yield event + if flushed_events: + return + + # Builds the event. + model_response_event = self._finalize_model_response_event( + llm_request, llm_response, model_response_event + ) + yield model_response_event + + # Handles function calls. + if model_response_event.get_function_calls(): + function_response_event = await functions.handle_function_calls_live( + invocation_context, model_response_event, llm_request.tools_dict + ) + # Always yield the function response event first + yield function_response_event + + # Check if this is a set_model_response function response + if json_response := _output_schema_processor.get_structured_model_response( + function_response_event + ): + # Create and yield a final model response event + final_event = ( + _output_schema_processor.create_final_model_response_event( + invocation_context, json_response ) - model_response_event.partial = llm_response.partial - yield model_response_event - return - - # Flush audio caches based on control events using configurable settings - if invocation_context.run_config.save_live_blob: - flushed_events = await self._handle_control_event_flush( - invocation_context, llm_response + ) + yield final_event + + async def _postprocess_run_processors_async( + self, invocation_context: InvocationContext, llm_response: LlmResponse + ) -> AsyncGenerator[Event, None]: + for processor in self.response_processors: + async with Aclosing( + processor.run_async(invocation_context, llm_response) + ) as agen: + async for event in agen: + yield event + + async def _postprocess_handle_function_calls_async( + self, + invocation_context: InvocationContext, + function_call_event: Event, + llm_request: LlmRequest, + ) -> AsyncGenerator[Event, None]: + if function_response_event := await functions.handle_function_calls_async( + invocation_context, function_call_event, llm_request.tools_dict + ): + auth_event = functions.generate_auth_event( + invocation_context, function_response_event + ) + if auth_event: + yield auth_event + + tool_confirmation_event = functions.generate_request_confirmation_event( + invocation_context, function_call_event, function_response_event + ) + if tool_confirmation_event: + yield tool_confirmation_event + + # Always yield the function response event first + yield function_response_event + + # Check if this is a set_model_response function response + if json_response := _output_schema_processor.get_structured_model_response( + function_response_event + ): + # Create and yield a final model response event + final_event = ( + _output_schema_processor.create_final_model_response_event( + invocation_context, json_response ) - for event in flushed_events: - yield event - if flushed_events: - return - - # Builds the event. - model_response_event = self._finalize_model_response_event( - llm_request, llm_response, model_response_event ) - yield model_response_event - - # Handles function calls. - if model_response_event.get_function_calls(): - function_response_event = await functions.handle_function_calls_live( - invocation_context, model_response_event, llm_request.tools_dict + yield final_event + transfer_to_agent = function_response_event.actions.transfer_to_agent + if transfer_to_agent: + agent_to_run = self._get_agent_to_run( + invocation_context, transfer_to_agent + ) + async with Aclosing(agent_to_run.run_async(invocation_context)) as agen: + async for event in agen: + yield event + + def _get_agent_to_run( + self, invocation_context: InvocationContext, agent_name: str + ) -> BaseAgent: + root_agent = invocation_context.agent.root_agent + agent_to_run = root_agent.find_agent(agent_name) + if not agent_to_run: + raise ValueError(f"Agent {agent_name} not found in the agent tree.") + return agent_to_run + + async def _call_llm_async( + self, + invocation_context: InvocationContext, + llm_request: LlmRequest, + model_response_event: Event, + ) -> AsyncGenerator[LlmResponse, None]: + # Runs before_model_callback if it exists. + if response := await self._handle_before_model_callback( + invocation_context, llm_request, model_response_event + ): + yield response + return + + llm_request.config = llm_request.config or types.GenerateContentConfig() + llm_request.config.labels = llm_request.config.labels or {} + + # Add agent name as a label to the llm_request. This will help with slicing + # the billing reports on a per-agent basis. + if _ADK_AGENT_NAME_LABEL_KEY not in llm_request.config.labels: + llm_request.config.labels[_ADK_AGENT_NAME_LABEL_KEY] = ( + invocation_context.agent.name + ) + + # Calls the LLM. + llm = self.__get_llm(invocation_context) + + async def _call_llm_body() -> AsyncGenerator[LlmResponse, None]: + if invocation_context.run_config.support_cfc: + invocation_context.live_request_queue = LiveRequestQueue() + responses_generator = self.run_live(invocation_context) + async with Aclosing( + self._run_and_handle_error( + responses_generator, + invocation_context, + llm_request, + model_response_event, ) - # Always yield the function response event first - yield function_response_event - - # Check if this is a set_model_response function response - if json_response := _output_schema_processor.get_structured_model_response( - function_response_event + ) as agen: + async for llm_response in agen: + # Runs after_model_callback if it exists. + if altered_llm_response := await self._handle_after_model_callback( + invocation_context, llm_response, model_response_event ): - # Create and yield a final model response event - final_event = ( - _output_schema_processor.create_final_model_response_event( - invocation_context, json_response - ) - ) - yield final_event - - async def _postprocess_run_processors_async( - self, invocation_context: InvocationContext, llm_response: LlmResponse - ) -> AsyncGenerator[Event, None]: - for processor in self.response_processors: - async with Aclosing( - processor.run_async(invocation_context, llm_response) - ) as agen: - async for event in agen: - yield event - - async def _postprocess_handle_function_calls_async( - self, - invocation_context: InvocationContext, - function_call_event: Event, - llm_request: LlmRequest, - ) -> AsyncGenerator[Event, None]: - if function_response_event := await functions.handle_function_calls_async( - invocation_context, function_call_event, llm_request.tools_dict - ): - auth_event = functions.generate_auth_event( - invocation_context, function_response_event + llm_response = altered_llm_response + # only yield partial response in SSE streaming mode + if ( + invocation_context.run_config.streaming_mode + == StreamingMode.SSE + or not llm_response.partial + ): + yield llm_response + if llm_response.turn_complete: + invocation_context.live_request_queue.close() + else: + # Check if we can make this llm call or not. If the current call + # pushes the counter beyond the max set value, then the execution is + # stopped right here, and exception is thrown. + invocation_context.increment_llm_call_count() + responses_generator = llm.generate_content_async( + llm_request, + stream=invocation_context.run_config.streaming_mode + == StreamingMode.SSE, + ) + async with Aclosing( + self._run_and_handle_error( + responses_generator, + invocation_context, + llm_request, + model_response_event, ) - if auth_event: - yield auth_event - - tool_confirmation_event = functions.generate_request_confirmation_event( - invocation_context, function_call_event, function_response_event + ) as agen: + async for llm_response in agen: + trace_call_llm( + invocation_context, + model_response_event.id, + llm_request, + llm_response, ) - if tool_confirmation_event: - yield tool_confirmation_event - - # Always yield the function response event first - yield function_response_event - - # Check if this is a set_model_response function response - if json_response := _output_schema_processor.get_structured_model_response( - function_response_event + # Runs after_model_callback if it exists. + if altered_llm_response := await self._handle_after_model_callback( + invocation_context, llm_response, model_response_event ): - # Create and yield a final model response event - final_event = ( - _output_schema_processor.create_final_model_response_event( - invocation_context, json_response - ) - ) - yield final_event - transfer_to_agent = function_response_event.actions.transfer_to_agent - if transfer_to_agent: - agent_to_run = self._get_agent_to_run( - invocation_context, transfer_to_agent - ) - async with Aclosing(agent_to_run.run_async(invocation_context)) as agen: - async for event in agen: - yield event - - def _get_agent_to_run( - self, invocation_context: InvocationContext, agent_name: str - ) -> BaseAgent: - root_agent = invocation_context.agent.root_agent - agent_to_run = root_agent.find_agent(agent_name) - if not agent_to_run: - raise ValueError(f"Agent {agent_name} not found in the agent tree.") - return agent_to_run - - async def _call_llm_async( - self, - invocation_context: InvocationContext, - llm_request: LlmRequest, - model_response_event: Event, - ) -> AsyncGenerator[LlmResponse, None]: - # Runs before_model_callback if it exists. - if response := await self._handle_before_model_callback( - invocation_context, llm_request, model_response_event - ): - yield response - return - - llm_request.config = llm_request.config or types.GenerateContentConfig() - llm_request.config.labels = llm_request.config.labels or {} + llm_response = altered_llm_response - # Add agent name as a label to the llm_request. This will help with slicing - # the billing reports on a per-agent basis. - if _ADK_AGENT_NAME_LABEL_KEY not in llm_request.config.labels: - llm_request.config.labels[_ADK_AGENT_NAME_LABEL_KEY] = ( - invocation_context.agent.name - ) + yield llm_response - # Calls the LLM. - llm = self.__get_llm(invocation_context) - - async def _call_llm_body() -> AsyncGenerator[LlmResponse, None]: - if invocation_context.run_config.support_cfc: - invocation_context.live_request_queue = LiveRequestQueue() - responses_generator = self.run_live(invocation_context) - async with Aclosing( - self._run_and_handle_error( - responses_generator, - invocation_context, - llm_request, - model_response_event, - ) - ) as agen: - async for llm_response in agen: - # Runs after_model_callback if it exists. - if altered_llm_response := await self._handle_after_model_callback( - invocation_context, llm_response, model_response_event - ): - llm_response = altered_llm_response - # only yield partial response in SSE streaming mode - if ( - invocation_context.run_config.streaming_mode - == StreamingMode.SSE - or not llm_response.partial - ): - yield llm_response - if llm_response.turn_complete: - invocation_context.live_request_queue.close() - else: - # Check if we can make this llm call or not. If the current call - # pushes the counter beyond the max set value, then the execution is - # stopped right here, and exception is thrown. - invocation_context.increment_llm_call_count() - responses_generator = llm.generate_content_async( - llm_request, - stream=invocation_context.run_config.streaming_mode - == StreamingMode.SSE, - ) - async with Aclosing( - self._run_and_handle_error( - responses_generator, - invocation_context, - llm_request, - model_response_event, - ) - ) as agen: - async for llm_response in agen: - trace_call_llm( - invocation_context, - model_response_event.id, - llm_request, - llm_response, - ) - # Runs after_model_callback if it exists. - if altered_llm_response := await self._handle_after_model_callback( - invocation_context, llm_response, model_response_event - ): - llm_response = altered_llm_response - - yield llm_response - - async def _call_llm_with_optional_tracing() -> ( - AsyncGenerator[LlmResponse, None] - ): - if is_telemetry_enabled(invocation_context.agent): - with tracer.start_as_current_span("call_llm"): - async with Aclosing(_call_llm_body()) as agen: - async for r in agen: - yield r - else: - async with Aclosing(_call_llm_body()) as agen: - async for r in agen: - yield r - - async with Aclosing(_call_llm_with_optional_tracing()) as agen: - async for event in agen: - yield event - - async def _handle_before_model_callback( - self, - invocation_context: InvocationContext, - llm_request: LlmRequest, - model_response_event: Event, - ) -> Optional[LlmResponse]: - from ...agents.llm_agent import LlmAgent + async def _call_llm_with_optional_tracing() -> ( + AsyncGenerator[LlmResponse, None] + ): + if is_telemetry_enabled(invocation_context.agent): + with tracer.start_as_current_span("call_llm"): + async with Aclosing(_call_llm_body()) as agen: + async for r in agen: + yield r + else: + async with Aclosing(_call_llm_body()) as agen: + async for r in agen: + yield r + + async with Aclosing(_call_llm_with_optional_tracing()) as agen: + async for event in agen: + yield event + + async def _handle_before_model_callback( + self, + invocation_context: InvocationContext, + llm_request: LlmRequest, + model_response_event: Event, + ) -> Optional[LlmResponse]: + from ...agents.llm_agent import LlmAgent - agent = invocation_context.agent + agent = invocation_context.agent - callback_context = CallbackContext( - invocation_context, event_actions=model_response_event.actions - ) + callback_context = CallbackContext( + invocation_context, event_actions=model_response_event.actions + ) - # First run callbacks from the plugins. - callback_response = ( - await invocation_context.plugin_manager.run_before_model_callback( - callback_context=callback_context, - llm_request=llm_request, - ) + # First run callbacks from the plugins. + callback_response = ( + await invocation_context.plugin_manager.run_before_model_callback( + callback_context=callback_context, + llm_request=llm_request, ) - if callback_response: - return callback_response - - # If no overrides are provided from the plugins, further run the canonical - # callbacks. - if not agent.canonical_before_model_callbacks: - return - for callback in agent.canonical_before_model_callbacks: - callback_response = callback( - callback_context=callback_context, llm_request=llm_request - ) - if inspect.isawaitable(callback_response): - callback_response = await callback_response - if callback_response: - return callback_response - - async def _handle_after_model_callback( - self, - invocation_context: InvocationContext, - llm_response: LlmResponse, - model_response_event: Event, - ) -> Optional[LlmResponse]: - from ...agents.llm_agent import LlmAgent - - agent = invocation_context.agent - - # Add grounding metadata to the response if needed. - # TODO(b/448114567): Remove this function once the workaround is no longer needed. - async def _maybe_add_grounding_metadata( - response: Optional[LlmResponse] = None, - ) -> Optional[LlmResponse]: - readonly_context = ReadonlyContext(invocation_context) - if (tools := invocation_context.canonical_tools_cache) is None: - tools = await agent.canonical_tools(readonly_context) - invocation_context.canonical_tools_cache = tools - - if not any(tool.name == "google_search_agent" for tool in tools): - return response - ground_metadata = invocation_context.session.state.get( - "temp:_adk_grounding_metadata", None - ) - if not ground_metadata: - return response + ) + if callback_response: + return callback_response + + # If no overrides are provided from the plugins, further run the canonical + # callbacks. + if not agent.canonical_before_model_callbacks: + return + for callback in agent.canonical_before_model_callbacks: + callback_response = callback( + callback_context=callback_context, llm_request=llm_request + ) + if inspect.isawaitable(callback_response): + callback_response = await callback_response + if callback_response: + return callback_response + + async def _handle_after_model_callback( + self, + invocation_context: InvocationContext, + llm_response: LlmResponse, + model_response_event: Event, + ) -> Optional[LlmResponse]: + from ...agents.llm_agent import LlmAgent - if not response: - response = llm_response - response.grounding_metadata = ground_metadata - return response + agent = invocation_context.agent - callback_context = CallbackContext( - invocation_context, event_actions=model_response_event.actions + # Add grounding metadata to the response if needed. + # TODO(b/448114567): Remove this function once the workaround is no longer needed. + async def _maybe_add_grounding_metadata( + response: Optional[LlmResponse] = None, + ) -> Optional[LlmResponse]: + readonly_context = ReadonlyContext(invocation_context) + if (tools := invocation_context.canonical_tools_cache) is None: + tools = await agent.canonical_tools(readonly_context) + invocation_context.canonical_tools_cache = tools + + if not any(tool.name == "google_search_agent" for tool in tools): + return response + ground_metadata = invocation_context.session.state.get( + "temp:_adk_grounding_metadata", None + ) + if not ground_metadata: + return response + + if not response: + response = llm_response + response.grounding_metadata = ground_metadata + return response + + callback_context = CallbackContext( + invocation_context, event_actions=model_response_event.actions + ) + + # First run callbacks from the plugins. + callback_response = ( + await invocation_context.plugin_manager.run_after_model_callback( + callback_context=CallbackContext(invocation_context), + llm_response=llm_response, ) - - # First run callbacks from the plugins. - callback_response = ( - await invocation_context.plugin_manager.run_after_model_callback( - callback_context=CallbackContext(invocation_context), - llm_response=llm_response, + ) + if callback_response: + return await _maybe_add_grounding_metadata(callback_response) + + # If no overrides are provided from the plugins, further run the canonical + # callbacks. + if not agent.canonical_after_model_callbacks: + return await _maybe_add_grounding_metadata() + for callback in agent.canonical_after_model_callbacks: + callback_response = callback( + callback_context=callback_context, llm_response=llm_response + ) + if inspect.isawaitable(callback_response): + callback_response = await callback_response + if callback_response: + return await _maybe_add_grounding_metadata(callback_response) + return await _maybe_add_grounding_metadata() + + def _finalize_model_response_event( + self, + llm_request: LlmRequest, + llm_response: LlmResponse, + model_response_event: Event, + ) -> Event: + model_response_event = Event.model_validate({ + **model_response_event.model_dump(exclude_none=True), + **llm_response.model_dump(exclude_none=True), + }) + + if model_response_event.content: + function_calls = model_response_event.get_function_calls() + if function_calls: + functions.populate_client_function_call_id(model_response_event) + model_response_event.long_running_tool_ids = ( + functions.get_long_running_function_calls( + function_calls, llm_request.tools_dict ) ) - if callback_response: - return await _maybe_add_grounding_metadata(callback_response) - - # If no overrides are provided from the plugins, further run the canonical - # callbacks. - if not agent.canonical_after_model_callbacks: - return await _maybe_add_grounding_metadata() - for callback in agent.canonical_after_model_callbacks: - callback_response = callback( - callback_context=callback_context, llm_response=llm_response - ) - if inspect.isawaitable(callback_response): - callback_response = await callback_response - if callback_response: - return await _maybe_add_grounding_metadata(callback_response) - return await _maybe_add_grounding_metadata() - - def _finalize_model_response_event( - self, - llm_request: LlmRequest, - llm_response: LlmResponse, - model_response_event: Event, - ) -> Event: - model_response_event = Event.model_validate( - { - **model_response_event.model_dump(exclude_none=True), - **llm_response.model_dump(exclude_none=True), - } - ) - if model_response_event.content: - function_calls = model_response_event.get_function_calls() - if function_calls: - functions.populate_client_function_call_id(model_response_event) - model_response_event.long_running_tool_ids = ( - functions.get_long_running_function_calls( - function_calls, llm_request.tools_dict - ) - ) + return model_response_event - return model_response_event + async def _handle_control_event_flush( + self, invocation_context: InvocationContext, llm_response: LlmResponse + ) -> list[Event]: + """Handle audio cache flushing based on control events. - async def _handle_control_event_flush( - self, invocation_context: InvocationContext, llm_response: LlmResponse - ) -> list[Event]: - """Handle audio cache flushing based on control events. + Args: + invocation_context: The invocation context containing audio caches. + llm_response: The LLM response containing control event information. - Args: - invocation_context: The invocation context containing audio caches. - llm_response: The LLM response containing control event information. + Returns: + A list of Event objects created from the flushed caches. + """ - Returns: - A list of Event objects created from the flushed caches. - """ + # Log cache statistics if enabled + if DEFAULT_ENABLE_CACHE_STATISTICS: + stats = self.audio_cache_manager.get_cache_stats(invocation_context) + logger.debug("Audio cache stats: %s", stats) + + if llm_response.interrupted: + # user interrupts so the model will stop. we can flush model audio here + return await self.audio_cache_manager.flush_caches( + invocation_context, + flush_user_audio=False, + flush_model_audio=True, + ) + elif llm_response.turn_complete: + # turn completes so we can flush both user and model + return await self.audio_cache_manager.flush_caches( + invocation_context, + flush_user_audio=True, + flush_model_audio=True, + ) + elif getattr(llm_response, "generation_complete", False): + # model generation complete so we can flush model audio + return await self.audio_cache_manager.flush_caches( + invocation_context, + flush_user_audio=False, + flush_model_audio=True, + ) + return [] + + async def _run_and_handle_error( + self, + response_generator: AsyncGenerator[LlmResponse, None], + invocation_context: InvocationContext, + llm_request: LlmRequest, + model_response_event: Event, + ) -> AsyncGenerator[LlmResponse, None]: + """Runs the response generator and processes the error with plugins. + + Args: + response_generator: The response generator to run. + invocation_context: The invocation context. + llm_request: The LLM request. + model_response_event: The model response event. + + Yields: + A generator of LlmResponse. + """ - # Log cache statistics if enabled - if DEFAULT_ENABLE_CACHE_STATISTICS: - stats = self.audio_cache_manager.get_cache_stats(invocation_context) - logger.debug("Audio cache stats: %s", stats) + from ...agents.llm_agent import LlmAgent - if llm_response.interrupted: - # user interrupts so the model will stop. we can flush model audio here - return await self.audio_cache_manager.flush_caches( - invocation_context, - flush_user_audio=False, - flush_model_audio=True, - ) - elif llm_response.turn_complete: - # turn completes so we can flush both user and model - return await self.audio_cache_manager.flush_caches( - invocation_context, - flush_user_audio=True, - flush_model_audio=True, - ) - elif getattr(llm_response, "generation_complete", False): - # model generation complete so we can flush model audio - return await self.audio_cache_manager.flush_caches( - invocation_context, - flush_user_audio=False, - flush_model_audio=True, - ) - return [] + agent = invocation_context.agent + if not isinstance(agent, LlmAgent): + raise TypeError( + f"Expected agent to be an LlmAgent, but got {type(agent)}" + ) - async def _run_and_handle_error( - self, - response_generator: AsyncGenerator[LlmResponse, None], - invocation_context: InvocationContext, + async def _run_on_model_error_callbacks( + *, + callback_context: CallbackContext, llm_request: LlmRequest, - model_response_event: Event, - ) -> AsyncGenerator[LlmResponse, None]: - """Runs the response generator and processes the error with plugins. - - Args: - response_generator: The response generator to run. - invocation_context: The invocation context. - llm_request: The LLM request. - model_response_event: The model response event. - - Yields: - A generator of LlmResponse. - """ - - from ...agents.llm_agent import LlmAgent - - agent = invocation_context.agent - if not isinstance(agent, LlmAgent): - raise TypeError(f"Expected agent to be an LlmAgent, but got {type(agent)}") - - async def _run_on_model_error_callbacks( - *, - callback_context: CallbackContext, - llm_request: LlmRequest, - error: Exception, - ) -> Optional[LlmResponse]: - error_response = ( - await invocation_context.plugin_manager.run_on_model_error_callback( - callback_context=callback_context, - llm_request=llm_request, - error=error, - ) - ) - if error_response is not None: - return error_response - - for callback in agent.canonical_on_model_error_callbacks: - error_response = callback( - callback_context=callback_context, - llm_request=llm_request, - error=error, - ) - if inspect.isawaitable(error_response): - error_response = await error_response - if error_response is not None: - return error_response - - return None - - try: - async with Aclosing(response_generator) as agen: - async for response in agen: - yield response - except Exception as model_error: - callback_context = CallbackContext( - invocation_context, event_actions=model_response_event.actions - ) - error_response = await _run_on_model_error_callbacks( - callback_context=callback_context, - llm_request=llm_request, - error=model_error, - ) - if error_response is not None: - yield error_response - else: - raise model_error - - def __get_llm(self, invocation_context: InvocationContext) -> BaseLlm: - from ...agents.llm_agent import LlmAgent + error: Exception, + ) -> Optional[LlmResponse]: + error_response = ( + await invocation_context.plugin_manager.run_on_model_error_callback( + callback_context=callback_context, + llm_request=llm_request, + error=error, + ) + ) + if error_response is not None: + return error_response + + for callback in agent.canonical_on_model_error_callbacks: + error_response = callback( + callback_context=callback_context, + llm_request=llm_request, + error=error, + ) + if inspect.isawaitable(error_response): + error_response = await error_response + if error_response is not None: + return error_response + + return None + + try: + async with Aclosing(response_generator) as agen: + async for response in agen: + yield response + except Exception as model_error: + callback_context = CallbackContext( + invocation_context, event_actions=model_response_event.actions + ) + error_response = await _run_on_model_error_callbacks( + callback_context=callback_context, + llm_request=llm_request, + error=model_error, + ) + if error_response is not None: + yield error_response + else: + raise model_error + + def __get_llm(self, invocation_context: InvocationContext) -> BaseLlm: + from ...agents.llm_agent import LlmAgent - return cast(LlmAgent, invocation_context.agent).canonical_model + return cast(LlmAgent, invocation_context.agent).canonical_model diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 890be12d40..999e85293b 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -45,7 +45,7 @@ from ...utils.telemetry_utils import is_telemetry_enabled if TYPE_CHECKING: - from ...agents.llm_agent import LlmAgent + from ...agents.llm_agent import LlmAgent AF_FUNCTION_CALL_ID_PREFIX = "adk-" REQUEST_EUC_FUNCTION_CALL_NAME = "adk_request_credential" @@ -55,88 +55,90 @@ def generate_client_function_call_id() -> str: - return f"{AF_FUNCTION_CALL_ID_PREFIX}{uuid.uuid4()}" + return f"{AF_FUNCTION_CALL_ID_PREFIX}{uuid.uuid4()}" def populate_client_function_call_id(model_response_event: Event) -> None: - if not model_response_event.get_function_calls(): - return - for function_call in model_response_event.get_function_calls(): - if not function_call.id: - function_call.id = generate_client_function_call_id() + if not model_response_event.get_function_calls(): + return + for function_call in model_response_event.get_function_calls(): + if not function_call.id: + function_call.id = generate_client_function_call_id() def remove_client_function_call_id(content: Optional[types.Content]) -> None: - """Removes ADK-generated function call IDs from content before sending to LLM. - - Strips client-side function call/response IDs that start with 'adk-' prefix - to avoid sending internal tracking IDs to the model. - - Args: - content: Content containing function calls/responses to clean. - """ - if content and content.parts: - for part in content.parts: - if ( - part.function_call - and part.function_call.id - and part.function_call.id.startswith(AF_FUNCTION_CALL_ID_PREFIX) - ): - part.function_call.id = None - if ( - part.function_response - and part.function_response.id - and part.function_response.id.startswith(AF_FUNCTION_CALL_ID_PREFIX) - ): - part.function_response.id = None + """Removes ADK-generated function call IDs from content before sending to LLM. + + Strips client-side function call/response IDs that start with 'adk-' prefix + to avoid sending internal tracking IDs to the model. + + Args: + content: Content containing function calls/responses to clean. + """ + if content and content.parts: + for part in content.parts: + if ( + part.function_call + and part.function_call.id + and part.function_call.id.startswith(AF_FUNCTION_CALL_ID_PREFIX) + ): + part.function_call.id = None + if ( + part.function_response + and part.function_response.id + and part.function_response.id.startswith(AF_FUNCTION_CALL_ID_PREFIX) + ): + part.function_response.id = None def get_long_running_function_calls( function_calls: list[types.FunctionCall], tools_dict: dict[str, BaseTool], ) -> set[str]: - long_running_tool_ids = set() - for function_call in function_calls: - if ( - function_call.name in tools_dict - and tools_dict[function_call.name].is_long_running - ): - long_running_tool_ids.add(function_call.id) + long_running_tool_ids = set() + for function_call in function_calls: + if ( + function_call.name in tools_dict + and tools_dict[function_call.name].is_long_running + ): + long_running_tool_ids.add(function_call.id) - return long_running_tool_ids + return long_running_tool_ids def generate_auth_event( invocation_context: InvocationContext, function_response_event: Event, ) -> Optional[Event]: - if not function_response_event.actions.requested_auth_configs: - return None - parts = [] - long_running_tool_ids = set() - for ( - function_call_id, - auth_config, - ) in function_response_event.actions.requested_auth_configs.items(): - - request_euc_function_call = types.FunctionCall( - name=REQUEST_EUC_FUNCTION_CALL_NAME, - args=AuthToolArguments( - function_call_id=function_call_id, - auth_config=auth_config, - ).model_dump(exclude_none=True, by_alias=True), - ) - request_euc_function_call.id = generate_client_function_call_id() - long_running_tool_ids.add(request_euc_function_call.id) - parts.append(types.Part(function_call=request_euc_function_call)) - - return Event( - invocation_id=invocation_context.invocation_id, - author=invocation_context.agent.name, - branch=invocation_context.branch, - content=types.Content(parts=parts, role=function_response_event.content.role), - long_running_tool_ids=long_running_tool_ids, + if not function_response_event.actions.requested_auth_configs: + return None + parts = [] + long_running_tool_ids = set() + for ( + function_call_id, + auth_config, + ) in function_response_event.actions.requested_auth_configs.items(): + + request_euc_function_call = types.FunctionCall( + name=REQUEST_EUC_FUNCTION_CALL_NAME, + args=AuthToolArguments( + function_call_id=function_call_id, + auth_config=auth_config, + ).model_dump(exclude_none=True, by_alias=True), ) + request_euc_function_call.id = generate_client_function_call_id() + long_running_tool_ids.add(request_euc_function_call.id) + parts.append(types.Part(function_call=request_euc_function_call)) + + return Event( + invocation_id=invocation_context.invocation_id, + author=invocation_context.agent.name, + branch=invocation_context.branch, + content=types.Content( + parts=parts, role=function_response_event.content.role + ), + long_running_tool_ids=long_running_tool_ids, + ) def generate_request_confirmation_event( @@ -144,43 +146,45 @@ def generate_request_confirmation_event( function_call_event: Event, function_response_event: Event, ) -> Optional[Event]: - """Generates a request confirmation event from a function response event.""" - if not function_response_event.actions.requested_tool_confirmations: - return None - parts = [] - long_running_tool_ids = set() - function_calls = function_call_event.get_function_calls() - for ( - function_call_id, - tool_confirmation, - ) in function_response_event.actions.requested_tool_confirmations.items(): - original_function_call = next( - (fc for fc in function_calls if fc.id == function_call_id), None - ) - if not original_function_call: - continue - request_confirmation_function_call = types.FunctionCall( - name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME, - args={ - "originalFunctionCall": original_function_call.model_dump( - exclude_none=True, by_alias=True - ), - "toolConfirmation": tool_confirmation.model_dump( - by_alias=True, exclude_none=True - ), - }, - ) - request_confirmation_function_call.id = generate_client_function_call_id() - long_running_tool_ids.add(request_confirmation_function_call.id) - parts.append(types.Part(function_call=request_confirmation_function_call)) - - return Event( - invocation_id=invocation_context.invocation_id, - author=invocation_context.agent.name, - branch=invocation_context.branch, - content=types.Content(parts=parts, role=function_response_event.content.role), - long_running_tool_ids=long_running_tool_ids, + """Generates a request confirmation event from a function response event.""" + if not function_response_event.actions.requested_tool_confirmations: + return None + parts = [] + long_running_tool_ids = set() + function_calls = function_call_event.get_function_calls() + for ( + function_call_id, + tool_confirmation, + ) in function_response_event.actions.requested_tool_confirmations.items(): + original_function_call = next( + (fc for fc in function_calls if fc.id == function_call_id), None ) + if not original_function_call: + continue + request_confirmation_function_call = types.FunctionCall( + name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME, + args={ + "originalFunctionCall": original_function_call.model_dump( + exclude_none=True, by_alias=True + ), + "toolConfirmation": tool_confirmation.model_dump( + by_alias=True, exclude_none=True + ), + }, + ) + request_confirmation_function_call.id = generate_client_function_call_id() + long_running_tool_ids.add(request_confirmation_function_call.id) + parts.append(types.Part(function_call=request_confirmation_function_call)) + + return Event( + invocation_id=invocation_context.invocation_id, + author=invocation_context.agent.name, + branch=invocation_context.branch, + content=types.Content( + parts=parts, role=function_response_event.content.role + ), + long_running_tool_ids=long_running_tool_ids, + ) async def handle_function_calls_async( @@ -190,15 +194,15 @@ async def handle_function_calls_async( filters: Optional[set[str]] = None, tool_confirmation_dict: Optional[dict[str, ToolConfirmation]] = None, ) -> Optional[Event]: - """Calls the functions and returns the function response event.""" - function_calls = function_call_event.get_function_calls() - return await handle_function_call_list_async( - invocation_context, - function_calls, - tools_dict, - filters, - tool_confirmation_dict, - ) + """Calls the functions and returns the function response event.""" + function_calls = function_call_event.get_function_calls() + return await handle_function_call_list_async( + invocation_context, + function_calls, + tools_dict, + filters, + tool_confirmation_dict, + ) async def handle_function_call_list_async( @@ -208,58 +212,62 @@ async def handle_function_call_list_async( filters: Optional[set[str]] = None, tool_confirmation_dict: Optional[dict[str, ToolConfirmation]] = None, ) -> Optional[Event]: - """Calls the functions and returns the function response event.""" - from ...agents.llm_agent import LlmAgent - - agent = invocation_context.agent - - # Filter function calls - filtered_calls = [fc for fc in function_calls if not filters or fc.id in filters] + """Calls the functions and returns the function response event.""" + from ...agents.llm_agent import LlmAgent - if not filtered_calls: - return None - - # Create tasks for parallel execution - tasks = [ - asyncio.create_task( - _execute_single_function_call_async( - invocation_context, - function_call, - tools_dict, - agent, - ( - tool_confirmation_dict[function_call.id] - if tool_confirmation_dict - else None - ), - ) - ) - for function_call in filtered_calls - ] + agent = invocation_context.agent - # Wait for all tasks to complete - function_response_events = await asyncio.gather(*tasks) + # Filter function calls + filtered_calls = [ + fc for fc in function_calls if not filters or fc.id in filters + ] - # Filter out None results - function_response_events = [ - event for event in function_response_events if event is not None - ] + if not filtered_calls: + return None - if not function_response_events: - return None + # Create tasks for parallel execution + tasks = [ + asyncio.create_task( + _execute_single_function_call_async( + invocation_context, + function_call, + tools_dict, + agent, + ( + tool_confirmation_dict[function_call.id] + if tool_confirmation_dict + else None + ), + ) + ) + for function_call in filtered_calls + ] + + # Wait for all tasks to complete + function_response_events = await asyncio.gather(*tasks) + + # Filter out None results + function_response_events = [ + event for event in function_response_events if event is not None + ] + + if not function_response_events: + return None - merged_event = merge_parallel_function_response_events(function_response_events) + merged_event = merge_parallel_function_response_events( + function_response_events + ) - if len(function_response_events) > 1 and is_telemetry_enabled(agent): - # this is needed for debug traces of parallel calls - # individual response with tool.name is traced in __build_response_event - # (we drop tool.name from span name here as this is merged event) - with tracer.start_as_current_span("execute_tool (merged)"): - trace_merged_tool_calls( - response_event_id=merged_event.id, - function_response_event=merged_event, - ) - return merged_event + if len(function_response_events) > 1 and is_telemetry_enabled(agent): + # this is needed for debug traces of parallel calls + # individual response with tool.name is traced in __build_response_event + # (we drop tool.name from span name here as this is merged event) + with tracer.start_as_current_span("execute_tool (merged)"): + trace_merged_tool_calls( + response_event_id=merged_event.id, + function_response_event=merged_event, + ) + return merged_event async def _execute_single_function_call_async( @@ -269,54 +277,99 @@ async def _execute_single_function_call_async( agent: LlmAgent, tool_confirmation: Optional[ToolConfirmation] = None, ) -> Optional[Event]: - """Execute a single function call with thread safety for state modifications.""" - - async def _run_on_tool_error_callbacks( - *, - tool: BaseTool, - tool_args: dict[str, Any], - tool_context: ToolContext, - error: Exception, - ) -> Optional[dict[str, Any]]: - """Runs the on_tool_error_callbacks for the given tool.""" - error_response = ( - await invocation_context.plugin_manager.run_on_tool_error_callback( - tool=tool, - tool_args=tool_args, - tool_context=tool_context, - error=error, - ) + """Execute a single function call with thread safety for state modifications.""" + + async def _run_on_tool_error_callbacks( + *, + tool: BaseTool, + tool_args: dict[str, Any], + tool_context: ToolContext, + error: Exception, + ) -> Optional[dict[str, Any]]: + """Runs the on_tool_error_callbacks for the given tool.""" + error_response = ( + await invocation_context.plugin_manager.run_on_tool_error_callback( + tool=tool, + tool_args=tool_args, + tool_context=tool_context, + error=error, ) - if error_response is not None: - return error_response + ) + if error_response is not None: + return error_response + + for callback in agent.canonical_on_tool_error_callbacks: + error_response = callback( + tool=tool, + args=tool_args, + tool_context=tool_context, + error=error, + ) + if inspect.isawaitable(error_response): + error_response = await error_response + if error_response is not None: + return error_response - for callback in agent.canonical_on_tool_error_callbacks: - error_response = callback( - tool=tool, - args=tool_args, - tool_context=tool_context, - error=error, - ) - if inspect.isawaitable(error_response): - error_response = await error_response - if error_response is not None: - return error_response + return None - return None + # Do not use "args" as the variable name, because it is a reserved keyword + # in python debugger. + # Make a deep copy to avoid being modified. + function_args = ( + copy.deepcopy(function_call.args) if function_call.args else {} + ) - # Do not use "args" as the variable name, because it is a reserved keyword - # in python debugger. - # Make a deep copy to avoid being modified. - function_args = copy.deepcopy(function_call.args) if function_call.args else {} + tool_context = _create_tool_context( + invocation_context, function_call, tool_confirmation + ) - tool_context = _create_tool_context( - invocation_context, function_call, tool_confirmation + try: + tool = _get_tool(function_call, tools_dict) + except ValueError as tool_error: + tool = BaseTool(name=function_call.name, description="Tool not found") + error_response = await _run_on_tool_error_callbacks( + tool=tool, + tool_args=function_args, + tool_context=tool_context, + error=tool_error, + ) + if error_response is not None: + return __build_response_event( + tool, error_response, tool_context, invocation_context + ) + else: + raise tool_error + + async def _run_with_trace(): + nonlocal function_args + + # Step 1: Check if plugin before_tool_callback overrides the function + # response. + function_response = ( + await invocation_context.plugin_manager.run_before_tool_callback( + tool=tool, tool_args=function_args, tool_context=tool_context + ) ) - try: - tool = _get_tool(function_call, tools_dict) - except ValueError as tool_error: - tool = BaseTool(name=function_call.name, description="Tool not found") + # Step 2: If no overrides are provided from the plugins, further run the + # canonical callback. + if function_response is None: + for callback in agent.canonical_before_tool_callbacks: + function_response = callback( + tool=tool, args=function_args, tool_context=tool_context + ) + if inspect.isawaitable(function_response): + function_response = await function_response + if function_response: + break + + # Step 3: Otherwise, proceed calling the tool normally. + if function_response is None: + try: + function_response = await __call_tool_async( + tool, args=function_args, tool_context=tool_context + ) + except Exception as tool_error: error_response = await _run_on_tool_error_callbacks( tool=tool, tool_args=function_args, @@ -324,117 +377,74 @@ async def _run_on_tool_error_callbacks( error=tool_error, ) if error_response is not None: - return __build_response_event( - tool, error_response, tool_context, invocation_context - ) + function_response = error_response else: - raise tool_error - - async def _run_with_trace(): - nonlocal function_args + raise tool_error - # Step 1: Check if plugin before_tool_callback overrides the function - # response. - function_response = ( - await invocation_context.plugin_manager.run_before_tool_callback( - tool=tool, tool_args=function_args, tool_context=tool_context - ) + # Step 4: Check if plugin after_tool_callback overrides the function + # response. + altered_function_response = ( + await invocation_context.plugin_manager.run_after_tool_callback( + tool=tool, + tool_args=function_args, + tool_context=tool_context, + result=function_response, ) + ) - # Step 2: If no overrides are provided from the plugins, further run the - # canonical callback. - if function_response is None: - for callback in agent.canonical_before_tool_callbacks: - function_response = callback( - tool=tool, args=function_args, tool_context=tool_context - ) - if inspect.isawaitable(function_response): - function_response = await function_response - if function_response: - break - - # Step 3: Otherwise, proceed calling the tool normally. - if function_response is None: - try: - function_response = await __call_tool_async( - tool, args=function_args, tool_context=tool_context - ) - except Exception as tool_error: - error_response = await _run_on_tool_error_callbacks( - tool=tool, - tool_args=function_args, - tool_context=tool_context, - error=tool_error, - ) - if error_response is not None: - function_response = error_response - else: - raise tool_error - - # Step 4: Check if plugin after_tool_callback overrides the function - # response. - altered_function_response = ( - await invocation_context.plugin_manager.run_after_tool_callback( - tool=tool, - tool_args=function_args, - tool_context=tool_context, - result=function_response, - ) + # Step 5: If no overrides are provided from the plugins, further run the + # canonical after_tool_callbacks. + if altered_function_response is None: + for callback in agent.canonical_after_tool_callbacks: + altered_function_response = callback( + tool=tool, + args=function_args, + tool_context=tool_context, + tool_response=function_response, ) + if inspect.isawaitable(altered_function_response): + altered_function_response = await altered_function_response + if altered_function_response: + break + + # Step 6: If alternative response exists from after_tool_callback, use it + # instead of the original function response. + if altered_function_response is not None: + function_response = altered_function_response + + if tool.is_long_running: + # Allow long running function to return None to not provide function + # response. + if not function_response: + return None + + # Note: State deltas are not applied here - they are collected in + # tool_context.actions.state_delta and applied later when the session + # service processes the events - # Step 5: If no overrides are provided from the plugins, further run the - # canonical after_tool_callbacks. - if altered_function_response is None: - for callback in agent.canonical_after_tool_callbacks: - altered_function_response = callback( - tool=tool, - args=function_args, - tool_context=tool_context, - tool_response=function_response, - ) - if inspect.isawaitable(altered_function_response): - altered_function_response = await altered_function_response - if altered_function_response: - break - - # Step 6: If alternative response exists from after_tool_callback, use it - # instead of the original function response. - if altered_function_response is not None: - function_response = altered_function_response - - if tool.is_long_running: - # Allow long running function to return None to not provide function - # response. - if not function_response: - return None - - # Note: State deltas are not applied here - they are collected in - # tool_context.actions.state_delta and applied later when the session - # service processes the events - - # Builds the function response event. - function_response_event = __build_response_event( - tool, function_response, tool_context, invocation_context + # Builds the function response event. + function_response_event = __build_response_event( + tool, function_response, tool_context, invocation_context + ) + return function_response_event + + if is_telemetry_enabled(agent): + with tracer.start_as_current_span(f"execute_tool {tool.name}"): + try: + function_response_event = await _run_with_trace() + trace_tool_call( + tool=tool, + args=function_args, + function_response_event=function_response_event, ) return function_response_event - - if is_telemetry_enabled(agent): - with tracer.start_as_current_span(f"execute_tool {tool.name}"): - try: - function_response_event = await _run_with_trace() - trace_tool_call( - tool=tool, - args=function_args, - function_response_event=function_response_event, - ) - return function_response_event - except: - trace_tool_call( - tool=tool, args=function_args, function_response_event=None - ) - raise - else: - return await _run_with_trace() + except: + trace_tool_call( + tool=tool, args=function_args, function_response_event=None + ) + raise + else: + return await _run_with_trace() async def handle_function_calls_live( @@ -442,55 +452,57 @@ async def handle_function_calls_live( function_call_event: Event, tools_dict: dict[str, BaseTool], ) -> Event: - """Calls the functions and returns the function response event.""" - from ...agents.llm_agent import LlmAgent - - agent = cast(LlmAgent, invocation_context.agent) - function_calls = function_call_event.get_function_calls() - - if not function_calls: - return None - - # Create async lock for active_streaming_tools modifications - streaming_lock = asyncio.Lock() - - # Create tasks for parallel execution - tasks = [ - asyncio.create_task( - _execute_single_function_call_live( - invocation_context, - function_call, - tools_dict, - agent, - streaming_lock, - ) - ) - for function_call in function_calls - ] + """Calls the functions and returns the function response event.""" + from ...agents.llm_agent import LlmAgent - # Wait for all tasks to complete - function_response_events = await asyncio.gather(*tasks) + agent = cast(LlmAgent, invocation_context.agent) + function_calls = function_call_event.get_function_calls() - # Filter out None results - function_response_events = [ - event for event in function_response_events if event is not None - ] + if not function_calls: + return None - if not function_response_events: - return None + # Create async lock for active_streaming_tools modifications + streaming_lock = asyncio.Lock() + + # Create tasks for parallel execution + tasks = [ + asyncio.create_task( + _execute_single_function_call_live( + invocation_context, + function_call, + tools_dict, + agent, + streaming_lock, + ) + ) + for function_call in function_calls + ] + + # Wait for all tasks to complete + function_response_events = await asyncio.gather(*tasks) + + # Filter out None results + function_response_events = [ + event for event in function_response_events if event is not None + ] + + if not function_response_events: + return None - merged_event = merge_parallel_function_response_events(function_response_events) + merged_event = merge_parallel_function_response_events( + function_response_events + ) - if len(function_response_events) > 1 and is_telemetry_enabled(agent): - # this is needed for debug traces of parallel calls - # individual response with tool.name is traced in __build_response_event - # (we drop tool.name from span name here as this is merged event) - with tracer.start_as_current_span("execute_tool (merged)"): - trace_merged_tool_calls( - response_event_id=merged_event.id, - function_response_event=merged_event, - ) - return merged_event + if len(function_response_events) > 1 and is_telemetry_enabled(agent): + # this is needed for debug traces of parallel calls + # individual response with tool.name is traced in __build_response_event + # (we drop tool.name from span name here as this is merged event) + with tracer.start_as_current_span("execute_tool (merged)"): + trace_merged_tool_calls( + response_event_id=merged_event.id, + function_response_event=merged_event, + ) + return merged_event async def _execute_single_function_call_live( @@ -500,91 +512,93 @@ async def _execute_single_function_call_live( agent: LlmAgent, streaming_lock: asyncio.Lock, ) -> Optional[Event]: - """Execute a single function call for live mode with thread safety.""" - tool, tool_context = _get_tool_and_context( - invocation_context, function_call, tools_dict - ) + """Execute a single function call for live mode with thread safety.""" + tool, tool_context = _get_tool_and_context( + invocation_context, function_call, tools_dict + ) - function_args = copy.deepcopy(function_call.args) if function_call.args else {} + function_args = ( + copy.deepcopy(function_call.args) if function_call.args else {} + ) - async def _run_with_trace(): - nonlocal function_args + async def _run_with_trace(): + nonlocal function_args - # Do not use "args" as the variable name, because it is a reserved keyword - # in python debugger. - # Make a deep copy to avoid being modified. - function_response = None + # Do not use "args" as the variable name, because it is a reserved keyword + # in python debugger. + # Make a deep copy to avoid being modified. + function_response = None - # Handle before_tool_callbacks - iterate through the canonical callback - # list - for callback in agent.canonical_before_tool_callbacks: - function_response = callback( - tool=tool, args=function_args, tool_context=tool_context - ) - if inspect.isawaitable(function_response): - function_response = await function_response - if function_response: - break - - if function_response is None: - function_response = await _process_function_live_helper( - tool, - tool_context, - function_call, - function_args, - invocation_context, - streaming_lock, - ) + # Handle before_tool_callbacks - iterate through the canonical callback + # list + for callback in agent.canonical_before_tool_callbacks: + function_response = callback( + tool=tool, args=function_args, tool_context=tool_context + ) + if inspect.isawaitable(function_response): + function_response = await function_response + if function_response: + break + + if function_response is None: + function_response = await _process_function_live_helper( + tool, + tool_context, + function_call, + function_args, + invocation_context, + streaming_lock, + ) + + # Calls after_tool_callback if it exists. + altered_function_response = None + for callback in agent.canonical_after_tool_callbacks: + altered_function_response = callback( + tool=tool, + args=function_args, + tool_context=tool_context, + tool_response=function_response, + ) + if inspect.isawaitable(altered_function_response): + altered_function_response = await altered_function_response + if altered_function_response: + break + + if altered_function_response is not None: + function_response = altered_function_response + + if tool.is_long_running: + # Allow async function to return None to not provide function response. + if not function_response: + return None - # Calls after_tool_callback if it exists. - altered_function_response = None - for callback in agent.canonical_after_tool_callbacks: - altered_function_response = callback( - tool=tool, - args=function_args, - tool_context=tool_context, - tool_response=function_response, - ) - if inspect.isawaitable(altered_function_response): - altered_function_response = await altered_function_response - if altered_function_response: - break - - if altered_function_response is not None: - function_response = altered_function_response - - if tool.is_long_running: - # Allow async function to return None to not provide function response. - if not function_response: - return None - - # Note: State deltas are not applied here - they are collected in - # tool_context.actions.state_delta and applied later when the session - # service processes the events - - # Builds the function response event. - function_response_event = __build_response_event( - tool, function_response, tool_context, invocation_context + # Note: State deltas are not applied here - they are collected in + # tool_context.actions.state_delta and applied later when the session + # service processes the events + + # Builds the function response event. + function_response_event = __build_response_event( + tool, function_response, tool_context, invocation_context + ) + return function_response_event + + if is_telemetry_enabled(agent): + with tracer.start_as_current_span(f"execute_tool {tool.name}"): + try: + function_response_event = await _run_with_trace() + trace_tool_call( + tool=tool, + args=function_args, + function_response_event=function_response_event, ) return function_response_event - - if is_telemetry_enabled(agent): - with tracer.start_as_current_span(f"execute_tool {tool.name}"): - try: - function_response_event = await _run_with_trace() - trace_tool_call( - tool=tool, - args=function_args, - function_response_event=function_response_event, - ) - return function_response_event - except: - trace_tool_call( - tool=tool, args=function_args, function_response_event=None - ) - raise - else: - return await _run_with_trace() + except: + trace_tool_call( + tool=tool, args=function_args, function_response_event=None + ) + raise + else: + return await _run_with_trace() async def _process_function_live_helper( @@ -595,134 +609,136 @@ async def _process_function_live_helper( invocation_context, streaming_lock: asyncio.Lock, ): - function_response = None - # Check if this is a stop_streaming function call - if function_call.name == "stop_streaming" and "function_name" in function_args: - function_name = function_args["function_name"] - # Thread-safe access to active_streaming_tools - async with streaming_lock: - active_tasks = invocation_context.active_streaming_tools - if ( - active_tasks - and function_name in active_tasks - and active_tasks[function_name].task - and not active_tasks[function_name].task.done() - ): - task = active_tasks[function_name].task - else: - task = None - - if task: - task.cancel() - try: - # Wait for the task to be cancelled - await asyncio.wait_for(task, timeout=1.0) - except (asyncio.CancelledError, asyncio.TimeoutError): - # Log the specific condition - if task.cancelled(): - logging.info("Task %s was cancelled successfully", function_name) - elif task.done(): - logging.info("Task %s completed during cancellation", function_name) - else: - logging.warning( - "Task %s might still be running after cancellation timeout", - function_name, - ) - function_response = { - "status": f"The task is not cancelled yet for {function_name}." - } - if not function_response: - # Clean up the reference under lock - async with streaming_lock: - if ( - invocation_context.active_streaming_tools - and function_name in invocation_context.active_streaming_tools - ): - invocation_context.active_streaming_tools[ - function_name - ].task = None - - function_response = { - "status": f"Successfully stopped streaming function {function_name}" - } + function_response = None + # Check if this is a stop_streaming function call + if ( + function_call.name == "stop_streaming" + and "function_name" in function_args + ): + function_name = function_args["function_name"] + # Thread-safe access to active_streaming_tools + async with streaming_lock: + active_tasks = invocation_context.active_streaming_tools + if ( + active_tasks + and function_name in active_tasks + and active_tasks[function_name].task + and not active_tasks[function_name].task.done() + ): + task = active_tasks[function_name].task + else: + task = None + + if task: + task.cancel() + try: + # Wait for the task to be cancelled + await asyncio.wait_for(task, timeout=1.0) + except (asyncio.CancelledError, asyncio.TimeoutError): + # Log the specific condition + if task.cancelled(): + logging.info("Task %s was cancelled successfully", function_name) + elif task.done(): + logging.info("Task %s completed during cancellation", function_name) else: - function_response = { - "status": f"No active streaming function named {function_name} found" - } - elif hasattr(tool, "func") and inspect.isasyncgenfunction(tool.func): - # for streaming tool use case - # we require the function to be an async generator function - async def run_tool_and_update_queue(tool, function_args, tool_context): - try: - async with Aclosing( - __call_tool_live( - tool=tool, - args=function_args, - tool_context=tool_context, - invocation_context=invocation_context, - ) - ) as agen: - async for result in agen: - updated_content = types.Content( - role="user", - parts=[ - types.Part.from_text( - text=f"Function {tool.name} returned: {result}" - ) - ], - ) - invocation_context.live_request_queue.send_content( - updated_content - ) - except asyncio.CancelledError: - raise # Re-raise to properly propagate the cancellation - - task = asyncio.create_task( - run_tool_and_update_queue(tool, function_args, tool_context) - ) - - # Register streaming tool using original logic + logging.warning( + "Task %s might still be running after cancellation timeout", + function_name, + ) + function_response = { + "status": f"The task is not cancelled yet for {function_name}." + } + if not function_response: + # Clean up the reference under lock async with streaming_lock: - if invocation_context.active_streaming_tools is None: - invocation_context.active_streaming_tools = {} - - if tool.name in invocation_context.active_streaming_tools: - invocation_context.active_streaming_tools[tool.name].task = task - else: - invocation_context.active_streaming_tools[tool.name] = ( - ActiveStreamingTool(task=task) - ) - - # Immediately return a pending response. - # This is required by current live model. + if ( + invocation_context.active_streaming_tools + and function_name in invocation_context.active_streaming_tools + ): + invocation_context.active_streaming_tools[function_name].task = None + function_response = { - "status": ( - "The function is running asynchronously and the results are" " pending." - ) + "status": f"Successfully stopped streaming function {function_name}" } else: - function_response = await __call_tool_async( - tool, args=function_args, tool_context=tool_context + function_response = { + "status": f"No active streaming function named {function_name} found" + } + elif hasattr(tool, "func") and inspect.isasyncgenfunction(tool.func): + # for streaming tool use case + # we require the function to be an async generator function + async def run_tool_and_update_queue(tool, function_args, tool_context): + try: + async with Aclosing( + __call_tool_live( + tool=tool, + args=function_args, + tool_context=tool_context, + invocation_context=invocation_context, + ) + ) as agen: + async for result in agen: + updated_content = types.Content( + role="user", + parts=[ + types.Part.from_text( + text=f"Function {tool.name} returned: {result}" + ) + ], + ) + invocation_context.live_request_queue.send_content(updated_content) + except asyncio.CancelledError: + raise # Re-raise to properly propagate the cancellation + + task = asyncio.create_task( + run_tool_and_update_queue(tool, function_args, tool_context) + ) + + # Register streaming tool using original logic + async with streaming_lock: + if invocation_context.active_streaming_tools is None: + invocation_context.active_streaming_tools = {} + + if tool.name in invocation_context.active_streaming_tools: + invocation_context.active_streaming_tools[tool.name].task = task + else: + invocation_context.active_streaming_tools[tool.name] = ( + ActiveStreamingTool(task=task) ) - return function_response - - -def _get_tool(function_call: types.FunctionCall, tools_dict: dict[str, BaseTool]): - """Returns the tool corresponding to the function call.""" - if function_call.name not in tools_dict: - available = list(tools_dict.keys()) - error_msg = ( - f"Tool '{function_call.name}' not found.\nAvailable tools:" - f" {', '.join(available)}\n\nPossible causes:\n 1. LLM hallucinated" - " the function name - review agent instruction clarity\n 2. Tool not" - " registered - verify agent.tools list\n 3. Name mismatch - check for" - " typos\n\nSuggested fixes:\n - Review agent instruction to ensure" - " tool usage is clear\n - Verify tool is included in agent.tools" - " list\n - Check for typos in function name" + + # Immediately return a pending response. + # This is required by current live model. + function_response = { + "status": ( + "The function is running asynchronously and the results are" + " pending." ) - raise ValueError(error_msg) + } + else: + function_response = await __call_tool_async( + tool, args=function_args, tool_context=tool_context + ) + return function_response - return tools_dict[function_call.name] + +def _get_tool( + function_call: types.FunctionCall, tools_dict: dict[str, BaseTool] +): + """Returns the tool corresponding to the function call.""" + if function_call.name not in tools_dict: + available = list(tools_dict.keys()) + error_msg = ( + f"Tool '{function_call.name}' not found.\nAvailable tools:" + f" {', '.join(available)}\n\nPossible causes:\n 1. LLM hallucinated" + " the function name - review agent instruction clarity\n 2. Tool not" + " registered - verify agent.tools list\n 3. Name mismatch - check for" + " typos\n\nSuggested fixes:\n - Review agent instruction to ensure" + " tool usage is clear\n - Verify tool is included in agent.tools" + " list\n - Check for typos in function name" + ) + raise ValueError(error_msg) + + return tools_dict[function_call.name] def _create_tool_context( @@ -730,12 +746,12 @@ def _create_tool_context( function_call: types.FunctionCall, tool_confirmation: Optional[ToolConfirmation] = None, ): - """Creates a ToolContext object.""" - return ToolContext( - invocation_context=invocation_context, - function_call_id=function_call.id, - tool_confirmation=tool_confirmation, - ) + """Creates a ToolContext object.""" + return ToolContext( + invocation_context=invocation_context, + function_call_id=function_call.id, + tool_confirmation=tool_confirmation, + ) def _get_tool_and_context( @@ -744,15 +760,15 @@ def _get_tool_and_context( tools_dict: dict[str, BaseTool], tool_confirmation: Optional[ToolConfirmation] = None, ): - """Returns the tool and tool context corresponding to the function call.""" - tool = _get_tool(function_call, tools_dict) - tool_context = _create_tool_context( - invocation_context, - function_call, - tool_confirmation, - ) + """Returns the tool and tool context corresponding to the function call.""" + tool = _get_tool(function_call, tools_dict) + tool_context = _create_tool_context( + invocation_context, + function_call, + tool_confirmation, + ) - return (tool, tool_context) + return (tool, tool_context) async def __call_tool_live( @@ -761,16 +777,16 @@ async def __call_tool_live( tool_context: ToolContext, invocation_context: InvocationContext, ) -> AsyncGenerator[Event, None]: - """Calls the tool asynchronously (awaiting the coroutine).""" - async with Aclosing( - tool._call_live( - args=args, - tool_context=tool_context, - invocation_context=invocation_context, - ) - ) as agen: - async for item in agen: - yield item + """Calls the tool asynchronously (awaiting the coroutine).""" + async with Aclosing( + tool._call_live( + args=args, + tool_context=tool_context, + invocation_context=invocation_context, + ) + ) as agen: + async for item in agen: + yield item async def __call_tool_async( @@ -778,8 +794,8 @@ async def __call_tool_async( args: dict[str, Any], tool_context: ToolContext, ) -> Any: - """Calls the tool.""" - return await tool.run_async(args=args, tool_context=tool_context) + """Calls the tool.""" + return await tool.run_async(args=args, tool_context=tool_context) def __build_response_event( @@ -788,111 +804,111 @@ def __build_response_event( tool_context: ToolContext, invocation_context: InvocationContext, ) -> Event: - # Specs requires the result to be a dict. - if not isinstance(function_result, dict): - function_result = {"result": function_result} + # Specs requires the result to be a dict. + if not isinstance(function_result, dict): + function_result = {"result": function_result} - part_function_response = types.Part.from_function_response( - name=tool.name, response=function_result - ) - part_function_response.function_response.id = tool_context.function_call_id + part_function_response = types.Part.from_function_response( + name=tool.name, response=function_result + ) + part_function_response.function_response.id = tool_context.function_call_id - content = types.Content( - role="user", - parts=[part_function_response], - ) + content = types.Content( + role="user", + parts=[part_function_response], + ) - function_response_event = Event( - invocation_id=invocation_context.invocation_id, - author=invocation_context.agent.name, - content=content, - actions=tool_context.actions, - branch=invocation_context.branch, - ) + function_response_event = Event( + invocation_id=invocation_context.invocation_id, + author=invocation_context.agent.name, + content=content, + actions=tool_context.actions, + branch=invocation_context.branch, + ) - return function_response_event + return function_response_event def deep_merge_dicts(d1: dict, d2: dict) -> dict: - """Recursively merges d2 into d1.""" - for key, value in d2.items(): - if key in d1 and isinstance(d1[key], dict) and isinstance(value, dict): - d1[key] = deep_merge_dicts(d1[key], value) - else: - d1[key] = value - return d1 + """Recursively merges d2 into d1.""" + for key, value in d2.items(): + if key in d1 and isinstance(d1[key], dict) and isinstance(value, dict): + d1[key] = deep_merge_dicts(d1[key], value) + else: + d1[key] = value + return d1 def merge_parallel_function_response_events( function_response_events: list["Event"], ) -> "Event": - if not function_response_events: - raise ValueError("No function response events provided.") - - if len(function_response_events) == 1: - return function_response_events[0] - merged_parts = [] - for event in function_response_events: - if event.content: - for part in event.content.parts or []: - merged_parts.append(part) - - # Use the first event as the "base" for common attributes - base_event = function_response_events[0] - - # Merge actions from all events - merged_actions_data: dict[str, Any] = {} - for event in function_response_events: - if event.actions: - # Use `by_alias=True` because it converts the model to a dictionary while respecting field aliases, ensuring that the enum fields are correctly handled without creating a duplicate. - merged_actions_data = deep_merge_dicts( - merged_actions_data, - event.actions.model_dump(exclude_none=True, by_alias=True), - ) - - merged_actions = EventActions.model_validate(merged_actions_data) - - # Create the new merged event - merged_event = Event( - invocation_id=base_event.invocation_id, - author=base_event.author, - branch=base_event.branch, - content=types.Content(role="user", parts=merged_parts), - actions=merged_actions, # Optionally merge actions if required - ) - - # Use the base_event as the timestamp - merged_event.timestamp = base_event.timestamp - return merged_event + if not function_response_events: + raise ValueError("No function response events provided.") + + if len(function_response_events) == 1: + return function_response_events[0] + merged_parts = [] + for event in function_response_events: + if event.content: + for part in event.content.parts or []: + merged_parts.append(part) + + # Use the first event as the "base" for common attributes + base_event = function_response_events[0] + + # Merge actions from all events + merged_actions_data: dict[str, Any] = {} + for event in function_response_events: + if event.actions: + # Use `by_alias=True` because it converts the model to a dictionary while respecting field aliases, ensuring that the enum fields are correctly handled without creating a duplicate. + merged_actions_data = deep_merge_dicts( + merged_actions_data, + event.actions.model_dump(exclude_none=True, by_alias=True), + ) + + merged_actions = EventActions.model_validate(merged_actions_data) + + # Create the new merged event + merged_event = Event( + invocation_id=base_event.invocation_id, + author=base_event.author, + branch=base_event.branch, + content=types.Content(role="user", parts=merged_parts), + actions=merged_actions, # Optionally merge actions if required + ) + + # Use the base_event as the timestamp + merged_event.timestamp = base_event.timestamp + return merged_event def find_matching_function_call( events: list[Event], ) -> Optional[Event]: - """Finds the function call event that matches the function response id of the last event.""" - if not events: - return None - - last_event = events[-1] - if ( - last_event.content - and last_event.content.parts - and any(part.function_response for part in last_event.content.parts) - ): - - function_call_id = next( - part.function_response.id - for part in last_event.content.parts - if part.function_response - ) - for i in range(len(events) - 2, -1, -1): - event = events[i] - # looking for the system long running request euc function call - function_calls = event.get_function_calls() - if not function_calls: - continue - - for function_call in function_calls: - if function_call.id == function_call_id: - return event + """Finds the function call event that matches the function response id of the last event.""" + if not events: return None + + last_event = events[-1] + if ( + last_event.content + and last_event.content.parts + and any(part.function_response for part in last_event.content.parts) + ): + + function_call_id = next( + part.function_response.id + for part in last_event.content.parts + if part.function_response + ) + for i in range(len(events) - 2, -1, -1): + event = events[i] + # looking for the system long running request euc function call + function_calls = event.get_function_calls() + if not function_calls: + continue + + for function_call in function_calls: + if function_call.id == function_call_id: + return event + return None diff --git a/src/google/adk/models/gemini_context_cache_manager.py b/src/google/adk/models/gemini_context_cache_manager.py index 1606ea11f4..9747c1043c 100644 --- a/src/google/adk/models/gemini_context_cache_manager.py +++ b/src/google/adk/models/gemini_context_cache_manager.py @@ -19,12 +19,12 @@ import hashlib import json import logging -from opentelemetry.trace import Span import time from typing import Optional from typing import TYPE_CHECKING from google.genai import types +from opentelemetry.trace import Span from ..utils.feature_decorator import experimental from .cache_metadata import CacheMetadata @@ -34,461 +34,470 @@ logger = logging.getLogger("google_adk." + __name__) if TYPE_CHECKING: - from google.genai import Client + from google.genai import Client @experimental class GeminiContextCacheManager: - """Manages context cache lifecycle for Gemini models. + """Manages context cache lifecycle for Gemini models. + + This manager handles cache creation, validation, cleanup, and metadata + population for Gemini context caching. It uses content hashing to determine + cache compatibility and implements efficient caching strategies. + """ + + def __init__(self, genai_client: Client, disable_telemetry: bool = False): + """Initialize cache manager with shared client. - This manager handles cache creation, validation, cleanup, and metadata - population for Gemini context caching. It uses content hashing to determine - cache compatibility and implements efficient caching strategies. + Args: + genai_client: The GenAI client to use for cache operations. + disable_telemetry: A bool to flag whether or not telemetry should be disabled. """ + self.genai_client = genai_client + self.disable_telemetry = disable_telemetry + + async def handle_context_caching( + self, llm_request: LlmRequest + ) -> Optional[CacheMetadata]: + """Handle context caching for Gemini models. + + Validates existing cache or creates a new one if needed. Applies + the cache to the request by setting cached_content and removing cached + contents from the request. + + Args: + llm_request: Request that may contain cache config and metadata. + Modified in-place to use the cache. - def __init__(self, genai_client: Client, disable_telemetry: bool = False): - """Initialize cache manager with shared client. - - Args: - genai_client: The GenAI client to use for cache operations. - disable_telemetry: A bool to flag whether or not telemetry should be disabled. - """ - self.genai_client = genai_client - self.disable_telemetry = disable_telemetry - - async def handle_context_caching( - self, llm_request: LlmRequest - ) -> Optional[CacheMetadata]: - """Handle context caching for Gemini models. - - Validates existing cache or creates a new one if needed. Applies - the cache to the request by setting cached_content and removing cached - contents from the request. - - Args: - llm_request: Request that may contain cache config and metadata. - Modified in-place to use the cache. - - Returns: - Cache metadata to be included in response, or None if caching failed - """ - # Check if we have existing cache metadata and if it's valid - if llm_request.cache_metadata: - logger.debug( - "Found existing cache metadata: %s", - llm_request.cache_metadata, + Returns: + Cache metadata to be included in response, or None if caching failed + """ + # Check if we have existing cache metadata and if it's valid + if llm_request.cache_metadata: + logger.debug( + "Found existing cache metadata: %s", + llm_request.cache_metadata, + ) + if await self._is_cache_valid(llm_request): + # Valid cache found - use it + logger.debug( + "Cache is valid, reusing cache: %s", + llm_request.cache_metadata.cache_name, + ) + cache_name = llm_request.cache_metadata.cache_name + cache_contents_count = llm_request.cache_metadata.contents_count + self._apply_cache_to_request( + llm_request, cache_name, cache_contents_count + ) + return llm_request.cache_metadata.model_copy() + else: + # Invalid cache - clean it up and check if we should create new one + old_cache_metadata = llm_request.cache_metadata + + # Only cleanup if there's an active cache + if old_cache_metadata.cache_name is not None: + logger.debug( + "Cache is invalid, cleaning up: %s", + old_cache_metadata.cache_name, + ) + await self.cleanup_cache(old_cache_metadata.cache_name) + + # Calculate current fingerprint using contents count from old metadata + cache_contents_count = old_cache_metadata.contents_count + current_fingerprint = self._generate_cache_fingerprint( + llm_request, cache_contents_count + ) + + # If fingerprints match, create new cache (expired but same content) + if current_fingerprint == old_cache_metadata.fingerprint: + logger.debug( + "Fingerprints match after invalidation, creating new cache" + ) + cache_metadata = await self._create_new_cache_with_contents( + llm_request, cache_contents_count + ) + if cache_metadata: + self._apply_cache_to_request( + llm_request, cache_metadata.cache_name, cache_contents_count ) - if await self._is_cache_valid(llm_request): - # Valid cache found - use it - logger.debug( - "Cache is valid, reusing cache: %s", - llm_request.cache_metadata.cache_name, - ) - cache_name = llm_request.cache_metadata.cache_name - cache_contents_count = llm_request.cache_metadata.contents_count - self._apply_cache_to_request( - llm_request, cache_name, cache_contents_count - ) - return llm_request.cache_metadata.model_copy() - else: - # Invalid cache - clean it up and check if we should create new one - old_cache_metadata = llm_request.cache_metadata - - # Only cleanup if there's an active cache - if old_cache_metadata.cache_name is not None: - logger.debug( - "Cache is invalid, cleaning up: %s", - old_cache_metadata.cache_name, - ) - await self.cleanup_cache(old_cache_metadata.cache_name) - - # Calculate current fingerprint using contents count from old metadata - cache_contents_count = old_cache_metadata.contents_count - current_fingerprint = self._generate_cache_fingerprint( - llm_request, cache_contents_count - ) - - # If fingerprints match, create new cache (expired but same content) - if current_fingerprint == old_cache_metadata.fingerprint: - logger.debug( - "Fingerprints match after invalidation, creating new cache" - ) - cache_metadata = await self._create_new_cache_with_contents( - llm_request, cache_contents_count - ) - if cache_metadata: - self._apply_cache_to_request( - llm_request, cache_metadata.cache_name, cache_contents_count - ) - return cache_metadata - - # Fingerprints don't match - recalculate with total contents - logger.debug( - "Fingerprints don't match, returning fingerprint-only metadata" - ) - total_contents_count = len(llm_request.contents) - fingerprint_for_all = self._generate_cache_fingerprint( - llm_request, total_contents_count - ) - return CacheMetadata( - fingerprint=fingerprint_for_all, - contents_count=total_contents_count, - ) - - # No existing cache metadata - return fingerprint-only metadata - # We don't create cache without previous fingerprint to match - logger.debug("No existing cache metadata, creating fingerprint-only metadata") + return cache_metadata + + # Fingerprints don't match - recalculate with total contents + logger.debug( + "Fingerprints don't match, returning fingerprint-only metadata" + ) total_contents_count = len(llm_request.contents) - fingerprint = self._generate_cache_fingerprint( + fingerprint_for_all = self._generate_cache_fingerprint( llm_request, total_contents_count ) return CacheMetadata( - fingerprint=fingerprint, + fingerprint=fingerprint_for_all, contents_count=total_contents_count, ) - def _find_count_of_contents_to_cache(self, contents: list[types.Content]) -> int: - """Find the number of contents to cache based on user content strategy. - - Strategy: Find the last continuous batch of user contents and cache - all contents before them. - - Args: - contents: List of contents from the LLM request - - Returns: - Number of contents to cache (can be 0 if all contents are user contents) - """ - if not contents: - return 0 - - # Find the last continuous batch of user contents - last_user_batch_start = len(contents) - - # Scan backwards to find the start of the last user content batch - for i in range(len(contents) - 1, -1, -1): - if contents[i].role == "user": - last_user_batch_start = i - else: - # Found non-user content, stop the batch - break - - # Cache all contents before the last user batch - # This ensures we always have some user content to send to the API - return last_user_batch_start - - async def _is_cache_valid(self, llm_request: LlmRequest) -> bool: - """Check if the cache from request metadata is still valid. - - Validates that it's an active cache (not fingerprint-only), checks expiry, - cache intervals, and fingerprint compatibility. - - Args: - llm_request: Request containing cache metadata to validate - - Returns: - True if cache is valid, False otherwise - """ - cache_metadata = llm_request.cache_metadata - if not cache_metadata: - return False - - # Fingerprint-only metadata is not a valid active cache - if cache_metadata.cache_name is None: - return False - - # Check if cache has expired - if time.time() >= cache_metadata.expire_time: - logger.info("Cache expired: %s", cache_metadata.cache_name) - return False - - # Check if cache has been used for too many invocations - if cache_metadata.invocations_used > llm_request.cache_config.cache_intervals: - logger.info( - "Cache exceeded cache intervals: %s (%d > %d intervals)", - cache_metadata.cache_name, - cache_metadata.invocations_used, - llm_request.cache_config.cache_intervals, - ) - return False - - # Check if fingerprint matches using cached contents count - current_fingerprint = self._generate_cache_fingerprint( - llm_request, cache_metadata.contents_count - ) - if current_fingerprint != cache_metadata.fingerprint: - logger.debug("Cache content fingerprint mismatch") - return False - - return True - - def _generate_cache_fingerprint( - self, llm_request: LlmRequest, cache_contents_count: int - ) -> str: - """Generate a fingerprint for cache validation. - - Includes system instruction, tools, tool_config, and first N contents. - - Args: - llm_request: Request to generate fingerprint for - cache_contents_count: Number of contents to include in fingerprint - - Returns: - 16-character hexadecimal fingerprint representing the cached state - """ - # Create fingerprint from system instruction, tools, tool_config, and first N contents - fingerprint_data = {} - - if llm_request.config and llm_request.config.system_instruction: - fingerprint_data["system_instruction"] = ( - llm_request.config.system_instruction - ) - - if llm_request.config and llm_request.config.tools: - # Simplified: just dump types.Tool instances to JSON - tools_data = [] - for tool in llm_request.config.tools: - if isinstance(tool, types.Tool): - tools_data.append(tool.model_dump()) - fingerprint_data["tools"] = tools_data - - if llm_request.config and llm_request.config.tool_config: - fingerprint_data["tool_config"] = ( - llm_request.config.tool_config.model_dump() - ) + # No existing cache metadata - return fingerprint-only metadata + # We don't create cache without previous fingerprint to match + logger.debug( + "No existing cache metadata, creating fingerprint-only metadata" + ) + total_contents_count = len(llm_request.contents) + fingerprint = self._generate_cache_fingerprint( + llm_request, total_contents_count + ) + return CacheMetadata( + fingerprint=fingerprint, + contents_count=total_contents_count, + ) + + def _find_count_of_contents_to_cache( + self, contents: list[types.Content] + ) -> int: + """Find the number of contents to cache based on user content strategy. + + Strategy: Find the last continuous batch of user contents and cache + all contents before them. + + Args: + contents: List of contents from the LLM request + + Returns: + Number of contents to cache (can be 0 if all contents are user contents) + """ + if not contents: + return 0 - # Include first N contents in fingerprint - if cache_contents_count > 0 and llm_request.contents: - contents_data = [] - for i in range(min(cache_contents_count, len(llm_request.contents))): - content = llm_request.contents[i] - contents_data.append(content.model_dump()) - fingerprint_data["cached_contents"] = contents_data - - # Generate hash using str() instead of json.dumps() to handle bytes - fingerprint_str = str(fingerprint_data) - return hashlib.sha256(fingerprint_str.encode()).hexdigest()[:16] - - async def _create_new_cache_with_contents( - self, llm_request: LlmRequest, cache_contents_count: int - ) -> Optional[CacheMetadata]: - """Create a new cache with specified number of contents. - - Args: - llm_request: Request to create cache for - cache_contents_count: Number of contents to include in cache - - Returns: - Cache metadata if successful, None otherwise - """ - # Check if we have token count from previous response for cache size validation - if llm_request.cacheable_contents_token_count is None: - logger.info( - "No previous token count available, skipping cache creation for" - " initial request" - ) - return None - - if ( - llm_request.cacheable_contents_token_count - < llm_request.cache_config.min_tokens - ): - logger.info( - "Previous request too small for caching (%d < %d tokens)", - llm_request.cacheable_contents_token_count, - llm_request.cache_config.min_tokens, - ) - return None - - try: - # Create cache using Gemini API directly - return await self._create_gemini_cache(llm_request, cache_contents_count) - except Exception as e: - logger.warning("Failed to create cache: %s", e) - return None - - def _estimate_request_tokens(self, llm_request: LlmRequest) -> int: - """Estimate token count for the request. - - This is a rough estimation based on content text length. - - Args: - llm_request: Request to estimate tokens for - - Returns: - Estimated token count - """ - total_chars = 0 - - # System instruction - if llm_request.config and llm_request.config.system_instruction: - total_chars += len(llm_request.config.system_instruction) - - # Tools - if llm_request.config and llm_request.config.tools: - for tool in llm_request.config.tools: - if isinstance(tool, types.Tool): - tool_str = json.dumps(tool.model_dump()) - total_chars += len(tool_str) - - # Contents - for content in llm_request.contents: - for part in content.parts: - if part.text: - total_chars += len(part.text) - - # Rough estimate: 4 characters per token - return total_chars // 4 - - async def _create_gemini_cache_with_optional_tracing( - self, llm_request: LlmRequest, cache_contents_count: int - ) -> CacheMetadata: - """Create cache using Gemini API. - - Args: - llm_request: Request to create cache for - cache_contents_count: Number of contents to cache - - Returns: - Cache metadata with precise creation timestamp - """ - - if not self.disable_telemetry: - from ..telemetry.tracing import tracer - - with tracer.start_as_current_span("create_cache") as span: - return await self._create_gemini_cache_body( - llm_request=llm_request, - cache_contents_count=cache_contents_count, - span=span, - ) - else: - return await self._create_gemini_cache_body( - llm_request=llm_request, cache_contents_count=cache_contents_count - ) + # Find the last continuous batch of user contents + last_user_batch_start = len(contents) - async def _create_gemini_cache_body( - self, - llm_request: LlmRequest, - cache_contents_count: int, - span: Optional[Span] = None, - ) -> CacheMetadata: - """Create cache using Gemini API. - - Args: - llm_request: Request to create cache for - cache_contents_count: Number of contents to cache - - Returns: - Cache metadata with precise creation timestamp - """ - - # Prepare cache contents (first N contents + system instruction + tools) - cache_contents = llm_request.contents[:cache_contents_count] - - cache_config = types.CreateCachedContentConfig( - contents=cache_contents, - ttl=llm_request.cache_config.ttl_string, - display_name=( - f"adk-cache-{int(time.time())}-{cache_contents_count}contents" - ), - ) + # Scan backwards to find the start of the last user content batch + for i in range(len(contents) - 1, -1, -1): + if contents[i].role == "user": + last_user_batch_start = i + else: + # Found non-user content, stop the batch + break - # Add system instruction if present - if llm_request.config and llm_request.config.system_instruction: - cache_config.system_instruction = llm_request.config.system_instruction - logger.debug( - "Added system instruction to cache config (length=%d)", - len(llm_request.config.system_instruction), - ) + # Cache all contents before the last user batch + # This ensures we always have some user content to send to the API + return last_user_batch_start - # Add tools if present - if llm_request.config and llm_request.config.tools: - cache_config.tools = llm_request.config.tools + async def _is_cache_valid(self, llm_request: LlmRequest) -> bool: + """Check if the cache from request metadata is still valid. - # Add tool config if present - if llm_request.config and llm_request.config.tool_config: - cache_config.tool_config = llm_request.config.tool_config + Validates that it's an active cache (not fingerprint-only), checks expiry, + cache intervals, and fingerprint compatibility. - if span is not None: - span.set_attribute("cache_contents_count", cache_contents_count) - span.set_attribute("model", llm_request.model) - span.set_attribute("ttl_seconds", llm_request.cache_config.ttl_seconds) + Args: + llm_request: Request containing cache metadata to validate - logger.debug( - "Creating cache with model %s and config: %s", - llm_request.model, - cache_config, - ) - cached_content = await self.genai_client.aio.caches.create( - model=llm_request.model, - config=cache_config, - ) - # Set precise creation timestamp right after cache creation - created_at = time.time() - logger.info("Cache created successfully: %s", cached_content.name) + Returns: + True if cache is valid, False otherwise + """ + cache_metadata = llm_request.cache_metadata + if not cache_metadata: + return False + + # Fingerprint-only metadata is not a valid active cache + if cache_metadata.cache_name is None: + return False + + # Check if cache has expired + if time.time() >= cache_metadata.expire_time: + logger.info("Cache expired: %s", cache_metadata.cache_name) + return False + + # Check if cache has been used for too many invocations + if ( + cache_metadata.invocations_used + > llm_request.cache_config.cache_intervals + ): + logger.info( + "Cache exceeded cache intervals: %s (%d > %d intervals)", + cache_metadata.cache_name, + cache_metadata.invocations_used, + llm_request.cache_config.cache_intervals, + ) + return False + + # Check if fingerprint matches using cached contents count + current_fingerprint = self._generate_cache_fingerprint( + llm_request, cache_metadata.contents_count + ) + if current_fingerprint != cache_metadata.fingerprint: + logger.debug("Cache content fingerprint mismatch") + return False + + return True + + def _generate_cache_fingerprint( + self, llm_request: LlmRequest, cache_contents_count: int + ) -> str: + """Generate a fingerprint for cache validation. + + Includes system instruction, tools, tool_config, and first N contents. + + Args: + llm_request: Request to generate fingerprint for + cache_contents_count: Number of contents to include in fingerprint + + Returns: + 16-character hexadecimal fingerprint representing the cached state + """ + # Create fingerprint from system instruction, tools, tool_config, and first N contents + fingerprint_data = {} + + if llm_request.config and llm_request.config.system_instruction: + fingerprint_data["system_instruction"] = ( + llm_request.config.system_instruction + ) + + if llm_request.config and llm_request.config.tools: + # Simplified: just dump types.Tool instances to JSON + tools_data = [] + for tool in llm_request.config.tools: + if isinstance(tool, types.Tool): + tools_data.append(tool.model_dump()) + fingerprint_data["tools"] = tools_data + + if llm_request.config and llm_request.config.tool_config: + fingerprint_data["tool_config"] = ( + llm_request.config.tool_config.model_dump() + ) + + # Include first N contents in fingerprint + if cache_contents_count > 0 and llm_request.contents: + contents_data = [] + for i in range(min(cache_contents_count, len(llm_request.contents))): + content = llm_request.contents[i] + contents_data.append(content.model_dump()) + fingerprint_data["cached_contents"] = contents_data + + # Generate hash using str() instead of json.dumps() to handle bytes + fingerprint_str = str(fingerprint_data) + return hashlib.sha256(fingerprint_str.encode()).hexdigest()[:16] + + async def _create_new_cache_with_contents( + self, llm_request: LlmRequest, cache_contents_count: int + ) -> Optional[CacheMetadata]: + """Create a new cache with specified number of contents. + + Args: + llm_request: Request to create cache for + cache_contents_count: Number of contents to include in cache + + Returns: + Cache metadata if successful, None otherwise + """ + # Check if we have token count from previous response for cache size validation + if llm_request.cacheable_contents_token_count is None: + logger.info( + "No previous token count available, skipping cache creation for" + " initial request" + ) + return None + + if ( + llm_request.cacheable_contents_token_count + < llm_request.cache_config.min_tokens + ): + logger.info( + "Previous request too small for caching (%d < %d tokens)", + llm_request.cacheable_contents_token_count, + llm_request.cache_config.min_tokens, + ) + return None + + try: + # Create cache using Gemini API directly + return await self._create_gemini_cache_with_optional_tracing( + llm_request, cache_contents_count + ) + except Exception as e: + logger.warning("Failed to create cache: %s", e) + return None + + def _estimate_request_tokens(self, llm_request: LlmRequest) -> int: + """Estimate token count for the request. + + This is a rough estimation based on content text length. + + Args: + llm_request: Request to estimate tokens for + + Returns: + Estimated token count + """ + total_chars = 0 + + # System instruction + if llm_request.config and llm_request.config.system_instruction: + total_chars += len(llm_request.config.system_instruction) + + # Tools + if llm_request.config and llm_request.config.tools: + for tool in llm_request.config.tools: + if isinstance(tool, types.Tool): + tool_str = json.dumps(tool.model_dump()) + total_chars += len(tool_str) + + # Contents + for content in llm_request.contents: + for part in content.parts: + if part.text: + total_chars += len(part.text) + + # Rough estimate: 4 characters per token + return total_chars // 4 + + async def _create_gemini_cache_with_optional_tracing( + self, llm_request: LlmRequest, cache_contents_count: int + ) -> CacheMetadata: + """Create cache using Gemini API. + + Args: + llm_request: Request to create cache for + cache_contents_count: Number of contents to cache + + Returns: + Cache metadata with precise creation timestamp + """ - if span is not None: - span.set_attribute("cache_name", cached_content.name) + if not self.disable_telemetry: + from ..telemetry.tracing import tracer - # Return complete cache metadata with precise timing - return CacheMetadata( - cache_name=cached_content.name, - expire_time=created_at + llm_request.cache_config.ttl_seconds, - fingerprint=self._generate_cache_fingerprint( - llm_request, cache_contents_count - ), - invocations_used=1, - contents_count=cache_contents_count, - created_at=created_at, + with tracer.start_as_current_span("create_cache") as span: + return await self._create_gemini_cache_body( + llm_request=llm_request, + cache_contents_count=cache_contents_count, + span=span, ) + else: + return await self._create_gemini_cache_body( + llm_request=llm_request, cache_contents_count=cache_contents_count + ) + + async def _create_gemini_cache_body( + self, + llm_request: LlmRequest, + cache_contents_count: int, + span: Optional[Span] = None, + ) -> CacheMetadata: + """Create cache using Gemini API. + + Args: + llm_request: Request to create cache for + cache_contents_count: Number of contents to cache + + Returns: + Cache metadata with precise creation timestamp + """ - async def cleanup_cache(self, cache_name: str) -> None: - """Clean up cache by deleting it. - - Args: - cache_name: Name of cache to delete - """ - logger.debug("Attempting to delete cache: %s", cache_name) - try: - await self.genai_client.aio.caches.delete(name=cache_name) - logger.info("Cache cleaned up: %s", cache_name) - except Exception as e: - logger.warning("Failed to cleanup cache %s: %s", cache_name, e) - - def _apply_cache_to_request( - self, - llm_request: LlmRequest, - cache_name: str, - cache_contents_count: int, - ) -> None: - """Apply cache to the request by modifying it to use cached content. - - Args: - llm_request: Request to modify - cache_name: Name of cache to use - cache_contents_count: Number of contents that are cached - """ - # Remove system instruction, tools, and tool config from request config since they're in cache - if llm_request.config: - llm_request.config.system_instruction = None - llm_request.config.tools = None - llm_request.config.tool_config = None - - # Set cached content reference - llm_request.config.cached_content = cache_name - - # Remove cached contents from the request (keep only uncached contents) - llm_request.contents = llm_request.contents[cache_contents_count:] - - def populate_cache_metadata_in_response( - self, llm_response: LlmResponse, cache_metadata: CacheMetadata - ) -> None: - """Populate cache metadata in LLM response. - - Args: - llm_response: Response to populate metadata in - cache_metadata: Cache metadata to copy into response - """ - # Create a copy of cache metadata for the response - llm_response.cache_metadata = cache_metadata.model_copy() + # Prepare cache contents (first N contents + system instruction + tools) + cache_contents = llm_request.contents[:cache_contents_count] + + cache_config = types.CreateCachedContentConfig( + contents=cache_contents, + ttl=llm_request.cache_config.ttl_string, + display_name=( + f"adk-cache-{int(time.time())}-{cache_contents_count}contents" + ), + ) + + # Add system instruction if present + if llm_request.config and llm_request.config.system_instruction: + cache_config.system_instruction = llm_request.config.system_instruction + logger.debug( + "Added system instruction to cache config (length=%d)", + len(llm_request.config.system_instruction), + ) + + # Add tools if present + if llm_request.config and llm_request.config.tools: + cache_config.tools = llm_request.config.tools + + # Add tool config if present + if llm_request.config and llm_request.config.tool_config: + cache_config.tool_config = llm_request.config.tool_config + + if span is not None: + span.set_attribute("cache_contents_count", cache_contents_count) + span.set_attribute("model", llm_request.model) + span.set_attribute("ttl_seconds", llm_request.cache_config.ttl_seconds) + + logger.debug( + "Creating cache with model %s and config: %s", + llm_request.model, + cache_config, + ) + cached_content = await self.genai_client.aio.caches.create( + model=llm_request.model, + config=cache_config, + ) + # Set precise creation timestamp right after cache creation + created_at = time.time() + logger.info("Cache created successfully: %s", cached_content.name) + + if span is not None: + span.set_attribute("cache_name", cached_content.name) + + # Return complete cache metadata with precise timing + return CacheMetadata( + cache_name=cached_content.name, + expire_time=created_at + llm_request.cache_config.ttl_seconds, + fingerprint=self._generate_cache_fingerprint( + llm_request, cache_contents_count + ), + invocations_used=1, + contents_count=cache_contents_count, + created_at=created_at, + ) + + async def cleanup_cache(self, cache_name: str) -> None: + """Clean up cache by deleting it. + + Args: + cache_name: Name of cache to delete + """ + logger.debug("Attempting to delete cache: %s", cache_name) + try: + await self.genai_client.aio.caches.delete(name=cache_name) + logger.info("Cache cleaned up: %s", cache_name) + except Exception as e: + logger.warning("Failed to cleanup cache %s: %s", cache_name, e) + + def _apply_cache_to_request( + self, + llm_request: LlmRequest, + cache_name: str, + cache_contents_count: int, + ) -> None: + """Apply cache to the request by modifying it to use cached content. + + Args: + llm_request: Request to modify + cache_name: Name of cache to use + cache_contents_count: Number of contents that are cached + """ + # Remove system instruction, tools, and tool config from request config since they're in cache + if llm_request.config: + llm_request.config.system_instruction = None + llm_request.config.tools = None + llm_request.config.tool_config = None + + # Set cached content reference + llm_request.config.cached_content = cache_name + + # Remove cached contents from the request (keep only uncached contents) + llm_request.contents = llm_request.contents[cache_contents_count:] + + def populate_cache_metadata_in_response( + self, llm_response: LlmResponse, cache_metadata: CacheMetadata + ) -> None: + """Populate cache metadata in LLM response. + + Args: + llm_response: Response to populate metadata in + cache_metadata: Cache metadata to copy into response + """ + # Create a copy of cache metadata for the response + llm_response.cache_metadata = cache_metadata.model_copy() diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index c0f5725e0b..bbb3b80dc2 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -40,9 +40,9 @@ from .llm_response import LlmResponse if TYPE_CHECKING: - from google.genai import Client + from google.genai import Client - from .llm_request import LlmRequest + from .llm_request import LlmRequest logger = logging.getLogger("google_adk." + __name__) @@ -58,42 +58,42 @@ class _ResourceExhaustedError(ClientError): - """Represents an resources exhausted error received from the Model.""" - - def __init__( - self, - client_error: ClientError, - ): - super().__init__( - code=client_error.code, - response_json=client_error.details, - response=client_error.response, - ) + """Represents an resources exhausted error received from the Model.""" + + def __init__( + self, + client_error: ClientError, + ): + super().__init__( + code=client_error.code, + response_json=client_error.details, + response=client_error.response, + ) - def __str__(self): - # We don't get override the actual message on ClientError, so we override - # this method instead. This will ensure that when the exception is - # stringified (for either publishing the exception on console or to logs) - # we put in the required details for the developer. - base_message = super().__str__() - return f"{_RESOURCE_EXHAUSTED_POSSIBLE_FIX_MESSAGE}\n\n{base_message}" + def __str__(self): + # We don't get override the actual message on ClientError, so we override + # this method instead. This will ensure that when the exception is + # stringified (for either publishing the exception on console or to logs) + # we put in the required details for the developer. + base_message = super().__str__() + return f"{_RESOURCE_EXHAUSTED_POSSIBLE_FIX_MESSAGE}\n\n{base_message}" class Gemini(BaseLlm): - """Integration for Gemini models. + """Integration for Gemini models. - Attributes: - model: The name of the Gemini model. - use_interactions_api: Whether to use the interactions API for model - invocation. - """ + Attributes: + model: The name of the Gemini model. + use_interactions_api: Whether to use the interactions API for model + invocation. + """ - model: str = "gemini-2.5-flash" + model: str = "gemini-2.5-flash" - speech_config: Optional[types.SpeechConfig] = None + speech_config: Optional[types.SpeechConfig] = None - use_interactions_api: bool = False - """Whether to use the interactions API for model invocation. + use_interactions_api: bool = False + """Whether to use the interactions API for model invocation. When enabled, uses the interactions API (client.aio.interactions.create()) instead of the traditional generate_content API. The interactions API @@ -110,8 +110,8 @@ class Gemini(BaseLlm): ``` """ - retry_options: Optional[types.HttpRetryOptions] = None - """Allow Gemini to retry failed responses. + retry_options: Optional[types.HttpRetryOptions] = None + """Allow Gemini to retry failed responses. Sample: ```python @@ -127,430 +127,435 @@ class Gemini(BaseLlm): ``` """ - disable_telemetry: bool = True - """A bool to flag whether or not telemetry should be being disabled for Gemini LLM interactions. + disable_telemetry: bool = False + """A bool to flag whether or not telemetry should be being disabled for Gemini LLM interactions. """ - @classmethod - @override - def supported_models(cls) -> list[str]: - """Provides the list of supported models. - - Returns: - A list of supported models. - """ - - return [ - r"gemini-.*", - # model optimizer pattern - r"model-optimizer-.*", - # fine-tuned vertex endpoint pattern - r"projects\/.+\/locations\/.+\/endpoints\/.+", - # vertex gemini long name - r"projects\/.+\/locations\/.+\/publishers\/google\/models\/gemini.+", - ] - - async def generate_content_async( - self, llm_request: LlmRequest, stream: bool = False - ) -> AsyncGenerator[LlmResponse, None]: - """Sends a request to the Gemini model. - - Args: - llm_request: LlmRequest, the request to send to the Gemini model. - stream: bool = False, whether to do streaming call. - - Yields: - LlmResponse: The model response. - """ - await self._preprocess_request(llm_request) - self._maybe_append_user_content(llm_request) - - # Handle context caching if configured - cache_metadata = None - cache_manager = None - if llm_request.cache_config: - from .gemini_context_cache_manager import GeminiContextCacheManager - - if not self.disable_telemetry: - from ..telemetry.tracing import tracer - - with tracer.start_as_current_span("handle_context_caching") as span: - cache_manager = GeminiContextCacheManager( - self.api_client, disable_telemetry=self.disable_telemetry - ) - cache_metadata = await cache_manager.handle_context_caching( - llm_request - ) - if cache_metadata: - if cache_metadata.cache_name: - span.set_attribute("cache_action", "active_cache") - span.set_attribute("cache_name", cache_metadata.cache_name) - else: - span.set_attribute("cache_action", "fingerprint_only") + @classmethod + @override + def supported_models(cls) -> list[str]: + """Provides the list of supported models. + + Returns: + A list of supported models. + """ + + return [ + r"gemini-.*", + # model optimizer pattern + r"model-optimizer-.*", + # fine-tuned vertex endpoint pattern + r"projects\/.+\/locations\/.+\/endpoints\/.+", + # vertex gemini long name + r"projects\/.+\/locations\/.+\/publishers\/google\/models\/gemini.+", + ] + + async def generate_content_async( + self, llm_request: LlmRequest, stream: bool = False + ) -> AsyncGenerator[LlmResponse, None]: + """Sends a request to the Gemini model. + + Args: + llm_request: LlmRequest, the request to send to the Gemini model. + stream: bool = False, whether to do streaming call. + + Yields: + LlmResponse: The model response. + """ + await self._preprocess_request(llm_request) + self._maybe_append_user_content(llm_request) + + # Handle context caching if configured + cache_metadata = None + cache_manager = None + if llm_request.cache_config: + from .gemini_context_cache_manager import GeminiContextCacheManager + + if not self.disable_telemetry: + from ..telemetry.tracing import tracer + + with tracer.start_as_current_span("handle_context_caching") as span: + cache_manager = GeminiContextCacheManager( + self.api_client, disable_telemetry=self.disable_telemetry + ) + cache_metadata = await cache_manager.handle_context_caching( + llm_request + ) + if cache_metadata: + if cache_metadata.cache_name: + span.set_attribute("cache_action", "active_cache") + span.set_attribute("cache_name", cache_metadata.cache_name) else: - cache_manager = GeminiContextCacheManager( - self.api_client, disable_telemetry=self.disable_telemetry - ) - cache_metadata = await cache_manager.handle_context_caching(llm_request) - - logger.info( - "Sending out request, model: %s, backend: %s, stream: %s", - llm_request.model, - self._api_backend, - stream, + span.set_attribute("cache_action", "fingerprint_only") + else: + cache_manager = GeminiContextCacheManager( + self.api_client, disable_telemetry=self.disable_telemetry ) + cache_metadata = await cache_manager.handle_context_caching(llm_request) - # Always add tracking headers to custom headers given it will override - # the headers set in the api client constructor to avoid tracking headers - # being dropped if user provides custom headers or overrides the api client. - if llm_request.config: - if not llm_request.config.http_options: - llm_request.config.http_options = types.HttpOptions() - llm_request.config.http_options.headers = self._merge_tracking_headers( - llm_request.config.http_options.headers - ) + logger.info( + "Sending out request, model: %s, backend: %s, stream: %s", + llm_request.model, + self._api_backend, + stream, + ) - try: - # Use interactions API if enabled - if self.use_interactions_api: - async for llm_response in self._generate_content_via_interactions( - llm_request, stream - ): - yield llm_response - return - - logger.debug(_build_request_log(llm_request)) - - if stream: - responses = await self.api_client.aio.models.generate_content_stream( - model=llm_request.model, - contents=llm_request.contents, - config=llm_request.config, - ) - - # for sse, similar as bidi (see receive method in - # gemini_llm_connection.py), we need to mark those text content as - # partial and after all partial contents are sent, we send an - # accumulated event which contains all the previous partial content. The - # only difference is bidi rely on complete_turn flag to detect end while - # sse depends on finish_reason. - aggregator = StreamingResponseAggregator() - async with Aclosing(responses) as agen: - async for response in agen: - logger.debug(_build_response_log(response)) - async with Aclosing( - aggregator.process_response(response) - ) as aggregator_gen: - async for llm_response in aggregator_gen: - yield llm_response - if (close_result := aggregator.close()) is not None: - # Populate cache metadata in the final aggregated response for - # streaming - if cache_metadata: - cache_manager.populate_cache_metadata_in_response( - close_result, cache_metadata - ) - yield close_result + # Always add tracking headers to custom headers given it will override + # the headers set in the api client constructor to avoid tracking headers + # being dropped if user provides custom headers or overrides the api client. + if llm_request.config: + if not llm_request.config.http_options: + llm_request.config.http_options = types.HttpOptions() + llm_request.config.http_options.headers = self._merge_tracking_headers( + llm_request.config.http_options.headers + ) - else: - response = await self.api_client.aio.models.generate_content( - model=llm_request.model, - contents=llm_request.contents, - config=llm_request.config, - ) - logger.info("Response received from the model.") - logger.debug(_build_response_log(response)) - - llm_response = LlmResponse.create(response) - if cache_metadata: - cache_manager.populate_cache_metadata_in_response( - llm_response, cache_metadata - ) - yield llm_response - except ClientError as ce: - if ce.code == 429: - # We expect running into a Resource Exhausted error to be a common - # client error that developers would run into. We enhance the messaging - # with possible fixes to this issue. - raise _ResourceExhaustedError(ce) from ce - - raise ce - - async def _generate_content_via_interactions( - self, - llm_request: LlmRequest, - stream: bool, - ) -> AsyncGenerator[LlmResponse, None]: - """Generate content using the interactions API. - - The interactions API provides stateful conversation capabilities. When - previous_interaction_id is set in the request, the API chains interactions - instead of requiring full conversation history. - - Note: Context caching is not used with the Interactions API since it - maintains conversation state via previous_interaction_id. - - Args: - llm_request: The LLM request to send. - stream: Whether to stream the response. - - Yields: - LlmResponse objects converted from interaction responses. - """ - from .interactions_utils import generate_content_via_interactions - - async for llm_response in generate_content_via_interactions( - api_client=self.api_client, - llm_request=llm_request, - stream=stream, + try: + # Use interactions API if enabled + if self.use_interactions_api: + async for llm_response in self._generate_content_via_interactions( + llm_request, stream ): - yield llm_response + yield llm_response + return - @cached_property - def api_client(self) -> Client: - """Provides the api client. + logger.debug(_build_request_log(llm_request)) - Returns: - The api client. - """ - from google.genai import Client + if stream: + responses = await self.api_client.aio.models.generate_content_stream( + model=llm_request.model, + contents=llm_request.contents, + config=llm_request.config, + ) - return Client( - http_options=types.HttpOptions( - headers=self._tracking_headers(), - retry_options=self.retry_options, + # for sse, similar as bidi (see receive method in + # gemini_llm_connection.py), we need to mark those text content as + # partial and after all partial contents are sent, we send an + # accumulated event which contains all the previous partial content. The + # only difference is bidi rely on complete_turn flag to detect end while + # sse depends on finish_reason. + aggregator = StreamingResponseAggregator() + async with Aclosing(responses) as agen: + async for response in agen: + logger.debug(_build_response_log(response)) + async with Aclosing( + aggregator.process_response(response) + ) as aggregator_gen: + async for llm_response in aggregator_gen: + yield llm_response + if (close_result := aggregator.close()) is not None: + # Populate cache metadata in the final aggregated response for + # streaming + if cache_metadata: + cache_manager.populate_cache_metadata_in_response( + close_result, cache_metadata ) - ) + yield close_result - @cached_property - def _api_backend(self) -> GoogleLLMVariant: - return ( - GoogleLLMVariant.VERTEX_AI - if self.api_client.vertexai - else GoogleLLMVariant.GEMINI_API + else: + response = await self.api_client.aio.models.generate_content( + model=llm_request.model, + contents=llm_request.contents, + config=llm_request.config, ) + logger.info("Response received from the model.") + logger.debug(_build_response_log(response)) + + llm_response = LlmResponse.create(response) + if cache_metadata: + cache_manager.populate_cache_metadata_in_response( + llm_response, cache_metadata + ) + yield llm_response + except ClientError as ce: + if ce.code == 429: + # We expect running into a Resource Exhausted error to be a common + # client error that developers would run into. We enhance the messaging + # with possible fixes to this issue. + raise _ResourceExhaustedError(ce) from ce + + raise ce + + async def _generate_content_via_interactions( + self, + llm_request: LlmRequest, + stream: bool, + ) -> AsyncGenerator[LlmResponse, None]: + """Generate content using the interactions API. + + The interactions API provides stateful conversation capabilities. When + previous_interaction_id is set in the request, the API chains interactions + instead of requiring full conversation history. + + Note: Context caching is not used with the Interactions API since it + maintains conversation state via previous_interaction_id. + + Args: + llm_request: The LLM request to send. + stream: Whether to stream the response. + + Yields: + LlmResponse objects converted from interaction responses. + """ + from .interactions_utils import generate_content_via_interactions - def _tracking_headers(self) -> dict[str, str]: - labels = get_client_labels() - header_value = " ".join(labels) - tracking_headers = { - "x-goog-api-client": header_value, - "user-agent": header_value, - } - return tracking_headers - - @cached_property - def _live_api_version(self) -> str: - if self._api_backend == GoogleLLMVariant.VERTEX_AI: - # use beta version for vertex api - return "v1beta1" - else: - # use v1alpha for using API KEY from Google AI Studio - return "v1alpha" - - @cached_property - def _live_api_client(self) -> Client: - from google.genai import Client - - return Client( - http_options=types.HttpOptions( - headers=self._tracking_headers(), api_version=self._live_api_version - ) + async for llm_response in generate_content_via_interactions( + api_client=self.api_client, + llm_request=llm_request, + stream=stream, + ): + yield llm_response + + @cached_property + def api_client(self) -> Client: + """Provides the api client. + + Returns: + The api client. + """ + from google.genai import Client + + return Client( + http_options=types.HttpOptions( + headers=self._tracking_headers(), + retry_options=self.retry_options, ) + ) - @contextlib.asynccontextmanager - async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection: - """Connects to the Gemini model and returns an llm connection. - - Args: - llm_request: LlmRequest, the request to send to the Gemini model. - - Yields: - BaseLlmConnection, the connection to the Gemini model. - """ - # add tracking headers to custom headers and set api_version given - # the customized http options will override the one set in the api client - # constructor - if ( - llm_request.live_connect_config - and llm_request.live_connect_config.http_options - ): - if not llm_request.live_connect_config.http_options.headers: - llm_request.live_connect_config.http_options.headers = {} - llm_request.live_connect_config.http_options.headers.update( - self._tracking_headers() - ) - llm_request.live_connect_config.http_options.api_version = ( - self._live_api_version - ) + @cached_property + def _api_backend(self) -> GoogleLLMVariant: + return ( + GoogleLLMVariant.VERTEX_AI + if self.api_client.vertexai + else GoogleLLMVariant.GEMINI_API + ) - if self.speech_config is not None: - llm_request.live_connect_config.speech_config = self.speech_config + def _tracking_headers(self) -> dict[str, str]: + labels = get_client_labels() + header_value = " ".join(labels) + tracking_headers = { + "x-goog-api-client": header_value, + "user-agent": header_value, + } + return tracking_headers + + @cached_property + def _live_api_version(self) -> str: + if self._api_backend == GoogleLLMVariant.VERTEX_AI: + # use beta version for vertex api + return "v1beta1" + else: + # use v1alpha for using API KEY from Google AI Studio + return "v1alpha" + + @cached_property + def _live_api_client(self) -> Client: + from google.genai import Client - llm_request.live_connect_config.system_instruction = types.Content( - role="system", - parts=[types.Part.from_text(text=llm_request.config.system_instruction)], + return Client( + http_options=types.HttpOptions( + headers=self._tracking_headers(), api_version=self._live_api_version ) - if ( - llm_request.live_connect_config.session_resumption - and llm_request.live_connect_config.session_resumption.transparent - ): - logger.debug( - "session resumption config: %s", - llm_request.live_connect_config.session_resumption, - ) - logger.debug( - "self._api_backend: %s", - self._api_backend, - ) - if self._api_backend == GoogleLLMVariant.GEMINI_API: - raise ValueError( - "Transparent session resumption is only supported for Vertex AI" - " backend. Please use Vertex AI backend." - ) - llm_request.live_connect_config.tools = llm_request.config.tools - logger.info("Connecting to live for model: %s", llm_request.model) - logger.debug("Connecting to live with llm_request:%s", llm_request) - logger.debug("Live connect config: %s", llm_request.live_connect_config) - async with self._live_api_client.aio.live.connect( - model=llm_request.model, config=llm_request.live_connect_config - ) as live_session: - yield GeminiLlmConnection(live_session, api_backend=self._api_backend) - - async def _adapt_computer_use_tool(self, llm_request: LlmRequest) -> None: - """Adapt the google computer use predefined functions to the adk computer use toolset.""" - - from ..tools.computer_use.computer_use_toolset import ComputerUseToolset - - async def convert_wait_to_wait_5_seconds(wait_func): - async def wait_5_seconds(): - return await wait_func(5) - - return wait_5_seconds - - await ComputerUseToolset.adapt_computer_use_tool( - "wait", convert_wait_to_wait_5_seconds, llm_request + ) + + @contextlib.asynccontextmanager + async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection: + """Connects to the Gemini model and returns an llm connection. + + Args: + llm_request: LlmRequest, the request to send to the Gemini model. + + Yields: + BaseLlmConnection, the connection to the Gemini model. + """ + # add tracking headers to custom headers and set api_version given + # the customized http options will override the one set in the api client + # constructor + if ( + llm_request.live_connect_config + and llm_request.live_connect_config.http_options + ): + if not llm_request.live_connect_config.http_options.headers: + llm_request.live_connect_config.http_options.headers = {} + llm_request.live_connect_config.http_options.headers.update( + self._tracking_headers() + ) + llm_request.live_connect_config.http_options.api_version = ( + self._live_api_version + ) + + if self.speech_config is not None: + llm_request.live_connect_config.speech_config = self.speech_config + + llm_request.live_connect_config.system_instruction = types.Content( + role="system", + parts=[ + types.Part.from_text(text=llm_request.config.system_instruction) + ], + ) + if ( + llm_request.live_connect_config.session_resumption + and llm_request.live_connect_config.session_resumption.transparent + ): + logger.debug( + "session resumption config: %s", + llm_request.live_connect_config.session_resumption, + ) + logger.debug( + "self._api_backend: %s", + self._api_backend, + ) + if self._api_backend == GoogleLLMVariant.GEMINI_API: + raise ValueError( + "Transparent session resumption is only supported for Vertex AI" + " backend. Please use Vertex AI backend." ) + llm_request.live_connect_config.tools = llm_request.config.tools + logger.info("Connecting to live for model: %s", llm_request.model) + logger.debug("Connecting to live with llm_request:%s", llm_request) + logger.debug("Live connect config: %s", llm_request.live_connect_config) + async with self._live_api_client.aio.live.connect( + model=llm_request.model, config=llm_request.live_connect_config + ) as live_session: + yield GeminiLlmConnection(live_session, api_backend=self._api_backend) + + async def _adapt_computer_use_tool(self, llm_request: LlmRequest) -> None: + """Adapt the google computer use predefined functions to the adk computer use toolset.""" - async def _preprocess_request(self, llm_request: LlmRequest) -> None: - - if self._api_backend == GoogleLLMVariant.GEMINI_API: - # Using API key from Google AI Studio to call model doesn't support labels. - if llm_request.config: - llm_request.config.labels = None - - if llm_request.contents: - for content in llm_request.contents: - if not content.parts: - continue - for part in content.parts: - # Create copies to avoid mutating the original objects - if part.inline_data: - part.inline_data = copy.copy(part.inline_data) - _remove_display_name_if_present(part.inline_data) - if part.file_data: - part.file_data = copy.copy(part.file_data) - _remove_display_name_if_present(part.file_data) - - # Initialize config if needed - if llm_request.config and llm_request.config.tools: - # Check if computer use is configured - for tool in llm_request.config.tools: - if isinstance(tool, types.Tool) and tool.computer_use: - llm_request.config.system_instruction = None - await self._adapt_computer_use_tool(llm_request) - - def _merge_tracking_headers(self, headers: dict[str, str]) -> dict[str, str]: - """Merge tracking headers to the given headers.""" - headers = headers or {} - for key, tracking_header_value in self._tracking_headers().items(): - custom_value = headers.get(key, None) - if not custom_value: - headers[key] = tracking_header_value - continue - - # Merge tracking headers with existing headers and avoid duplicates. - value_parts = tracking_header_value.split(" ") - for custom_value_part in custom_value.split(" "): - if custom_value_part not in value_parts: - value_parts.append(custom_value_part) - headers[key] = " ".join(value_parts) - return headers + from ..tools.computer_use.computer_use_toolset import ComputerUseToolset + + async def convert_wait_to_wait_5_seconds(wait_func): + async def wait_5_seconds(): + return await wait_func(5) + + return wait_5_seconds + + await ComputerUseToolset.adapt_computer_use_tool( + "wait", convert_wait_to_wait_5_seconds, llm_request + ) + + async def _preprocess_request(self, llm_request: LlmRequest) -> None: + + if self._api_backend == GoogleLLMVariant.GEMINI_API: + # Using API key from Google AI Studio to call model doesn't support labels. + if llm_request.config: + llm_request.config.labels = None + + if llm_request.contents: + for content in llm_request.contents: + if not content.parts: + continue + for part in content.parts: + # Create copies to avoid mutating the original objects + if part.inline_data: + part.inline_data = copy.copy(part.inline_data) + _remove_display_name_if_present(part.inline_data) + if part.file_data: + part.file_data = copy.copy(part.file_data) + _remove_display_name_if_present(part.file_data) + + # Initialize config if needed + if llm_request.config and llm_request.config.tools: + # Check if computer use is configured + for tool in llm_request.config.tools: + if isinstance(tool, types.Tool) and tool.computer_use: + llm_request.config.system_instruction = None + await self._adapt_computer_use_tool(llm_request) + + def _merge_tracking_headers(self, headers: dict[str, str]) -> dict[str, str]: + """Merge tracking headers to the given headers.""" + headers = headers or {} + for key, tracking_header_value in self._tracking_headers().items(): + custom_value = headers.get(key, None) + if not custom_value: + headers[key] = tracking_header_value + continue + + # Merge tracking headers with existing headers and avoid duplicates. + value_parts = tracking_header_value.split(" ") + for custom_value_part in custom_value.split(" "): + if custom_value_part not in value_parts: + value_parts.append(custom_value_part) + headers[key] = " ".join(value_parts) + return headers def _build_function_declaration_log( func_decl: types.FunctionDeclaration, ) -> str: - param_str = "{}" - if func_decl.parameters and func_decl.parameters.properties: - param_str = str( - { - k: v.model_dump(exclude_none=True) - for k, v in func_decl.parameters.properties.items() - } - ) - elif func_decl.parameters_json_schema: - param_str = str(func_decl.parameters_json_schema) + param_str = "{}" + if func_decl.parameters and func_decl.parameters.properties: + param_str = str({ + k: v.model_dump(exclude_none=True) + for k, v in func_decl.parameters.properties.items() + }) + elif func_decl.parameters_json_schema: + param_str = str(func_decl.parameters_json_schema) - return_str = "" - if func_decl.response: - return_str = "-> " + str(func_decl.response.model_dump(exclude_none=True)) - elif func_decl.response_json_schema: - return_str = "-> " + str(func_decl.response_json_schema) + return_str = "" + if func_decl.response: + return_str = "-> " + str(func_decl.response.model_dump(exclude_none=True)) + elif func_decl.response_json_schema: + return_str = "-> " + str(func_decl.response_json_schema) - return f"{func_decl.name}: {param_str} {return_str}" + return f"{func_decl.name}: {param_str} {return_str}" def _build_request_log(req: LlmRequest) -> str: - # Find which tool contains function_declarations - function_decls: list[types.FunctionDeclaration] = [] - function_decl_tool_index: Optional[int] = None - - if req.config.tools: - for idx, tool in enumerate(req.config.tools): - if tool.function_declarations: - function_decls = cast( - list[types.FunctionDeclaration], tool.function_declarations - ) - function_decl_tool_index = idx - break - - function_logs = ( - [_build_function_declaration_log(func_decl) for func_decl in function_decls] - if function_decls - else [] - ) - contents_logs = [ - content.model_dump_json( + # Find which tool contains function_declarations + function_decls: list[types.FunctionDeclaration] = [] + function_decl_tool_index: Optional[int] = None + + if req.config.tools: + for idx, tool in enumerate(req.config.tools): + if tool.function_declarations: + function_decls = cast( + list[types.FunctionDeclaration], tool.function_declarations + ) + function_decl_tool_index = idx + break + + function_logs = ( + [ + _build_function_declaration_log(func_decl) + for func_decl in function_decls + ] + if function_decls + else [] + ) + contents_logs = [ + content.model_dump_json( + exclude_none=True, + exclude={ + "parts": { + i: _EXCLUDED_PART_FIELD for i in range(len(content.parts)) + } + }, + ) + for content in req.contents + ] + + # Build exclusion dict for config logging + tools_exclusion = ( + {function_decl_tool_index: {"function_declarations"}} + if function_decl_tool_index is not None + else True + ) + + try: + config_log = str( + req.config.model_dump( exclude_none=True, exclude={ - "parts": {i: _EXCLUDED_PART_FIELD for i in range(len(content.parts))} + "system_instruction": True, + "tools": tools_exclusion if req.config.tools else True, }, ) - for content in req.contents - ] - - # Build exclusion dict for config logging - tools_exclusion = ( - {function_decl_tool_index: {"function_declarations"}} - if function_decl_tool_index is not None - else True ) + except Exception: + config_log = repr(req.config) - try: - config_log = str( - req.config.model_dump( - exclude_none=True, - exclude={ - "system_instruction": True, - "tools": tools_exclusion if req.config.tools else True, - }, - ) - ) - except Exception: - config_log = repr(req.config) - - return f""" + return f""" LLM Request: ----------------------------------------------------------- System Instruction: @@ -569,13 +574,13 @@ def _build_request_log(req: LlmRequest) -> str: def _build_response_log(resp: types.GenerateContentResponse) -> str: - function_calls_text = [] - if function_calls := resp.function_calls: - for func_call in function_calls: - function_calls_text.append( - f"name: {func_call.name}, args: {func_call.args}" - ) - return f""" + function_calls_text = [] + if function_calls := resp.function_calls: + for func_call in function_calls: + function_calls_text.append( + f"name: {func_call.name}, args: {func_call.args}" + ) + return f""" LLM Response: ----------------------------------------------------------- Text: @@ -593,10 +598,10 @@ def _build_response_log(resp: types.GenerateContentResponse) -> str: def _remove_display_name_if_present( data_obj: Union[types.Blob, types.FileData, None], ): - """Sets display_name to None for the Gemini API (non-Vertex) backend. + """Sets display_name to None for the Gemini API (non-Vertex) backend. - This backend does not support the display_name parameter for file uploads, - so it must be removed to prevent request failures. - """ - if data_obj and data_obj.display_name: - data_obj.display_name = None + This backend does not support the display_name parameter for file uploads, + so it must be removed to prevent request failures. + """ + if data_obj and data_obj.display_name: + data_obj.display_name = None diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 2ad9594ebd..3dd455cfc9 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -69,1475 +69,1480 @@ def _is_tool_call_or_response(event: Event) -> bool: - return bool(event.get_function_calls() or event.get_function_responses()) + return bool(event.get_function_calls() or event.get_function_responses()) def _is_transcription(event: Event) -> bool: - return ( - event.input_transcription is not None or event.output_transcription is not None - ) + return ( + event.input_transcription is not None + or event.output_transcription is not None + ) def _has_non_empty_transcription_text(transcription) -> bool: - return bool(transcription and transcription.text and transcription.text.strip()) + return bool( + transcription and transcription.text and transcription.text.strip() + ) class Runner: - """The Runner class is used to run agents. - - It manages the execution of an agent within a session, handling message - processing, event generation, and interaction with various services like - artifact storage, session management, and memory. - - Attributes: - app_name: The application name of the runner. - agent: The root agent to run. + """The Runner class is used to run agents. + + It manages the execution of an agent within a session, handling message + processing, event generation, and interaction with various services like + artifact storage, session management, and memory. + + Attributes: + app_name: The application name of the runner. + agent: The root agent to run. + artifact_service: The artifact service for the runner. + plugin_manager: The plugin manager for the runner. + session_service: The session service for the runner. + memory_service: The memory service for the runner. + credential_service: The credential service for the runner. + context_cache_config: The context cache config for the runner. + resumability_config: The resumability config for the application. + """ + + app_name: str + """The app name of the runner.""" + agent: BaseAgent + """The root agent to run.""" + artifact_service: Optional[BaseArtifactService] = None + """The artifact service for the runner.""" + plugin_manager: PluginManager + """The plugin manager for the runner.""" + session_service: BaseSessionService + """The session service for the runner.""" + memory_service: Optional[BaseMemoryService] = None + """The memory service for the runner.""" + credential_service: Optional[BaseCredentialService] = None + """The credential service for the runner.""" + context_cache_config: Optional[ContextCacheConfig] = None + """The context cache config for the runner.""" + resumability_config: Optional[ResumabilityConfig] = None + """The resumability config for the application.""" + + def __init__( + self, + *, + app: Optional[App] = None, + app_name: Optional[str] = None, + agent: Optional[BaseAgent] = None, + plugins: Optional[List[BasePlugin]] = None, + artifact_service: Optional[BaseArtifactService] = None, + session_service: BaseSessionService, + memory_service: Optional[BaseMemoryService] = None, + credential_service: Optional[BaseCredentialService] = None, + plugin_close_timeout: float = 5.0, + ): + """Initializes the Runner. + + Developers should provide either an `app` instance or both `app_name` and + `agent`. Providing a mix of `app` and `app_name`/`agent` will result in a + `ValueError`. Providing `app` is the recommended way to create a runner. + + Args: + app: An optional `App` instance. If provided, `app_name` and `agent` + should not be specified. + app_name: The application name of the runner. Required if `app` is not + provided. + agent: The root agent to run. Required if `app` is not provided. + plugins: Deprecated. A list of plugins for the runner. Please use the + `app` argument to provide plugins instead. artifact_service: The artifact service for the runner. - plugin_manager: The plugin manager for the runner. session_service: The session service for the runner. memory_service: The memory service for the runner. credential_service: The credential service for the runner. - context_cache_config: The context cache config for the runner. - resumability_config: The resumability config for the application. + plugin_close_timeout: The timeout in seconds for plugin close methods. + + Raises: + ValueError: If `app` is provided along with `app_name` or `plugins`, or + if `app` is not provided but either `app_name` or `agent` is missing. """ + self.app = app + ( + self.app_name, + self.agent, + self.context_cache_config, + self.resumability_config, + plugins, + ) = self._validate_runner_params(app, app_name, agent, plugins) + self.artifact_service = artifact_service + self.session_service = session_service + self.memory_service = memory_service + self.credential_service = credential_service + self.plugin_manager = PluginManager( + plugins=plugins, close_timeout=plugin_close_timeout + ) + ( + self._agent_origin_app_name, + self._agent_origin_dir, + ) = self._infer_agent_origin(self.agent) + self._app_name_alignment_hint: Optional[str] = None + self._enforce_app_name_alignment() + + def _validate_runner_params( + self, + app: Optional[App], + app_name: Optional[str], + agent: Optional[BaseAgent], + plugins: Optional[List[BasePlugin]], + ) -> tuple[ + str, + BaseAgent, + Optional[ContextCacheConfig], + Optional[ResumabilityConfig], + Optional[List[BasePlugin]], + ]: + """Validates and extracts runner parameters. + + Args: + app: An optional `App` instance. + app_name: The application name of the runner. + agent: The root agent to run. + plugins: A list of plugins for the runner. - app_name: str - """The app name of the runner.""" - agent: BaseAgent - """The root agent to run.""" - artifact_service: Optional[BaseArtifactService] = None - """The artifact service for the runner.""" - plugin_manager: PluginManager - """The plugin manager for the runner.""" - session_service: BaseSessionService - """The session service for the runner.""" - memory_service: Optional[BaseMemoryService] = None - """The memory service for the runner.""" - credential_service: Optional[BaseCredentialService] = None - """The credential service for the runner.""" - context_cache_config: Optional[ContextCacheConfig] = None - """The context cache config for the runner.""" - resumability_config: Optional[ResumabilityConfig] = None - """The resumability config for the application.""" - - def __init__( - self, - *, - app: Optional[App] = None, - app_name: Optional[str] = None, - agent: Optional[BaseAgent] = None, - plugins: Optional[List[BasePlugin]] = None, - artifact_service: Optional[BaseArtifactService] = None, - session_service: BaseSessionService, - memory_service: Optional[BaseMemoryService] = None, - credential_service: Optional[BaseCredentialService] = None, - plugin_close_timeout: float = 5.0, - ): - """Initializes the Runner. - - Developers should provide either an `app` instance or both `app_name` and - `agent`. Providing a mix of `app` and `app_name`/`agent` will result in a - `ValueError`. Providing `app` is the recommended way to create a runner. - - Args: - app: An optional `App` instance. If provided, `app_name` and `agent` - should not be specified. - app_name: The application name of the runner. Required if `app` is not - provided. - agent: The root agent to run. Required if `app` is not provided. - plugins: Deprecated. A list of plugins for the runner. Please use the - `app` argument to provide plugins instead. - artifact_service: The artifact service for the runner. - session_service: The session service for the runner. - memory_service: The memory service for the runner. - credential_service: The credential service for the runner. - plugin_close_timeout: The timeout in seconds for plugin close methods. - - Raises: - ValueError: If `app` is provided along with `app_name` or `plugins`, or - if `app` is not provided but either `app_name` or `agent` is missing. - """ - self.app = app - ( - self.app_name, - self.agent, - self.context_cache_config, - self.resumability_config, - plugins, - ) = self._validate_runner_params(app, app_name, agent, plugins) - self.artifact_service = artifact_service - self.session_service = session_service - self.memory_service = memory_service - self.credential_service = credential_service - self.plugin_manager = PluginManager( - plugins=plugins, close_timeout=plugin_close_timeout - ) - ( - self._agent_origin_app_name, - self._agent_origin_dir, - ) = self._infer_agent_origin(self.agent) - self._app_name_alignment_hint: Optional[str] = None - self._enforce_app_name_alignment() - - def _validate_runner_params( - self, - app: Optional[App], - app_name: Optional[str], - agent: Optional[BaseAgent], - plugins: Optional[List[BasePlugin]], - ) -> tuple[ - str, - BaseAgent, - Optional[ContextCacheConfig], - Optional[ResumabilityConfig], - Optional[List[BasePlugin]], - ]: - """Validates and extracts runner parameters. - - Args: - app: An optional `App` instance. - app_name: The application name of the runner. - agent: The root agent to run. - plugins: A list of plugins for the runner. - - Returns: - A tuple containing (app_name, agent, context_cache_config, - resumability_config, plugins). - - Raises: - ValueError: If parameters are invalid. - """ - if plugins is not None: - warnings.warn( - "The `plugins` argument is deprecated. Please use the `app` argument" - " to provide plugins instead.", - DeprecationWarning, - ) + Returns: + A tuple containing (app_name, agent, context_cache_config, + resumability_config, plugins). - if app: - if app_name: - raise ValueError( - "When app is provided, app_name should not be provided." - ) - if agent: - raise ValueError("When app is provided, agent should not be provided.") - if plugins: - raise ValueError( - "When app is provided, plugins should not be provided and should be" - " provided in the app instead." - ) - app_name = app.name - agent = app.root_agent - plugins = app.plugins - context_cache_config = app.context_cache_config - resumability_config = app.resumability_config - elif not app_name or not agent: - raise ValueError("Either app or both app_name and agent must be provided.") - else: - context_cache_config = None - resumability_config = None - - return app_name, agent, context_cache_config, resumability_config, plugins - - def _infer_agent_origin( - self, agent: BaseAgent - ) -> tuple[Optional[str], Optional[Path]]: - """Infer the origin app name and directory from an agent's module location. - - Returns: - A tuple of (origin_app_name, origin_path): - - origin_app_name: The inferred app name (directory name containing the - agent), or None if inference is not possible/applicable. - - origin_path: The directory path where the agent is defined, or None - if the path cannot be determined. - - Both values are None when: - - The agent has no associated module - - The agent is defined in google.adk.* (ADK internal modules) - - The module has no __file__ attribute - """ - # First, check for metadata set by AgentLoader (most reliable source). - # AgentLoader sets these attributes when loading agents. - origin_app_name = getattr(agent, "_adk_origin_app_name", None) - origin_path = getattr(agent, "_adk_origin_path", None) - if origin_app_name is not None and origin_path is not None: - return origin_app_name, origin_path - - # Fall back to heuristic inference for programmatic usage. - module = inspect.getmodule(agent.__class__) - if not module: - return None, None - - # Skip ADK internal modules. When users instantiate LlmAgent directly - # (not subclassed), inspect.getmodule() returns the ADK module. This - # could falsely match 'agents' in 'google/adk/agents/' path. - if module.__name__.startswith("google.adk."): - return None, None - - module_file = getattr(module, "__file__", None) - if not module_file: - return None, None - module_path = Path(module_file).resolve() - project_root = Path.cwd() - try: - relative_path = module_path.relative_to(project_root) - except ValueError: - return None, module_path.parent - origin_dir = module_path.parent - if "agents" not in relative_path.parts: - return None, origin_dir - origin_name = origin_dir.name - if origin_name.startswith("."): - return None, origin_dir - return origin_name, origin_dir - - def _enforce_app_name_alignment(self) -> None: - origin_name = self._agent_origin_app_name - origin_dir = self._agent_origin_dir - if not origin_name or origin_name.startswith("__"): - self._app_name_alignment_hint = None - return - if origin_name == self.app_name: - self._app_name_alignment_hint = None - return - origin_location = str(origin_dir) if origin_dir else origin_name - mismatch_details = ( - "The runner is configured with app name " - f'"{self.app_name}", but the root agent was loaded from ' - f'"{origin_location}", which implies app name "{origin_name}".' - ) - resolution = ( - "Ensure the runner app_name matches that directory or pass app_name " - "explicitly when constructing the runner." + Raises: + ValueError: If parameters are invalid. + """ + if plugins is not None: + warnings.warn( + "The `plugins` argument is deprecated. Please use the `app` argument" + " to provide plugins instead.", + DeprecationWarning, + ) + + if app: + if app_name: + raise ValueError( + "When app is provided, app_name should not be provided." ) - self._app_name_alignment_hint = f"{mismatch_details} {resolution}" - logger.warning("App name mismatch detected. %s", mismatch_details) - - def _format_session_not_found_message(self, session_id: str) -> str: - message = f"Session not found: {session_id}" - if not self._app_name_alignment_hint: - return message - return ( - f"{message}. {self._app_name_alignment_hint} " - "The mismatch prevents the runner from locating the session." + if agent: + raise ValueError("When app is provided, agent should not be provided.") + if plugins: + raise ValueError( + "When app is provided, plugins should not be provided and should be" + " provided in the app instead." ) + app_name = app.name + agent = app.root_agent + plugins = app.plugins + context_cache_config = app.context_cache_config + resumability_config = app.resumability_config + elif not app_name or not agent: + raise ValueError( + "Either app or both app_name and agent must be provided." + ) + else: + context_cache_config = None + resumability_config = None + + return app_name, agent, context_cache_config, resumability_config, plugins + + def _infer_agent_origin( + self, agent: BaseAgent + ) -> tuple[Optional[str], Optional[Path]]: + """Infer the origin app name and directory from an agent's module location. + + Returns: + A tuple of (origin_app_name, origin_path): + - origin_app_name: The inferred app name (directory name containing the + agent), or None if inference is not possible/applicable. + - origin_path: The directory path where the agent is defined, or None + if the path cannot be determined. + + Both values are None when: + - The agent has no associated module + - The agent is defined in google.adk.* (ADK internal modules) + - The module has no __file__ attribute + """ + # First, check for metadata set by AgentLoader (most reliable source). + # AgentLoader sets these attributes when loading agents. + origin_app_name = getattr(agent, "_adk_origin_app_name", None) + origin_path = getattr(agent, "_adk_origin_path", None) + if origin_app_name is not None and origin_path is not None: + return origin_app_name, origin_path + + # Fall back to heuristic inference for programmatic usage. + module = inspect.getmodule(agent.__class__) + if not module: + return None, None + + # Skip ADK internal modules. When users instantiate LlmAgent directly + # (not subclassed), inspect.getmodule() returns the ADK module. This + # could falsely match 'agents' in 'google/adk/agents/' path. + if module.__name__.startswith("google.adk."): + return None, None + + module_file = getattr(module, "__file__", None) + if not module_file: + return None, None + module_path = Path(module_file).resolve() + project_root = Path.cwd() + try: + relative_path = module_path.relative_to(project_root) + except ValueError: + return None, module_path.parent + origin_dir = module_path.parent + if "agents" not in relative_path.parts: + return None, origin_dir + origin_name = origin_dir.name + if origin_name.startswith("."): + return None, origin_dir + return origin_name, origin_dir + + def _enforce_app_name_alignment(self) -> None: + origin_name = self._agent_origin_app_name + origin_dir = self._agent_origin_dir + if not origin_name or origin_name.startswith("__"): + self._app_name_alignment_hint = None + return + if origin_name == self.app_name: + self._app_name_alignment_hint = None + return + origin_location = str(origin_dir) if origin_dir else origin_name + mismatch_details = ( + "The runner is configured with app name " + f'"{self.app_name}", but the root agent was loaded from ' + f'"{origin_location}", which implies app name "{origin_name}".' + ) + resolution = ( + "Ensure the runner app_name matches that directory or pass app_name " + "explicitly when constructing the runner." + ) + self._app_name_alignment_hint = f"{mismatch_details} {resolution}" + logger.warning("App name mismatch detected. %s", mismatch_details) - def run( - self, - *, - user_id: str, - session_id: str, - new_message: types.Content, - run_config: Optional[RunConfig] = None, - ) -> Generator[Event, None, None]: - """Runs the agent. - - NOTE: - This sync interface is only for local testing and convenience purpose. - Consider using `run_async` for production usage. - - If event compaction is enabled in the App configuration, it will be - performed after all agent events for the current invocation have been - yielded. The generator will only finish iterating after event - compaction is complete. - - Args: - user_id: The user ID of the session. - session_id: The session ID of the session. - new_message: A new message to append to the session. - run_config: The run config for the agent. - - Yields: - The events generated by the agent. - """ - run_config = run_config or RunConfig() - event_queue = queue.Queue() - - async def _invoke_run_async(): - try: - async with Aclosing( - self.run_async( - user_id=user_id, - session_id=session_id, - new_message=new_message, - run_config=run_config, - ) - ) as agen: - async for event in agen: - event_queue.put(event) - finally: - event_queue.put(None) - - def _asyncio_thread_main(): - try: - asyncio.run(_invoke_run_async()) - finally: - event_queue.put(None) - - thread = create_thread(target=_asyncio_thread_main) - thread.start() - - # consumes and re-yield the events from background thread. - while True: - event = event_queue.get() - if event is None: - break - else: - yield event - - thread.join() - - async def run_async( - self, - *, - user_id: str, - session_id: str, - invocation_id: Optional[str] = None, - new_message: Optional[types.Content] = None, - state_delta: Optional[dict[str, Any]] = None, - run_config: Optional[RunConfig] = None, - ) -> AsyncGenerator[Event, None]: - """Main entry method to run the agent in this runner. - - If event compaction is enabled in the App configuration, it will be - performed after all agent events for the current invocation have been - yielded. The async generator will only finish iterating after event - compaction is complete. However, this does not block new `run_async` - calls for subsequent user queries, which can be started concurrently. - - Args: - user_id: The user ID of the session. - session_id: The session ID of the session. - invocation_id: The invocation ID of the session, set this to resume an - interrupted invocation. - new_message: A new message to append to the session. - state_delta: Optional state changes to apply to the session. - run_config: The run config for the agent. - - Yields: - The events generated by the agent. - - Raises: - ValueError: If the session is not found; If both invocation_id and - new_message are None. - """ - run_config = run_config or RunConfig() - - if new_message and not new_message.role: - new_message.role = "user" - - async def _run_body( - new_message: Optional[types.Content] = None, - invocation_id: Optional[str] = None, - ) -> AsyncGenerator[Event, None]: - session = await self.session_service.get_session( - app_name=self.app_name, user_id=user_id, session_id=session_id - ) - if not session: - message = self._format_session_not_found_message(session_id) - raise ValueError(message) - if not invocation_id and not new_message: - raise ValueError( - "Running an agent requires either a new_message or an " - "invocation_id to resume a previous invocation. " - f"Session: {session_id}, User: {user_id}" - ) - - if invocation_id: - if ( - not self.resumability_config - or not self.resumability_config.is_resumable - ): - raise ValueError( - f"invocation_id: {invocation_id} is provided but the app is not" - " resumable." - ) - invocation_context = await self._setup_context_for_resumed_invocation( - session=session, - new_message=new_message, - invocation_id=invocation_id, - run_config=run_config, - state_delta=state_delta, - ) - if invocation_context.end_of_agents.get(invocation_context.agent.name): - # Directly return if the current agent in invocation context is - # already final. - return - else: - invocation_context = await self._setup_context_for_new_invocation( - session=session, - new_message=new_message, # new_message is not None. - run_config=run_config, - state_delta=state_delta, - ) - - async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]: - async with Aclosing(ctx.agent.run_async(ctx)) as agen: - async for event in agen: - yield event - - async with Aclosing( - self._exec_with_plugin( - invocation_context=invocation_context, - session=session, - execute_fn=execute, - is_live_call=False, - ) - ) as agen: - async for event in agen: - yield event - # Run compaction after all events are yielded from the agent. - # (We don't compact in the middle of an invocation, we only compact at - # the end of an invocation.) - if self.app and self.app.events_compaction_config: - logger.debug("Running event compactor.") - await _run_compaction_for_sliding_window( - self.app, session, self.session_service - ) + def _format_session_not_found_message(self, session_id: str) -> str: + message = f"Session not found: {session_id}" + if not self._app_name_alignment_hint: + return message + return ( + f"{message}. {self._app_name_alignment_hint} " + "The mismatch prevents the runner from locating the session." + ) - async def _run_with_optional_trace( - agent: BaseAgent, - new_message: Optional[types.Content] = None, - invocation_id: Optional[str] = None, - ) -> AsyncGenerator[Event, None]: - if is_telemetry_enabled(agent): - with tracer.start_as_current_span("invocation"): - async with Aclosing( - _run_body(new_message=new_message, invocation_id=invocation_id) - ) as agen: - async for e in agen: - yield e - else: - async with Aclosing( - _run_body(new_message=new_message, invocation_id=invocation_id) - ) as agen: - async for e in agen: - yield e + def run( + self, + *, + user_id: str, + session_id: str, + new_message: types.Content, + run_config: Optional[RunConfig] = None, + ) -> Generator[Event, None, None]: + """Runs the agent. + + NOTE: + This sync interface is only for local testing and convenience purpose. + Consider using `run_async` for production usage. + + If event compaction is enabled in the App configuration, it will be + performed after all agent events for the current invocation have been + yielded. The generator will only finish iterating after event + compaction is complete. + + Args: + user_id: The user ID of the session. + session_id: The session ID of the session. + new_message: A new message to append to the session. + run_config: The run config for the agent. + + Yields: + The events generated by the agent. + """ + run_config = run_config or RunConfig() + event_queue = queue.Queue() + async def _invoke_run_async(): + try: async with Aclosing( - _run_with_optional_trace(self.agent, new_message, invocation_id) + self.run_async( + user_id=user_id, + session_id=session_id, + new_message=new_message, + run_config=run_config, + ) ) as agen: - async for event in agen: - yield event - - async def rewind_async( - self, - *, - user_id: str, - session_id: str, - rewind_before_invocation_id: str, - ) -> None: - """Rewinds the session to before the specified invocation.""" - session = await self.session_service.get_session( - app_name=self.app_name, user_id=user_id, session_id=session_id - ) - if not session: - raise ValueError(f"Session not found: {session_id}") - - rewind_event_index = -1 - for i, event in enumerate(session.events): - if event.invocation_id == rewind_before_invocation_id: - rewind_event_index = i - break + async for event in agen: + event_queue.put(event) + finally: + event_queue.put(None) + + def _asyncio_thread_main(): + try: + asyncio.run(_invoke_run_async()) + finally: + event_queue.put(None) + + thread = create_thread(target=_asyncio_thread_main) + thread.start() + + # consumes and re-yield the events from background thread. + while True: + event = event_queue.get() + if event is None: + break + else: + yield event + + thread.join() + + async def run_async( + self, + *, + user_id: str, + session_id: str, + invocation_id: Optional[str] = None, + new_message: Optional[types.Content] = None, + state_delta: Optional[dict[str, Any]] = None, + run_config: Optional[RunConfig] = None, + ) -> AsyncGenerator[Event, None]: + """Main entry method to run the agent in this runner. + + If event compaction is enabled in the App configuration, it will be + performed after all agent events for the current invocation have been + yielded. The async generator will only finish iterating after event + compaction is complete. However, this does not block new `run_async` + calls for subsequent user queries, which can be started concurrently. + + Args: + user_id: The user ID of the session. + session_id: The session ID of the session. + invocation_id: The invocation ID of the session, set this to resume an + interrupted invocation. + new_message: A new message to append to the session. + state_delta: Optional state changes to apply to the session. + run_config: The run config for the agent. + + Yields: + The events generated by the agent. + + Raises: + ValueError: If the session is not found; If both invocation_id and + new_message are None. + """ + run_config = run_config or RunConfig() - if rewind_event_index == -1: - raise ValueError(f"Invocation ID not found: {rewind_before_invocation_id}") + if new_message and not new_message.role: + new_message.role = "user" - # Compute state delta to reverse changes - state_delta = await self._compute_state_delta_for_rewind( - session, rewind_event_index + async def _run_body( + new_message: Optional[types.Content] = None, + invocation_id: Optional[str] = None, + ) -> AsyncGenerator[Event, None]: + session = await self.session_service.get_session( + app_name=self.app_name, user_id=user_id, session_id=session_id + ) + if not session: + message = self._format_session_not_found_message(session_id) + raise ValueError(message) + if not invocation_id and not new_message: + raise ValueError( + "Running an agent requires either a new_message or an " + "invocation_id to resume a previous invocation. " + f"Session: {session_id}, User: {user_id}" ) - # Compute artifact delta to reverse changes - artifact_delta = await self._compute_artifact_delta_for_rewind( - session, rewind_event_index + if invocation_id: + if ( + not self.resumability_config + or not self.resumability_config.is_resumable + ): + raise ValueError( + f"invocation_id: {invocation_id} is provided but the app is not" + " resumable." + ) + invocation_context = await self._setup_context_for_resumed_invocation( + session=session, + new_message=new_message, + invocation_id=invocation_id, + run_config=run_config, + state_delta=state_delta, ) - - # Create rewind event - rewind_event = Event( - invocation_id=new_invocation_context_id(), - author="user", - actions=EventActions( - rewind_before_invocation_id=rewind_before_invocation_id, - state_delta=state_delta, - artifact_delta=artifact_delta, - ), + if invocation_context.end_of_agents.get(invocation_context.agent.name): + # Directly return if the current agent in invocation context is + # already final. + return + else: + invocation_context = await self._setup_context_for_new_invocation( + session=session, + new_message=new_message, # new_message is not None. + run_config=run_config, + state_delta=state_delta, ) - logger.info("Rewinding session to invocation: %s", rewind_event) - - await self.session_service.append_event(session=session, event=rewind_event) - - async def _compute_state_delta_for_rewind( - self, session: Session, rewind_event_index: int - ) -> dict[str, Any]: - """Computes the state delta to reverse changes.""" - state_at_rewind_point: dict[str, Any] = {} - for i in range(rewind_event_index): - if session.events[i].actions.state_delta: - for k, v in session.events[i].actions.state_delta.items(): - if k.startswith("app:") or k.startswith("user:"): - continue - if v is None: - state_at_rewind_point.pop(k, None) - else: - state_at_rewind_point[k] = v - - current_state = session.state - rewind_state_delta = {} - - # 1. Add/update keys in rewind_state_delta to match state_at_rewind_point. - for key, value_at_rewind in state_at_rewind_point.items(): - if key not in current_state or current_state[key] != value_at_rewind: - rewind_state_delta[key] = value_at_rewind - - # 2. Set keys to None in rewind_state_delta if they are in current_state - # but not in state_at_rewind_point. These keys were added after the - # rewind point and need to be removed. - for key in current_state: - if key.startswith("app:") or key.startswith("user:"): - continue - if key not in state_at_rewind_point: - rewind_state_delta[key] = None - - return rewind_state_delta - - async def _compute_artifact_delta_for_rewind( - self, session: Session, rewind_event_index: int - ) -> dict[str, int]: - """Computes the artifact delta to reverse changes.""" - if not self.artifact_service: - return {} - - versions_at_rewind_point: dict[str, int] = {} - for i in range(rewind_event_index): - event = session.events[i] - if event.actions.artifact_delta: - versions_at_rewind_point.update(event.actions.artifact_delta) - - current_versions: dict[str, int] = {} - for event in session.events: - if event.actions.artifact_delta: - current_versions.update(event.actions.artifact_delta) - - rewind_artifact_delta = {} - for filename, vn in current_versions.items(): - if filename.startswith("user:"): - # User artifacts are not restored on rewind. - continue - vt = versions_at_rewind_point.get(filename) - if vt == vn: - continue - - rewind_artifact_delta[filename] = vn + 1 - if vt is None: - # Artifact did not exist at rewind point. Mark it as inaccessible. - artifact = types.Part( - inline_data=types.Blob( - mime_type="application/octet-stream", data=b"" - ) - ) - else: - # Artifact version changed after rewind point. Restore to version at - # rewind point. - artifact_uri = artifact_util.get_artifact_uri( - app_name=self.app_name, - user_id=session.user_id, - session_id=session.id, - filename=filename, - version=vt, - ) - artifact = types.Part(file_data=types.FileData(file_uri=artifact_uri)) - await self.artifact_service.save_artifact( - app_name=self.app_name, - user_id=session.user_id, - session_id=session.id, - filename=filename, - artifact=artifact, - ) + async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]: + async with Aclosing(ctx.agent.run_async(ctx)) as agen: + async for event in agen: + yield event + + async with Aclosing( + self._exec_with_plugin( + invocation_context=invocation_context, + session=session, + execute_fn=execute, + is_live_call=False, + ) + ) as agen: + async for event in agen: + yield event + # Run compaction after all events are yielded from the agent. + # (We don't compact in the middle of an invocation, we only compact at + # the end of an invocation.) + if self.app and self.app.events_compaction_config: + logger.debug("Running event compactor.") + await _run_compaction_for_sliding_window( + self.app, session, self.session_service + ) - return rewind_artifact_delta - - def _should_append_event(self, event: Event, is_live_call: bool) -> bool: - """Checks if an event should be appended to the session.""" - # Don't append audio response from model in live mode to session. - # The data is appended to artifacts with a reference in file_data in the - # event. - # We should append non-partial events only.For example, non-finished(partial) - # transcription events should not be appended. - # Function call and function response events should be appended. - # Other control events should be appended. - if is_live_call and contents._is_live_model_audio_event_with_inline_data(event): - # We don't append live model audio events with inline data to avoid - # storing large blobs in the session. However, events with file_data - # (references to artifacts) should be appended. - return False - return True - - async def _exec_with_plugin( - self, - invocation_context: InvocationContext, - session: Session, - execute_fn: Callable[[InvocationContext], AsyncGenerator[Event, None]], - is_live_call: bool = False, + async def _run_with_optional_trace( + agent: BaseAgent, + new_message: Optional[types.Content] = None, + invocation_id: Optional[str] = None, ) -> AsyncGenerator[Event, None]: - """Wraps execution with plugin callbacks. - - Args: - invocation_context: The invocation context - session: The current session - execute_fn: A callable that returns an AsyncGenerator of Events - is_live_call: Whether this is a live call + if is_telemetry_enabled(agent): + with tracer.start_as_current_span("invocation"): + async with Aclosing( + _run_body(new_message=new_message, invocation_id=invocation_id) + ) as agen: + async for e in agen: + yield e + else: + async with Aclosing( + _run_body(new_message=new_message, invocation_id=invocation_id) + ) as agen: + async for e in agen: + yield e + + async with Aclosing( + _run_with_optional_trace(self.agent, new_message, invocation_id) + ) as agen: + async for event in agen: + yield event + + async def rewind_async( + self, + *, + user_id: str, + session_id: str, + rewind_before_invocation_id: str, + ) -> None: + """Rewinds the session to before the specified invocation.""" + session = await self.session_service.get_session( + app_name=self.app_name, user_id=user_id, session_id=session_id + ) + if not session: + raise ValueError(f"Session not found: {session_id}") + + rewind_event_index = -1 + for i, event in enumerate(session.events): + if event.invocation_id == rewind_before_invocation_id: + rewind_event_index = i + break + + if rewind_event_index == -1: + raise ValueError( + f"Invocation ID not found: {rewind_before_invocation_id}" + ) + + # Compute state delta to reverse changes + state_delta = await self._compute_state_delta_for_rewind( + session, rewind_event_index + ) - Yields: - Events from the execution, including any generated by plugins - """ + # Compute artifact delta to reverse changes + artifact_delta = await self._compute_artifact_delta_for_rewind( + session, rewind_event_index + ) - plugin_manager = invocation_context.plugin_manager + # Create rewind event + rewind_event = Event( + invocation_id=new_invocation_context_id(), + author="user", + actions=EventActions( + rewind_before_invocation_id=rewind_before_invocation_id, + state_delta=state_delta, + artifact_delta=artifact_delta, + ), + ) - # Step 1: Run the before_run callbacks to see if we should early exit. - early_exit_result = await plugin_manager.run_before_run_callback( - invocation_context=invocation_context - ) - if isinstance(early_exit_result, types.Content): - early_exit_event = Event( - invocation_id=invocation_context.invocation_id, - author="model", - content=early_exit_result, + logger.info("Rewinding session to invocation: %s", rewind_event) + + await self.session_service.append_event(session=session, event=rewind_event) + + async def _compute_state_delta_for_rewind( + self, session: Session, rewind_event_index: int + ) -> dict[str, Any]: + """Computes the state delta to reverse changes.""" + state_at_rewind_point: dict[str, Any] = {} + for i in range(rewind_event_index): + if session.events[i].actions.state_delta: + for k, v in session.events[i].actions.state_delta.items(): + if k.startswith("app:") or k.startswith("user:"): + continue + if v is None: + state_at_rewind_point.pop(k, None) + else: + state_at_rewind_point[k] = v + + current_state = session.state + rewind_state_delta = {} + + # 1. Add/update keys in rewind_state_delta to match state_at_rewind_point. + for key, value_at_rewind in state_at_rewind_point.items(): + if key not in current_state or current_state[key] != value_at_rewind: + rewind_state_delta[key] = value_at_rewind + + # 2. Set keys to None in rewind_state_delta if they are in current_state + # but not in state_at_rewind_point. These keys were added after the + # rewind point and need to be removed. + for key in current_state: + if key.startswith("app:") or key.startswith("user:"): + continue + if key not in state_at_rewind_point: + rewind_state_delta[key] = None + + return rewind_state_delta + + async def _compute_artifact_delta_for_rewind( + self, session: Session, rewind_event_index: int + ) -> dict[str, int]: + """Computes the artifact delta to reverse changes.""" + if not self.artifact_service: + return {} + + versions_at_rewind_point: dict[str, int] = {} + for i in range(rewind_event_index): + event = session.events[i] + if event.actions.artifact_delta: + versions_at_rewind_point.update(event.actions.artifact_delta) + + current_versions: dict[str, int] = {} + for event in session.events: + if event.actions.artifact_delta: + current_versions.update(event.actions.artifact_delta) + + rewind_artifact_delta = {} + for filename, vn in current_versions.items(): + if filename.startswith("user:"): + # User artifacts are not restored on rewind. + continue + vt = versions_at_rewind_point.get(filename) + if vt == vn: + continue + + rewind_artifact_delta[filename] = vn + 1 + if vt is None: + # Artifact did not exist at rewind point. Mark it as inaccessible. + artifact = types.Part( + inline_data=types.Blob( + mime_type="application/octet-stream", data=b"" ) - if self._should_append_event(early_exit_event, is_live_call): - await self.session_service.append_event( - session=session, - event=early_exit_event, - ) - yield early_exit_event - else: - # Step 2: Otherwise continue with normal execution - # Note for live/bidi: - # the transcription may arrive later then the action(function call - # event and thus function response event). In this case, the order of - # transcription and function call event will be wrong if we just - # append as it arrives. To address this, we should check if there is - # transcription going on. If there is transcription going on, we - # should hold on appending the function call event until the - # transcription is finished. The transcription in progress can be - # identified by checking if the transcription event is partial. When - # the next transcription event is not partial, it means the previous - # transcription is finished. Then if there is any buffered function - # call event, we should append them after this finished(non-parital) - # transcription event. - buffered_events: list[Event] = [] - is_transcribing: bool = False - - async with Aclosing(execute_fn(invocation_context)) as agen: - async for event in agen: - if is_live_call: - if event.partial and _is_transcription(event): - is_transcribing = True - if is_transcribing and _is_tool_call_or_response(event): - # only buffer function call and function response event which is - # non-partial - buffered_events.append(event) - continue - # Note for live/bidi: for audio response, it's considered as - # non-paritla event(event.partial=None) - # event.partial=False and event.partial=None are considered as - # non-partial event; event.partial=True is considered as partial - # event. - if event.partial is not True: - if _is_transcription(event) and ( - _has_non_empty_transcription_text( - event.input_transcription - ) - or _has_non_empty_transcription_text( - event.output_transcription - ) - ): - # transcription end signal, append buffered events - is_transcribing = False - logger.debug( - "Appending transcription finished event: %s", event - ) - if self._should_append_event(event, is_live_call): - await self.session_service.append_event( - session=session, event=event - ) - - for buffered_event in buffered_events: - logger.debug( - "Appending buffered event: %s", buffered_event - ) - await self.session_service.append_event( - session=session, event=buffered_event - ) - buffered_events = [] - else: - # non-transcription event or empty transcription event, for - # example, event that stores blob reference, should be appended. - if self._should_append_event(event, is_live_call): - logger.debug( - "Appending non-buffered event: %s", event - ) - await self.session_service.append_event( - session=session, event=event - ) - else: - if event.partial is not True: - await self.session_service.append_event( - session=session, event=event - ) - - # Step 3: Run the on_event callbacks to optionally modify the event. - modified_event = await plugin_manager.run_on_event_callback( - invocation_context=invocation_context, event=event - ) - yield (modified_event if modified_event else event) - - # Step 4: Run the after_run callbacks to perform global cleanup tasks or - # finalizing logs and metrics data. - # This does NOT emit any event. - await plugin_manager.run_after_run_callback( - invocation_context=invocation_context ) - - async def _append_new_message_to_session( - self, - *, - session: Session, - new_message: types.Content, - invocation_context: InvocationContext, - save_input_blobs_as_artifacts: bool = False, - state_delta: Optional[dict[str, Any]] = None, - ): - """Appends a new message to the session. - - Args: - session: The session to append the message to. - new_message: The new message to append. - invocation_context: The invocation context for the message. - save_input_blobs_as_artifacts: Whether to save input blobs as artifacts. - state_delta: Optional state changes to apply to the session. - """ - if not new_message.parts: - raise ValueError("No parts in the new_message.") - - if self.artifact_service and save_input_blobs_as_artifacts: - # Issue deprecation warning - warnings.warn( - "The 'save_input_blobs_as_artifacts' parameter is deprecated. Use" - " SaveFilesAsArtifactsPlugin instead for better control and" - " flexibility. See google.adk.plugins.SaveFilesAsArtifactsPlugin for" - " migration guidance.", - DeprecationWarning, - stacklevel=3, - ) - # The runner directly saves the artifacts (if applicable) in the - # user message and replaces the artifact data with a file name - # placeholder. - for i, part in enumerate(new_message.parts): - if part.inline_data is None: - continue - file_name = f"artifact_{invocation_context.invocation_id}_{i}" - await self.artifact_service.save_artifact( - app_name=self.app_name, - user_id=session.user_id, - session_id=session.id, - filename=file_name, - artifact=part, - ) - new_message.parts[i] = types.Part( - text=f"Uploaded file: {file_name}. It is saved into artifacts" - ) - # Appends only. We do not yield the event because it's not from the model. - if state_delta: - event = Event( - invocation_id=invocation_context.invocation_id, - author="user", - actions=EventActions(state_delta=state_delta), - content=new_message, - ) - else: - event = Event( - invocation_id=invocation_context.invocation_id, - author="user", - content=new_message, - ) - # If new_message is a function response, find the matching function call - # and use its branch as the new event's branch. - if function_call := invocation_context._find_matching_function_call(event): - event.branch = function_call.branch - - await self.session_service.append_event(session=session, event=event) - - async def run_live( - self, - *, - user_id: Optional[str] = None, - session_id: Optional[str] = None, - live_request_queue: LiveRequestQueue, - run_config: Optional[RunConfig] = None, - session: Optional[Session] = None, - ) -> AsyncGenerator[Event, None]: - """Runs the agent in live mode (experimental feature). - - The `run_live` method yields a stream of `Event` objects, but not all - yielded events are saved to the session. Here's a breakdown: - - **Events Yielded to Callers:** - * **Live Model Audio Events with Inline Data:** Events containing raw - audio `Blob` data(`inline_data`). - * **Live Model Audio Events with File Data:** Both input and ouput audio - data are aggregated into a audio file saved into artifacts. The - reference to the file is saved in the event as `file_data`. - * **Usage Metadata:** Events containing token usage. - * **Transcription Events:** Both partial and non-partial transcription - events are yielded. - * **Function Call and Response Events:** Always saved. - * **Other Control Events:** Most control events are saved. - - **Events Saved to the Session:** - * **Live Model Audio Events with File Data:** Both input and ouput audio - data are aggregated into a audio file saved into artifacts. The - reference to the file is saved as event in the `file_data` to session - if RunConfig.save_live_model_audio_to_session is True. - * **Usage Metadata Events:** Saved to the session. - * **Non-Partial Transcription Events:** Non-partial transcription events - are saved. - * **Function Call and Response Events:** Always saved. - * **Other Control Events:** Most control events are saved. - - **Events Not Saved to the Session:** - * **Live Model Audio Events with Inline Data:** Events containing raw - audio `Blob` data are **not** saved to the session. - - Args: - user_id: The user ID for the session. Required if `session` is None. - session_id: The session ID for the session. Required if `session` is - None. - live_request_queue: The queue for live requests. - run_config: The run config for the agent. - session: The session to use. This parameter is deprecated, please use - `user_id` and `session_id` instead. - - Yields: - AsyncGenerator[Event, None]: An asynchronous generator that yields - `Event` - objects as they are produced by the agent during its live execution. - - .. warning:: - This feature is **experimental** and its API or behavior may change - in future releases. - - .. NOTE:: - Either `session` or both `user_id` and `session_id` must be provided. - """ - run_config = run_config or RunConfig() - # Some native audio models requires the modality to be set. So we set it to - # AUDIO by default. - if run_config.response_modalities is None: - run_config.response_modalities = ["AUDIO"] - if session is None and (user_id is None or session_id is None): - raise ValueError( - "Either session or user_id and session_id must be provided." - ) - if session is not None: - warnings.warn( - "The `session` parameter is deprecated. Please use `user_id` and" - " `session_id` instead.", - DeprecationWarning, - stacklevel=2, - ) - if not session: - session = await self.session_service.get_session( - app_name=self.app_name, user_id=user_id, session_id=session_id - ) - if not session: - raise ValueError(f"Session not found: {session_id}") - invocation_context = self._new_invocation_context_for_live( - session, - live_request_queue=live_request_queue, - run_config=run_config, + else: + # Artifact version changed after rewind point. Restore to version at + # rewind point. + artifact_uri = artifact_util.get_artifact_uri( + app_name=self.app_name, + user_id=session.user_id, + session_id=session.id, + filename=filename, + version=vt, ) + artifact = types.Part(file_data=types.FileData(file_uri=artifact_uri)) + await self.artifact_service.save_artifact( + app_name=self.app_name, + user_id=session.user_id, + session_id=session.id, + filename=filename, + artifact=artifact, + ) + + return rewind_artifact_delta + + def _should_append_event(self, event: Event, is_live_call: bool) -> bool: + """Checks if an event should be appended to the session.""" + # Don't append audio response from model in live mode to session. + # The data is appended to artifacts with a reference in file_data in the + # event. + # We should append non-partial events only.For example, non-finished(partial) + # transcription events should not be appended. + # Function call and function response events should be appended. + # Other control events should be appended. + if is_live_call and contents._is_live_model_audio_event_with_inline_data( + event + ): + # We don't append live model audio events with inline data to avoid + # storing large blobs in the session. However, events with file_data + # (references to artifacts) should be appended. + return False + return True + + async def _exec_with_plugin( + self, + invocation_context: InvocationContext, + session: Session, + execute_fn: Callable[[InvocationContext], AsyncGenerator[Event, None]], + is_live_call: bool = False, + ) -> AsyncGenerator[Event, None]: + """Wraps execution with plugin callbacks. + + Args: + invocation_context: The invocation context + session: The current session + execute_fn: A callable that returns an AsyncGenerator of Events + is_live_call: Whether this is a live call + + Yields: + Events from the execution, including any generated by plugins + """ - root_agent = self.agent - invocation_context.agent = self._find_agent_to_run(session, root_agent) - - # Pre-processing for live streaming tools - # Inspect the tool's parameters to find if it uses LiveRequestQueue - invocation_context.active_streaming_tools = {} - # TODO(hangfei): switch to use canonical_tools. - # for shell agents, there is no tools associated with it so we should skip. - if hasattr(invocation_context.agent, "tools"): - import inspect - - for tool in invocation_context.agent.tools: - # We use `inspect.signature()` to examine the tool's underlying function (`tool.func`). - # This approach is deliberately chosen over `typing.get_type_hints()` for robustness. - # - # The Problem with `get_type_hints()`: - # `get_type_hints()` attempts to resolve forward-referenced (string-based) type - # annotations. This resolution can easily fail with a `NameError` (e.g., "Union not found") - # if the type isn't available in the scope where `get_type_hints()` is called. - # This is a common and brittle issue in framework code that inspects functions - # defined in separate user modules. - # - # Why `inspect.signature()` is Better Here: - # `inspect.signature()` does NOT resolve the annotations; it retrieves the raw - # annotation object as it was defined on the function. This allows us to - # perform a direct and reliable identity check (`param.annotation is LiveRequestQueue`) - # without risking a `NameError`. - callable_to_inspect = tool.func if hasattr(tool, "func") else tool - # Ensure the target is actually callable before inspecting to avoid errors. - if not callable(callable_to_inspect): - continue - for param in inspect.signature(callable_to_inspect).parameters.values(): - if param.annotation is LiveRequestQueue: - if not invocation_context.active_streaming_tools: - invocation_context.active_streaming_tools = {} - active_streaming_tool = ActiveStreamingTool( - stream=LiveRequestQueue() - ) - invocation_context.active_streaming_tools[tool.__name__] = ( - active_streaming_tool - ) - - async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]: - async with Aclosing(ctx.agent.run_live(ctx)) as agen: - async for event in agen: - yield event + plugin_manager = invocation_context.plugin_manager - async with Aclosing( - self._exec_with_plugin( - invocation_context=invocation_context, - session=session, - execute_fn=execute, - is_live_call=True, - ) - ) as agen: - async for event in agen: - yield event - - def _find_agent_to_run(self, session: Session, root_agent: BaseAgent) -> BaseAgent: - """Finds the agent to run to continue the session. - - A qualified agent must be either of: - - - The agent that returned a function call and the last user message is a - function response to this function call. - - The root agent. - - An LlmAgent who replied last and is capable to transfer to any other agent - in the agent hierarchy. - - Args: - session: The session to find the agent for. - root_agent: The root agent of the runner. - - Returns: - The agent to run. (the active agent that should reply to the latest user - message) - """ - # If the last event is a function response, should send this response to - # the agent that returned the corresponding function call regardless the - # type of the agent. e.g. a remote a2a agent may surface a credential - # request as a special long running function tool call. - event = find_matching_function_call(session.events) - if event and event.author: - return root_agent.find_agent(event.author) - - def _event_filter(event: Event) -> bool: - """Filters out user-authored events and agent state change events.""" - if event.author == "user": - return False - if event.actions.agent_state is not None or event.actions.end_of_agent: - return False - return True - - for event in filter(_event_filter, reversed(session.events)): - if event.author == root_agent.name: - # Found root agent. - return root_agent - if not (agent := root_agent.find_sub_agent(event.author)): - # Agent not found, continue looking. - logger.warning( - "Event from an unknown agent: %s, event id: %s", - event.author, - event.id, + # Step 1: Run the before_run callbacks to see if we should early exit. + early_exit_result = await plugin_manager.run_before_run_callback( + invocation_context=invocation_context + ) + if isinstance(early_exit_result, types.Content): + early_exit_event = Event( + invocation_id=invocation_context.invocation_id, + author="model", + content=early_exit_result, + ) + if self._should_append_event(early_exit_event, is_live_call): + await self.session_service.append_event( + session=session, + event=early_exit_event, + ) + yield early_exit_event + else: + # Step 2: Otherwise continue with normal execution + # Note for live/bidi: + # the transcription may arrive later then the action(function call + # event and thus function response event). In this case, the order of + # transcription and function call event will be wrong if we just + # append as it arrives. To address this, we should check if there is + # transcription going on. If there is transcription going on, we + # should hold on appending the function call event until the + # transcription is finished. The transcription in progress can be + # identified by checking if the transcription event is partial. When + # the next transcription event is not partial, it means the previous + # transcription is finished. Then if there is any buffered function + # call event, we should append them after this finished(non-parital) + # transcription event. + buffered_events: list[Event] = [] + is_transcribing: bool = False + + async with Aclosing(execute_fn(invocation_context)) as agen: + async for event in agen: + if is_live_call: + if event.partial and _is_transcription(event): + is_transcribing = True + if is_transcribing and _is_tool_call_or_response(event): + # only buffer function call and function response event which is + # non-partial + buffered_events.append(event) + continue + # Note for live/bidi: for audio response, it's considered as + # non-paritla event(event.partial=None) + # event.partial=False and event.partial=None are considered as + # non-partial event; event.partial=True is considered as partial + # event. + if event.partial is not True: + if _is_transcription(event) and ( + _has_non_empty_transcription_text(event.input_transcription) + or _has_non_empty_transcription_text( + event.output_transcription + ) + ): + # transcription end signal, append buffered events + is_transcribing = False + logger.debug( + "Appending transcription finished event: %s", event ) - continue - if self._is_transferable_across_agent_tree(agent): - return agent - # Falls back to root agent if no suitable agents are found in the session. - return root_agent + if self._should_append_event(event, is_live_call): + await self.session_service.append_event( + session=session, event=event + ) + + for buffered_event in buffered_events: + logger.debug("Appending buffered event: %s", buffered_event) + await self.session_service.append_event( + session=session, event=buffered_event + ) + buffered_events = [] + else: + # non-transcription event or empty transcription event, for + # example, event that stores blob reference, should be appended. + if self._should_append_event(event, is_live_call): + logger.debug("Appending non-buffered event: %s", event) + await self.session_service.append_event( + session=session, event=event + ) + else: + if event.partial is not True: + await self.session_service.append_event( + session=session, event=event + ) + + # Step 3: Run the on_event callbacks to optionally modify the event. + modified_event = await plugin_manager.run_on_event_callback( + invocation_context=invocation_context, event=event + ) + yield (modified_event if modified_event else event) + + # Step 4: Run the after_run callbacks to perform global cleanup tasks or + # finalizing logs and metrics data. + # This does NOT emit any event. + await plugin_manager.run_after_run_callback( + invocation_context=invocation_context + ) - def _is_transferable_across_agent_tree(self, agent_to_run: BaseAgent) -> bool: - """Whether the agent to run can transfer to any other agent in the agent tree. - - This typically means all agent_to_run's ancestor can transfer to their - parent_agent all the way to the root_agent. - - Args: - agent_to_run: The agent to check for transferability. - - Returns: - True if the agent can transfer, False otherwise. - """ - agent = agent_to_run - while agent: - if not isinstance(agent, LlmAgent): - # Only LLM-based Agent can provide agent transfer capability. - return False - if agent.disallow_transfer_to_parent: - return False - agent = agent.parent_agent - return True - - async def run_debug( - self, - user_messages: str | list[str], - *, - user_id: str = "debug_user_id", - session_id: str = "debug_session_id", - run_config: RunConfig | None = None, - quiet: bool = False, - verbose: bool = False, - ) -> list[Event]: - """Debug helper for quick agent experimentation and testing. - - This convenience method is designed for developers getting started with ADK - who want to quickly test agents without dealing with session management, - content formatting, or event streaming. It automatically handles common - boilerplate while hiding complexity. - - IMPORTANT: This is for debugging and experimentation only. For production - use, please use the standard run_async() method which provides full control - over session management, event streaming, and error handling. - - Args: - user_messages: Message(s) to send to the agent. Can be: - Single string: - "What is 2+2?" - List of strings: ["Hello!", "What's my name?"] - user_id: User identifier. Defaults to "debug_user_id". - session_id: Session identifier for conversation persistence. Defaults to - "debug_session_id". Reuse the same ID to continue a conversation. - run_config: Optional configuration for the agent execution. - quiet: If True, suppresses console output. Defaults to False (output - shown). - verbose: If True, shows detailed tool calls and responses. Defaults to - False for cleaner output showing only final agent responses. - - Returns: - list[Event]: All events from all messages. - - Raises: - ValueError: If session creation/retrieval fails. - - Examples: - Quick debugging: - >>> runner = InMemoryRunner(agent=my_agent) - >>> await runner.run_debug("What is 2+2?") - - Multiple queries in conversation: - >>> await runner.run_debug(["Hello!", "What's my name?"]) - - Continue a debug session: - >>> await runner.run_debug("What did we discuss?") # Continues default - session - - Separate debug sessions: - >>> await runner.run_debug("Hi", user_id="alice", session_id="debug1") - >>> await runner.run_debug("Hi", user_id="bob", session_id="debug2") - - Capture events for inspection: - >>> events = await runner.run_debug("Analyze this") - >>> for event in events: - ... inspect_event(event) - - Note: - For production applications requiring: - - Custom session/memory services (Spanner, Cloud SQL, etc.) - - Fine-grained event processing and streaming - - Error recovery and resumability - - Performance optimization - Please use run_async() with proper configuration. - """ - session = await self.session_service.get_session( - app_name=self.app_name, user_id=user_id, session_id=session_id + async def _append_new_message_to_session( + self, + *, + session: Session, + new_message: types.Content, + invocation_context: InvocationContext, + save_input_blobs_as_artifacts: bool = False, + state_delta: Optional[dict[str, Any]] = None, + ): + """Appends a new message to the session. + + Args: + session: The session to append the message to. + new_message: The new message to append. + invocation_context: The invocation context for the message. + save_input_blobs_as_artifacts: Whether to save input blobs as artifacts. + state_delta: Optional state changes to apply to the session. + """ + if not new_message.parts: + raise ValueError("No parts in the new_message.") + + if self.artifact_service and save_input_blobs_as_artifacts: + # Issue deprecation warning + warnings.warn( + "The 'save_input_blobs_as_artifacts' parameter is deprecated. Use" + " SaveFilesAsArtifactsPlugin instead for better control and" + " flexibility. See google.adk.plugins.SaveFilesAsArtifactsPlugin for" + " migration guidance.", + DeprecationWarning, + stacklevel=3, + ) + # The runner directly saves the artifacts (if applicable) in the + # user message and replaces the artifact data with a file name + # placeholder. + for i, part in enumerate(new_message.parts): + if part.inline_data is None: + continue + file_name = f"artifact_{invocation_context.invocation_id}_{i}" + await self.artifact_service.save_artifact( + app_name=self.app_name, + user_id=session.user_id, + session_id=session.id, + filename=file_name, + artifact=part, ) - if not session: - session = await self.session_service.create_session( - app_name=self.app_name, user_id=user_id, session_id=session_id - ) - if not quiet: - print(f"\n ### Created new session: {session_id}") - elif not quiet: - print(f"\n ### Continue session: {session_id}") - - collected_events: list[Event] = [] + new_message.parts[i] = types.Part( + text=f"Uploaded file: {file_name}. It is saved into artifacts" + ) + # Appends only. We do not yield the event because it's not from the model. + if state_delta: + event = Event( + invocation_id=invocation_context.invocation_id, + author="user", + actions=EventActions(state_delta=state_delta), + content=new_message, + ) + else: + event = Event( + invocation_id=invocation_context.invocation_id, + author="user", + content=new_message, + ) + # If new_message is a function response, find the matching function call + # and use its branch as the new event's branch. + if function_call := invocation_context._find_matching_function_call(event): + event.branch = function_call.branch + + await self.session_service.append_event(session=session, event=event) + + async def run_live( + self, + *, + user_id: Optional[str] = None, + session_id: Optional[str] = None, + live_request_queue: LiveRequestQueue, + run_config: Optional[RunConfig] = None, + session: Optional[Session] = None, + ) -> AsyncGenerator[Event, None]: + """Runs the agent in live mode (experimental feature). + + The `run_live` method yields a stream of `Event` objects, but not all + yielded events are saved to the session. Here's a breakdown: + + **Events Yielded to Callers:** + * **Live Model Audio Events with Inline Data:** Events containing raw + audio `Blob` data(`inline_data`). + * **Live Model Audio Events with File Data:** Both input and ouput audio + data are aggregated into a audio file saved into artifacts. The + reference to the file is saved in the event as `file_data`. + * **Usage Metadata:** Events containing token usage. + * **Transcription Events:** Both partial and non-partial transcription + events are yielded. + * **Function Call and Response Events:** Always saved. + * **Other Control Events:** Most control events are saved. + + **Events Saved to the Session:** + * **Live Model Audio Events with File Data:** Both input and ouput audio + data are aggregated into a audio file saved into artifacts. The + reference to the file is saved as event in the `file_data` to session + if RunConfig.save_live_model_audio_to_session is True. + * **Usage Metadata Events:** Saved to the session. + * **Non-Partial Transcription Events:** Non-partial transcription events + are saved. + * **Function Call and Response Events:** Always saved. + * **Other Control Events:** Most control events are saved. + + **Events Not Saved to the Session:** + * **Live Model Audio Events with Inline Data:** Events containing raw + audio `Blob` data are **not** saved to the session. + + Args: + user_id: The user ID for the session. Required if `session` is None. + session_id: The session ID for the session. Required if `session` is + None. + live_request_queue: The queue for live requests. + run_config: The run config for the agent. + session: The session to use. This parameter is deprecated, please use + `user_id` and `session_id` instead. + + Yields: + AsyncGenerator[Event, None]: An asynchronous generator that yields + `Event` + objects as they are produced by the agent during its live execution. + + .. warning:: + This feature is **experimental** and its API or behavior may change + in future releases. + + .. NOTE:: + Either `session` or both `user_id` and `session_id` must be provided. + """ + run_config = run_config or RunConfig() + # Some native audio models requires the modality to be set. So we set it to + # AUDIO by default. + if run_config.response_modalities is None: + run_config.response_modalities = ["AUDIO"] + if session is None and (user_id is None or session_id is None): + raise ValueError( + "Either session or user_id and session_id must be provided." + ) + if session is not None: + warnings.warn( + "The `session` parameter is deprecated. Please use `user_id` and" + " `session_id` instead.", + DeprecationWarning, + stacklevel=2, + ) + if not session: + session = await self.session_service.get_session( + app_name=self.app_name, user_id=user_id, session_id=session_id + ) + if not session: + raise ValueError(f"Session not found: {session_id}") + invocation_context = self._new_invocation_context_for_live( + session, + live_request_queue=live_request_queue, + run_config=run_config, + ) - if isinstance(user_messages, str): - user_messages = [user_messages] + root_agent = self.agent + invocation_context.agent = self._find_agent_to_run(session, root_agent) - for message in user_messages: - if not quiet: - print(f"\nUser > {message}") + # Pre-processing for live streaming tools + # Inspect the tool's parameters to find if it uses LiveRequestQueue + invocation_context.active_streaming_tools = {} + # TODO(hangfei): switch to use canonical_tools. + # for shell agents, there is no tools associated with it so we should skip. + if hasattr(invocation_context.agent, "tools"): + import inspect - async for event in self.run_async( - user_id=user_id, - session_id=session.id, - new_message=types.UserContent(parts=[types.Part(text=message)]), - run_config=run_config, - ): - if not quiet: - print_event(event, verbose=verbose) - - collected_events.append(event) - - return collected_events - - async def _setup_context_for_new_invocation( - self, - *, - session: Session, - new_message: types.Content, - run_config: RunConfig, - state_delta: Optional[dict[str, Any]], - ) -> InvocationContext: - """Sets up the context for a new invocation. - - Args: - session: The session to set up the invocation context for. - new_message: The new message to process and append to the session. - run_config: The run config of the agent. - state_delta: Optional state changes to apply to the session. - - Returns: - The invocation context for the new invocation. - """ - # Step 1: Create invocation context in memory. - invocation_context = self._new_invocation_context( - session, - new_message=new_message, - run_config=run_config, - ) - # Step 2: Handle new message, by running callbacks and appending to - # session. - await self._handle_new_message( - session=session, - new_message=new_message, - invocation_context=invocation_context, - run_config=run_config, - state_delta=state_delta, - ) - # Step 3: Set agent to run for the invocation. - invocation_context.agent = self._find_agent_to_run(session, self.agent) - return invocation_context - - async def _setup_context_for_resumed_invocation( - self, - *, - session: Session, - new_message: Optional[types.Content], - invocation_id: Optional[str], - run_config: RunConfig, - state_delta: Optional[dict[str, Any]], - ) -> InvocationContext: - """Sets up the context for a resumed invocation. - - Args: - session: The session to set up the invocation context for. - new_message: The new message to process and append to the session. - invocation_id: The invocation id to resume. - run_config: The run config of the agent. - state_delta: Optional state changes to apply to the session. - - Returns: - The invocation context for the resumed invocation. - - Raises: - ValueError: If the session has no events to resume; If no user message is - available for resuming the invocation; Or if the app is not resumable. - """ - if not session.events: - raise ValueError(f"Session {session.id} has no events to resume.") - - # Step 1: Maybe retrieve a previous user message for the invocation. - user_message = new_message or self._find_user_message_for_invocation( - session.events, invocation_id - ) - if not user_message: - raise ValueError( - f"No user message available for resuming invocation: {invocation_id}" + for tool in invocation_context.agent.tools: + # We use `inspect.signature()` to examine the tool's underlying function (`tool.func`). + # This approach is deliberately chosen over `typing.get_type_hints()` for robustness. + # + # The Problem with `get_type_hints()`: + # `get_type_hints()` attempts to resolve forward-referenced (string-based) type + # annotations. This resolution can easily fail with a `NameError` (e.g., "Union not found") + # if the type isn't available in the scope where `get_type_hints()` is called. + # This is a common and brittle issue in framework code that inspects functions + # defined in separate user modules. + # + # Why `inspect.signature()` is Better Here: + # `inspect.signature()` does NOT resolve the annotations; it retrieves the raw + # annotation object as it was defined on the function. This allows us to + # perform a direct and reliable identity check (`param.annotation is LiveRequestQueue`) + # without risking a `NameError`. + callable_to_inspect = tool.func if hasattr(tool, "func") else tool + # Ensure the target is actually callable before inspecting to avoid errors. + if not callable(callable_to_inspect): + continue + for param in inspect.signature(callable_to_inspect).parameters.values(): + if param.annotation is LiveRequestQueue: + if not invocation_context.active_streaming_tools: + invocation_context.active_streaming_tools = {} + active_streaming_tool = ActiveStreamingTool( + stream=LiveRequestQueue() ) - # Step 2: Create invocation context. - invocation_context = self._new_invocation_context( - session, - new_message=user_message, - run_config=run_config, - invocation_id=invocation_id, - ) - # Step 3: Maybe handle new message. - if new_message: - await self._handle_new_message( - session=session, - new_message=user_message, - invocation_context=invocation_context, - run_config=run_config, - state_delta=state_delta, + invocation_context.active_streaming_tools[tool.__name__] = ( + active_streaming_tool ) - # Step 4: Populate agent states for the current invocation. - invocation_context.populate_invocation_agent_states() - # Step 5: Set agent to run for the invocation. - # - # If the root agent is not found in end_of_agents, it means the invocation - # started from a sub-agent and paused on a sub-agent. - # We should find the appropriate agent to run to continue the invocation. - if self.agent.name not in invocation_context.end_of_agents: - invocation_context.agent = self._find_agent_to_run(session, self.agent) - return invocation_context - - def _find_user_message_for_invocation( - self, events: list[Event], invocation_id: str - ) -> Optional[types.Content]: - """Finds the user message that started a specific invocation.""" - for event in events: - if ( - event.invocation_id == invocation_id - and event.author == "user" - and event.content - and event.content.parts - and event.content.parts[0].text - ): - return event.content - return None - - def _new_invocation_context( - self, - session: Session, - *, - invocation_id: Optional[str] = None, - new_message: Optional[types.Content] = None, - live_request_queue: Optional[LiveRequestQueue] = None, - run_config: Optional[RunConfig] = None, - ) -> InvocationContext: - """Creates a new invocation context. - - Args: - session: The session for the context. - invocation_id: The invocation id for the context. - new_message: The new message for the context. - live_request_queue: The live request queue for the context. - run_config: The run config for the context. - - Returns: - The new invocation context. - """ - run_config = run_config or RunConfig() - invocation_id = invocation_id or new_invocation_context_id() - - if run_config.support_cfc and isinstance(self.agent, LlmAgent): - model_name = self.agent.canonical_model.model - if not model_name.startswith("gemini-2"): - raise ValueError( - f"CFC is not supported for model: {model_name} in agent:" - f" {self.agent.name}" - ) - if not isinstance(self.agent.code_executor, BuiltInCodeExecutor): - self.agent.code_executor = BuiltInCodeExecutor() - - return InvocationContext( - artifact_service=self.artifact_service, - session_service=self.session_service, - memory_service=self.memory_service, - credential_service=self.credential_service, - plugin_manager=self.plugin_manager, - context_cache_config=self.context_cache_config, - invocation_id=invocation_id, - agent=self.agent, + + async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]: + async with Aclosing(ctx.agent.run_live(ctx)) as agen: + async for event in agen: + yield event + + async with Aclosing( + self._exec_with_plugin( + invocation_context=invocation_context, session=session, - user_content=new_message, - live_request_queue=live_request_queue, - run_config=run_config, - resumability_config=self.resumability_config, + execute_fn=execute, + is_live_call=True, ) - - def _new_invocation_context_for_live( - self, - session: Session, - *, - live_request_queue: Optional[LiveRequestQueue] = None, - run_config: Optional[RunConfig] = None, - ) -> InvocationContext: - """Creates a new invocation context for live multi-agent.""" - run_config = run_config or RunConfig() - - # For live multi-agent, we need model's text transcription as context for - # next agent. - if self.agent.sub_agents and live_request_queue: - if not run_config.response_modalities: - # default - run_config.response_modalities = ["AUDIO"] - if not run_config.output_audio_transcription: - run_config.output_audio_transcription = ( - types.AudioTranscriptionConfig() - ) - elif "TEXT" not in run_config.response_modalities: - if not run_config.output_audio_transcription: - run_config.output_audio_transcription = ( - types.AudioTranscriptionConfig() - ) - if not run_config.input_audio_transcription: - # need this input transcription for agent transferring in live mode. - run_config.input_audio_transcription = types.AudioTranscriptionConfig() - return self._new_invocation_context( - session, - live_request_queue=live_request_queue, - run_config=run_config, + ) as agen: + async for event in agen: + yield event + + def _find_agent_to_run( + self, session: Session, root_agent: BaseAgent + ) -> BaseAgent: + """Finds the agent to run to continue the session. + + A qualified agent must be either of: + + - The agent that returned a function call and the last user message is a + function response to this function call. + - The root agent. + - An LlmAgent who replied last and is capable to transfer to any other agent + in the agent hierarchy. + + Args: + session: The session to find the agent for. + root_agent: The root agent of the runner. + + Returns: + The agent to run. (the active agent that should reply to the latest user + message) + """ + # If the last event is a function response, should send this response to + # the agent that returned the corresponding function call regardless the + # type of the agent. e.g. a remote a2a agent may surface a credential + # request as a special long running function tool call. + event = find_matching_function_call(session.events) + if event and event.author: + return root_agent.find_agent(event.author) + + def _event_filter(event: Event) -> bool: + """Filters out user-authored events and agent state change events.""" + if event.author == "user": + return False + if event.actions.agent_state is not None or event.actions.end_of_agent: + return False + return True + + for event in filter(_event_filter, reversed(session.events)): + if event.author == root_agent.name: + # Found root agent. + return root_agent + if not (agent := root_agent.find_sub_agent(event.author)): + # Agent not found, continue looking. + logger.warning( + "Event from an unknown agent: %s, event id: %s", + event.author, + event.id, ) + continue + if self._is_transferable_across_agent_tree(agent): + return agent + # Falls back to root agent if no suitable agents are found in the session. + return root_agent - async def _handle_new_message( - self, - *, - session: Session, - new_message: types.Content, - invocation_context: InvocationContext, - run_config: RunConfig, - state_delta: Optional[dict[str, Any]], - ) -> None: - """Handles a new message by running callbacks and appending to session. - - Args: - session: The session of the new message. - new_message: The new message to process and append to the session. - invocation_context: The invocation context to use for the message - handling. - run_config: The run config of the agent. - state_delta: Optional state changes to apply to the session. - """ - modified_user_message = ( - await invocation_context.plugin_manager.run_on_user_message_callback( - invocation_context=invocation_context, user_message=new_message - ) - ) - if modified_user_message is not None: - new_message = modified_user_message - invocation_context.user_content = new_message + def _is_transferable_across_agent_tree(self, agent_to_run: BaseAgent) -> bool: + """Whether the agent to run can transfer to any other agent in the agent tree. - if new_message: - await self._append_new_message_to_session( - session=session, - new_message=new_message, - invocation_context=invocation_context, - save_input_blobs_as_artifacts=run_config.save_input_blobs_as_artifacts, - state_delta=state_delta, - ) + This typically means all agent_to_run's ancestor can transfer to their + parent_agent all the way to the root_agent. - def _collect_toolset(self, agent: BaseAgent) -> set[BaseToolset]: - toolsets = set() - if isinstance(agent, LlmAgent): - for tool_union in agent.tools: - if isinstance(tool_union, BaseToolset): - toolsets.add(tool_union) - for sub_agent in agent.sub_agents: - toolsets.update(self._collect_toolset(sub_agent)) - return toolsets - - async def _cleanup_toolsets(self, toolsets_to_close: set[BaseToolset]): - """Clean up toolsets with proper task context management.""" - if not toolsets_to_close: - return - - # This maintains the same task context throughout cleanup - for toolset in toolsets_to_close: - try: - logger.info("Closing toolset: %s", type(toolset).__name__) - # Use asyncio.wait_for to add timeout protection - await asyncio.wait_for(toolset.close(), timeout=10.0) - logger.info("Successfully closed toolset: %s", type(toolset).__name__) - except asyncio.TimeoutError: - logger.warning("Toolset %s cleanup timed out", type(toolset).__name__) - except asyncio.CancelledError as e: - # Handle cancel scope issues in Python 3.10 and 3.11 with anyio - # - # Root cause: MCP library uses anyio.CancelScope() in RequestResponder.__enter__() - # and __exit__() methods. When asyncio.wait_for() creates a new task for cleanup, - # the cancel scope is entered in one task context but exited in another. - # - # Python 3.12+ fixes: Enhanced task context management (Task.get_context()), - # improved context propagation across task boundaries, and better cancellation - # handling prevent the cross-task cancel scope violation. - logger.warning( - "Toolset %s cleanup cancelled: %s", type(toolset).__name__, e - ) - except Exception as e: - logger.error("Error closing toolset %s: %s", type(toolset).__name__, e) + Args: + agent_to_run: The agent to check for transferability. - async def close(self): - """Closes the runner.""" - logger.info("Closing runner...") - # Close Toolsets - await self._cleanup_toolsets(self._collect_toolset(self.agent)) + Returns: + True if the agent can transfer, False otherwise. + """ + agent = agent_to_run + while agent: + if not isinstance(agent, LlmAgent): + # Only LLM-based Agent can provide agent transfer capability. + return False + if agent.disallow_transfer_to_parent: + return False + agent = agent.parent_agent + return True + + async def run_debug( + self, + user_messages: str | list[str], + *, + user_id: str = "debug_user_id", + session_id: str = "debug_session_id", + run_config: RunConfig | None = None, + quiet: bool = False, + verbose: bool = False, + ) -> list[Event]: + """Debug helper for quick agent experimentation and testing. + + This convenience method is designed for developers getting started with ADK + who want to quickly test agents without dealing with session management, + content formatting, or event streaming. It automatically handles common + boilerplate while hiding complexity. + + IMPORTANT: This is for debugging and experimentation only. For production + use, please use the standard run_async() method which provides full control + over session management, event streaming, and error handling. + + Args: + user_messages: Message(s) to send to the agent. Can be: - Single string: + "What is 2+2?" - List of strings: ["Hello!", "What's my name?"] + user_id: User identifier. Defaults to "debug_user_id". + session_id: Session identifier for conversation persistence. Defaults to + "debug_session_id". Reuse the same ID to continue a conversation. + run_config: Optional configuration for the agent execution. + quiet: If True, suppresses console output. Defaults to False (output + shown). + verbose: If True, shows detailed tool calls and responses. Defaults to + False for cleaner output showing only final agent responses. + + Returns: + list[Event]: All events from all messages. + + Raises: + ValueError: If session creation/retrieval fails. + + Examples: + Quick debugging: + >>> runner = InMemoryRunner(agent=my_agent) + >>> await runner.run_debug("What is 2+2?") + + Multiple queries in conversation: + >>> await runner.run_debug(["Hello!", "What's my name?"]) + + Continue a debug session: + >>> await runner.run_debug("What did we discuss?") # Continues default + session + + Separate debug sessions: + >>> await runner.run_debug("Hi", user_id="alice", session_id="debug1") + >>> await runner.run_debug("Hi", user_id="bob", session_id="debug2") + + Capture events for inspection: + >>> events = await runner.run_debug("Analyze this") + >>> for event in events: + ... inspect_event(event) + + Note: + For production applications requiring: + - Custom session/memory services (Spanner, Cloud SQL, etc.) + - Fine-grained event processing and streaming + - Error recovery and resumability + - Performance optimization + Please use run_async() with proper configuration. + """ + session = await self.session_service.get_session( + app_name=self.app_name, user_id=user_id, session_id=session_id + ) + if not session: + session = await self.session_service.create_session( + app_name=self.app_name, user_id=user_id, session_id=session_id + ) + if not quiet: + print(f"\n ### Created new session: {session_id}") + elif not quiet: + print(f"\n ### Continue session: {session_id}") + + collected_events: list[Event] = [] + + if isinstance(user_messages, str): + user_messages = [user_messages] + + for message in user_messages: + if not quiet: + print(f"\nUser > {message}") + + async for event in self.run_async( + user_id=user_id, + session_id=session.id, + new_message=types.UserContent(parts=[types.Part(text=message)]), + run_config=run_config, + ): + if not quiet: + print_event(event, verbose=verbose) + + collected_events.append(event) + + return collected_events + + async def _setup_context_for_new_invocation( + self, + *, + session: Session, + new_message: types.Content, + run_config: RunConfig, + state_delta: Optional[dict[str, Any]], + ) -> InvocationContext: + """Sets up the context for a new invocation. + + Args: + session: The session to set up the invocation context for. + new_message: The new message to process and append to the session. + run_config: The run config of the agent. + state_delta: Optional state changes to apply to the session. + + Returns: + The invocation context for the new invocation. + """ + # Step 1: Create invocation context in memory. + invocation_context = self._new_invocation_context( + session, + new_message=new_message, + run_config=run_config, + ) + # Step 2: Handle new message, by running callbacks and appending to + # session. + await self._handle_new_message( + session=session, + new_message=new_message, + invocation_context=invocation_context, + run_config=run_config, + state_delta=state_delta, + ) + # Step 3: Set agent to run for the invocation. + invocation_context.agent = self._find_agent_to_run(session, self.agent) + return invocation_context + + async def _setup_context_for_resumed_invocation( + self, + *, + session: Session, + new_message: Optional[types.Content], + invocation_id: Optional[str], + run_config: RunConfig, + state_delta: Optional[dict[str, Any]], + ) -> InvocationContext: + """Sets up the context for a resumed invocation. + + Args: + session: The session to set up the invocation context for. + new_message: The new message to process and append to the session. + invocation_id: The invocation id to resume. + run_config: The run config of the agent. + state_delta: Optional state changes to apply to the session. + + Returns: + The invocation context for the resumed invocation. + + Raises: + ValueError: If the session has no events to resume; If no user message is + available for resuming the invocation; Or if the app is not resumable. + """ + if not session.events: + raise ValueError(f"Session {session.id} has no events to resume.") - # Close Plugins - if self.plugin_manager: - await self.plugin_manager.close() + # Step 1: Maybe retrieve a previous user message for the invocation. + user_message = new_message or self._find_user_message_for_invocation( + session.events, invocation_id + ) + if not user_message: + raise ValueError( + f"No user message available for resuming invocation: {invocation_id}" + ) + # Step 2: Create invocation context. + invocation_context = self._new_invocation_context( + session, + new_message=user_message, + run_config=run_config, + invocation_id=invocation_id, + ) + # Step 3: Maybe handle new message. + if new_message: + await self._handle_new_message( + session=session, + new_message=user_message, + invocation_context=invocation_context, + run_config=run_config, + state_delta=state_delta, + ) + # Step 4: Populate agent states for the current invocation. + invocation_context.populate_invocation_agent_states() + # Step 5: Set agent to run for the invocation. + # + # If the root agent is not found in end_of_agents, it means the invocation + # started from a sub-agent and paused on a sub-agent. + # We should find the appropriate agent to run to continue the invocation. + if self.agent.name not in invocation_context.end_of_agents: + invocation_context.agent = self._find_agent_to_run(session, self.agent) + return invocation_context + + def _find_user_message_for_invocation( + self, events: list[Event], invocation_id: str + ) -> Optional[types.Content]: + """Finds the user message that started a specific invocation.""" + for event in events: + if ( + event.invocation_id == invocation_id + and event.author == "user" + and event.content + and event.content.parts + and event.content.parts[0].text + ): + return event.content + return None + + def _new_invocation_context( + self, + session: Session, + *, + invocation_id: Optional[str] = None, + new_message: Optional[types.Content] = None, + live_request_queue: Optional[LiveRequestQueue] = None, + run_config: Optional[RunConfig] = None, + ) -> InvocationContext: + """Creates a new invocation context. + + Args: + session: The session for the context. + invocation_id: The invocation id for the context. + new_message: The new message for the context. + live_request_queue: The live request queue for the context. + run_config: The run config for the context. + + Returns: + The new invocation context. + """ + run_config = run_config or RunConfig() + invocation_id = invocation_id or new_invocation_context_id() + + if run_config.support_cfc and isinstance(self.agent, LlmAgent): + model_name = self.agent.canonical_model.model + if not model_name.startswith("gemini-2"): + raise ValueError( + f"CFC is not supported for model: {model_name} in agent:" + f" {self.agent.name}" + ) + if not isinstance(self.agent.code_executor, BuiltInCodeExecutor): + self.agent.code_executor = BuiltInCodeExecutor() + + return InvocationContext( + artifact_service=self.artifact_service, + session_service=self.session_service, + memory_service=self.memory_service, + credential_service=self.credential_service, + plugin_manager=self.plugin_manager, + context_cache_config=self.context_cache_config, + invocation_id=invocation_id, + agent=self.agent, + session=session, + user_content=new_message, + live_request_queue=live_request_queue, + run_config=run_config, + resumability_config=self.resumability_config, + ) - logger.info("Runner closed.") + def _new_invocation_context_for_live( + self, + session: Session, + *, + live_request_queue: Optional[LiveRequestQueue] = None, + run_config: Optional[RunConfig] = None, + ) -> InvocationContext: + """Creates a new invocation context for live multi-agent.""" + run_config = run_config or RunConfig() + + # For live multi-agent, we need model's text transcription as context for + # next agent. + if self.agent.sub_agents and live_request_queue: + if not run_config.response_modalities: + # default + run_config.response_modalities = ["AUDIO"] + if not run_config.output_audio_transcription: + run_config.output_audio_transcription = ( + types.AudioTranscriptionConfig() + ) + elif "TEXT" not in run_config.response_modalities: + if not run_config.output_audio_transcription: + run_config.output_audio_transcription = ( + types.AudioTranscriptionConfig() + ) + if not run_config.input_audio_transcription: + # need this input transcription for agent transferring in live mode. + run_config.input_audio_transcription = types.AudioTranscriptionConfig() + return self._new_invocation_context( + session, + live_request_queue=live_request_queue, + run_config=run_config, + ) - if sys.version_info < (3, 11): - Self = "Runner" # pylint: disable=invalid-name - else: - from typing import Self # pylint: disable=g-import-not-at-top + async def _handle_new_message( + self, + *, + session: Session, + new_message: types.Content, + invocation_context: InvocationContext, + run_config: RunConfig, + state_delta: Optional[dict[str, Any]], + ) -> None: + """Handles a new message by running callbacks and appending to session. + + Args: + session: The session of the new message. + new_message: The new message to process and append to the session. + invocation_context: The invocation context to use for the message + handling. + run_config: The run config of the agent. + state_delta: Optional state changes to apply to the session. + """ + modified_user_message = ( + await invocation_context.plugin_manager.run_on_user_message_callback( + invocation_context=invocation_context, user_message=new_message + ) + ) + if modified_user_message is not None: + new_message = modified_user_message + invocation_context.user_content = new_message + + if new_message: + await self._append_new_message_to_session( + session=session, + new_message=new_message, + invocation_context=invocation_context, + save_input_blobs_as_artifacts=run_config.save_input_blobs_as_artifacts, + state_delta=state_delta, + ) + + def _collect_toolset(self, agent: BaseAgent) -> set[BaseToolset]: + toolsets = set() + if isinstance(agent, LlmAgent): + for tool_union in agent.tools: + if isinstance(tool_union, BaseToolset): + toolsets.add(tool_union) + for sub_agent in agent.sub_agents: + toolsets.update(self._collect_toolset(sub_agent)) + return toolsets + + async def _cleanup_toolsets(self, toolsets_to_close: set[BaseToolset]): + """Clean up toolsets with proper task context management.""" + if not toolsets_to_close: + return + + # This maintains the same task context throughout cleanup + for toolset in toolsets_to_close: + try: + logger.info("Closing toolset: %s", type(toolset).__name__) + # Use asyncio.wait_for to add timeout protection + await asyncio.wait_for(toolset.close(), timeout=10.0) + logger.info("Successfully closed toolset: %s", type(toolset).__name__) + except asyncio.TimeoutError: + logger.warning("Toolset %s cleanup timed out", type(toolset).__name__) + except asyncio.CancelledError as e: + # Handle cancel scope issues in Python 3.10 and 3.11 with anyio + # + # Root cause: MCP library uses anyio.CancelScope() in RequestResponder.__enter__() + # and __exit__() methods. When asyncio.wait_for() creates a new task for cleanup, + # the cancel scope is entered in one task context but exited in another. + # + # Python 3.12+ fixes: Enhanced task context management (Task.get_context()), + # improved context propagation across task boundaries, and better cancellation + # handling prevent the cross-task cancel scope violation. + logger.warning( + "Toolset %s cleanup cancelled: %s", type(toolset).__name__, e + ) + except Exception as e: + logger.error("Error closing toolset %s: %s", type(toolset).__name__, e) - async def __aenter__(self) -> Self: - """Async context manager entry.""" - return self + async def close(self): + """Closes the runner.""" + logger.info("Closing runner...") + # Close Toolsets + await self._cleanup_toolsets(self._collect_toolset(self.agent)) - async def __aexit__(self, exc_type, exc_val, exc_tb): - """Async context manager exit.""" - await self.close() - return False # Don't suppress exceptions from the async with block + # Close Plugins + if self.plugin_manager: + await self.plugin_manager.close() + logger.info("Runner closed.") -class InMemoryRunner(Runner): - """An in-memory Runner for testing and development. + if sys.version_info < (3, 11): + Self = "Runner" # pylint: disable=invalid-name + else: + from typing import Self # pylint: disable=g-import-not-at-top + + async def __aenter__(self) -> Self: + """Async context manager entry.""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + await self.close() + return False # Don't suppress exceptions from the async with block - This runner uses in-memory implementations for artifact, session, and memory - services, providing a lightweight and self-contained environment for agent - execution. - Attributes: +class InMemoryRunner(Runner): + """An in-memory Runner for testing and development. + + This runner uses in-memory implementations for artifact, session, and memory + services, providing a lightweight and self-contained environment for agent + execution. + + Attributes: + agent: The root agent to run. + app_name: The application name of the runner. Defaults to + 'InMemoryRunner'. + """ + + def __init__( + self, + agent: Optional[BaseAgent] = None, + *, + app_name: Optional[str] = None, + plugins: Optional[list[BasePlugin]] = None, + app: Optional[App] = None, + plugin_close_timeout: float = 5.0, + ): + """Initializes the InMemoryRunner. + + Args: agent: The root agent to run. app_name: The application name of the runner. Defaults to 'InMemoryRunner'. + plugins: Optional list of plugins for the runner. + app: Optional App instance. + plugin_close_timeout: The timeout in seconds for plugin close methods. """ - - def __init__( - self, - agent: Optional[BaseAgent] = None, - *, - app_name: Optional[str] = None, - plugins: Optional[list[BasePlugin]] = None, - app: Optional[App] = None, - plugin_close_timeout: float = 5.0, - ): - """Initializes the InMemoryRunner. - - Args: - agent: The root agent to run. - app_name: The application name of the runner. Defaults to - 'InMemoryRunner'. - plugins: Optional list of plugins for the runner. - app: Optional App instance. - plugin_close_timeout: The timeout in seconds for plugin close methods. - """ - if app is None and app_name is None: - app_name = "InMemoryRunner" - super().__init__( - app_name=app_name, - agent=agent, - artifact_service=InMemoryArtifactService(), - plugins=plugins, - app=app, - session_service=InMemorySessionService(), - memory_service=InMemoryMemoryService(), - plugin_close_timeout=plugin_close_timeout, - ) + if app is None and app_name is None: + app_name = "InMemoryRunner" + super().__init__( + app_name=app_name, + agent=agent, + artifact_service=InMemoryArtifactService(), + plugins=plugins, + app=app, + session_service=InMemorySessionService(), + memory_service=InMemoryMemoryService(), + plugin_close_timeout=plugin_close_timeout, + ) diff --git a/src/google/adk/telemetry/tracing.py b/src/google/adk/telemetry/tracing.py index 3b57c4328b..9e67ac6c0e 100644 --- a/src/google/adk/telemetry/tracing.py +++ b/src/google/adk/telemetry/tracing.py @@ -50,11 +50,11 @@ # Needed to avoid circular imports if TYPE_CHECKING: - from ..agents.base_agent import BaseAgent - from ..agents.invocation_context import InvocationContext - from ..models.llm_request import LlmRequest - from ..models.llm_response import LlmResponse - from ..tools.base_tool import BaseTool + from ..agents.base_agent import BaseAgent + from ..agents.invocation_context import InvocationContext + from ..models.llm_request import LlmRequest + from ..models.llm_response import LlmResponse + from ..tools.base_tool import BaseTool tracer = trace.get_tracer( instrumenting_module_name="gcp.vertex.agent", @@ -65,55 +65,55 @@ def _safe_json_serialize(obj) -> str: - """Convert any Python object to a JSON-serializable type or string. + """Convert any Python object to a JSON-serializable type or string. - Args: - obj: The object to serialize. + Args: + obj: The object to serialize. - Returns: - The JSON-serialized object string or if the object cannot be serialized. - """ + Returns: + The JSON-serialized object string or if the object cannot be serialized. + """ - try: - # Try direct JSON serialization first - return json.dumps( - obj, ensure_ascii=False, default=lambda o: "" - ) - except (TypeError, OverflowError): - return "" + try: + # Try direct JSON serialization first + return json.dumps( + obj, ensure_ascii=False, default=lambda o: "" + ) + except (TypeError, OverflowError): + return "" def trace_agent_invocation( span: trace.Span, agent: BaseAgent, ctx: InvocationContext ) -> None: - """Sets span attributes immediately available on agent invocation according to OTEL semconv version 1.37. + """Sets span attributes immediately available on agent invocation according to OTEL semconv version 1.37. - Args: - span: Span on which attributes are set. - agent: Agent from which attributes are gathered. - ctx: InvocationContext from which attributes are gathered. + Args: + span: Span on which attributes are set. + agent: Agent from which attributes are gathered. + ctx: InvocationContext from which attributes are gathered. - Inference related fields are not set, due to their planned removal from invoke_agent span: - https://github.com/open-telemetry/semantic-conventions/issues/2632 + Inference related fields are not set, due to their planned removal from invoke_agent span: + https://github.com/open-telemetry/semantic-conventions/issues/2632 - `gen_ai.agent.id` is not set because currently it's unclear what attributes this field should have, specifically: - - In which scope should it be unique (globally, given project, given agentic flow, given deployment). - - Should it be unchanging between deployments, and how this should this be achieved. + `gen_ai.agent.id` is not set because currently it's unclear what attributes this field should have, specifically: + - In which scope should it be unique (globally, given project, given agentic flow, given deployment). + - Should it be unchanging between deployments, and how this should this be achieved. - `gen_ai.data_source.id` is not set because it's not available. - Closest type which could contain this information is types.GroundingMetadata, which does not have an ID. + `gen_ai.data_source.id` is not set because it's not available. + Closest type which could contain this information is types.GroundingMetadata, which does not have an ID. - `server.*` attributes are not set pending confirmation from aabmass. - """ + `server.*` attributes are not set pending confirmation from aabmass. + """ - # Required - span.set_attribute(GEN_AI_OPERATION_NAME, "invoke_agent") + # Required + span.set_attribute(GEN_AI_OPERATION_NAME, "invoke_agent") - # Conditionally Required - span.set_attribute(GEN_AI_AGENT_DESCRIPTION, agent.description) + # Conditionally Required + span.set_attribute(GEN_AI_AGENT_DESCRIPTION, agent.description) - span.set_attribute(GEN_AI_AGENT_NAME, agent.name) - span.set_attribute(GEN_AI_CONVERSATION_ID, ctx.session.id) + span.set_attribute(GEN_AI_AGENT_NAME, agent.name) + span.set_attribute(GEN_AI_CONVERSATION_ID, ctx.session.id) def trace_tool_call( @@ -121,112 +121,112 @@ def trace_tool_call( args: dict[str, Any], function_response_event: Optional[Event], ): - """Traces tool call. + """Traces tool call. - Args: - tool: The tool that was called. - args: The arguments to the tool call. - function_response_event: The event with the function response details. - """ - span = trace.get_current_span() + Args: + tool: The tool that was called. + args: The arguments to the tool call. + function_response_event: The event with the function response details. + """ + span = trace.get_current_span() - span.set_attribute(GEN_AI_OPERATION_NAME, "execute_tool") + span.set_attribute(GEN_AI_OPERATION_NAME, "execute_tool") - span.set_attribute(GEN_AI_TOOL_DESCRIPTION, tool.description) - span.set_attribute(GEN_AI_TOOL_NAME, tool.name) + span.set_attribute(GEN_AI_TOOL_DESCRIPTION, tool.description) + span.set_attribute(GEN_AI_TOOL_NAME, tool.name) - # e.g. FunctionTool - span.set_attribute(GEN_AI_TOOL_TYPE, tool.__class__.__name__) + # e.g. FunctionTool + span.set_attribute(GEN_AI_TOOL_TYPE, tool.__class__.__name__) - # Setting empty llm request and response (as UI expect these) while not - # applicable for tool_response. - span.set_attribute("gcp.vertex.agent.llm_request", "{}") - span.set_attribute("gcp.vertex.agent.llm_response", "{}") + # Setting empty llm request and response (as UI expect these) while not + # applicable for tool_response. + span.set_attribute("gcp.vertex.agent.llm_request", "{}") + span.set_attribute("gcp.vertex.agent.llm_response", "{}") - if _should_add_request_response_to_spans(): - span.set_attribute( - "gcp.vertex.agent.tool_call_args", - _safe_json_serialize(args), - ) - else: - span.set_attribute("gcp.vertex.agent.tool_call_args", {}) - - # Tracing tool response - tool_call_id = "" - tool_response = "" - if ( - function_response_event is not None - and function_response_event.content is not None - and function_response_event.content.parts - ): - response_parts = function_response_event.content.parts - function_response = response_parts[0].function_response - if function_response is not None: - if function_response.id is not None: - tool_call_id = function_response.id - if function_response.response is not None: - tool_response = function_response.response - - span.set_attribute(GEN_AI_TOOL_CALL_ID, tool_call_id) - - if not isinstance(tool_response, dict): - tool_response = {"result": tool_response} - if function_response_event is not None: - span.set_attribute("gcp.vertex.agent.event_id", function_response_event.id) - if _should_add_request_response_to_spans(): - span.set_attribute( - "gcp.vertex.agent.tool_response", - _safe_json_serialize(tool_response), - ) - else: - span.set_attribute("gcp.vertex.agent.tool_response", {}) + if _should_add_request_response_to_spans(): + span.set_attribute( + "gcp.vertex.agent.tool_call_args", + _safe_json_serialize(args), + ) + else: + span.set_attribute("gcp.vertex.agent.tool_call_args", {}) + + # Tracing tool response + tool_call_id = "" + tool_response = "" + if ( + function_response_event is not None + and function_response_event.content is not None + and function_response_event.content.parts + ): + response_parts = function_response_event.content.parts + function_response = response_parts[0].function_response + if function_response is not None: + if function_response.id is not None: + tool_call_id = function_response.id + if function_response.response is not None: + tool_response = function_response.response + + span.set_attribute(GEN_AI_TOOL_CALL_ID, tool_call_id) + + if not isinstance(tool_response, dict): + tool_response = {"result": tool_response} + if function_response_event is not None: + span.set_attribute("gcp.vertex.agent.event_id", function_response_event.id) + if _should_add_request_response_to_spans(): + span.set_attribute( + "gcp.vertex.agent.tool_response", + _safe_json_serialize(tool_response), + ) + else: + span.set_attribute("gcp.vertex.agent.tool_response", {}) def trace_merged_tool_calls( response_event_id: str, function_response_event: Event, ): - """Traces merged tool call events. + """Traces merged tool call events. - Calling this function is not needed for telemetry purposes. This is provided - for preventing /debug/trace requests (typically sent by web UI). + Calling this function is not needed for telemetry purposes. This is provided + for preventing /debug/trace requests (typically sent by web UI). - Args: - response_event_id: The ID of the response event. - function_response_event: The merged response event. - """ + Args: + response_event_id: The ID of the response event. + function_response_event: The merged response event. + """ - span = trace.get_current_span() + span = trace.get_current_span() - span.set_attribute(GEN_AI_OPERATION_NAME, "execute_tool") - span.set_attribute(GEN_AI_TOOL_NAME, "(merged tools)") - span.set_attribute(GEN_AI_TOOL_DESCRIPTION, "(merged tools)") - span.set_attribute(GEN_AI_TOOL_CALL_ID, response_event_id) + span.set_attribute(GEN_AI_OPERATION_NAME, "execute_tool") + span.set_attribute(GEN_AI_TOOL_NAME, "(merged tools)") + span.set_attribute(GEN_AI_TOOL_DESCRIPTION, "(merged tools)") + span.set_attribute(GEN_AI_TOOL_CALL_ID, response_event_id) - # TODO(b/441461932): See if these are still necessary - span.set_attribute("gcp.vertex.agent.tool_call_args", "N/A") - span.set_attribute("gcp.vertex.agent.event_id", response_event_id) - try: - function_response_event_json = function_response_event.model_dumps_json( - exclude_none=True - ) - except Exception: # pylint: disable=broad-exception-caught - function_response_event_json = "" + # TODO(b/441461932): See if these are still necessary + span.set_attribute("gcp.vertex.agent.tool_call_args", "N/A") + span.set_attribute("gcp.vertex.agent.event_id", response_event_id) + try: + function_response_event_json = function_response_event.model_dumps_json( + exclude_none=True + ) + except Exception: # pylint: disable=broad-exception-caught + function_response_event_json = "" - if _should_add_request_response_to_spans(): - span.set_attribute( - "gcp.vertex.agent.tool_response", - function_response_event_json, - ) - else: - span.set_attribute("gcp.vertex.agent.tool_response", {}) - # Setting empty llm request and response (as UI expect these) while not - # applicable for tool_response. - span.set_attribute("gcp.vertex.agent.llm_request", "{}") + if _should_add_request_response_to_spans(): span.set_attribute( - "gcp.vertex.agent.llm_response", - "{}", + "gcp.vertex.agent.tool_response", + function_response_event_json, ) + else: + span.set_attribute("gcp.vertex.agent.tool_response", {}) + # Setting empty llm request and response (as UI expect these) while not + # applicable for tool_response. + span.set_attribute("gcp.vertex.agent.llm_request", "{}") + span.set_attribute( + "gcp.vertex.agent.llm_response", + "{}", + ) def trace_call_llm( @@ -235,80 +235,82 @@ def trace_call_llm( llm_request: LlmRequest, llm_response: LlmResponse, ): - """Traces a call to the LLM. - - This function records details about the LLM request and response as - attributes on the current OpenTelemetry span. - - Args: - invocation_context: The invocation context for the current agent run. - event_id: The ID of the event. - llm_request: The LLM request object. - llm_response: The LLM response object. - """ - span = trace.get_current_span() - # Special standard Open Telemetry GenaI attributes that indicate - # that this is a span related to a Generative AI system. - span.set_attribute("gen_ai.system", "gcp.vertex.agent") - span.set_attribute("gen_ai.request.model", llm_request.model) + """Traces a call to the LLM. + + This function records details about the LLM request and response as + attributes on the current OpenTelemetry span. + + Args: + invocation_context: The invocation context for the current agent run. + event_id: The ID of the event. + llm_request: The LLM request object. + llm_response: The LLM response object. + """ + span = trace.get_current_span() + # Special standard Open Telemetry GenaI attributes that indicate + # that this is a span related to a Generative AI system. + span.set_attribute("gen_ai.system", "gcp.vertex.agent") + span.set_attribute("gen_ai.request.model", llm_request.model) + span.set_attribute( + "gcp.vertex.agent.invocation_id", invocation_context.invocation_id + ) + span.set_attribute( + "gcp.vertex.agent.session_id", invocation_context.session.id + ) + span.set_attribute("gcp.vertex.agent.event_id", event_id) + # Consider removing once GenAI SDK provides a way to record this info. + if _should_add_request_response_to_spans(): span.set_attribute( - "gcp.vertex.agent.invocation_id", invocation_context.invocation_id + "gcp.vertex.agent.llm_request", + _safe_json_serialize(_build_llm_request_for_trace(llm_request)), ) - span.set_attribute("gcp.vertex.agent.session_id", invocation_context.session.id) - span.set_attribute("gcp.vertex.agent.event_id", event_id) - # Consider removing once GenAI SDK provides a way to record this info. - if _should_add_request_response_to_spans(): - span.set_attribute( - "gcp.vertex.agent.llm_request", - _safe_json_serialize(_build_llm_request_for_trace(llm_request)), - ) - else: - span.set_attribute("gcp.vertex.agent.llm_request", {}) - # Consider removing once GenAI SDK provides a way to record this info. - if llm_request.config: - if llm_request.config.top_p: - span.set_attribute( - "gen_ai.request.top_p", - llm_request.config.top_p, - ) - if llm_request.config.max_output_tokens: - span.set_attribute( - "gen_ai.request.max_tokens", - llm_request.config.max_output_tokens, - ) + else: + span.set_attribute("gcp.vertex.agent.llm_request", {}) + # Consider removing once GenAI SDK provides a way to record this info. + if llm_request.config: + if llm_request.config.top_p: + span.set_attribute( + "gen_ai.request.top_p", + llm_request.config.top_p, + ) + if llm_request.config.max_output_tokens: + span.set_attribute( + "gen_ai.request.max_tokens", + llm_request.config.max_output_tokens, + ) + + try: + llm_response_json = llm_response.model_dump_json(exclude_none=True) + except Exception: # pylint: disable=broad-exception-caught + llm_response_json = "" + + if _should_add_request_response_to_spans(): + span.set_attribute( + "gcp.vertex.agent.llm_response", + llm_response_json, + ) + else: + span.set_attribute("gcp.vertex.agent.llm_response", {}) + if llm_response.usage_metadata is not None: + span.set_attribute( + "gen_ai.usage.input_tokens", + llm_response.usage_metadata.prompt_token_count, + ) + if llm_response.usage_metadata.candidates_token_count is not None: + span.set_attribute( + "gen_ai.usage.output_tokens", + llm_response.usage_metadata.candidates_token_count, + ) + if llm_response.finish_reason: try: - llm_response_json = llm_response.model_dump_json(exclude_none=True) - except Exception: # pylint: disable=broad-exception-caught - llm_response_json = "" - - if _should_add_request_response_to_spans(): - span.set_attribute( - "gcp.vertex.agent.llm_response", - llm_response_json, - ) - else: - span.set_attribute("gcp.vertex.agent.llm_response", {}) - - if llm_response.usage_metadata is not None: - span.set_attribute( - "gen_ai.usage.input_tokens", - llm_response.usage_metadata.prompt_token_count, - ) - if llm_response.usage_metadata.candidates_token_count is not None: - span.set_attribute( - "gen_ai.usage.output_tokens", - llm_response.usage_metadata.candidates_token_count, - ) - if llm_response.finish_reason: - try: - finish_reason_str = llm_response.finish_reason.value.lower() - except AttributeError: - finish_reason_str = str(llm_response.finish_reason).lower() - span.set_attribute( - "gen_ai.response.finish_reasons", - [finish_reason_str], - ) + finish_reason_str = llm_response.finish_reason.value.lower() + except AttributeError: + finish_reason_str = str(llm_response.finish_reason).lower() + span.set_attribute( + "gen_ai.response.finish_reasons", + [finish_reason_str], + ) def trace_send_data( @@ -316,67 +318,67 @@ def trace_send_data( event_id: str, data: list[types.Content], ): - """Traces the sending of data to the agent. - - This function records details about the data sent to the agent as - attributes on the current OpenTelemetry span. - - Args: - invocation_context: The invocation context for the current agent run. - event_id: The ID of the event. - data: A list of content objects. - """ - span = trace.get_current_span() + """Traces the sending of data to the agent. + + This function records details about the data sent to the agent as + attributes on the current OpenTelemetry span. + + Args: + invocation_context: The invocation context for the current agent run. + event_id: The ID of the event. + data: A list of content objects. + """ + span = trace.get_current_span() + span.set_attribute( + "gcp.vertex.agent.invocation_id", invocation_context.invocation_id + ) + span.set_attribute("gcp.vertex.agent.event_id", event_id) + # Once instrumentation is added to the GenAI SDK, consider whether this + # information still needs to be recorded by the Agent Development Kit. + if _should_add_request_response_to_spans(): span.set_attribute( - "gcp.vertex.agent.invocation_id", invocation_context.invocation_id + "gcp.vertex.agent.data", + _safe_json_serialize([ + types.Content(role=content.role, parts=content.parts).model_dump( + exclude_none=True + ) + for content in data + ]), ) - span.set_attribute("gcp.vertex.agent.event_id", event_id) - # Once instrumentation is added to the GenAI SDK, consider whether this - # information still needs to be recorded by the Agent Development Kit. - if _should_add_request_response_to_spans(): - span.set_attribute( - "gcp.vertex.agent.data", - _safe_json_serialize( - [ - types.Content(role=content.role, parts=content.parts).model_dump( - exclude_none=True - ) - for content in data - ] - ), - ) - else: - span.set_attribute("gcp.vertex.agent.data", {}) + else: + span.set_attribute("gcp.vertex.agent.data", {}) def _build_llm_request_for_trace(llm_request: LlmRequest) -> dict[str, Any]: - """Builds a dictionary representation of the LLM request for tracing. - - This function prepares a dictionary representation of the LlmRequest - object, suitable for inclusion in a trace. It excludes fields that cannot - be serialized (e.g., function pointers) and avoids sending bytes data. - - Args: - llm_request: The LlmRequest object. - - Returns: - A dictionary representation of the LLM request. - """ - # Some fields in LlmRequest are function pointers and cannot be serialized. - result = { - "model": llm_request.model, - "config": llm_request.config.model_dump( - exclude_none=True, exclude="response_schema" - ), - "contents": [], - } - # We do not want to send bytes data to the trace. - for content in llm_request.contents: - parts = [part for part in content.parts if not part.inline_data] - result["contents"].append( - types.Content(role=content.role, parts=parts).model_dump(exclude_none=True) + """Builds a dictionary representation of the LLM request for tracing. + + This function prepares a dictionary representation of the LlmRequest + object, suitable for inclusion in a trace. It excludes fields that cannot + be serialized (e.g., function pointers) and avoids sending bytes data. + + Args: + llm_request: The LlmRequest object. + + Returns: + A dictionary representation of the LLM request. + """ + # Some fields in LlmRequest are function pointers and cannot be serialized. + result = { + "model": llm_request.model, + "config": llm_request.config.model_dump( + exclude_none=True, exclude="response_schema" + ), + "contents": [], + } + # We do not want to send bytes data to the trace. + for content in llm_request.contents: + parts = [part for part in content.parts if not part.inline_data] + result["contents"].append( + types.Content(role=content.role, parts=parts).model_dump( + exclude_none=True ) - return result + ) + return result # Defaults to true for now to preserve backward compatibility. @@ -384,7 +386,7 @@ def _build_llm_request_for_trace(llm_request: LlmRequest) -> dict[str, Any]: # a deprecation of request/response content in spans by switching the default # to false. def _should_add_request_response_to_spans() -> bool: - disabled_via_env_var = os.getenv( - ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS, "true" - ).lower() in ("false", "0") - return not disabled_via_env_var + disabled_via_env_var = os.getenv( + ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS, "true" + ).lower() in ("false", "0") + return not disabled_via_env_var diff --git a/src/google/adk/utils/telemetry_utils.py b/src/google/adk/utils/telemetry_utils.py index 0388fc8cfd..5d303c323a 100644 --- a/src/google/adk/utils/telemetry_utils.py +++ b/src/google/adk/utils/telemetry_utils.py @@ -18,46 +18,47 @@ Please do not rely on the implementation details. """ -from .env_utils import is_env_enabled from typing import TYPE_CHECKING +from .env_utils import is_env_enabled + if TYPE_CHECKING: - from ..agents.base_agent import BaseAgent + from ..agents.base_agent import BaseAgent def is_telemetry_enabled(agent: "BaseAgent") -> bool: - """Check if telemetry is enabled for the given agent. - - By default telemetry is enabled for an agent unless any of the variables to disable telemetry are set to true. - - Args: - agent: The agent to check if telemetry is enabled for. - - Returns: - False if any of the environment variables or attributes to disable telemetryare set to True, 'true' or 1, False otherwise. - - Examples: - >>> os.environ['OTEL_SDK_DISABLED'] = 'true' - >>> is_telemetry_disabled(my_agent) - False - - >>> os.environ['ADK_TELEMETRY_DISABLED'] = 1 - >>> is_telemetry_disabled(my_agent) - False - - >>> my_agent.disable_telemetry = True - >>> is_telemetry_disabled(my_agent) - False - - >>> os.environ['OTEL_SDK_DISABLED'] = 0 - >>> os.environ['ADK_TELEMETRY_DISABLED'] = 'false' - >>> my_agent.disable_telemetry = False - >>> is_telemetry_disabled(my_agent) - True - """ - telemetry_disabled = ( - is_env_enabled("OTEL_SDK_DISABLED") - or is_env_enabled("ADK_TELEMETRY_DISABLED") - or getattr(agent, "disable_telemetry", False) - ) - return not telemetry_disabled + """Check if telemetry is enabled for the given agent. + + By default telemetry is enabled for an agent unless any of the variables to disable telemetry are set to true. + + Args: + agent: The agent to check if telemetry is enabled for. + + Returns: + False if any of the environment variables or attributes to disable telemetryare set to True, 'true' or 1, False otherwise. + + Examples: + >>> os.environ['OTEL_SDK_DISABLED'] = 'true' + >>> is_telemetry_disabled(my_agent) + False + + >>> os.environ['ADK_TELEMETRY_DISABLED'] = 1 + >>> is_telemetry_disabled(my_agent) + False + + >>> my_agent.disable_telemetry = True + >>> is_telemetry_disabled(my_agent) + False + + >>> os.environ['OTEL_SDK_DISABLED'] = 0 + >>> os.environ['ADK_TELEMETRY_DISABLED'] = 'false' + >>> my_agent.disable_telemetry = False + >>> is_telemetry_disabled(my_agent) + True + """ + telemetry_disabled = ( + is_env_enabled("OTEL_SDK_DISABLED") + or is_env_enabled("ADK_TELEMETRY_DISABLED") + or getattr(agent, "disable_telemetry", False) + ) + return not telemetry_disabled diff --git a/tests/integration/telemetry/test_telemetry_disable.py b/tests/integration/telemetry/test_telemetry_disable.py index 5f953362d0..d4ed5704dc 100644 --- a/tests/integration/telemetry/test_telemetry_disable.py +++ b/tests/integration/telemetry/test_telemetry_disable.py @@ -12,17 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from opentelemetry.sdk.trace import TracerProvider -from opentelemetry.sdk.trace.export import SimpleSpanProcessor -from opentelemetry.sdk.trace.export.in_memory_span_exporter import ( - InMemorySpanExporter, -) -import pytest - from google.adk.agents.llm_agent import Agent from google.adk.telemetry import tracing from google.adk.utils.context_utils import Aclosing from google.genai.types import Part +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter +import pytest from tests.unittests.testing_utils import MockModel from tests.unittests.testing_utils import TestInMemoryRunner @@ -30,94 +27,96 @@ @pytest.fixture def span_exporter(monkeypatch: pytest.MonkeyPatch) -> InMemorySpanExporter: - tracer_provider = TracerProvider() - exporter = InMemorySpanExporter() - tracer_provider.add_span_processor(SimpleSpanProcessor(exporter)) - real_tracer = tracer_provider.get_tracer(__name__) + tracer_provider = TracerProvider() + exporter = InMemorySpanExporter() + tracer_provider.add_span_processor(SimpleSpanProcessor(exporter)) + real_tracer = tracer_provider.get_tracer(__name__) - monkeypatch.setattr( - tracing.tracer, - "start_as_current_span", - real_tracer.start_as_current_span, - ) - return exporter + monkeypatch.setattr( + tracing.tracer, + "start_as_current_span", + real_tracer.start_as_current_span, + ) + return exporter @pytest.mark.asyncio async def test_telemetry_enabled_records_spans(monkeypatch, span_exporter): - monkeypatch.delenv("OTEL_SDK_DISABLED", raising=False) - monkeypatch.delenv("ADK_TELEMETRY_DISABLED", raising=False) + monkeypatch.delenv("OTEL_SDK_DISABLED", raising=False) + monkeypatch.delenv("ADK_TELEMETRY_DISABLED", raising=False) - agent = Agent( - name="test_agent", - model=MockModel.create(responses=[Part.from_text(text="ok")]), - disable_telemetry=False, - ) - runner = TestInMemoryRunner(agent) + agent = Agent( + name="test_agent", + model=MockModel.create(responses=[Part.from_text(text="ok")]), + disable_telemetry=False, + ) + runner = TestInMemoryRunner(agent) - async with Aclosing(runner.run_async_with_new_session_agen("")) as agen: - async for _ in agen: - pass + async with Aclosing(runner.run_async_with_new_session_agen("")) as agen: + async for _ in agen: + pass - spans = span_exporter.get_finished_spans() - assert spans + spans = span_exporter.get_finished_spans() + assert spans @pytest.mark.asyncio -async def test_adk_telemetry_disabled_env_var_disables(monkeypatch, span_exporter): - monkeypatch.setenv("ADK_TELEMETRY_DISABLED", "true") - monkeypatch.delenv("OTEL_SDK_DISABLED", raising=False) +async def test_adk_telemetry_disabled_env_var_disables( + monkeypatch, span_exporter +): + monkeypatch.setenv("ADK_TELEMETRY_DISABLED", "true") + monkeypatch.delenv("OTEL_SDK_DISABLED", raising=False) - agent = Agent( - name="test_agent", - model=MockModel.create(responses=[Part.from_text(text="ok")]), - disable_telemetry=False, - ) - runner = TestInMemoryRunner(agent) + agent = Agent( + name="test_agent", + model=MockModel.create(responses=[Part.from_text(text="ok")]), + disable_telemetry=False, + ) + runner = TestInMemoryRunner(agent) - async with Aclosing(runner.run_async_with_new_session_agen("")) as agen: - async for _ in agen: - pass + async with Aclosing(runner.run_async_with_new_session_agen("")) as agen: + async for _ in agen: + pass - spans = span_exporter.get_finished_spans() - assert not spans + spans = span_exporter.get_finished_spans() + assert not spans @pytest.mark.asyncio async def test_otel_sdk_env_var_disables_telemetry(monkeypatch, span_exporter): - monkeypatch.setenv("OTEL_SDK_DISABLED", "true") - monkeypatch.delenv("ADK_TELEMETRY_DISABLED", raising=False) + monkeypatch.setenv("OTEL_SDK_DISABLED", "true") + monkeypatch.delenv("ADK_TELEMETRY_DISABLED", raising=False) - agent = Agent( - name="test_agent", - model=MockModel.create(responses=[Part.from_text(text="ok")]), - disable_telemetry=False, - ) - runner = TestInMemoryRunner(agent) + agent = Agent( + name="test_agent", + model=MockModel.create(responses=[Part.from_text(text="ok")]), + disable_telemetry=False, + ) + runner = TestInMemoryRunner(agent) - async with Aclosing(runner.run_async_with_new_session_agen("")) as agen: - async for _ in agen: - pass + async with Aclosing(runner.run_async_with_new_session_agen("")) as agen: + async for _ in agen: + pass - spans = span_exporter.get_finished_spans() - assert not spans + spans = span_exporter.get_finished_spans() + assert not spans @pytest.mark.asyncio async def test_agent_flag_disables_telemetry(monkeypatch, span_exporter): - monkeypatch.delenv("OTEL_SDK_DISABLED", raising=False) - monkeypatch.delenv("ADK_TELEMETRY_DISABLED", raising=False) - - agent = Agent( - name="test_agent", - model=MockModel.create(responses=[Part.from_text(text="ok")]), - disable_telemetry=True, - ) - runner = TestInMemoryRunner(agent) - - async with Aclosing(runner.run_async_with_new_session_agen("")) as agen: - async for _ in agen: - pass - - spans = span_exporter.get_finished_spans() - assert not spans + monkeypatch.delenv("OTEL_SDK_DISABLED", raising=False) + monkeypatch.delenv("ADK_TELEMETRY_DISABLED", raising=False) + + agent = Agent( + name="test_agent", + model=MockModel.create(responses=[Part.from_text(text="ok")]), + disable_telemetry=True, + ) + runner = TestInMemoryRunner(agent) + + async with Aclosing(runner.run_async_with_new_session_agen("")) as agen: + async for _ in agen: + pass + + spans = span_exporter.get_finished_spans() + assert not spans diff --git a/tests/unittests/agents/test_gemini_context_cache_manager.py b/tests/unittests/agents/test_gemini_context_cache_manager.py index 0443843ae1..0575d4eaff 100644 --- a/tests/unittests/agents/test_gemini_context_cache_manager.py +++ b/tests/unittests/agents/test_gemini_context_cache_manager.py @@ -479,7 +479,9 @@ async def test_create_new_cache_with_proper_ttl(self): with patch.object( self.manager, "_generate_cache_fingerprint", return_value="test_fp" ): - await self.manager._create_gemini_cache(llm_request, cache_contents_count) + await self.manager._create_gemini_cache_with_optional_tracing( + llm_request, cache_contents_count + ) # Verify cache creation call includes TTL create_call = self.manager.genai_client.aio.caches.create.call_args diff --git a/tests/unittests/conftest.py b/tests/unittests/conftest.py index 59b66bd622..ab2ad70041 100644 --- a/tests/unittests/conftest.py +++ b/tests/unittests/conftest.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import os +from unittest import mock from pytest import fixture from pytest import FixtureRequest @@ -20,19 +22,19 @@ from pytest import Metafunc _ENV_VARS = { - 'GOOGLE_API_KEY': 'fake_google_api_key', - 'GOOGLE_CLOUD_PROJECT': 'fake_google_cloud_project', - 'GOOGLE_CLOUD_LOCATION': 'fake_google_cloud_location', - 'ADK_ALLOW_WIP_FEATURES': 'true', + "GOOGLE_API_KEY": "fake_google_api_key", + "GOOGLE_CLOUD_PROJECT": "fake_google_cloud_project", + "GOOGLE_CLOUD_LOCATION": "fake_google_cloud_location", + "ADK_ALLOW_WIP_FEATURES": "true", } ENV_SETUPS = { - 'GOOGLE_AI': { - 'GOOGLE_GENAI_USE_VERTEXAI': '0', + "GOOGLE_AI": { + "GOOGLE_GENAI_USE_VERTEXAI": "0", **_ENV_VARS, }, - 'VERTEX': { - 'GOOGLE_GENAI_USE_VERTEXAI': '1', + "VERTEX": { + "GOOGLE_GENAI_USE_VERTEXAI": "1", **_ENV_VARS, }, } @@ -96,8 +98,19 @@ def pytest_generate_tests(metafunc: Metafunc): def _is_explicitly_marked(mark_name: str, metafunc: Metafunc) -> bool: - if hasattr(metafunc.function, 'pytestmark'): + if hasattr(metafunc.function, "pytestmark"): for mark in metafunc.function.pytestmark: - if mark.name == 'parametrize' and mark.args[0] == mark_name: + if mark.name == "parametrize" and mark.args[0] == mark_name: return True return False + + +@fixture +def context_manager_with_span(span=None): + span = span or mock.MagicMock() + + @contextlib.contextmanager + def _cm(): + yield span + + return _cm() diff --git a/tests/unittests/telemetry/test_telemetry_disable_agent.py b/tests/unittests/telemetry/test_telemetry_disable_agent.py index 4a91778511..5e8e87b89f 100644 --- a/tests/unittests/telemetry/test_telemetry_disable_agent.py +++ b/tests/unittests/telemetry/test_telemetry_disable_agent.py @@ -12,70 +12,70 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest from unittest import mock from google.adk.agents.llm_agent import Agent from google.adk.telemetry import tracing from google.adk.utils.context_utils import Aclosing from google.genai.types import Part +import pytest - -from ..testing_utils import MockModel, TestInMemoryRunner +from ..testing_utils import MockModel +from ..testing_utils import TestInMemoryRunner @pytest.mark.asyncio async def test_disable_telemetry_prevents_span_creation(monkeypatch): - monkeypatch.delenv("OTEL_SDK_DISABLED", raising=False) - monkeypatch.delenv("ADK_TELEMETRY_DISABLED", raising=False) - span = mock.MagicMock() - context_manager = mock.MagicMock() - context_manager.__enter__.return_value = span - context_manager.__exit__.return_value = False + monkeypatch.delenv("OTEL_SDK_DISABLED", raising=False) + monkeypatch.delenv("ADK_TELEMETRY_DISABLED", raising=False) + span = mock.MagicMock() + context_manager = mock.MagicMock() + context_manager.__enter__.return_value = span + context_manager.__exit__.return_value = False - mock_start = mock.Mock(return_value=context_manager) - monkeypatch.setattr(tracing.tracer, "start_as_current_span", mock_start) + mock_start = mock.Mock(return_value=context_manager) + monkeypatch.setattr(tracing.tracer, "start_as_current_span", mock_start) - agent = Agent( - name="agent", - model=MockModel.create(responses=[Part.from_text(text="ok")]), - disable_telemetry=True, - ) + agent = Agent( + name="agent", + model=MockModel.create(responses=[Part.from_text(text="ok")]), + disable_telemetry=True, + ) - runner = TestInMemoryRunner(agent) + runner = TestInMemoryRunner(agent) - async with Aclosing(runner.run_async_with_new_session_agen("")) as agen: - async for _ in agen: - pass + async with Aclosing(runner.run_async_with_new_session_agen("")) as agen: + async for _ in agen: + pass - assert mock_start.call_count == 0 + assert mock_start.call_count == 0 @pytest.mark.asyncio async def test_enabled_telemetry_causes_span_creation(monkeypatch): - monkeypatch.setenv("OTEL_SDK_DISABLED", "false") - monkeypatch.setenv("ADK_TELEMETRY_DISABLED", "false") - span = mock.MagicMock() - context_manager = mock.MagicMock() - context_manager.__enter__.return_value = span - context_manager.__exit__.return_value = False + monkeypatch.setenv("OTEL_SDK_DISABLED", "false") + monkeypatch.setenv("ADK_TELEMETRY_DISABLED", "false") + span = mock.MagicMock() + context_manager = mock.MagicMock() + context_manager.__enter__.return_value = span + context_manager.__exit__.return_value = False - mock_start = mock.Mock(return_value=context_manager) - monkeypatch.setattr(tracing.tracer, "start_as_current_span", mock_start) + mock_start = mock.Mock(return_value=context_manager) + monkeypatch.setattr(tracing.tracer, "start_as_current_span", mock_start) - agent = Agent( - name="agent", - model=MockModel.create(responses=[Part.from_text(text="ok")]), - disable_telemetry=False, - ) + agent = Agent( + name="agent", + model=MockModel.create(responses=[Part.from_text(text="ok")]), + disable_telemetry=False, + ) - runner = TestInMemoryRunner(agent) + runner = TestInMemoryRunner(agent) - async with Aclosing(runner.run_async_with_new_session_agen("")) as agen: - async for _ in agen: - pass + async with Aclosing(runner.run_async_with_new_session_agen("")) as agen: + async for _ in agen: + pass - assert mock_start.call_count > 0 + assert mock_start.call_count > 0 @pytest.mark.asyncio @@ -89,29 +89,31 @@ async def test_enabled_telemetry_causes_span_creation(monkeypatch): ], ) async def test_env_flag_disables_telemetry(monkeypatch, env_var, env_value): - monkeypatch.setenv(env_var, env_value) - monkeypatch.delenv( - "ADK_TELEMETRY_DISABLED" if env_var == "OTEL_SDK_DISABLED" else "OTEL_SDK_DISABLED", - raising=False, - ) - span = mock.MagicMock() - context_manager = mock.MagicMock() - context_manager.__enter__.return_value = span - context_manager.__exit__.return_value = False - - mock_start = mock.Mock(return_value=context_manager) - monkeypatch.setattr(tracing.tracer, "start_as_current_span", mock_start) - - agent = Agent( - name="agent", - model=MockModel.create(responses=[Part.from_text(text="ok")]), - disable_telemetry=False, - ) - - runner = TestInMemoryRunner(agent) - - async with Aclosing(runner.run_async_with_new_session_agen("")) as agen: - async for _ in agen: - pass - - assert mock_start.call_count == 0 + monkeypatch.setenv(env_var, env_value) + monkeypatch.delenv( + "ADK_TELEMETRY_DISABLED" + if env_var == "OTEL_SDK_DISABLED" + else "OTEL_SDK_DISABLED", + raising=False, + ) + span = mock.MagicMock() + context_manager = mock.MagicMock() + context_manager.__enter__.return_value = span + context_manager.__exit__.return_value = False + + mock_start = mock.Mock(return_value=context_manager) + monkeypatch.setattr(tracing.tracer, "start_as_current_span", mock_start) + + agent = Agent( + name="agent", + model=MockModel.create(responses=[Part.from_text(text="ok")]), + disable_telemetry=False, + ) + + runner = TestInMemoryRunner(agent) + + async with Aclosing(runner.run_async_with_new_session_agen("")) as agen: + async for _ in agen: + pass + + assert mock_start.call_count == 0 diff --git a/tests/unittests/telemetry/test_telemetry_disable_google_llm.py b/tests/unittests/telemetry/test_telemetry_disable_google_llm.py index df4cb16752..f63cbdaaf4 100644 --- a/tests/unittests/telemetry/test_telemetry_disable_google_llm.py +++ b/tests/unittests/telemetry/test_telemetry_disable_google_llm.py @@ -13,102 +13,110 @@ # limitations under the License. from unittest import mock -import pytest -from google.adk.models.google_llm import Gemini -from google.adk.models import llm_response as llm_response_mod from google.adk.models import gemini_context_cache_manager as cache_mod +from google.adk.models import llm_response as llm_response_mod +from google.adk.models.google_llm import Gemini +import pytest @pytest.mark.asyncio -async def test_disable_google_llm_telemetry(monkeypatch, context_manager_with_span): - monkeypatch.setenv("OTEL_SDK_DISABLED", "false") - monkeypatch.setenv("ADK_TELEMETRY_DISABLED", "false") - start_span = mock.Mock(return_value=context_manager_with_span) - monkeypatch.setattr( - "google.adk.telemetry.tracing.tracer.start_as_current_span", - start_span, - ) - - gemini = Gemini(disable_telemetry=True) - - # Avoid real Client construction - fake_client = mock.MagicMock() - fake_client.vertexai = False - fake_client.aio.models.generate_content = mock.AsyncMock( - return_value=mock.MagicMock() - ) - gemini.__dict__["api_client"] = fake_client - - # Prevent cache validation code running (the bit that touches expire_time) - monkeypatch.setattr( - cache_mod.GeminiContextCacheManager, - "handle_context_caching", - mock.AsyncMock(return_value=None), - ) - - req = mock.MagicMock() - req.cache_config = object() # force the cache path - req.model = "gemini-2.5-flash" - req.contents = [] - req.config = mock.MagicMock() - req.config.tools = None - req.config.system_instruction = "" - req.config.model_dump = mock.Mock(return_value={}) - req.config.http_options = None - - monkeypatch.setattr( - llm_response_mod.LlmResponse, "create", mock.Mock(return_value=mock.MagicMock()) - ) - - async for _ in gemini.generate_content_async(req, stream=False): - break - - assert start_span.call_count == 0 +async def test_disable_google_llm_telemetry( + monkeypatch, context_manager_with_span +): + monkeypatch.setenv("OTEL_SDK_DISABLED", "false") + monkeypatch.setenv("ADK_TELEMETRY_DISABLED", "false") + start_span = mock.Mock(return_value=context_manager_with_span) + monkeypatch.setattr( + "google.adk.telemetry.tracing.tracer.start_as_current_span", + start_span, + ) + + gemini = Gemini(disable_telemetry=True) + + # Avoid real Client construction + fake_client = mock.MagicMock() + fake_client.vertexai = False + fake_client.aio.models.generate_content = mock.AsyncMock( + return_value=mock.MagicMock() + ) + gemini.__dict__["api_client"] = fake_client + + # Prevent cache validation code running (the bit that touches expire_time) + monkeypatch.setattr( + cache_mod.GeminiContextCacheManager, + "handle_context_caching", + mock.AsyncMock(return_value=None), + ) + + req = mock.MagicMock() + req.cache_config = object() # force the cache path + req.model = "gemini-2.5-flash" + req.contents = [] + req.config = mock.MagicMock() + req.config.tools = None + req.config.system_instruction = "" + req.config.model_dump = mock.Mock(return_value={}) + req.config.http_options = None + + monkeypatch.setattr( + llm_response_mod.LlmResponse, + "create", + mock.Mock(return_value=mock.MagicMock()), + ) + + async for _ in gemini.generate_content_async(req, stream=False): + break + + assert start_span.call_count == 0 @pytest.mark.asyncio -async def test_enable_google_llm_telemetry(monkeypatch, context_manager_with_span): - monkeypatch.setenv("OTEL_SDK_DISABLED", "false") - monkeypatch.setenv("ADK_TELEMETRY_DISABLED", "false") - start_span = mock.Mock(return_value=context_manager_with_span) - monkeypatch.setattr( - "google.adk.telemetry.tracing.tracer.start_as_current_span", - start_span, - ) - - gemini = Gemini(disable_telemetry=False) - - # Avoid real Client construction - fake_client = mock.MagicMock() - fake_client.vertexai = False - fake_client.aio.models.generate_content = mock.AsyncMock( - return_value=mock.MagicMock() - ) - gemini.__dict__["api_client"] = fake_client - - # Prevent cache validation code running (the bit that touches expire_time) - monkeypatch.setattr( - cache_mod.GeminiContextCacheManager, - "handle_context_caching", - mock.AsyncMock(return_value=None), - ) - - req = mock.MagicMock() - req.cache_config = object() # force the cache path - req.model = "gemini-2.5-flash" - req.contents = [] - req.config = mock.MagicMock() - req.config.tools = None - req.config.system_instruction = "" - req.config.model_dump = mock.Mock(return_value={}) - req.config.http_options = None - - monkeypatch.setattr( - llm_response_mod.LlmResponse, "create", mock.Mock(return_value=mock.MagicMock()) - ) - - async for _ in gemini.generate_content_async(req, stream=False): - break - - assert start_span.call_count > 0 +async def test_enable_google_llm_telemetry( + monkeypatch, context_manager_with_span +): + monkeypatch.setenv("OTEL_SDK_DISABLED", "false") + monkeypatch.setenv("ADK_TELEMETRY_DISABLED", "false") + start_span = mock.Mock(return_value=context_manager_with_span) + monkeypatch.setattr( + "google.adk.telemetry.tracing.tracer.start_as_current_span", + start_span, + ) + + gemini = Gemini(disable_telemetry=False) + + # Avoid real Client construction + fake_client = mock.MagicMock() + fake_client.vertexai = False + fake_client.aio.models.generate_content = mock.AsyncMock( + return_value=mock.MagicMock() + ) + gemini.__dict__["api_client"] = fake_client + + # Prevent cache validation code running (the bit that touches expire_time) + monkeypatch.setattr( + cache_mod.GeminiContextCacheManager, + "handle_context_caching", + mock.AsyncMock(return_value=None), + ) + + req = mock.MagicMock() + req.cache_config = object() # force the cache path + req.model = "gemini-2.5-flash" + req.contents = [] + req.config = mock.MagicMock() + req.config.tools = None + req.config.system_instruction = "" + req.config.model_dump = mock.Mock(return_value={}) + req.config.http_options = None + + monkeypatch.setattr( + llm_response_mod.LlmResponse, + "create", + mock.Mock(return_value=mock.MagicMock()), + ) + + async for _ in gemini.generate_content_async(req, stream=False): + break + + assert start_span.call_count > 0 From ba0ce7a5796d5741e7a879366a23bd5f0ab445b7 Mon Sep 17 00:00:00 2001 From: mportdata Date: Sat, 27 Dec 2025 23:30:54 +0000 Subject: [PATCH 05/24] fixed examples in docstring for is_telemetry_enabled which previously incorrectly said disabled --- src/google/adk/utils/telemetry_utils.py | 58 ++++++++++++------------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/src/google/adk/utils/telemetry_utils.py b/src/google/adk/utils/telemetry_utils.py index 5d303c323a..99af4e0d73 100644 --- a/src/google/adk/utils/telemetry_utils.py +++ b/src/google/adk/utils/telemetry_utils.py @@ -23,42 +23,42 @@ from .env_utils import is_env_enabled if TYPE_CHECKING: - from ..agents.base_agent import BaseAgent + from ..agents.base_agent import BaseAgent def is_telemetry_enabled(agent: "BaseAgent") -> bool: - """Check if telemetry is enabled for the given agent. + """Check if telemetry is enabled for the given agent. - By default telemetry is enabled for an agent unless any of the variables to disable telemetry are set to true. + By default telemetry is enabled for an agent unless any of the variables to disable telemetry are set to true. - Args: - agent: The agent to check if telemetry is enabled for. + Args: + agent: The agent to check if telemetry is enabled for. - Returns: - False if any of the environment variables or attributes to disable telemetryare set to True, 'true' or 1, False otherwise. + Returns: + False if any of the environment variables or attributes to disable telemetryare set to True, 'true' or 1, False otherwise. - Examples: - >>> os.environ['OTEL_SDK_DISABLED'] = 'true' - >>> is_telemetry_disabled(my_agent) - False + Examples: + >>> os.environ['OTEL_SDK_DISABLED'] = 'true' + >>> is_telemetry_enabled(my_agent) + True - >>> os.environ['ADK_TELEMETRY_DISABLED'] = 1 - >>> is_telemetry_disabled(my_agent) - False + >>> os.environ['ADK_TELEMETRY_DISABLED'] = 1 + >>> is_telemetry_enabled(my_agent) + True - >>> my_agent.disable_telemetry = True - >>> is_telemetry_disabled(my_agent) - False + >>> my_agent.disable_telemetry = True + >>> is_telemetry_enabled(my_agent) + True - >>> os.environ['OTEL_SDK_DISABLED'] = 0 - >>> os.environ['ADK_TELEMETRY_DISABLED'] = 'false' - >>> my_agent.disable_telemetry = False - >>> is_telemetry_disabled(my_agent) - True - """ - telemetry_disabled = ( - is_env_enabled("OTEL_SDK_DISABLED") - or is_env_enabled("ADK_TELEMETRY_DISABLED") - or getattr(agent, "disable_telemetry", False) - ) - return not telemetry_disabled + >>> os.environ['OTEL_SDK_DISABLED'] = 1 + >>> os.environ['ADK_TELEMETRY_DISABLED'] = 'false' + >>> my_agent.disable_telemetry = False + >>> is_telemetry_enabled(my_agent) + False + """ + telemetry_disabled = ( + is_env_enabled("OTEL_SDK_DISABLED") + or is_env_enabled("ADK_TELEMETRY_DISABLED") + or getattr(agent, "disable_telemetry", False) + ) + return not telemetry_disabled From 849b281da1c4029a876781fcef894adce716cf68 Mon Sep 17 00:00:00 2001 From: mportdata Date: Sat, 27 Dec 2025 23:33:41 +0000 Subject: [PATCH 06/24] moved docstring to 1 line for conciseness --- src/google/adk/agents/base_agent.py | 1137 +++++++++++++-------------- 1 file changed, 562 insertions(+), 575 deletions(-) diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index 5207897c54..a0a0560249 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -49,7 +49,7 @@ from .callback_context import CallbackContext if TYPE_CHECKING: - from .invocation_context import InvocationContext + from .invocation_context import InvocationContext logger = logging.getLogger("google_adk." + __name__) @@ -73,27 +73,27 @@ @experimental class BaseAgentState(BaseModel): - """Base class for all agent states.""" + """Base class for all agent states.""" - model_config = ConfigDict( - extra="forbid", - ) + model_config = ConfigDict( + extra="forbid", + ) AgentState = TypeVar("AgentState", bound=BaseAgentState) class BaseAgent(BaseModel): - """Base class for all agents in Agent Development Kit.""" + """Base class for all agents in Agent Development Kit.""" - model_config = ConfigDict( - arbitrary_types_allowed=True, - extra="forbid", - ) - """The pydantic model config.""" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) + """The pydantic model config.""" - config_type: ClassVar[type[BaseAgentConfig]] = BaseAgentConfig - """The config type for this agent. + config_type: ClassVar[type[BaseAgentConfig]] = BaseAgentConfig + """The config type for this agent. Sub-classes should override this to specify their own config type. @@ -108,22 +108,22 @@ class MyAgent(BaseAgent): ``` """ - name: str - """The agent's name. + name: str + """The agent's name. Agent name must be a Python identifier and unique within the agent tree. Agent name cannot be "user", since it's reserved for end-user's input. """ - description: str = "" - """Description about the agent's capability. + description: str = "" + """Description about the agent's capability. The model uses this to determine whether to delegate control to the agent. One-line description is enough and preferred. """ - parent_agent: Optional[BaseAgent] = Field(default=None, init=False) - """The parent agent of this agent. + parent_agent: Optional[BaseAgent] = Field(default=None, init=False) + """The parent agent of this agent. Note that an agent can ONLY be added as sub-agent once. @@ -131,13 +131,12 @@ class MyAgent(BaseAgent): instances with identical config, but with different name and add them to the agent tree. """ - sub_agents: list[BaseAgent] = Field(default_factory=list) - """The sub-agents of this agent.""" - disable_telemetry: bool = False - """Whether to disable telemetry for this agent. - """ - before_agent_callback: Optional[BeforeAgentCallback] = None - """Callback or list of callbacks to be invoked before the agent run. + sub_agents: list[BaseAgent] = Field(default_factory=list) + """The sub-agents of this agent.""" + disable_telemetry: bool = False + """Whether to disable telemetry for this agent.""" + before_agent_callback: Optional[BeforeAgentCallback] = None + """Callback or list of callbacks to be invoked before the agent run. When a list of callbacks is provided, the callbacks will be called in the order they are listed until a callback does not return None. @@ -150,8 +149,8 @@ class MyAgent(BaseAgent): When the content is present, the agent run will be skipped and the provided content will be returned to user. """ - after_agent_callback: Optional[AfterAgentCallback] = None - """Callback or list of callbacks to be invoked after the agent run. + after_agent_callback: Optional[AfterAgentCallback] = None + """Callback or list of callbacks to be invoked after the agent run. When a list of callbacks is provided, the callbacks will be called in the order they are listed until a callback does not return None. @@ -165,562 +164,550 @@ class MyAgent(BaseAgent): will be appended to event history as an additional agent response. """ - def _load_agent_state( - self, - ctx: InvocationContext, - state_type: Type[AgentState], - ) -> Optional[AgentState]: - """Loads the agent state from the invocation context. - - Args: - ctx: The invocation context. - state_type: The type of the agent state. - - Returns: - The current state if exists; otherwise, None. - """ - if ctx.agent_states is None or self.name not in ctx.agent_states: - return None - else: - return state_type.model_validate(ctx.agent_states.get(self.name)) - - def _create_agent_state_event( - self, - ctx: InvocationContext, - ) -> Event: - """Returns an event with current agent state set in the invocation context. - - Args: - ctx: The invocation context. - - Returns: - An event with the current agent state set in the invocation context. - """ - event_actions = EventActions() - if (agent_state := ctx.agent_states.get(self.name)) is not None: - event_actions.agent_state = agent_state - if ctx.end_of_agents.get(self.name): - event_actions.end_of_agent = True - return Event( - invocation_id=ctx.invocation_id, - author=self.name, - branch=ctx.branch, - actions=event_actions, - ) - - def clone( - self: SelfAgent, update: Mapping[str, Any] | None = None - ) -> SelfAgent: - """Creates a copy of this agent instance. - - Args: - update: Optional mapping of new values for the fields of the cloned agent. - The keys of the mapping are the names of the fields to be updated, and - the values are the new values for those fields. - For example: {"name": "cloned_agent"} - - Returns: - A new agent instance with identical configuration as the original - agent except for the fields specified in the update. - """ - if update is not None and "parent_agent" in update: - raise ValueError( - "Cannot update `parent_agent` field in clone. Parent agent is set" - " only when the parent agent is instantiated with the sub-agents." - ) - - # Only allow updating fields that are defined in the agent class. - allowed_fields = set(self.__class__.model_fields) - if update is not None: - invalid_fields = set(update) - allowed_fields - if invalid_fields: - raise ValueError( - f"Cannot update nonexistent fields in {self.__class__.__name__}:" - f" {invalid_fields}" + def _load_agent_state( + self, + ctx: InvocationContext, + state_type: Type[AgentState], + ) -> Optional[AgentState]: + """Loads the agent state from the invocation context. + + Args: + ctx: The invocation context. + state_type: The type of the agent state. + + Returns: + The current state if exists; otherwise, None. + """ + if ctx.agent_states is None or self.name not in ctx.agent_states: + return None + else: + return state_type.model_validate(ctx.agent_states.get(self.name)) + + def _create_agent_state_event( + self, + ctx: InvocationContext, + ) -> Event: + """Returns an event with current agent state set in the invocation context. + + Args: + ctx: The invocation context. + + Returns: + An event with the current agent state set in the invocation context. + """ + event_actions = EventActions() + if (agent_state := ctx.agent_states.get(self.name)) is not None: + event_actions.agent_state = agent_state + if ctx.end_of_agents.get(self.name): + event_actions.end_of_agent = True + return Event( + invocation_id=ctx.invocation_id, + author=self.name, + branch=ctx.branch, + actions=event_actions, ) - cloned_agent = self.model_copy(update=update) - - # If any field is stored as list and not provided in the update, need to - # shallow copy it for the cloned agent to avoid sharing the same list object - # with the original agent. - for field_name in cloned_agent.__class__.model_fields: - if field_name == "sub_agents": - continue - if update is not None and field_name in update: - continue - field = getattr(cloned_agent, field_name) - if isinstance(field, list): - setattr(cloned_agent, field_name, field.copy()) - - if update is None or "sub_agents" not in update: - # If `sub_agents` is not provided in the update, need to recursively clone - # the sub-agents to avoid sharing the sub-agents with the original agent. - cloned_agent.sub_agents = [] - for sub_agent in self.sub_agents: - cloned_sub_agent = sub_agent.clone() - cloned_sub_agent.parent_agent = cloned_agent - cloned_agent.sub_agents.append(cloned_sub_agent) - else: - for sub_agent in cloned_agent.sub_agents: - sub_agent.parent_agent = cloned_agent - - # Remove the parent agent from the cloned agent to avoid sharing the parent - # agent with the cloned agent. - cloned_agent.parent_agent = None - return cloned_agent - - @final - async def run_async( - self, - parent_context: InvocationContext, - ) -> AsyncGenerator[Event, None]: - """Entry method to run an agent via text-based conversation. - - Args: - parent_context: InvocationContext, the invocation context of the parent - agent. - - Yields: - Event: the events generated by the agent. - """ - - ctx = self._create_invocation_context(parent_context) - if is_telemetry_enabled(self): - with tracer.start_as_current_span(f"invoke_agent {self.name}") as span: - tracing.trace_agent_invocation(span, self, ctx) - async with Aclosing( - self._run_callbacks_and_impl(ctx, mode="async") - ) as agen: - async for event in agen: - yield event - else: - async with Aclosing( - self._run_callbacks_and_impl(ctx, mode="async") - ) as agen: - async for event in agen: - yield event - - @final - async def run_live( - self, - parent_context: InvocationContext, - ) -> AsyncGenerator[Event, None]: - """Entry method to run an agent via video/audio-based conversation. - - Args: - parent_context: InvocationContext, the invocation context of the parent - agent. - - Yields: - Event: the events generated by the agent. - """ - - ctx = self._create_invocation_context(parent_context) - if is_telemetry_enabled(self): - with tracer.start_as_current_span(f"invoke_agent {self.name}") as span: - tracing.trace_agent_invocation(span, self, ctx) - async for event in self._run_callbacks_and_impl(ctx, mode="live"): - yield event - else: - async for event in self._run_callbacks_and_impl(ctx, mode="live"): - yield event - - async def _run_async_impl( - self, ctx: InvocationContext - ) -> AsyncGenerator[Event, None]: - """Core logic to run this agent via text-based conversation. - - Args: - ctx: InvocationContext, the invocation context for this agent. - - Yields: - Event: the events generated by the agent. - """ - raise NotImplementedError( - f"_run_async_impl for {type(self)} is not implemented." - ) - yield # AsyncGenerator requires having at least one yield statement + def clone(self: SelfAgent, update: Mapping[str, Any] | None = None) -> SelfAgent: + """Creates a copy of this agent instance. + + Args: + update: Optional mapping of new values for the fields of the cloned agent. + The keys of the mapping are the names of the fields to be updated, and + the values are the new values for those fields. + For example: {"name": "cloned_agent"} + + Returns: + A new agent instance with identical configuration as the original + agent except for the fields specified in the update. + """ + if update is not None and "parent_agent" in update: + raise ValueError( + "Cannot update `parent_agent` field in clone. Parent agent is set" + " only when the parent agent is instantiated with the sub-agents." + ) + + # Only allow updating fields that are defined in the agent class. + allowed_fields = set(self.__class__.model_fields) + if update is not None: + invalid_fields = set(update) - allowed_fields + if invalid_fields: + raise ValueError( + f"Cannot update nonexistent fields in {self.__class__.__name__}:" + f" {invalid_fields}" + ) + + cloned_agent = self.model_copy(update=update) + + # If any field is stored as list and not provided in the update, need to + # shallow copy it for the cloned agent to avoid sharing the same list object + # with the original agent. + for field_name in cloned_agent.__class__.model_fields: + if field_name == "sub_agents": + continue + if update is not None and field_name in update: + continue + field = getattr(cloned_agent, field_name) + if isinstance(field, list): + setattr(cloned_agent, field_name, field.copy()) + + if update is None or "sub_agents" not in update: + # If `sub_agents` is not provided in the update, need to recursively clone + # the sub-agents to avoid sharing the sub-agents with the original agent. + cloned_agent.sub_agents = [] + for sub_agent in self.sub_agents: + cloned_sub_agent = sub_agent.clone() + cloned_sub_agent.parent_agent = cloned_agent + cloned_agent.sub_agents.append(cloned_sub_agent) + else: + for sub_agent in cloned_agent.sub_agents: + sub_agent.parent_agent = cloned_agent + + # Remove the parent agent from the cloned agent to avoid sharing the parent + # agent with the cloned agent. + cloned_agent.parent_agent = None + return cloned_agent + + @final + async def run_async( + self, + parent_context: InvocationContext, + ) -> AsyncGenerator[Event, None]: + """Entry method to run an agent via text-based conversation. + + Args: + parent_context: InvocationContext, the invocation context of the parent + agent. + + Yields: + Event: the events generated by the agent. + """ + + ctx = self._create_invocation_context(parent_context) + if is_telemetry_enabled(self): + with tracer.start_as_current_span(f"invoke_agent {self.name}") as span: + tracing.trace_agent_invocation(span, self, ctx) + async with Aclosing( + self._run_callbacks_and_impl(ctx, mode="async") + ) as agen: + async for event in agen: + yield event + else: + async with Aclosing( + self._run_callbacks_and_impl(ctx, mode="async") + ) as agen: + async for event in agen: + yield event + + @final + async def run_live( + self, + parent_context: InvocationContext, + ) -> AsyncGenerator[Event, None]: + """Entry method to run an agent via video/audio-based conversation. + + Args: + parent_context: InvocationContext, the invocation context of the parent + agent. + + Yields: + Event: the events generated by the agent. + """ + + ctx = self._create_invocation_context(parent_context) + if is_telemetry_enabled(self): + with tracer.start_as_current_span(f"invoke_agent {self.name}") as span: + tracing.trace_agent_invocation(span, self, ctx) + async for event in self._run_callbacks_and_impl(ctx, mode="live"): + yield event + else: + async for event in self._run_callbacks_and_impl(ctx, mode="live"): + yield event + + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + """Core logic to run this agent via text-based conversation. + + Args: + ctx: InvocationContext, the invocation context for this agent. + + Yields: + Event: the events generated by the agent. + """ + raise NotImplementedError( + f"_run_async_impl for {type(self)} is not implemented." + ) + yield # AsyncGenerator requires having at least one yield statement - async def _run_live_impl( - self, ctx: InvocationContext - ) -> AsyncGenerator[Event, None]: - """Core logic to run this agent via video/audio-based conversation. + async def _run_live_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + """Core logic to run this agent via video/audio-based conversation. - Args: - ctx: InvocationContext, the invocation context for this agent. + Args: + ctx: InvocationContext, the invocation context for this agent. - Yields: - Event: the events generated by the agent. - """ - raise NotImplementedError( - f"_run_live_impl for {type(self)} is not implemented." - ) - yield # AsyncGenerator requires having at least one yield statement - - async def _run_callbacks_and_impl( - self, ctx: InvocationContext, mode: str = "async" - ) -> AsyncGenerator[Event, None]: - """Runs the before and after agent callbacks around the core agent logic. - Args: - ctx: InvocationContext, the invocation context for this agent. - mode: str, either 'async' or 'live', indicating which core agent logic to run. - Yields: - Event: the events generated by the agent. - """ - if event := await self._handle_before_agent_callback(ctx): - yield event - if ctx.end_invocation: - return - if mode.lower() == "async": - async with Aclosing(self._run_async_impl(ctx)) as agen: - async for event in agen: - yield event - elif mode.lower() == "live": - async with Aclosing(self._run_live_impl(ctx)) as agen: - async for event in agen: - yield event - else: - raise ValueError( - f"Invalid mode: {mode}. Must be either 'async' or 'live'." - ) - - if ctx.end_invocation: - return - - if event := await self._handle_after_agent_callback(ctx): - yield event - - @property - def root_agent(self) -> BaseAgent: - """Gets the root agent of this agent.""" - root_agent = self - while root_agent.parent_agent is not None: - root_agent = root_agent.parent_agent - return root_agent - - def find_agent(self, name: str) -> Optional[BaseAgent]: - """Finds the agent with the given name in this agent and its descendants. - - Args: - name: The name of the agent to find. - - Returns: - The agent with the matching name, or None if no such agent is found. - """ - if self.name == name: - return self - return self.find_sub_agent(name) - - def find_sub_agent(self, name: str) -> Optional[BaseAgent]: - """Finds the agent with the given name in this agent's descendants. - - Args: - name: The name of the agent to find. - - Returns: - The agent with the matching name, or None if no such agent is found. - """ - for sub_agent in self.sub_agents: - if result := sub_agent.find_agent(name): - return result - return None - - def _create_invocation_context( - self, parent_context: InvocationContext - ) -> InvocationContext: - """Creates a new invocation context for this agent.""" - invocation_context = parent_context.model_copy(update={"agent": self}) - return invocation_context - - @property - def canonical_before_agent_callbacks(self) -> list[_SingleAgentCallback]: - """The resolved self.before_agent_callback field as a list of _SingleAgentCallback. - - This method is only for use by Agent Development Kit. - """ - if not self.before_agent_callback: - return [] - if isinstance(self.before_agent_callback, list): - return self.before_agent_callback - return [self.before_agent_callback] - - @property - def canonical_after_agent_callbacks(self) -> list[_SingleAgentCallback]: - """The resolved self.after_agent_callback field as a list of _SingleAgentCallback. - - This method is only for use by Agent Development Kit. - """ - if not self.after_agent_callback: - return [] - if isinstance(self.after_agent_callback, list): - return self.after_agent_callback - return [self.after_agent_callback] - - async def _handle_before_agent_callback( - self, ctx: InvocationContext - ) -> Optional[Event]: - """Runs the before_agent_callback if it exists. - - Args: - ctx: InvocationContext, the invocation context for this agent. - - Returns: - Optional[Event]: an event if callback provides content or changed state. - """ - callback_context = CallbackContext(ctx) - - # Run callbacks from the plugins. - before_agent_callback_content = ( - await ctx.plugin_manager.run_before_agent_callback( - agent=self, callback_context=callback_context + Yields: + Event: the events generated by the agent. + """ + raise NotImplementedError( + f"_run_live_impl for {type(self)} is not implemented." ) - ) + yield # AsyncGenerator requires having at least one yield statement + + async def _run_callbacks_and_impl( + self, ctx: InvocationContext, mode: str = "async" + ) -> AsyncGenerator[Event, None]: + """Runs the before and after agent callbacks around the core agent logic. + Args: + ctx: InvocationContext, the invocation context for this agent. + mode: str, either 'async' or 'live', indicating which core agent logic to run. + Yields: + Event: the events generated by the agent. + """ + if event := await self._handle_before_agent_callback(ctx): + yield event + if ctx.end_invocation: + return + if mode.lower() == "async": + async with Aclosing(self._run_async_impl(ctx)) as agen: + async for event in agen: + yield event + elif mode.lower() == "live": + async with Aclosing(self._run_live_impl(ctx)) as agen: + async for event in agen: + yield event + else: + raise ValueError(f"Invalid mode: {mode}. Must be either 'async' or 'live'.") + + if ctx.end_invocation: + return + + if event := await self._handle_after_agent_callback(ctx): + yield event - # If no overrides are provided from the plugins, further run the canonical - # callbacks. - if ( - not before_agent_callback_content - and self.canonical_before_agent_callbacks - ): - for callback in self.canonical_before_agent_callbacks: - before_agent_callback_content = callback( - callback_context=callback_context + @property + def root_agent(self) -> BaseAgent: + """Gets the root agent of this agent.""" + root_agent = self + while root_agent.parent_agent is not None: + root_agent = root_agent.parent_agent + return root_agent + + def find_agent(self, name: str) -> Optional[BaseAgent]: + """Finds the agent with the given name in this agent and its descendants. + + Args: + name: The name of the agent to find. + + Returns: + The agent with the matching name, or None if no such agent is found. + """ + if self.name == name: + return self + return self.find_sub_agent(name) + + def find_sub_agent(self, name: str) -> Optional[BaseAgent]: + """Finds the agent with the given name in this agent's descendants. + + Args: + name: The name of the agent to find. + + Returns: + The agent with the matching name, or None if no such agent is found. + """ + for sub_agent in self.sub_agents: + if result := sub_agent.find_agent(name): + return result + return None + + def _create_invocation_context( + self, parent_context: InvocationContext + ) -> InvocationContext: + """Creates a new invocation context for this agent.""" + invocation_context = parent_context.model_copy(update={"agent": self}) + return invocation_context + + @property + def canonical_before_agent_callbacks(self) -> list[_SingleAgentCallback]: + """The resolved self.before_agent_callback field as a list of _SingleAgentCallback. + + This method is only for use by Agent Development Kit. + """ + if not self.before_agent_callback: + return [] + if isinstance(self.before_agent_callback, list): + return self.before_agent_callback + return [self.before_agent_callback] + + @property + def canonical_after_agent_callbacks(self) -> list[_SingleAgentCallback]: + """The resolved self.after_agent_callback field as a list of _SingleAgentCallback. + + This method is only for use by Agent Development Kit. + """ + if not self.after_agent_callback: + return [] + if isinstance(self.after_agent_callback, list): + return self.after_agent_callback + return [self.after_agent_callback] + + async def _handle_before_agent_callback( + self, ctx: InvocationContext + ) -> Optional[Event]: + """Runs the before_agent_callback if it exists. + + Args: + ctx: InvocationContext, the invocation context for this agent. + + Returns: + Optional[Event]: an event if callback provides content or changed state. + """ + callback_context = CallbackContext(ctx) + + # Run callbacks from the plugins. + before_agent_callback_content = ( + await ctx.plugin_manager.run_before_agent_callback( + agent=self, callback_context=callback_context + ) ) - if inspect.isawaitable(before_agent_callback_content): - before_agent_callback_content = await before_agent_callback_content + + # If no overrides are provided from the plugins, further run the canonical + # callbacks. + if not before_agent_callback_content and self.canonical_before_agent_callbacks: + for callback in self.canonical_before_agent_callbacks: + before_agent_callback_content = callback( + callback_context=callback_context + ) + if inspect.isawaitable(before_agent_callback_content): + before_agent_callback_content = await before_agent_callback_content + if before_agent_callback_content: + break + + # Process the override content if exists, and further process the state + # change if exists. if before_agent_callback_content: - break - - # Process the override content if exists, and further process the state - # change if exists. - if before_agent_callback_content: - ret_event = Event( - invocation_id=ctx.invocation_id, - author=self.name, - branch=ctx.branch, - content=before_agent_callback_content, - actions=callback_context._event_actions, - ) - ctx.end_invocation = True - return ret_event - - if callback_context.state.has_delta(): - return Event( - invocation_id=ctx.invocation_id, - author=self.name, - branch=ctx.branch, - actions=callback_context._event_actions, - ) - - return None - - async def _handle_after_agent_callback( - self, invocation_context: InvocationContext - ) -> Optional[Event]: - """Runs the after_agent_callback if it exists. - - Args: - invocation_context: InvocationContext, the invocation context for this - agent. - - Returns: - Optional[Event]: an event if callback provides content or changed state. - """ - - callback_context = CallbackContext(invocation_context) - - # Run callbacks from the plugins. - after_agent_callback_content = ( - await invocation_context.plugin_manager.run_after_agent_callback( - agent=self, callback_context=callback_context + ret_event = Event( + invocation_id=ctx.invocation_id, + author=self.name, + branch=ctx.branch, + content=before_agent_callback_content, + actions=callback_context._event_actions, + ) + ctx.end_invocation = True + return ret_event + + if callback_context.state.has_delta(): + return Event( + invocation_id=ctx.invocation_id, + author=self.name, + branch=ctx.branch, + actions=callback_context._event_actions, + ) + + return None + + async def _handle_after_agent_callback( + self, invocation_context: InvocationContext + ) -> Optional[Event]: + """Runs the after_agent_callback if it exists. + + Args: + invocation_context: InvocationContext, the invocation context for this + agent. + + Returns: + Optional[Event]: an event if callback provides content or changed state. + """ + + callback_context = CallbackContext(invocation_context) + + # Run callbacks from the plugins. + after_agent_callback_content = ( + await invocation_context.plugin_manager.run_after_agent_callback( + agent=self, callback_context=callback_context + ) ) - ) - # If no overrides are provided from the plugins, further run the canonical - # callbacks. - if ( - not after_agent_callback_content - and self.canonical_after_agent_callbacks - ): - for callback in self.canonical_after_agent_callbacks: - after_agent_callback_content = callback( - callback_context=callback_context - ) - if inspect.isawaitable(after_agent_callback_content): - after_agent_callback_content = await after_agent_callback_content + # If no overrides are provided from the plugins, further run the canonical + # callbacks. + if not after_agent_callback_content and self.canonical_after_agent_callbacks: + for callback in self.canonical_after_agent_callbacks: + after_agent_callback_content = callback( + callback_context=callback_context + ) + if inspect.isawaitable(after_agent_callback_content): + after_agent_callback_content = await after_agent_callback_content + if after_agent_callback_content: + break + + # Process the override content if exists, and further process the state + # change if exists. if after_agent_callback_content: - break - - # Process the override content if exists, and further process the state - # change if exists. - if after_agent_callback_content: - ret_event = Event( - invocation_id=invocation_context.invocation_id, - author=self.name, - branch=invocation_context.branch, - content=after_agent_callback_content, - actions=callback_context._event_actions, - ) - return ret_event - - if callback_context.state.has_delta(): - return Event( - invocation_id=invocation_context.invocation_id, - author=self.name, - branch=invocation_context.branch, - content=after_agent_callback_content, - actions=callback_context._event_actions, - ) - return None - - @override - def model_post_init(self, __context: Any) -> None: - self.__set_parent_agent_for_sub_agents() - - @field_validator("name", mode="after") - @classmethod - def validate_name(cls, value: str): - if not value.isidentifier(): - raise ValueError( - f"Found invalid agent name: `{value}`." - " Agent name must be a valid identifier. It should start with a" - " letter (a-z, A-Z) or an underscore (_), and can only contain" - " letters, digits (0-9), and underscores." - ) - if value == "user": - raise ValueError( - "Agent name cannot be `user`. `user` is reserved for end-user's" - " input." - ) - return value - - @field_validator("sub_agents", mode="after") - @classmethod - def validate_sub_agents_unique_names( - cls, value: list[BaseAgent] - ) -> list[BaseAgent]: - """Validates that all sub-agents have unique names. - - Args: - value: The list of sub-agents to validate. - - Returns: - The validated list of sub-agents. - - """ - if not value: - return value - - seen_names: set[str] = set() - duplicates: set[str] = set() - - for sub_agent in value: - name = sub_agent.name - if name in seen_names: - duplicates.add(name) - else: - seen_names.add(name) - - if duplicates: - duplicate_names_str = ", ".join( - f"`{name}`" for name in sorted(duplicates) - ) - logger.warning( - "Found duplicate sub-agent names: %s. " - "All sub-agents must have unique names.", - duplicate_names_str, - ) - - return value - - def __set_parent_agent_for_sub_agents(self) -> BaseAgent: - for sub_agent in self.sub_agents: - if sub_agent.parent_agent is not None: - raise ValueError( - f"Agent `{sub_agent.name}` already has a parent agent, current" - f" parent: `{sub_agent.parent_agent.name}`, trying to add:" - f" `{self.name}`" - ) - sub_agent.parent_agent = self - return self - - @final - @classmethod - @experimental - def from_config( - cls: Type[SelfAgent], - config: BaseAgentConfig, - config_abs_path: str, - ) -> SelfAgent: - """Creates an agent from a config. - - If sub-classes uses a custom agent config, override `_from_config_kwargs` - method to return an updated kwargs for agent constructor. - - Args: - config: The config to create the agent from. - config_abs_path: The absolute path to the config file that contains the - agent config. - - Returns: - The created agent. - """ - kwargs = cls.__create_kwargs(config, config_abs_path) - kwargs = cls._parse_config(config, config_abs_path, kwargs) - return cls(**kwargs) - - @classmethod - @experimental - def _parse_config( - cls: Type[SelfAgent], - config: BaseAgentConfig, - config_abs_path: str, - kwargs: Dict[str, Any], - ) -> Dict[str, Any]: - """Parses the config and returns updated kwargs to construct the agent. - - Sub-classes should override this method to use a custom agent config class. - - Args: - config: The config to parse. - config_abs_path: The absolute path to the config file that contains the - agent config. - kwargs: The keyword arguments used for agent constructor. - - Returns: - The updated keyword arguments used for agent constructor. - """ - return kwargs - - @classmethod - def __create_kwargs( - cls, - config: BaseAgentConfig, - config_abs_path: str, - ) -> Dict[str, Any]: - """Creates kwargs for the fields of BaseAgent.""" - - from .config_agent_utils import resolve_agent_reference - from .config_agent_utils import resolve_callbacks - - kwargs: Dict[str, Any] = { - "name": config.name, - "description": config.description, - } - if config.sub_agents: - sub_agents = [] - for sub_agent_config in config.sub_agents: - sub_agent = resolve_agent_reference(sub_agent_config, config_abs_path) - sub_agents.append(sub_agent) - kwargs["sub_agents"] = sub_agents - - if config.before_agent_callbacks: - kwargs["before_agent_callback"] = resolve_callbacks( - config.before_agent_callbacks - ) - if config.after_agent_callbacks: - kwargs["after_agent_callback"] = resolve_callbacks( - config.after_agent_callbacks - ) - return kwargs + ret_event = Event( + invocation_id=invocation_context.invocation_id, + author=self.name, + branch=invocation_context.branch, + content=after_agent_callback_content, + actions=callback_context._event_actions, + ) + return ret_event + + if callback_context.state.has_delta(): + return Event( + invocation_id=invocation_context.invocation_id, + author=self.name, + branch=invocation_context.branch, + content=after_agent_callback_content, + actions=callback_context._event_actions, + ) + return None + + @override + def model_post_init(self, __context: Any) -> None: + self.__set_parent_agent_for_sub_agents() + + @field_validator("name", mode="after") + @classmethod + def validate_name(cls, value: str): + if not value.isidentifier(): + raise ValueError( + f"Found invalid agent name: `{value}`." + " Agent name must be a valid identifier. It should start with a" + " letter (a-z, A-Z) or an underscore (_), and can only contain" + " letters, digits (0-9), and underscores." + ) + if value == "user": + raise ValueError( + "Agent name cannot be `user`. `user` is reserved for end-user's" + " input." + ) + return value + + @field_validator("sub_agents", mode="after") + @classmethod + def validate_sub_agents_unique_names( + cls, value: list[BaseAgent] + ) -> list[BaseAgent]: + """Validates that all sub-agents have unique names. + + Args: + value: The list of sub-agents to validate. + + Returns: + The validated list of sub-agents. + + """ + if not value: + return value + + seen_names: set[str] = set() + duplicates: set[str] = set() + + for sub_agent in value: + name = sub_agent.name + if name in seen_names: + duplicates.add(name) + else: + seen_names.add(name) + + if duplicates: + duplicate_names_str = ", ".join(f"`{name}`" for name in sorted(duplicates)) + logger.warning( + "Found duplicate sub-agent names: %s. " + "All sub-agents must have unique names.", + duplicate_names_str, + ) + + return value + + def __set_parent_agent_for_sub_agents(self) -> BaseAgent: + for sub_agent in self.sub_agents: + if sub_agent.parent_agent is not None: + raise ValueError( + f"Agent `{sub_agent.name}` already has a parent agent, current" + f" parent: `{sub_agent.parent_agent.name}`, trying to add:" + f" `{self.name}`" + ) + sub_agent.parent_agent = self + return self + + @final + @classmethod + @experimental + def from_config( + cls: Type[SelfAgent], + config: BaseAgentConfig, + config_abs_path: str, + ) -> SelfAgent: + """Creates an agent from a config. + + If sub-classes uses a custom agent config, override `_from_config_kwargs` + method to return an updated kwargs for agent constructor. + + Args: + config: The config to create the agent from. + config_abs_path: The absolute path to the config file that contains the + agent config. + + Returns: + The created agent. + """ + kwargs = cls.__create_kwargs(config, config_abs_path) + kwargs = cls._parse_config(config, config_abs_path, kwargs) + return cls(**kwargs) + + @classmethod + @experimental + def _parse_config( + cls: Type[SelfAgent], + config: BaseAgentConfig, + config_abs_path: str, + kwargs: Dict[str, Any], + ) -> Dict[str, Any]: + """Parses the config and returns updated kwargs to construct the agent. + + Sub-classes should override this method to use a custom agent config class. + + Args: + config: The config to parse. + config_abs_path: The absolute path to the config file that contains the + agent config. + kwargs: The keyword arguments used for agent constructor. + + Returns: + The updated keyword arguments used for agent constructor. + """ + return kwargs + + @classmethod + def __create_kwargs( + cls, + config: BaseAgentConfig, + config_abs_path: str, + ) -> Dict[str, Any]: + """Creates kwargs for the fields of BaseAgent.""" + + from .config_agent_utils import resolve_agent_reference + from .config_agent_utils import resolve_callbacks + + kwargs: Dict[str, Any] = { + "name": config.name, + "description": config.description, + } + if config.sub_agents: + sub_agents = [] + for sub_agent_config in config.sub_agents: + sub_agent = resolve_agent_reference(sub_agent_config, config_abs_path) + sub_agents.append(sub_agent) + kwargs["sub_agents"] = sub_agents + + if config.before_agent_callbacks: + kwargs["before_agent_callback"] = resolve_callbacks( + config.before_agent_callbacks + ) + if config.after_agent_callbacks: + kwargs["after_agent_callback"] = resolve_callbacks( + config.after_agent_callbacks + ) + return kwargs From a6a70be02f6235be7928e13303af6c11c85bbfdc Mon Sep 17 00:00:00 2001 From: mportdata Date: Sun, 28 Dec 2025 09:43:00 +0000 Subject: [PATCH 07/24] refactor(agents): simplify span handling Use a nullcontext fallback to avoid duplicating the async run loop when telemetry is disabled. --- src/google/adk/agents/base_agent.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index a0a0560249..3f88d7336f 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -14,6 +14,7 @@ from __future__ import annotations +import contextlib import inspect import logging from typing import Any @@ -283,15 +284,13 @@ async def run_async( """ ctx = self._create_invocation_context(parent_context) + span_context = contextlib.nullcontext() if is_telemetry_enabled(self): - with tracer.start_as_current_span(f"invoke_agent {self.name}") as span: + span_context = tracer.start_as_current_span(f"invoke_agent {self.name}") + + with span_context as span: + if span: tracing.trace_agent_invocation(span, self, ctx) - async with Aclosing( - self._run_callbacks_and_impl(ctx, mode="async") - ) as agen: - async for event in agen: - yield event - else: async with Aclosing( self._run_callbacks_and_impl(ctx, mode="async") ) as agen: From 8435ad5352235f1d56bd3d0c8fdc5e50459c0a41 Mon Sep 17 00:00:00 2001 From: mportdata Date: Sun, 28 Dec 2025 09:53:18 +0000 Subject: [PATCH 08/24] refactor(agents): align run_live span handling Use a nullcontext fallback to keep run_live telemetry handling consistent with run_async. --- src/google/adk/agents/base_agent.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index 3f88d7336f..a16d89770a 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -313,12 +313,13 @@ async def run_live( """ ctx = self._create_invocation_context(parent_context) + span_context = contextlib.nullcontext() if is_telemetry_enabled(self): - with tracer.start_as_current_span(f"invoke_agent {self.name}") as span: + span_context = tracer.start_as_current_span(f"invoke_agent {self.name}") + + with span_context as span: + if span: tracing.trace_agent_invocation(span, self, ctx) - async for event in self._run_callbacks_and_impl(ctx, mode="live"): - yield event - else: async for event in self._run_callbacks_and_impl(ctx, mode="live"): yield event From 231c2f26351a87e43632c09b0465c4f00df1de86 Mon Sep 17 00:00:00 2001 From: mportdata Date: Sun, 28 Dec 2025 10:06:01 +0000 Subject: [PATCH 09/24] refactor(telemetry): align span contexts Use nullcontext fallbacks for tracing in base_agent and base_llm_flow. Run autoformat on telemetry utilities. --- src/google/adk/agents/base_agent.py | 1136 +++++++++-------- .../adk/flows/llm_flows/base_llm_flow.py | 16 +- src/google/adk/utils/telemetry_utils.py | 58 +- 3 files changed, 610 insertions(+), 600 deletions(-) diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index a16d89770a..cd591e4a1a 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -50,7 +50,7 @@ from .callback_context import CallbackContext if TYPE_CHECKING: - from .invocation_context import InvocationContext + from .invocation_context import InvocationContext logger = logging.getLogger("google_adk." + __name__) @@ -74,27 +74,27 @@ @experimental class BaseAgentState(BaseModel): - """Base class for all agent states.""" + """Base class for all agent states.""" - model_config = ConfigDict( - extra="forbid", - ) + model_config = ConfigDict( + extra="forbid", + ) AgentState = TypeVar("AgentState", bound=BaseAgentState) class BaseAgent(BaseModel): - """Base class for all agents in Agent Development Kit.""" + """Base class for all agents in Agent Development Kit.""" - model_config = ConfigDict( - arbitrary_types_allowed=True, - extra="forbid", - ) - """The pydantic model config.""" + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) + """The pydantic model config.""" - config_type: ClassVar[type[BaseAgentConfig]] = BaseAgentConfig - """The config type for this agent. + config_type: ClassVar[type[BaseAgentConfig]] = BaseAgentConfig + """The config type for this agent. Sub-classes should override this to specify their own config type. @@ -109,22 +109,22 @@ class MyAgent(BaseAgent): ``` """ - name: str - """The agent's name. + name: str + """The agent's name. Agent name must be a Python identifier and unique within the agent tree. Agent name cannot be "user", since it's reserved for end-user's input. """ - description: str = "" - """Description about the agent's capability. + description: str = "" + """Description about the agent's capability. The model uses this to determine whether to delegate control to the agent. One-line description is enough and preferred. """ - parent_agent: Optional[BaseAgent] = Field(default=None, init=False) - """The parent agent of this agent. + parent_agent: Optional[BaseAgent] = Field(default=None, init=False) + """The parent agent of this agent. Note that an agent can ONLY be added as sub-agent once. @@ -132,12 +132,12 @@ class MyAgent(BaseAgent): instances with identical config, but with different name and add them to the agent tree. """ - sub_agents: list[BaseAgent] = Field(default_factory=list) - """The sub-agents of this agent.""" - disable_telemetry: bool = False - """Whether to disable telemetry for this agent.""" - before_agent_callback: Optional[BeforeAgentCallback] = None - """Callback or list of callbacks to be invoked before the agent run. + sub_agents: list[BaseAgent] = Field(default_factory=list) + """The sub-agents of this agent.""" + disable_telemetry: bool = False + """Whether to disable telemetry for this agent.""" + before_agent_callback: Optional[BeforeAgentCallback] = None + """Callback or list of callbacks to be invoked before the agent run. When a list of callbacks is provided, the callbacks will be called in the order they are listed until a callback does not return None. @@ -150,8 +150,8 @@ class MyAgent(BaseAgent): When the content is present, the agent run will be skipped and the provided content will be returned to user. """ - after_agent_callback: Optional[AfterAgentCallback] = None - """Callback or list of callbacks to be invoked after the agent run. + after_agent_callback: Optional[AfterAgentCallback] = None + """Callback or list of callbacks to be invoked after the agent run. When a list of callbacks is provided, the callbacks will be called in the order they are listed until a callback does not return None. @@ -165,549 +165,561 @@ class MyAgent(BaseAgent): will be appended to event history as an additional agent response. """ - def _load_agent_state( - self, - ctx: InvocationContext, - state_type: Type[AgentState], - ) -> Optional[AgentState]: - """Loads the agent state from the invocation context. - - Args: - ctx: The invocation context. - state_type: The type of the agent state. - - Returns: - The current state if exists; otherwise, None. - """ - if ctx.agent_states is None or self.name not in ctx.agent_states: - return None - else: - return state_type.model_validate(ctx.agent_states.get(self.name)) - - def _create_agent_state_event( - self, - ctx: InvocationContext, - ) -> Event: - """Returns an event with current agent state set in the invocation context. - - Args: - ctx: The invocation context. - - Returns: - An event with the current agent state set in the invocation context. - """ - event_actions = EventActions() - if (agent_state := ctx.agent_states.get(self.name)) is not None: - event_actions.agent_state = agent_state - if ctx.end_of_agents.get(self.name): - event_actions.end_of_agent = True - return Event( - invocation_id=ctx.invocation_id, - author=self.name, - branch=ctx.branch, - actions=event_actions, - ) + def _load_agent_state( + self, + ctx: InvocationContext, + state_type: Type[AgentState], + ) -> Optional[AgentState]: + """Loads the agent state from the invocation context. + + Args: + ctx: The invocation context. + state_type: The type of the agent state. + + Returns: + The current state if exists; otherwise, None. + """ + if ctx.agent_states is None or self.name not in ctx.agent_states: + return None + else: + return state_type.model_validate(ctx.agent_states.get(self.name)) + + def _create_agent_state_event( + self, + ctx: InvocationContext, + ) -> Event: + """Returns an event with current agent state set in the invocation context. + + Args: + ctx: The invocation context. + + Returns: + An event with the current agent state set in the invocation context. + """ + event_actions = EventActions() + if (agent_state := ctx.agent_states.get(self.name)) is not None: + event_actions.agent_state = agent_state + if ctx.end_of_agents.get(self.name): + event_actions.end_of_agent = True + return Event( + invocation_id=ctx.invocation_id, + author=self.name, + branch=ctx.branch, + actions=event_actions, + ) - def clone(self: SelfAgent, update: Mapping[str, Any] | None = None) -> SelfAgent: - """Creates a copy of this agent instance. - - Args: - update: Optional mapping of new values for the fields of the cloned agent. - The keys of the mapping are the names of the fields to be updated, and - the values are the new values for those fields. - For example: {"name": "cloned_agent"} - - Returns: - A new agent instance with identical configuration as the original - agent except for the fields specified in the update. - """ - if update is not None and "parent_agent" in update: - raise ValueError( - "Cannot update `parent_agent` field in clone. Parent agent is set" - " only when the parent agent is instantiated with the sub-agents." - ) - - # Only allow updating fields that are defined in the agent class. - allowed_fields = set(self.__class__.model_fields) - if update is not None: - invalid_fields = set(update) - allowed_fields - if invalid_fields: - raise ValueError( - f"Cannot update nonexistent fields in {self.__class__.__name__}:" - f" {invalid_fields}" - ) - - cloned_agent = self.model_copy(update=update) - - # If any field is stored as list and not provided in the update, need to - # shallow copy it for the cloned agent to avoid sharing the same list object - # with the original agent. - for field_name in cloned_agent.__class__.model_fields: - if field_name == "sub_agents": - continue - if update is not None and field_name in update: - continue - field = getattr(cloned_agent, field_name) - if isinstance(field, list): - setattr(cloned_agent, field_name, field.copy()) - - if update is None or "sub_agents" not in update: - # If `sub_agents` is not provided in the update, need to recursively clone - # the sub-agents to avoid sharing the sub-agents with the original agent. - cloned_agent.sub_agents = [] - for sub_agent in self.sub_agents: - cloned_sub_agent = sub_agent.clone() - cloned_sub_agent.parent_agent = cloned_agent - cloned_agent.sub_agents.append(cloned_sub_agent) - else: - for sub_agent in cloned_agent.sub_agents: - sub_agent.parent_agent = cloned_agent - - # Remove the parent agent from the cloned agent to avoid sharing the parent - # agent with the cloned agent. - cloned_agent.parent_agent = None - return cloned_agent - - @final - async def run_async( - self, - parent_context: InvocationContext, - ) -> AsyncGenerator[Event, None]: - """Entry method to run an agent via text-based conversation. - - Args: - parent_context: InvocationContext, the invocation context of the parent - agent. - - Yields: - Event: the events generated by the agent. - """ - - ctx = self._create_invocation_context(parent_context) - span_context = contextlib.nullcontext() - if is_telemetry_enabled(self): - span_context = tracer.start_as_current_span(f"invoke_agent {self.name}") - - with span_context as span: - if span: - tracing.trace_agent_invocation(span, self, ctx) - async with Aclosing( - self._run_callbacks_and_impl(ctx, mode="async") - ) as agen: - async for event in agen: - yield event - - @final - async def run_live( - self, - parent_context: InvocationContext, - ) -> AsyncGenerator[Event, None]: - """Entry method to run an agent via video/audio-based conversation. - - Args: - parent_context: InvocationContext, the invocation context of the parent - agent. - - Yields: - Event: the events generated by the agent. - """ - - ctx = self._create_invocation_context(parent_context) - span_context = contextlib.nullcontext() - if is_telemetry_enabled(self): - span_context = tracer.start_as_current_span(f"invoke_agent {self.name}") - - with span_context as span: - if span: - tracing.trace_agent_invocation(span, self, ctx) - async for event in self._run_callbacks_and_impl(ctx, mode="live"): - yield event - - async def _run_async_impl( - self, ctx: InvocationContext - ) -> AsyncGenerator[Event, None]: - """Core logic to run this agent via text-based conversation. - - Args: - ctx: InvocationContext, the invocation context for this agent. - - Yields: - Event: the events generated by the agent. - """ - raise NotImplementedError( - f"_run_async_impl for {type(self)} is not implemented." + def clone( + self: SelfAgent, update: Mapping[str, Any] | None = None + ) -> SelfAgent: + """Creates a copy of this agent instance. + + Args: + update: Optional mapping of new values for the fields of the cloned agent. + The keys of the mapping are the names of the fields to be updated, and + the values are the new values for those fields. + For example: {"name": "cloned_agent"} + + Returns: + A new agent instance with identical configuration as the original + agent except for the fields specified in the update. + """ + if update is not None and "parent_agent" in update: + raise ValueError( + "Cannot update `parent_agent` field in clone. Parent agent is set" + " only when the parent agent is instantiated with the sub-agents." + ) + + # Only allow updating fields that are defined in the agent class. + allowed_fields = set(self.__class__.model_fields) + if update is not None: + invalid_fields = set(update) - allowed_fields + if invalid_fields: + raise ValueError( + f"Cannot update nonexistent fields in {self.__class__.__name__}:" + f" {invalid_fields}" ) - yield # AsyncGenerator requires having at least one yield statement - async def _run_live_impl( - self, ctx: InvocationContext - ) -> AsyncGenerator[Event, None]: - """Core logic to run this agent via video/audio-based conversation. + cloned_agent = self.model_copy(update=update) + + # If any field is stored as list and not provided in the update, need to + # shallow copy it for the cloned agent to avoid sharing the same list object + # with the original agent. + for field_name in cloned_agent.__class__.model_fields: + if field_name == "sub_agents": + continue + if update is not None and field_name in update: + continue + field = getattr(cloned_agent, field_name) + if isinstance(field, list): + setattr(cloned_agent, field_name, field.copy()) + + if update is None or "sub_agents" not in update: + # If `sub_agents` is not provided in the update, need to recursively clone + # the sub-agents to avoid sharing the sub-agents with the original agent. + cloned_agent.sub_agents = [] + for sub_agent in self.sub_agents: + cloned_sub_agent = sub_agent.clone() + cloned_sub_agent.parent_agent = cloned_agent + cloned_agent.sub_agents.append(cloned_sub_agent) + else: + for sub_agent in cloned_agent.sub_agents: + sub_agent.parent_agent = cloned_agent + + # Remove the parent agent from the cloned agent to avoid sharing the parent + # agent with the cloned agent. + cloned_agent.parent_agent = None + return cloned_agent + + @final + async def run_async( + self, + parent_context: InvocationContext, + ) -> AsyncGenerator[Event, None]: + """Entry method to run an agent via text-based conversation. + + Args: + parent_context: InvocationContext, the invocation context of the parent + agent. + + Yields: + Event: the events generated by the agent. + """ + + ctx = self._create_invocation_context(parent_context) + span_context = contextlib.nullcontext() + if is_telemetry_enabled(self): + span_context = tracer.start_as_current_span(f"invoke_agent {self.name}") + + with span_context as span: + if span: + tracing.trace_agent_invocation(span, self, ctx) + async with Aclosing( + self._run_callbacks_and_impl(ctx, mode="async") + ) as agen: + async for event in agen: + yield event + + @final + async def run_live( + self, + parent_context: InvocationContext, + ) -> AsyncGenerator[Event, None]: + """Entry method to run an agent via video/audio-based conversation. + + Args: + parent_context: InvocationContext, the invocation context of the parent + agent. + + Yields: + Event: the events generated by the agent. + """ + + ctx = self._create_invocation_context(parent_context) + span_context = contextlib.nullcontext() + if is_telemetry_enabled(self): + span_context = tracer.start_as_current_span(f"invoke_agent {self.name}") + + with span_context as span: + if span: + tracing.trace_agent_invocation(span, self, ctx) + async for event in self._run_callbacks_and_impl(ctx, mode="live"): + yield event + + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + """Core logic to run this agent via text-based conversation. + + Args: + ctx: InvocationContext, the invocation context for this agent. + + Yields: + Event: the events generated by the agent. + """ + raise NotImplementedError( + f"_run_async_impl for {type(self)} is not implemented." + ) + yield # AsyncGenerator requires having at least one yield statement - Args: - ctx: InvocationContext, the invocation context for this agent. + async def _run_live_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + """Core logic to run this agent via video/audio-based conversation. - Yields: - Event: the events generated by the agent. - """ - raise NotImplementedError( - f"_run_live_impl for {type(self)} is not implemented." - ) - yield # AsyncGenerator requires having at least one yield statement - - async def _run_callbacks_and_impl( - self, ctx: InvocationContext, mode: str = "async" - ) -> AsyncGenerator[Event, None]: - """Runs the before and after agent callbacks around the core agent logic. - Args: - ctx: InvocationContext, the invocation context for this agent. - mode: str, either 'async' or 'live', indicating which core agent logic to run. - Yields: - Event: the events generated by the agent. - """ - if event := await self._handle_before_agent_callback(ctx): - yield event - if ctx.end_invocation: - return - if mode.lower() == "async": - async with Aclosing(self._run_async_impl(ctx)) as agen: - async for event in agen: - yield event - elif mode.lower() == "live": - async with Aclosing(self._run_live_impl(ctx)) as agen: - async for event in agen: - yield event - else: - raise ValueError(f"Invalid mode: {mode}. Must be either 'async' or 'live'.") - - if ctx.end_invocation: - return - - if event := await self._handle_after_agent_callback(ctx): - yield event - - @property - def root_agent(self) -> BaseAgent: - """Gets the root agent of this agent.""" - root_agent = self - while root_agent.parent_agent is not None: - root_agent = root_agent.parent_agent - return root_agent - - def find_agent(self, name: str) -> Optional[BaseAgent]: - """Finds the agent with the given name in this agent and its descendants. - - Args: - name: The name of the agent to find. - - Returns: - The agent with the matching name, or None if no such agent is found. - """ - if self.name == name: - return self - return self.find_sub_agent(name) - - def find_sub_agent(self, name: str) -> Optional[BaseAgent]: - """Finds the agent with the given name in this agent's descendants. - - Args: - name: The name of the agent to find. - - Returns: - The agent with the matching name, or None if no such agent is found. - """ - for sub_agent in self.sub_agents: - if result := sub_agent.find_agent(name): - return result - return None - - def _create_invocation_context( - self, parent_context: InvocationContext - ) -> InvocationContext: - """Creates a new invocation context for this agent.""" - invocation_context = parent_context.model_copy(update={"agent": self}) - return invocation_context - - @property - def canonical_before_agent_callbacks(self) -> list[_SingleAgentCallback]: - """The resolved self.before_agent_callback field as a list of _SingleAgentCallback. - - This method is only for use by Agent Development Kit. - """ - if not self.before_agent_callback: - return [] - if isinstance(self.before_agent_callback, list): - return self.before_agent_callback - return [self.before_agent_callback] - - @property - def canonical_after_agent_callbacks(self) -> list[_SingleAgentCallback]: - """The resolved self.after_agent_callback field as a list of _SingleAgentCallback. - - This method is only for use by Agent Development Kit. - """ - if not self.after_agent_callback: - return [] - if isinstance(self.after_agent_callback, list): - return self.after_agent_callback - return [self.after_agent_callback] - - async def _handle_before_agent_callback( - self, ctx: InvocationContext - ) -> Optional[Event]: - """Runs the before_agent_callback if it exists. - - Args: - ctx: InvocationContext, the invocation context for this agent. - - Returns: - Optional[Event]: an event if callback provides content or changed state. - """ - callback_context = CallbackContext(ctx) - - # Run callbacks from the plugins. - before_agent_callback_content = ( - await ctx.plugin_manager.run_before_agent_callback( - agent=self, callback_context=callback_context - ) + Args: + ctx: InvocationContext, the invocation context for this agent. + + Yields: + Event: the events generated by the agent. + """ + raise NotImplementedError( + f"_run_live_impl for {type(self)} is not implemented." + ) + yield # AsyncGenerator requires having at least one yield statement + + async def _run_callbacks_and_impl( + self, ctx: InvocationContext, mode: str = "async" + ) -> AsyncGenerator[Event, None]: + """Runs the before and after agent callbacks around the core agent logic. + Args: + ctx: InvocationContext, the invocation context for this agent. + mode: str, either 'async' or 'live', indicating which core agent logic to run. + Yields: + Event: the events generated by the agent. + """ + if event := await self._handle_before_agent_callback(ctx): + yield event + if ctx.end_invocation: + return + if mode.lower() == "async": + async with Aclosing(self._run_async_impl(ctx)) as agen: + async for event in agen: + yield event + elif mode.lower() == "live": + async with Aclosing(self._run_live_impl(ctx)) as agen: + async for event in agen: + yield event + else: + raise ValueError( + f"Invalid mode: {mode}. Must be either 'async' or 'live'." + ) + + if ctx.end_invocation: + return + + if event := await self._handle_after_agent_callback(ctx): + yield event + + @property + def root_agent(self) -> BaseAgent: + """Gets the root agent of this agent.""" + root_agent = self + while root_agent.parent_agent is not None: + root_agent = root_agent.parent_agent + return root_agent + + def find_agent(self, name: str) -> Optional[BaseAgent]: + """Finds the agent with the given name in this agent and its descendants. + + Args: + name: The name of the agent to find. + + Returns: + The agent with the matching name, or None if no such agent is found. + """ + if self.name == name: + return self + return self.find_sub_agent(name) + + def find_sub_agent(self, name: str) -> Optional[BaseAgent]: + """Finds the agent with the given name in this agent's descendants. + + Args: + name: The name of the agent to find. + + Returns: + The agent with the matching name, or None if no such agent is found. + """ + for sub_agent in self.sub_agents: + if result := sub_agent.find_agent(name): + return result + return None + + def _create_invocation_context( + self, parent_context: InvocationContext + ) -> InvocationContext: + """Creates a new invocation context for this agent.""" + invocation_context = parent_context.model_copy(update={"agent": self}) + return invocation_context + + @property + def canonical_before_agent_callbacks(self) -> list[_SingleAgentCallback]: + """The resolved self.before_agent_callback field as a list of _SingleAgentCallback. + + This method is only for use by Agent Development Kit. + """ + if not self.before_agent_callback: + return [] + if isinstance(self.before_agent_callback, list): + return self.before_agent_callback + return [self.before_agent_callback] + + @property + def canonical_after_agent_callbacks(self) -> list[_SingleAgentCallback]: + """The resolved self.after_agent_callback field as a list of _SingleAgentCallback. + + This method is only for use by Agent Development Kit. + """ + if not self.after_agent_callback: + return [] + if isinstance(self.after_agent_callback, list): + return self.after_agent_callback + return [self.after_agent_callback] + + async def _handle_before_agent_callback( + self, ctx: InvocationContext + ) -> Optional[Event]: + """Runs the before_agent_callback if it exists. + + Args: + ctx: InvocationContext, the invocation context for this agent. + + Returns: + Optional[Event]: an event if callback provides content or changed state. + """ + callback_context = CallbackContext(ctx) + + # Run callbacks from the plugins. + before_agent_callback_content = ( + await ctx.plugin_manager.run_before_agent_callback( + agent=self, callback_context=callback_context ) + ) - # If no overrides are provided from the plugins, further run the canonical - # callbacks. - if not before_agent_callback_content and self.canonical_before_agent_callbacks: - for callback in self.canonical_before_agent_callbacks: - before_agent_callback_content = callback( - callback_context=callback_context - ) - if inspect.isawaitable(before_agent_callback_content): - before_agent_callback_content = await before_agent_callback_content - if before_agent_callback_content: - break - - # Process the override content if exists, and further process the state - # change if exists. + # If no overrides are provided from the plugins, further run the canonical + # callbacks. + if ( + not before_agent_callback_content + and self.canonical_before_agent_callbacks + ): + for callback in self.canonical_before_agent_callbacks: + before_agent_callback_content = callback( + callback_context=callback_context + ) + if inspect.isawaitable(before_agent_callback_content): + before_agent_callback_content = await before_agent_callback_content if before_agent_callback_content: - ret_event = Event( - invocation_id=ctx.invocation_id, - author=self.name, - branch=ctx.branch, - content=before_agent_callback_content, - actions=callback_context._event_actions, - ) - ctx.end_invocation = True - return ret_event - - if callback_context.state.has_delta(): - return Event( - invocation_id=ctx.invocation_id, - author=self.name, - branch=ctx.branch, - actions=callback_context._event_actions, - ) - - return None - - async def _handle_after_agent_callback( - self, invocation_context: InvocationContext - ) -> Optional[Event]: - """Runs the after_agent_callback if it exists. - - Args: - invocation_context: InvocationContext, the invocation context for this - agent. - - Returns: - Optional[Event]: an event if callback provides content or changed state. - """ - - callback_context = CallbackContext(invocation_context) - - # Run callbacks from the plugins. - after_agent_callback_content = ( - await invocation_context.plugin_manager.run_after_agent_callback( - agent=self, callback_context=callback_context - ) + break + + # Process the override content if exists, and further process the state + # change if exists. + if before_agent_callback_content: + ret_event = Event( + invocation_id=ctx.invocation_id, + author=self.name, + branch=ctx.branch, + content=before_agent_callback_content, + actions=callback_context._event_actions, + ) + ctx.end_invocation = True + return ret_event + + if callback_context.state.has_delta(): + return Event( + invocation_id=ctx.invocation_id, + author=self.name, + branch=ctx.branch, + actions=callback_context._event_actions, + ) + + return None + + async def _handle_after_agent_callback( + self, invocation_context: InvocationContext + ) -> Optional[Event]: + """Runs the after_agent_callback if it exists. + + Args: + invocation_context: InvocationContext, the invocation context for this + agent. + + Returns: + Optional[Event]: an event if callback provides content or changed state. + """ + + callback_context = CallbackContext(invocation_context) + + # Run callbacks from the plugins. + after_agent_callback_content = ( + await invocation_context.plugin_manager.run_after_agent_callback( + agent=self, callback_context=callback_context ) + ) - # If no overrides are provided from the plugins, further run the canonical - # callbacks. - if not after_agent_callback_content and self.canonical_after_agent_callbacks: - for callback in self.canonical_after_agent_callbacks: - after_agent_callback_content = callback( - callback_context=callback_context - ) - if inspect.isawaitable(after_agent_callback_content): - after_agent_callback_content = await after_agent_callback_content - if after_agent_callback_content: - break - - # Process the override content if exists, and further process the state - # change if exists. + # If no overrides are provided from the plugins, further run the canonical + # callbacks. + if ( + not after_agent_callback_content + and self.canonical_after_agent_callbacks + ): + for callback in self.canonical_after_agent_callbacks: + after_agent_callback_content = callback( + callback_context=callback_context + ) + if inspect.isawaitable(after_agent_callback_content): + after_agent_callback_content = await after_agent_callback_content if after_agent_callback_content: - ret_event = Event( - invocation_id=invocation_context.invocation_id, - author=self.name, - branch=invocation_context.branch, - content=after_agent_callback_content, - actions=callback_context._event_actions, - ) - return ret_event - - if callback_context.state.has_delta(): - return Event( - invocation_id=invocation_context.invocation_id, - author=self.name, - branch=invocation_context.branch, - content=after_agent_callback_content, - actions=callback_context._event_actions, - ) - return None - - @override - def model_post_init(self, __context: Any) -> None: - self.__set_parent_agent_for_sub_agents() - - @field_validator("name", mode="after") - @classmethod - def validate_name(cls, value: str): - if not value.isidentifier(): - raise ValueError( - f"Found invalid agent name: `{value}`." - " Agent name must be a valid identifier. It should start with a" - " letter (a-z, A-Z) or an underscore (_), and can only contain" - " letters, digits (0-9), and underscores." - ) - if value == "user": - raise ValueError( - "Agent name cannot be `user`. `user` is reserved for end-user's" - " input." - ) - return value - - @field_validator("sub_agents", mode="after") - @classmethod - def validate_sub_agents_unique_names( - cls, value: list[BaseAgent] - ) -> list[BaseAgent]: - """Validates that all sub-agents have unique names. - - Args: - value: The list of sub-agents to validate. - - Returns: - The validated list of sub-agents. - - """ - if not value: - return value - - seen_names: set[str] = set() - duplicates: set[str] = set() - - for sub_agent in value: - name = sub_agent.name - if name in seen_names: - duplicates.add(name) - else: - seen_names.add(name) - - if duplicates: - duplicate_names_str = ", ".join(f"`{name}`" for name in sorted(duplicates)) - logger.warning( - "Found duplicate sub-agent names: %s. " - "All sub-agents must have unique names.", - duplicate_names_str, - ) - - return value - - def __set_parent_agent_for_sub_agents(self) -> BaseAgent: - for sub_agent in self.sub_agents: - if sub_agent.parent_agent is not None: - raise ValueError( - f"Agent `{sub_agent.name}` already has a parent agent, current" - f" parent: `{sub_agent.parent_agent.name}`, trying to add:" - f" `{self.name}`" - ) - sub_agent.parent_agent = self - return self - - @final - @classmethod - @experimental - def from_config( - cls: Type[SelfAgent], - config: BaseAgentConfig, - config_abs_path: str, - ) -> SelfAgent: - """Creates an agent from a config. - - If sub-classes uses a custom agent config, override `_from_config_kwargs` - method to return an updated kwargs for agent constructor. - - Args: - config: The config to create the agent from. - config_abs_path: The absolute path to the config file that contains the - agent config. - - Returns: - The created agent. - """ - kwargs = cls.__create_kwargs(config, config_abs_path) - kwargs = cls._parse_config(config, config_abs_path, kwargs) - return cls(**kwargs) - - @classmethod - @experimental - def _parse_config( - cls: Type[SelfAgent], - config: BaseAgentConfig, - config_abs_path: str, - kwargs: Dict[str, Any], - ) -> Dict[str, Any]: - """Parses the config and returns updated kwargs to construct the agent. - - Sub-classes should override this method to use a custom agent config class. - - Args: - config: The config to parse. - config_abs_path: The absolute path to the config file that contains the - agent config. - kwargs: The keyword arguments used for agent constructor. - - Returns: - The updated keyword arguments used for agent constructor. - """ - return kwargs - - @classmethod - def __create_kwargs( - cls, - config: BaseAgentConfig, - config_abs_path: str, - ) -> Dict[str, Any]: - """Creates kwargs for the fields of BaseAgent.""" - - from .config_agent_utils import resolve_agent_reference - from .config_agent_utils import resolve_callbacks - - kwargs: Dict[str, Any] = { - "name": config.name, - "description": config.description, - } - if config.sub_agents: - sub_agents = [] - for sub_agent_config in config.sub_agents: - sub_agent = resolve_agent_reference(sub_agent_config, config_abs_path) - sub_agents.append(sub_agent) - kwargs["sub_agents"] = sub_agents - - if config.before_agent_callbacks: - kwargs["before_agent_callback"] = resolve_callbacks( - config.before_agent_callbacks - ) - if config.after_agent_callbacks: - kwargs["after_agent_callback"] = resolve_callbacks( - config.after_agent_callbacks - ) - return kwargs + break + + # Process the override content if exists, and further process the state + # change if exists. + if after_agent_callback_content: + ret_event = Event( + invocation_id=invocation_context.invocation_id, + author=self.name, + branch=invocation_context.branch, + content=after_agent_callback_content, + actions=callback_context._event_actions, + ) + return ret_event + + if callback_context.state.has_delta(): + return Event( + invocation_id=invocation_context.invocation_id, + author=self.name, + branch=invocation_context.branch, + content=after_agent_callback_content, + actions=callback_context._event_actions, + ) + return None + + @override + def model_post_init(self, __context: Any) -> None: + self.__set_parent_agent_for_sub_agents() + + @field_validator("name", mode="after") + @classmethod + def validate_name(cls, value: str): + if not value.isidentifier(): + raise ValueError( + f"Found invalid agent name: `{value}`." + " Agent name must be a valid identifier. It should start with a" + " letter (a-z, A-Z) or an underscore (_), and can only contain" + " letters, digits (0-9), and underscores." + ) + if value == "user": + raise ValueError( + "Agent name cannot be `user`. `user` is reserved for end-user's" + " input." + ) + return value + + @field_validator("sub_agents", mode="after") + @classmethod + def validate_sub_agents_unique_names( + cls, value: list[BaseAgent] + ) -> list[BaseAgent]: + """Validates that all sub-agents have unique names. + + Args: + value: The list of sub-agents to validate. + + Returns: + The validated list of sub-agents. + + """ + if not value: + return value + + seen_names: set[str] = set() + duplicates: set[str] = set() + + for sub_agent in value: + name = sub_agent.name + if name in seen_names: + duplicates.add(name) + else: + seen_names.add(name) + + if duplicates: + duplicate_names_str = ", ".join( + f"`{name}`" for name in sorted(duplicates) + ) + logger.warning( + "Found duplicate sub-agent names: %s. " + "All sub-agents must have unique names.", + duplicate_names_str, + ) + + return value + + def __set_parent_agent_for_sub_agents(self) -> BaseAgent: + for sub_agent in self.sub_agents: + if sub_agent.parent_agent is not None: + raise ValueError( + f"Agent `{sub_agent.name}` already has a parent agent, current" + f" parent: `{sub_agent.parent_agent.name}`, trying to add:" + f" `{self.name}`" + ) + sub_agent.parent_agent = self + return self + + @final + @classmethod + @experimental + def from_config( + cls: Type[SelfAgent], + config: BaseAgentConfig, + config_abs_path: str, + ) -> SelfAgent: + """Creates an agent from a config. + + If sub-classes uses a custom agent config, override `_from_config_kwargs` + method to return an updated kwargs for agent constructor. + + Args: + config: The config to create the agent from. + config_abs_path: The absolute path to the config file that contains the + agent config. + + Returns: + The created agent. + """ + kwargs = cls.__create_kwargs(config, config_abs_path) + kwargs = cls._parse_config(config, config_abs_path, kwargs) + return cls(**kwargs) + + @classmethod + @experimental + def _parse_config( + cls: Type[SelfAgent], + config: BaseAgentConfig, + config_abs_path: str, + kwargs: Dict[str, Any], + ) -> Dict[str, Any]: + """Parses the config and returns updated kwargs to construct the agent. + + Sub-classes should override this method to use a custom agent config class. + + Args: + config: The config to parse. + config_abs_path: The absolute path to the config file that contains the + agent config. + kwargs: The keyword arguments used for agent constructor. + + Returns: + The updated keyword arguments used for agent constructor. + """ + return kwargs + + @classmethod + def __create_kwargs( + cls, + config: BaseAgentConfig, + config_abs_path: str, + ) -> Dict[str, Any]: + """Creates kwargs for the fields of BaseAgent.""" + + from .config_agent_utils import resolve_agent_reference + from .config_agent_utils import resolve_callbacks + + kwargs: Dict[str, Any] = { + "name": config.name, + "description": config.description, + } + if config.sub_agents: + sub_agents = [] + for sub_agent_config in config.sub_agents: + sub_agent = resolve_agent_reference(sub_agent_config, config_abs_path) + sub_agents.append(sub_agent) + kwargs["sub_agents"] = sub_agents + + if config.before_agent_callbacks: + kwargs["before_agent_callback"] = resolve_callbacks( + config.before_agent_callbacks + ) + if config.after_agent_callbacks: + kwargs["after_agent_callback"] = resolve_callbacks( + config.after_agent_callbacks + ) + return kwargs diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 4d37e66db5..89411bf1b4 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -16,6 +16,7 @@ from abc import ABC import asyncio +import contextlib import datetime import inspect import logging @@ -130,19 +131,16 @@ async def run_live( async with llm.connect(llm_request) as llm_connection: if llm_request.contents: # Sends the conversation history to the model. + logger.debug("Sending history to model: %s", llm_request.contents) + span_context = contextlib.nullcontext() if is_telemetry_enabled(invocation_context.agent): - with tracer.start_as_current_span("send_data"): - # Combine regular contents with audio/transcription from session - logger.debug( - "Sending history to model: %s", llm_request.contents - ) - await llm_connection.send_history(llm_request.contents) + span_context = tracer.start_as_current_span("send_data") + with span_context as span: + await llm_connection.send_history(llm_request.contents) + if span: trace_send_data( invocation_context, event_id, llm_request.contents ) - else: - logger.debug("Sending history to model: %s", llm_request.contents) - await llm_connection.send_history(llm_request.contents) send_task = asyncio.create_task( self._send_to_model(llm_connection, invocation_context) diff --git a/src/google/adk/utils/telemetry_utils.py b/src/google/adk/utils/telemetry_utils.py index 99af4e0d73..fb5b3e8594 100644 --- a/src/google/adk/utils/telemetry_utils.py +++ b/src/google/adk/utils/telemetry_utils.py @@ -23,42 +23,42 @@ from .env_utils import is_env_enabled if TYPE_CHECKING: - from ..agents.base_agent import BaseAgent + from ..agents.base_agent import BaseAgent def is_telemetry_enabled(agent: "BaseAgent") -> bool: - """Check if telemetry is enabled for the given agent. + """Check if telemetry is enabled for the given agent. - By default telemetry is enabled for an agent unless any of the variables to disable telemetry are set to true. + By default telemetry is enabled for an agent unless any of the variables to disable telemetry are set to true. - Args: - agent: The agent to check if telemetry is enabled for. + Args: + agent: The agent to check if telemetry is enabled for. - Returns: - False if any of the environment variables or attributes to disable telemetryare set to True, 'true' or 1, False otherwise. + Returns: + False if any of the environment variables or attributes to disable telemetryare set to True, 'true' or 1, False otherwise. - Examples: - >>> os.environ['OTEL_SDK_DISABLED'] = 'true' - >>> is_telemetry_enabled(my_agent) - True + Examples: + >>> os.environ['OTEL_SDK_DISABLED'] = 'true' + >>> is_telemetry_enabled(my_agent) + True - >>> os.environ['ADK_TELEMETRY_DISABLED'] = 1 - >>> is_telemetry_enabled(my_agent) - True + >>> os.environ['ADK_TELEMETRY_DISABLED'] = 1 + >>> is_telemetry_enabled(my_agent) + True - >>> my_agent.disable_telemetry = True - >>> is_telemetry_enabled(my_agent) - True + >>> my_agent.disable_telemetry = True + >>> is_telemetry_enabled(my_agent) + True - >>> os.environ['OTEL_SDK_DISABLED'] = 1 - >>> os.environ['ADK_TELEMETRY_DISABLED'] = 'false' - >>> my_agent.disable_telemetry = False - >>> is_telemetry_enabled(my_agent) - False - """ - telemetry_disabled = ( - is_env_enabled("OTEL_SDK_DISABLED") - or is_env_enabled("ADK_TELEMETRY_DISABLED") - or getattr(agent, "disable_telemetry", False) - ) - return not telemetry_disabled + >>> os.environ['OTEL_SDK_DISABLED'] = 1 + >>> os.environ['ADK_TELEMETRY_DISABLED'] = 'false' + >>> my_agent.disable_telemetry = False + >>> is_telemetry_enabled(my_agent) + False + """ + telemetry_disabled = ( + is_env_enabled("OTEL_SDK_DISABLED") + or is_env_enabled("ADK_TELEMETRY_DISABLED") + or getattr(agent, "disable_telemetry", False) + ) + return not telemetry_disabled From 628658bcbdb4193f6c34b9e1518c1fb77cb29664 Mon Sep 17 00:00:00 2001 From: mportdata Date: Sun, 28 Dec 2025 10:20:03 +0000 Subject: [PATCH 10/24] refactor(flows): simplify call_llm tracing Use a nullcontext fallback in _call_llm_with_optional_tracing to avoid branching. --- src/google/adk/flows/llm_flows/base_llm_flow.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 89411bf1b4..962a718a83 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -825,12 +825,10 @@ async def _call_llm_body() -> AsyncGenerator[LlmResponse, None]: async def _call_llm_with_optional_tracing() -> ( AsyncGenerator[LlmResponse, None] ): + span_context = contextlib.nullcontext() if is_telemetry_enabled(invocation_context.agent): - with tracer.start_as_current_span("call_llm"): - async with Aclosing(_call_llm_body()) as agen: - async for r in agen: - yield r - else: + span_context = tracer.start_as_current_span("call_llm") + with span_context: async with Aclosing(_call_llm_body()) as agen: async for r in agen: yield r From 2a2c7f259b6b618aa18c5e25f1ab2cb67ef3b78f Mon Sep 17 00:00:00 2001 From: mportdata Date: Sun, 28 Dec 2025 10:39:35 +0000 Subject: [PATCH 11/24] refactor(flows): simplify execute_tool tracing Use early return when telemetry is disabled to avoid branching in execute_tool spans. --- src/google/adk/flows/llm_flows/functions.py | 62 ++++++++++----------- 1 file changed, 30 insertions(+), 32 deletions(-) diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 999e85293b..10edc3b410 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -428,23 +428,22 @@ async def _run_with_trace(): ) return function_response_event - if is_telemetry_enabled(agent): - with tracer.start_as_current_span(f"execute_tool {tool.name}"): - try: - function_response_event = await _run_with_trace() - trace_tool_call( - tool=tool, - args=function_args, - function_response_event=function_response_event, - ) - return function_response_event - except: - trace_tool_call( - tool=tool, args=function_args, function_response_event=None - ) - raise - else: + if not is_telemetry_enabled(agent): return await _run_with_trace() + with tracer.start_as_current_span(f"execute_tool {tool.name}"): + try: + function_response_event = await _run_with_trace() + trace_tool_call( + tool=tool, + args=function_args, + function_response_event=function_response_event, + ) + return function_response_event + except: + trace_tool_call( + tool=tool, args=function_args, function_response_event=None + ) + raise async def handle_function_calls_live( @@ -582,23 +581,22 @@ async def _run_with_trace(): ) return function_response_event - if is_telemetry_enabled(agent): - with tracer.start_as_current_span(f"execute_tool {tool.name}"): - try: - function_response_event = await _run_with_trace() - trace_tool_call( - tool=tool, - args=function_args, - function_response_event=function_response_event, - ) - return function_response_event - except: - trace_tool_call( - tool=tool, args=function_args, function_response_event=None - ) - raise - else: + if not is_telemetry_enabled(agent): return await _run_with_trace() + with tracer.start_as_current_span(f"execute_tool {tool.name}"): + try: + function_response_event = await _run_with_trace() + trace_tool_call( + tool=tool, + args=function_args, + function_response_event=function_response_event, + ) + return function_response_event + except: + trace_tool_call( + tool=tool, args=function_args, function_response_event=None + ) + raise async def _process_function_live_helper( From 75b375cd7037e3e7f76a01b4c89fda94fe0e6ab5 Mon Sep 17 00:00:00 2001 From: mportdata Date: Sun, 28 Dec 2025 14:40:39 +0000 Subject: [PATCH 12/24] refactor(telemetry): unify cache span handling Use a nullcontext fallback for cache creation tracing. Update GEPA sample formatting. --- contributing/samples/gepa/experiment.py | 1051 +++++++++-------- contributing/samples/gepa/run_experiment.py | 190 ++- .../models/gemini_context_cache_manager.py | 16 +- 3 files changed, 625 insertions(+), 632 deletions(-) diff --git a/contributing/samples/gepa/experiment.py b/contributing/samples/gepa/experiment.py index f68b349d9c..8d05507ed9 100644 --- a/contributing/samples/gepa/experiment.py +++ b/contributing/samples/gepa/experiment.py @@ -43,6 +43,7 @@ from tau_bench.types import EnvRunResult from tau_bench.types import RunConfig import tau_bench_agent as tau_bench_agent_lib + import utils @@ -52,202 +53,204 @@ def run_tau_bench_rollouts( system_instruction: str | None = None, rater: rater_lib.Rater | None = None, ) -> list[EnvRunResult]: - """Runs a set of tau-bench tasks with a given agent configuration. - - This is a customized version of the standard tau-bench run function, adapted - for this experiment's needs. It handles environment setup, agent creation, - task execution in parallel, and result aggregation. - - Args: - config: A RunConfig object specifying the environment, models, and other - parameters for the run. - print_results: If True, prints the result of each task as it completes. - system_instruction: An optional system instruction to use for the agent, - overriding the default. - rater: An optional rater to evaluate the agent's performance. - - Returns: - A list of EnvRunResult objects, one for each completed task. - """ - if config.env not in ['retail', 'airline']: - raise ValueError('Only retail and airline envs are supported') - if config.model_provider not in provider_list: - raise ValueError('Invalid model provider') - if config.user_model_provider not in provider_list: - raise ValueError('Invalid user model provider') - if config.agent_strategy not in ['tool-calling', 'act', 'react', 'few-shot']: - raise ValueError('Invalid agent strategy') - if config.task_split not in ['train', 'test', 'dev']: - raise ValueError('Invalid task split') - if config.user_strategy not in [item.value for item in UserStrategy]: - raise ValueError('Invalid user strategy') - - random.seed(config.seed) - time_str = datetime.now().strftime('%m%d%H%M%S') - model_name = config.model.split('/')[-1] - ckpt_filename = ( - f'{config.agent_strategy}-{model_name}-{config.temperature}_range_' - f'{config.start_index}-{config.end_index}_user-{config.user_model}-' - f'{config.user_strategy}_{time_str}.json' - ) - ckpt_path = os.path.join(config.log_dir, ckpt_filename) - if not os.path.exists(config.log_dir): - os.makedirs(config.log_dir) - - print(f'Loading user with strategy: {config.user_strategy}') - env = get_env( - config.env, - user_strategy=config.user_strategy, - user_model=config.user_model, - user_provider=config.user_model_provider, - task_split=config.task_split, - ) - if system_instruction: - env.wiki = system_instruction - agent = tau_bench_agent_lib.adk_agent_factory( - tools_info=env.tools_info, - wiki=env.wiki, - config=config, - ) - if config.end_index == -1: - end_index = len(env.tasks) - else: - end_index = min(config.end_index, len(env.tasks)) - results: list[EnvRunResult] = [] - lock = multiprocessing.Lock() - if config.task_ids: - print(f'Running tasks {config.task_ids} (checkpoint path: {ckpt_path})') - else: - print( - f'Running tasks {config.start_index} to {end_index} ' - f'(checkpoint path: {ckpt_path})' + """Runs a set of tau-bench tasks with a given agent configuration. + + This is a customized version of the standard tau-bench run function, adapted + for this experiment's needs. It handles environment setup, agent creation, + task execution in parallel, and result aggregation. + + Args: + config: A RunConfig object specifying the environment, models, and other + parameters for the run. + print_results: If True, prints the result of each task as it completes. + system_instruction: An optional system instruction to use for the agent, + overriding the default. + rater: An optional rater to evaluate the agent's performance. + + Returns: + A list of EnvRunResult objects, one for each completed task. + """ + if config.env not in ["retail", "airline"]: + raise ValueError("Only retail and airline envs are supported") + if config.model_provider not in provider_list: + raise ValueError("Invalid model provider") + if config.user_model_provider not in provider_list: + raise ValueError("Invalid user model provider") + if config.agent_strategy not in ["tool-calling", "act", "react", "few-shot"]: + raise ValueError("Invalid agent strategy") + if config.task_split not in ["train", "test", "dev"]: + raise ValueError("Invalid task split") + if config.user_strategy not in [item.value for item in UserStrategy]: + raise ValueError("Invalid user strategy") + + random.seed(config.seed) + time_str = datetime.now().strftime("%m%d%H%M%S") + model_name = config.model.split("/")[-1] + ckpt_filename = ( + f"{config.agent_strategy}-{model_name}-{config.temperature}_range_" + f"{config.start_index}-{config.end_index}_user-{config.user_model}-" + f"{config.user_strategy}_{time_str}.json" + ) + ckpt_path = os.path.join(config.log_dir, ckpt_filename) + if not os.path.exists(config.log_dir): + os.makedirs(config.log_dir) + + print(f"Loading user with strategy: {config.user_strategy}") + env = get_env( + config.env, + user_strategy=config.user_strategy, + user_model=config.user_model, + user_provider=config.user_model_provider, + task_split=config.task_split, + ) + if system_instruction: + env.wiki = system_instruction + agent = tau_bench_agent_lib.adk_agent_factory( + tools_info=env.tools_info, + wiki=env.wiki, + config=config, ) - for i in range(config.num_trials): + if config.end_index == -1: + end_index = len(env.tasks) + else: + end_index = min(config.end_index, len(env.tasks)) + results: list[EnvRunResult] = [] + lock = multiprocessing.Lock() if config.task_ids: - idxs = config.task_ids + print(f"Running tasks {config.task_ids} (checkpoint path: {ckpt_path})") else: - idxs = list(range(config.start_index, end_index)) - if config.shuffle: - random.shuffle(idxs) - - @retry(tries=3, delay=10, backoff=2) - def _run_with_retry(idx: int) -> EnvRunResult: - isolated_env = get_env( - config.env, - user_strategy=config.user_strategy, - user_model=config.user_model, - task_split=config.task_split, - user_provider=config.user_model_provider, - task_index=idx, - ) - if print_results: - print(f'Running task {idx}') - res = agent.solve( - env=isolated_env, - task_index=idx, - ) - - rating = ( - rater(res.messages[1:] if len(res.messages) > 1 else res.messages) - if rater - else None - ) - info = dict(res.info) - info['metrics'] = dict(rating=rating, reward=res.reward) - - if rater: - score = rating['score'] - feedback = {k: v for k, v in rating.items() if k != 'score'} - else: - score = res.reward - feedback = ( - 'The agent successfully resolved all customer issues' - if score > 0 - else 'The agent failed to resolve all customer issues correctly' - ) - - info['feedback'] = feedback - return EnvRunResult( - task_id=idx, - reward=score, - info=info, - traj=res.messages, - trial=i, - ) - - def _run(idx: int) -> EnvRunResult: - try: - result = _run_with_retry(idx) - except Exception as e: - logging.warning('Inference error: %s', str(e)) - result = EnvRunResult( - task_id=idx, - reward=0.0, - info={ - 'error': str(e), - 'traceback': traceback.format_exc(), - 'metrics': dict(reward=0.0), - }, - traj=[], - trial=i, - ) - - if print_results: print( - '✅' if result.reward == 1 else '❌', - f'task_id={idx}', + f"Running tasks {config.start_index} to {end_index} " + f"(checkpoint path: {ckpt_path})" ) - print('-----') - with lock: - data = [] - if os.path.exists(ckpt_path): - with open(ckpt_path, 'r') as f: - data = json.load(f) - with open(ckpt_path, 'w') as f: - json.dump(data + [result.model_dump()], f, indent=2) - return result - - with ThreadPoolExecutor(max_workers=config.max_concurrency) as executor: - res = list(executor.map(_run, idxs)) - results.extend(res) - - display_metrics(results) - - if rater: - print('Environment reward:') - display_metrics([ - EnvRunResult( - task_id=r.task_id, - reward=r.info['metrics']['reward'], - info={}, - traj=[], - trial=r.trial, + for i in range(config.num_trials): + if config.task_ids: + idxs = config.task_ids + else: + idxs = list(range(config.start_index, end_index)) + if config.shuffle: + random.shuffle(idxs) + + @retry(tries=3, delay=10, backoff=2) + def _run_with_retry(idx: int) -> EnvRunResult: + isolated_env = get_env( + config.env, + user_strategy=config.user_strategy, + user_model=config.user_model, + task_split=config.task_split, + user_provider=config.user_model_provider, + task_index=idx, + ) + if print_results: + print(f"Running task {idx}") + res = agent.solve( + env=isolated_env, + task_index=idx, + ) + + rating = ( + rater(res.messages[1:] if len(res.messages) > 1 else res.messages) + if rater + else None + ) + info = dict(res.info) + info["metrics"] = dict(rating=rating, reward=res.reward) + + if rater: + score = rating["score"] + feedback = {k: v for k, v in rating.items() if k != "score"} + else: + score = res.reward + feedback = ( + "The agent successfully resolved all customer issues" + if score > 0 + else "The agent failed to resolve all customer issues correctly" + ) + + info["feedback"] = feedback + return EnvRunResult( + task_id=idx, + reward=score, + info=info, + traj=res.messages, + trial=i, + ) + + def _run(idx: int) -> EnvRunResult: + try: + result = _run_with_retry(idx) + except Exception as e: + logging.warning("Inference error: %s", str(e)) + result = EnvRunResult( + task_id=idx, + reward=0.0, + info={ + "error": str(e), + "traceback": traceback.format_exc(), + "metrics": dict(reward=0.0), + }, + traj=[], + trial=i, + ) + + if print_results: + print( + "✅" if result.reward == 1 else "❌", + f"task_id={idx}", + ) + print("-----") + with lock: + data = [] + if os.path.exists(ckpt_path): + with open(ckpt_path, "r") as f: + data = json.load(f) + with open(ckpt_path, "w") as f: + json.dump(data + [result.model_dump()], f, indent=2) + return result + + with ThreadPoolExecutor(max_workers=config.max_concurrency) as executor: + res = list(executor.map(_run, idxs)) + results.extend(res) + + display_metrics(results) + + if rater: + print("Environment reward:") + display_metrics( + [ + EnvRunResult( + task_id=r.task_id, + reward=r.info["metrics"]["reward"], + info={}, + traj=[], + trial=r.trial, + ) + for r in results + ] ) - for r in results - ]) - with open(ckpt_path, 'w') as f: - json.dump([result.model_dump() for result in results], f, indent=2) - print(f'\n📄 Results saved to {ckpt_path}\n') - return results + with open(ckpt_path, "w") as f: + json.dump([result.model_dump() for result in results], f, indent=2) + print(f"\n📄 Results saved to {ckpt_path}\n") + return results class TauBenchDataInst(TypedDict): - env: str - task_id: int - task_split: str + env: str + task_id: int + task_split: str class TauBenchTrajectory(TypedDict): - result_traj: list[dict[str, Any]] + result_traj: list[dict[str, Any]] class TauBenchRolloutOutput(TypedDict): - env: str - task_id: int - reward: float - task_info: dict[str, Any] + env: str + task_id: int + reward: float + task_info: dict[str, Any] class TauBenchAdapter( @@ -257,383 +260,381 @@ class TauBenchAdapter( TauBenchRolloutOutput, ] ): - """A GEPA adapter for evaluating agent performance on tau-bench benchmark.""" - - def __init__( - self, - env_name: str, - agent_model: str = 'gemini-2.5-flash', - agent_model_provider: str = 'vertex_ai', - user_model: str = 'gemini-2.5-pro', - user_model_provider: str = 'vertex_ai', - agent_strategy: str = 'tool-calling', - user_strategy: str = 'llm', - system_instruction_name: str = 'system_instruction', - max_concurrency: int = 4, - rater: rater_lib.Rater | None = None, - log_dir: str | None = None, - ): - """Initializes the TauBenchAdapter. - - Args: - env_name: environment - agent_model: The model to use for the agent. - agent_model_provider: The provider for the agent model. - user_model: The model to use for simulating the user. - user_model_provider: The provider for the user model. - agent_strategy: The agent strategy to use (e.g., 'tool-calling'). - user_strategy: The user simulation strategy (e.g., 'llm'). - system_instruction_name: The key in the candidate dictionary that holds - the system instruction. - max_concurrency: The maximum number of tasks to run in parallel. - rater: An optional rater to evaluate the agent's performance. - log_dir: The directory to save traces and other logs. - """ - self._env_name = env_name - self._agent_model = agent_model - self._agent_model_provider = agent_model_provider - self._user_model = user_model - self._user_model_provider = user_model_provider - self._agent_strategy = agent_strategy - self._user_strategy = user_strategy - self._max_concurrency = max_concurrency - self._system_instruction_name = system_instruction_name - self._rater = rater - self._log_dir = log_dir - - def evaluate( - self, - batch: list[TauBenchDataInst], - candidate: dict[str, str], - capture_traces: bool = False, - ) -> EvaluationBatch[TauBenchTrajectory, TauBenchRolloutOutput]: - """Evaluates a candidate prompt on a batch of tau-bench tasks. - - This method is called by GEPA during the optimization loop. It takes a - candidate prompt, runs it against the specified tasks from tau-bench, and - returns the results. - - Args: - batch: A list of task instances to evaluate on. Each instance specifies - the environment and task ID. - candidate: A dictionary containing the components to be evaluated, - including the system instruction. - capture_traces: (Not used in this adapter) Whether to capture detailed - traces. - - Returns: - An EvaluationBatch object containing scores, outputs, and trajectories for - each task in the batch. - """ - del capture_traces # Not used. - env = batch[0]['env'] - task_ids = [inst['task_id'] for inst in batch] - tau_bench_run_config = RunConfig( - env=env, - model=self._agent_model, - model_provider=self._agent_model_provider, - user_model=self._user_model, - user_model_provider=self._user_model_provider, - agent_strategy=self._agent_strategy, - user_strategy=self._user_strategy, - max_concurrency=self._max_concurrency, - task_ids=task_ids, - log_dir=self._log_dir, - task_split=batch[0]['task_split'], - ) - tau_bench_results = run_tau_bench_rollouts( - tau_bench_run_config, - system_instruction=candidate.get(self._system_instruction_name), - rater=self._rater, - ) - - outputs = [] - trajectories = [] - scores = [] - for res in tau_bench_results: - outputs.append( - TauBenchRolloutOutput( - env=env, - task_id=res.task_id, - reward=res.reward, - task_info=res.info, - ) - ) - result_traj = res.traj - trajectories.append(TauBenchTrajectory(result_traj=result_traj)) - scores.append(res.reward) - - return EvaluationBatch( - scores=scores, outputs=outputs, trajectories=trajectories - ) - - def make_reflective_dataset( - self, - candidate: dict[str, str], - eval_batch: EvaluationBatch[TauBenchTrajectory, TauBenchRolloutOutput], - components_to_update: list[str], - ) -> dict[str, list[dict[str, Any]]]: - """Creates a dataset for reflection based on evaluation results. - - This method transforms the trajectories and scores from an evaluation run - into a structured format that a reflection model can use to generate - suggestions for improving the prompt. - - Args: - candidate: The candidate that was evaluated. - eval_batch: The results of the evaluation. - components_to_update: A list of component names that the reflection should - focus on improving. + """A GEPA adapter for evaluating agent performance on tau-bench benchmark.""" + + def __init__( + self, + env_name: str, + agent_model: str = "gemini-2.5-flash", + agent_model_provider: str = "vertex_ai", + user_model: str = "gemini-2.5-pro", + user_model_provider: str = "vertex_ai", + agent_strategy: str = "tool-calling", + user_strategy: str = "llm", + system_instruction_name: str = "system_instruction", + max_concurrency: int = 4, + rater: rater_lib.Rater | None = None, + log_dir: str | None = None, + ): + """Initializes the TauBenchAdapter. + + Args: + env_name: environment + agent_model: The model to use for the agent. + agent_model_provider: The provider for the agent model. + user_model: The model to use for simulating the user. + user_model_provider: The provider for the user model. + agent_strategy: The agent strategy to use (e.g., 'tool-calling'). + user_strategy: The user simulation strategy (e.g., 'llm'). + system_instruction_name: The key in the candidate dictionary that holds + the system instruction. + max_concurrency: The maximum number of tasks to run in parallel. + rater: An optional rater to evaluate the agent's performance. + log_dir: The directory to save traces and other logs. + """ + self._env_name = env_name + self._agent_model = agent_model + self._agent_model_provider = agent_model_provider + self._user_model = user_model + self._user_model_provider = user_model_provider + self._agent_strategy = agent_strategy + self._user_strategy = user_strategy + self._max_concurrency = max_concurrency + self._system_instruction_name = system_instruction_name + self._rater = rater + self._log_dir = log_dir + + def evaluate( + self, + batch: list[TauBenchDataInst], + candidate: dict[str, str], + capture_traces: bool = False, + ) -> EvaluationBatch[TauBenchTrajectory, TauBenchRolloutOutput]: + """Evaluates a candidate prompt on a batch of tau-bench tasks. + + This method is called by GEPA during the optimization loop. It takes a + candidate prompt, runs it against the specified tasks from tau-bench, and + returns the results. + + Args: + batch: A list of task instances to evaluate on. Each instance specifies + the environment and task ID. + candidate: A dictionary containing the components to be evaluated, + including the system instruction. + capture_traces: (Not used in this adapter) Whether to capture detailed + traces. + + Returns: + An EvaluationBatch object containing scores, outputs, and trajectories for + each task in the batch. + """ + del capture_traces # Not used. + env = batch[0]["env"] + task_ids = [inst["task_id"] for inst in batch] + tau_bench_run_config = RunConfig( + env=env, + model=self._agent_model, + model_provider=self._agent_model_provider, + user_model=self._user_model, + user_model_provider=self._user_model_provider, + agent_strategy=self._agent_strategy, + user_strategy=self._user_strategy, + max_concurrency=self._max_concurrency, + task_ids=task_ids, + log_dir=self._log_dir, + task_split=batch[0]["task_split"], + ) + tau_bench_results = run_tau_bench_rollouts( + tau_bench_run_config, + system_instruction=candidate.get(self._system_instruction_name), + rater=self._rater, + ) - Returns: - A dictionary where keys are component names and values are lists of - data instances for reflection. - """ - system_instruction = candidate[self._system_instruction_name] + outputs = [] + trajectories = [] + scores = [] + for res in tau_bench_results: + outputs.append( + TauBenchRolloutOutput( + env=env, + task_id=res.task_id, + reward=res.reward, + task_info=res.info, + ) + ) + result_traj = res.traj + trajectories.append(TauBenchTrajectory(result_traj=result_traj)) + scores.append(res.reward) + + return EvaluationBatch( + scores=scores, outputs=outputs, trajectories=trajectories + ) - env = get_env( - self._env_name, - user_strategy=self._user_strategy, - user_model=self._user_model, - user_provider=self._user_model_provider, - task_split='train', - ) + def make_reflective_dataset( + self, + candidate: dict[str, str], + eval_batch: EvaluationBatch[TauBenchTrajectory, TauBenchRolloutOutput], + components_to_update: list[str], + ) -> dict[str, list[dict[str, Any]]]: + """Creates a dataset for reflection based on evaluation results. + + This method transforms the trajectories and scores from an evaluation run + into a structured format that a reflection model can use to generate + suggestions for improving the prompt. + + Args: + candidate: The candidate that was evaluated. + eval_batch: The results of the evaluation. + components_to_update: A list of component names that the reflection should + focus on improving. + + Returns: + A dictionary where keys are component names and values are lists of + data instances for reflection. + """ + system_instruction = candidate[self._system_instruction_name] + + env = get_env( + self._env_name, + user_strategy=self._user_strategy, + user_model=self._user_model, + user_provider=self._user_model_provider, + task_split="train", + ) - tool_definitions = json.dumps( - env.tools_info, - indent=2, - default=str, - ) - inputs = '\n\n'.join([ - f'# System Instruction\n{system_instruction}', - f'# Tool Definitions\n{tool_definitions}', - ]) - ret_d: dict[str, list[dict[str, Any]]] = {} - for comp in components_to_update: - items: list[dict[str, Any]] = [] - trace_instances = list( - zip( - eval_batch.trajectories, - eval_batch.scores, - eval_batch.outputs, - strict=True, - ) - ) - for trace_instance in trace_instances: - traj, _, rollout = trace_instance - messages = traj['result_traj'] - # Remove instructions. - if len(messages) > 1: - messages = messages[1:] - d = { - 'Inputs': inputs, - 'Generated Outputs': json.dumps(messages, indent=2, default=str), - 'Feedback': json.dumps( - rollout['task_info']['feedback'], indent=2, default=str - ), - } - items.append(d) - if items: - ret_d[comp] = items - assert ret_d, ( - 'empty reflective dataset for components ' - f'{[comp for comp in components_to_update]}' - ) - return ret_d + tool_definitions = json.dumps( + env.tools_info, + indent=2, + default=str, + ) + inputs = "\n\n".join( + [ + f"# System Instruction\n{system_instruction}", + f"# Tool Definitions\n{tool_definitions}", + ] + ) + ret_d: dict[str, list[dict[str, Any]]] = {} + for comp in components_to_update: + items: list[dict[str, Any]] = [] + trace_instances = list( + zip( + eval_batch.trajectories, + eval_batch.scores, + eval_batch.outputs, + strict=True, + ) + ) + for trace_instance in trace_instances: + traj, _, rollout = trace_instance + messages = traj["result_traj"] + # Remove instructions. + if len(messages) > 1: + messages = messages[1:] + d = { + "Inputs": inputs, + "Generated Outputs": json.dumps(messages, indent=2, default=str), + "Feedback": json.dumps( + rollout["task_info"]["feedback"], indent=2, default=str + ), + } + items.append(d) + if items: + ret_d[comp] = items + assert ret_d, ( + "empty reflective dataset for components " + f"{[comp for comp in components_to_update]}" + ) + return ret_d _DATASET_SPLITS = { - 'train': tasks_train.TASKS_TRAIN, - 'dev': tasks_dev.TASKS_DEV, - 'test': tasks_test.TASKS_TEST, + "train": tasks_train.TASKS_TRAIN, + "dev": tasks_dev.TASKS_DEV, + "test": tasks_test.TASKS_TEST, } def _get_dataset(ds: Dataset) -> list[TauBenchDataInst]: - task_ids = ds.indexes or list(range(len(_DATASET_SPLITS[ds.split]))) - if ds.max_size is not None: - task_ids = task_ids[: ds.max_size] - random.shuffle(task_ids) - return task_ids + task_ids = ds.indexes or list(range(len(_DATASET_SPLITS[ds.split]))) + if ds.max_size is not None: + task_ids = task_ids[: ds.max_size] + random.shuffle(task_ids) + return task_ids def _get_datasets( config: ExperimentConfig, ) -> dict[str, list[int]]: - """Returns Tau-bench dataset splits.""" - random.seed(config.rnd_seed) - train_task_ids = _get_dataset(config.feedback_dataset) - eval_task_ids = _get_dataset(config.pareto_dataset) - test_task_ids = _get_dataset(config.eval_dataset) - logging.info( - 'Using datasets of size: train=%d, eval=%d, test=%d', - len(train_task_ids), - len(eval_task_ids), - len(test_task_ids), - ) - return dict( - train=train_task_ids, - dev=eval_task_ids, - test=test_task_ids, - ) + """Returns Tau-bench dataset splits.""" + random.seed(config.rnd_seed) + train_task_ids = _get_dataset(config.feedback_dataset) + eval_task_ids = _get_dataset(config.pareto_dataset) + test_task_ids = _get_dataset(config.eval_dataset) + logging.info( + "Using datasets of size: train=%d, eval=%d, test=%d", + len(train_task_ids), + len(eval_task_ids), + len(test_task_ids), + ) + return dict( + train=train_task_ids, + dev=eval_task_ids, + test=test_task_ids, + ) SEED_SYSTEM_INSTRUCTION = ( - 'you are a customer support agent helping customers resolve their ' - 'issues by using the right tools' + "you are a customer support agent helping customers resolve their " + "issues by using the right tools" ) @dataclasses.dataclass(frozen=True) class Dataset: - split: str - indexes: list[int] | None = None - max_size: int = None + split: str + indexes: list[int] | None = None + max_size: int = None @dataclasses.dataclass class ExperimentConfig: - """Configures a GEPA experiment on Tau-bench.""" - - tau_bench_env: str - agent_model: str - agent_model_provider: str - user_model: str - user_model_provider: str - max_concurrency: int - num_eval_trials: int - rnd_seed: int - max_metric_calls: int - reflection_model: str - reflection_minibatch_size: int - use_rater: bool - feedback_dataset: Dataset - pareto_dataset: Dataset - eval_dataset: Dataset + """Configures a GEPA experiment on Tau-bench.""" + + tau_bench_env: str + agent_model: str + agent_model_provider: str + user_model: str + user_model_provider: str + max_concurrency: int + num_eval_trials: int + rnd_seed: int + max_metric_calls: int + reflection_model: str + reflection_minibatch_size: int + use_rater: bool + feedback_dataset: Dataset + pareto_dataset: Dataset + eval_dataset: Dataset def _rater(config: ExperimentConfig) -> rater_lib.Rater: - env = get_env( - config.tau_bench_env, - user_strategy='llm', - user_model=config.user_model, - user_provider=config.user_model_provider, - task_split='train', - ) - return rater_lib.Rater(json.dumps(env.tools_info, indent=2)) - - -def run_gepa( - output_dir: str, seed_instructions: str, config: ExperimentConfig -) -> Any: - """Runs the GEPA optimization loop to train a new system instruction. - - Args: - output_dir: The directory to save experiment results and artifacts. - seed_instructions: Agent instructions to initialize the agent with. - config: The experiment configuration. - - Returns: - The results of the GEPA optimization. - """ - # This section sets up and runs the GEPA optimization experiment. - # Here we define all the parameters for the tau-bench environment, the GEPA - # optimization loop, and the models to be used. - datasets = _get_datasets(config) - training_set = [ - TauBenchDataInst( - env=config.tau_bench_env, - task_id=task_id, - task_split=config.feedback_dataset.split, - ) - for task_id in datasets['train'] - ] - eval_set = [ - TauBenchDataInst( - env=config.tau_bench_env, - task_id=task_id, - task_split=config.pareto_dataset.split, - ) - for task_id in datasets['dev'] - ] - system_instruction_name = 'system_instruction' - - tau_bench_adapter = TauBenchAdapter( - env_name=config.tau_bench_env, - agent_model=config.agent_model, - agent_model_provider=config.agent_model_provider, - user_model=config.user_model, - user_model_provider=config.user_model_provider, - agent_strategy='tool-calling', - user_strategy='llm', - system_instruction_name=system_instruction_name, - max_concurrency=config.max_concurrency, - rater=_rater(config) if config.use_rater else None, - log_dir=os.path.join(output_dir, 'traces'), - ) - - gepa_results = gepa.optimize( - seed_candidate={ - system_instruction_name: seed_instructions, - }, - trainset=training_set, - valset=eval_set, - task_lm=None, # this must be None when a custom adapter is used - adapter=tau_bench_adapter, - max_metric_calls=config.max_metric_calls, - reflection_lm=utils.reflection_inference_fn(config.reflection_model), - reflection_minibatch_size=config.reflection_minibatch_size, - run_dir=output_dir, - ) - json.dump( - gepa_results.to_dict(), - open(os.path.join(output_dir, 'results.json'), 'w'), - ) - return gepa_results + env = get_env( + config.tau_bench_env, + user_strategy="llm", + user_model=config.user_model, + user_provider=config.user_model_provider, + task_split="train", + ) + return rater_lib.Rater(json.dumps(env.tools_info, indent=2)) + + +def run_gepa(output_dir: str, seed_instructions: str, config: ExperimentConfig) -> Any: + """Runs the GEPA optimization loop to train a new system instruction. + + Args: + output_dir: The directory to save experiment results and artifacts. + seed_instructions: Agent instructions to initialize the agent with. + config: The experiment configuration. + + Returns: + The results of the GEPA optimization. + """ + # This section sets up and runs the GEPA optimization experiment. + # Here we define all the parameters for the tau-bench environment, the GEPA + # optimization loop, and the models to be used. + datasets = _get_datasets(config) + training_set = [ + TauBenchDataInst( + env=config.tau_bench_env, + task_id=task_id, + task_split=config.feedback_dataset.split, + ) + for task_id in datasets["train"] + ] + eval_set = [ + TauBenchDataInst( + env=config.tau_bench_env, + task_id=task_id, + task_split=config.pareto_dataset.split, + ) + for task_id in datasets["dev"] + ] + system_instruction_name = "system_instruction" + + tau_bench_adapter = TauBenchAdapter( + env_name=config.tau_bench_env, + agent_model=config.agent_model, + agent_model_provider=config.agent_model_provider, + user_model=config.user_model, + user_model_provider=config.user_model_provider, + agent_strategy="tool-calling", + user_strategy="llm", + system_instruction_name=system_instruction_name, + max_concurrency=config.max_concurrency, + rater=_rater(config) if config.use_rater else None, + log_dir=os.path.join(output_dir, "traces"), + ) + + gepa_results = gepa.optimize( + seed_candidate={ + system_instruction_name: seed_instructions, + }, + trainset=training_set, + valset=eval_set, + task_lm=None, # this must be None when a custom adapter is used + adapter=tau_bench_adapter, + max_metric_calls=config.max_metric_calls, + reflection_lm=utils.reflection_inference_fn(config.reflection_model), + reflection_minibatch_size=config.reflection_minibatch_size, + run_dir=output_dir, + ) + json.dump( + gepa_results.to_dict(), + open(os.path.join(output_dir, "results.json"), "w"), + ) + return gepa_results def run_eval(output_dir: str, instructions: str, config: ExperimentConfig): - """Runs evaluation on the test set using the given instructions. - - Args: - output_dir: The directory to save evaluation results. - instructions: The system instructions to evaluate. - config: The experiment configuration. - """ - eval_dataset = _get_dataset(config.eval_dataset) - tau_bench_run_config = RunConfig( - env=config.tau_bench_env, - model=config.agent_model, - model_provider=config.agent_model_provider, - user_model=config.user_model, - user_model_provider=config.user_model_provider, - agent_strategy='tool-calling', - user_strategy='llm', - max_concurrency=config.max_concurrency, - num_trials=config.num_eval_trials, - task_ids=eval_dataset, - log_dir=output_dir, - task_split=config.eval_dataset.split, - ) - with open(os.path.join(output_dir, 'prompt.txt'), 'w') as f: - f.write(instructions) - - json.dump( - tau_bench_run_config.model_dump(), - open(os.path.join(output_dir, 'run_config.json'), 'w'), - ) - tau_bench_results = run_tau_bench_rollouts( - tau_bench_run_config, - system_instruction=instructions, - rater=_rater(config) if config.use_rater else None, - ) - total = len(tau_bench_results) - numerator = sum(1 for res in tau_bench_results if res.reward == 1) - print( - f'average reward (total={total}): {numerator/total if total > 0 else 0}' - ) - json.dump( - dict(results=[r.model_dump() for r in tau_bench_results]), - open(os.path.join(output_dir, 'results.json'), 'w'), - ) + """Runs evaluation on the test set using the given instructions. + + Args: + output_dir: The directory to save evaluation results. + instructions: The system instructions to evaluate. + config: The experiment configuration. + """ + eval_dataset = _get_dataset(config.eval_dataset) + tau_bench_run_config = RunConfig( + env=config.tau_bench_env, + model=config.agent_model, + model_provider=config.agent_model_provider, + user_model=config.user_model, + user_model_provider=config.user_model_provider, + agent_strategy="tool-calling", + user_strategy="llm", + max_concurrency=config.max_concurrency, + num_trials=config.num_eval_trials, + task_ids=eval_dataset, + log_dir=output_dir, + task_split=config.eval_dataset.split, + ) + with open(os.path.join(output_dir, "prompt.txt"), "w") as f: + f.write(instructions) + + json.dump( + tau_bench_run_config.model_dump(), + open(os.path.join(output_dir, "run_config.json"), "w"), + ) + tau_bench_results = run_tau_bench_rollouts( + tau_bench_run_config, + system_instruction=instructions, + rater=_rater(config) if config.use_rater else None, + ) + total = len(tau_bench_results) + numerator = sum(1 for res in tau_bench_results if res.reward == 1) + print(f"average reward (total={total}): {numerator/total if total > 0 else 0}") + json.dump( + dict(results=[r.model_dump() for r in tau_bench_results]), + open(os.path.join(output_dir, "results.json"), "w"), + ) diff --git a/contributing/samples/gepa/run_experiment.py b/contributing/samples/gepa/run_experiment.py index 1bc4ee58c8..642a7e9bd3 100644 --- a/contributing/samples/gepa/run_experiment.py +++ b/contributing/samples/gepa/run_experiment.py @@ -25,143 +25,137 @@ from absl import flags import experiment from google.genai import types + import utils _OUTPUT_DIR = flags.DEFINE_string( - 'output_dir', + "output_dir", None, - 'Directory to save experiment results and artifacts.', + "Directory to save experiment results and artifacts.", required=True, ) _EVAL_SET_SIZE = flags.DEFINE_integer( - 'eval_set_size', + "eval_set_size", None, - 'Size of the dev set to use for Pareto frontier evaluation in GEPA. If' - ' None, uses all available dev tasks. A few tens of examples might' - ' suffice more simpler tasks and up to a few hundreds for ' - ' more complex and variable tasks. Increase the size to mitigate effect of' - ' variability at greater cost.', + "Size of the dev set to use for Pareto frontier evaluation in GEPA. If" + " None, uses all available dev tasks. A few tens of examples might" + " suffice more simpler tasks and up to a few hundreds for " + " more complex and variable tasks. Increase the size to mitigate effect of" + " variability at greater cost.", ) _MAX_METRIC_CALLS = flags.DEFINE_integer( - 'max_metric_calls', + "max_metric_calls", 500, - 'Total budget for GEPA prompt evaluations. This is the main control for' - ' runtime/cost. One could start with 100 and increase to 500+ for further' - ' optimization.', + "Total budget for GEPA prompt evaluations. This is the main control for" + " runtime/cost. One could start with 100 and increase to 500+ for further" + " optimization.", ) _NUM_TEST_RECORDS = flags.DEFINE_integer( - 'num_test_records', + "num_test_records", None, - 'Size of the test set for final evaluation of the optimized prompt. If' - ' None, uses all available test tasks.', + "Size of the test set for final evaluation of the optimized prompt. If" + " None, uses all available test tasks.", ) _NUM_EVAL_TRIALS = flags.DEFINE_integer( - 'num_eval_trials', + "num_eval_trials", 4, - 'Number of times each task is run during evaluation. Higher values give' - ' more stable evaluation metrics but increase runtime. Recommended: 4-8.', + "Number of times each task is run during evaluation. Higher values give" + " more stable evaluation metrics but increase runtime. Recommended: 4-8.", ) _MAX_CONCURRENCY = flags.DEFINE_integer( - 'max_concurrency', + "max_concurrency", 8, - 'Maximum number of parallel agent-environment interactions. Increase if' - ' you have sufficient API quota.', + "Maximum number of parallel agent-environment interactions. Increase if" + " you have sufficient API quota.", ) _EVAL_MODE = flags.DEFINE_bool( - 'eval_mode', + "eval_mode", False, - 'If set, run evaluation only using the seed prompt, skipping GEPA' - ' optimization.', + "If set, run evaluation only using the seed prompt, skipping GEPA" " optimization.", ) _USE_RATER = flags.DEFINE_bool( - 'use_rater', + "use_rater", False, - 'If set, use an LLM rater to score trajectories.', + "If set, use an LLM rater to score trajectories.", ) _TRAIN_BATCH_SIZE = flags.DEFINE_integer( - 'train_batch_size', + "train_batch_size", 3, - 'Number of trajectories sampled from rollouts to be used by the' - ' reflection model in each GEPA step to generate prompt improvements.' - ' Increasing the batch size may help provide a more stable signal and' - ' estimate of a prompt quality but entails higher cost. One can start with' - ' a low value and increase the size if significant variations are' - ' observed.', + "Number of trajectories sampled from rollouts to be used by the" + " reflection model in each GEPA step to generate prompt improvements." + " Increasing the batch size may help provide a more stable signal and" + " estimate of a prompt quality but entails higher cost. One can start with" + " a low value and increase the size if significant variations are" + " observed.", ) def main(argv: Sequence[str]) -> None: - if len(argv) > 1: - raise app.UsageError('Too many command-line arguments.') + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + + # Get a list of all existing loggers + # logging.root.manager.loggerDict contains all named loggers + # logging.getLogger(name) retrieves the logger object + loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict] - # Get a list of all existing loggers - # logging.root.manager.loggerDict contains all named loggers - # logging.getLogger(name) retrieves the logger object - loggers = [ - logging.getLogger(name) for name in logging.root.manager.loggerDict - ] + # Iterate through the loggers and set their level to WARNING + for logger in loggers: + logger.setLevel(logging.WARNING) - # Iterate through the loggers and set their level to WARNING - for logger in loggers: - logger.setLevel(logging.WARNING) + types.logger.addFilter(utils.FilterInferenceWarnings()) + output_dir = os.path.join( + _OUTPUT_DIR.value, datetime.now().strftime("%Y%m%d%H%M%S%f") + ) + os.makedirs(output_dir) + logging.info("Writing to output_dir=%s", output_dir) + config = experiment.ExperimentConfig( + tau_bench_env="retail", + agent_model="gemini-2.5-flash", + agent_model_provider="vertex_ai", + user_model="gemini-2.5-flash", + user_model_provider="vertex_ai", + max_concurrency=_MAX_CONCURRENCY.value, + num_eval_trials=_NUM_EVAL_TRIALS.value, + rnd_seed=42, + max_metric_calls=_MAX_METRIC_CALLS.value, + reflection_model="gemini-2.5-pro", + reflection_minibatch_size=_TRAIN_BATCH_SIZE.value, + use_rater=_USE_RATER.value, + feedback_dataset=experiment.Dataset(split="train"), + pareto_dataset=experiment.Dataset(split="dev", max_size=_EVAL_SET_SIZE.value), + eval_dataset=experiment.Dataset(split="test", max_size=_NUM_TEST_RECORDS.value), + ) + json.dump( + dataclasses.asdict(config), + open(os.path.join(output_dir, "config.json"), "w"), + ) + logging.info("Using config=%s", config) - types.logger.addFilter(utils.FilterInferenceWarnings()) - output_dir = os.path.join( - _OUTPUT_DIR.value, datetime.now().strftime('%Y%m%d%H%M%S%f') - ) - os.makedirs(output_dir) - logging.info('Writing to output_dir=%s', output_dir) - config = experiment.ExperimentConfig( - tau_bench_env='retail', - agent_model='gemini-2.5-flash', - agent_model_provider='vertex_ai', - user_model='gemini-2.5-flash', - user_model_provider='vertex_ai', - max_concurrency=_MAX_CONCURRENCY.value, - num_eval_trials=_NUM_EVAL_TRIALS.value, - rnd_seed=42, - max_metric_calls=_MAX_METRIC_CALLS.value, - reflection_model='gemini-2.5-pro', - reflection_minibatch_size=_TRAIN_BATCH_SIZE.value, - use_rater=_USE_RATER.value, - feedback_dataset=experiment.Dataset(split='train'), - pareto_dataset=experiment.Dataset( - split='dev', max_size=_EVAL_SET_SIZE.value - ), - eval_dataset=experiment.Dataset( - split='test', max_size=_NUM_TEST_RECORDS.value - ), - ) - json.dump( - dataclasses.asdict(config), - open(os.path.join(output_dir, 'config.json'), 'w'), - ) - logging.info('Using config=%s', config) + if _EVAL_MODE.value: + return experiment.run_eval( + output_dir=output_dir, + instructions=experiment.SEED_SYSTEM_INSTRUCTION, + config=config, + ) - if _EVAL_MODE.value: - return experiment.run_eval( - output_dir=output_dir, - instructions=experiment.SEED_SYSTEM_INSTRUCTION, + results = experiment.run_gepa( config=config, + seed_instructions=experiment.SEED_SYSTEM_INSTRUCTION, + output_dir=output_dir, ) + print(list(enumerate(results.val_aggregate_scores))) - results = experiment.run_gepa( - config=config, - seed_instructions=experiment.SEED_SYSTEM_INSTRUCTION, - output_dir=output_dir, - ) - print(list(enumerate(results.val_aggregate_scores))) - - eval_dir = os.path.join( - output_dir, 'evals', datetime.now().strftime('%Y%m%d%H%M%S%f') - ) - os.makedirs(eval_dir) - experiment.run_eval( - output_dir=eval_dir, - instructions=results.best_candidate['system_instruction'], - config=config, - ) + eval_dir = os.path.join( + output_dir, "evals", datetime.now().strftime("%Y%m%d%H%M%S%f") + ) + os.makedirs(eval_dir) + experiment.run_eval( + output_dir=eval_dir, + instructions=results.best_candidate["system_instruction"], + config=config, + ) -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/src/google/adk/models/gemini_context_cache_manager.py b/src/google/adk/models/gemini_context_cache_manager.py index 9747c1043c..81da84a022 100644 --- a/src/google/adk/models/gemini_context_cache_manager.py +++ b/src/google/adk/models/gemini_context_cache_manager.py @@ -15,7 +15,7 @@ """Manages context cache lifecycle for Gemini models.""" from __future__ import annotations - +import contextlib import hashlib import json import logging @@ -362,18 +362,16 @@ async def _create_gemini_cache_with_optional_tracing( Cache metadata with precise creation timestamp """ + span_context = contextlib.nullcontext() if not self.disable_telemetry: from ..telemetry.tracing import tracer + span_context = tracer.start_as_current_span("create_cache") - with tracer.start_as_current_span("create_cache") as span: - return await self._create_gemini_cache_body( - llm_request=llm_request, - cache_contents_count=cache_contents_count, - span=span, - ) - else: + with span_context as span: return await self._create_gemini_cache_body( - llm_request=llm_request, cache_contents_count=cache_contents_count + llm_request=llm_request, + cache_contents_count=cache_contents_count, + span=span, ) async def _create_gemini_cache_body( From 6a3a388347d75bc0510d01e419f8c949f6a5d149 Mon Sep 17 00:00:00 2001 From: mportdata Date: Sun, 28 Dec 2025 14:47:54 +0000 Subject: [PATCH 13/24] chore(samples): revert GEPA formatting Restore GEPA sample files to upstream main to avoid unrelated formatting changes. --- contributing/samples/gepa/experiment.py | 1050 +++++++++---------- contributing/samples/gepa/run_experiment.py | 189 ++-- 2 files changed, 623 insertions(+), 616 deletions(-) diff --git a/contributing/samples/gepa/experiment.py b/contributing/samples/gepa/experiment.py index 8d05507ed9..2f5d03a772 100644 --- a/contributing/samples/gepa/experiment.py +++ b/contributing/samples/gepa/experiment.py @@ -53,204 +53,202 @@ def run_tau_bench_rollouts( system_instruction: str | None = None, rater: rater_lib.Rater | None = None, ) -> list[EnvRunResult]: - """Runs a set of tau-bench tasks with a given agent configuration. - - This is a customized version of the standard tau-bench run function, adapted - for this experiment's needs. It handles environment setup, agent creation, - task execution in parallel, and result aggregation. - - Args: - config: A RunConfig object specifying the environment, models, and other - parameters for the run. - print_results: If True, prints the result of each task as it completes. - system_instruction: An optional system instruction to use for the agent, - overriding the default. - rater: An optional rater to evaluate the agent's performance. - - Returns: - A list of EnvRunResult objects, one for each completed task. - """ - if config.env not in ["retail", "airline"]: - raise ValueError("Only retail and airline envs are supported") - if config.model_provider not in provider_list: - raise ValueError("Invalid model provider") - if config.user_model_provider not in provider_list: - raise ValueError("Invalid user model provider") - if config.agent_strategy not in ["tool-calling", "act", "react", "few-shot"]: - raise ValueError("Invalid agent strategy") - if config.task_split not in ["train", "test", "dev"]: - raise ValueError("Invalid task split") - if config.user_strategy not in [item.value for item in UserStrategy]: - raise ValueError("Invalid user strategy") - - random.seed(config.seed) - time_str = datetime.now().strftime("%m%d%H%M%S") - model_name = config.model.split("/")[-1] - ckpt_filename = ( - f"{config.agent_strategy}-{model_name}-{config.temperature}_range_" - f"{config.start_index}-{config.end_index}_user-{config.user_model}-" - f"{config.user_strategy}_{time_str}.json" - ) - ckpt_path = os.path.join(config.log_dir, ckpt_filename) - if not os.path.exists(config.log_dir): - os.makedirs(config.log_dir) - - print(f"Loading user with strategy: {config.user_strategy}") - env = get_env( - config.env, - user_strategy=config.user_strategy, - user_model=config.user_model, - user_provider=config.user_model_provider, - task_split=config.task_split, + """Runs a set of tau-bench tasks with a given agent configuration. + + This is a customized version of the standard tau-bench run function, adapted + for this experiment's needs. It handles environment setup, agent creation, + task execution in parallel, and result aggregation. + + Args: + config: A RunConfig object specifying the environment, models, and other + parameters for the run. + print_results: If True, prints the result of each task as it completes. + system_instruction: An optional system instruction to use for the agent, + overriding the default. + rater: An optional rater to evaluate the agent's performance. + + Returns: + A list of EnvRunResult objects, one for each completed task. + """ + if config.env not in ['retail', 'airline']: + raise ValueError('Only retail and airline envs are supported') + if config.model_provider not in provider_list: + raise ValueError('Invalid model provider') + if config.user_model_provider not in provider_list: + raise ValueError('Invalid user model provider') + if config.agent_strategy not in ['tool-calling', 'act', 'react', 'few-shot']: + raise ValueError('Invalid agent strategy') + if config.task_split not in ['train', 'test', 'dev']: + raise ValueError('Invalid task split') + if config.user_strategy not in [item.value for item in UserStrategy]: + raise ValueError('Invalid user strategy') + + random.seed(config.seed) + time_str = datetime.now().strftime('%m%d%H%M%S') + model_name = config.model.split('/')[-1] + ckpt_filename = ( + f'{config.agent_strategy}-{model_name}-{config.temperature}_range_' + f'{config.start_index}-{config.end_index}_user-{config.user_model}-' + f'{config.user_strategy}_{time_str}.json' + ) + ckpt_path = os.path.join(config.log_dir, ckpt_filename) + if not os.path.exists(config.log_dir): + os.makedirs(config.log_dir) + + print(f'Loading user with strategy: {config.user_strategy}') + env = get_env( + config.env, + user_strategy=config.user_strategy, + user_model=config.user_model, + user_provider=config.user_model_provider, + task_split=config.task_split, + ) + if system_instruction: + env.wiki = system_instruction + agent = tau_bench_agent_lib.adk_agent_factory( + tools_info=env.tools_info, + wiki=env.wiki, + config=config, + ) + if config.end_index == -1: + end_index = len(env.tasks) + else: + end_index = min(config.end_index, len(env.tasks)) + results: list[EnvRunResult] = [] + lock = multiprocessing.Lock() + if config.task_ids: + print(f'Running tasks {config.task_ids} (checkpoint path: {ckpt_path})') + else: + print( + f'Running tasks {config.start_index} to {end_index} ' + f'(checkpoint path: {ckpt_path})' ) - if system_instruction: - env.wiki = system_instruction - agent = tau_bench_agent_lib.adk_agent_factory( - tools_info=env.tools_info, - wiki=env.wiki, - config=config, - ) - if config.end_index == -1: - end_index = len(env.tasks) - else: - end_index = min(config.end_index, len(env.tasks)) - results: list[EnvRunResult] = [] - lock = multiprocessing.Lock() + for i in range(config.num_trials): if config.task_ids: - print(f"Running tasks {config.task_ids} (checkpoint path: {ckpt_path})") + idxs = config.task_ids else: + idxs = list(range(config.start_index, end_index)) + if config.shuffle: + random.shuffle(idxs) + + @retry(tries=3, delay=10, backoff=2) + def _run_with_retry(idx: int) -> EnvRunResult: + isolated_env = get_env( + config.env, + user_strategy=config.user_strategy, + user_model=config.user_model, + task_split=config.task_split, + user_provider=config.user_model_provider, + task_index=idx, + ) + if print_results: + print(f'Running task {idx}') + res = agent.solve( + env=isolated_env, + task_index=idx, + ) + + rating = ( + rater(res.messages[1:] if len(res.messages) > 1 else res.messages) + if rater + else None + ) + info = dict(res.info) + info['metrics'] = dict(rating=rating, reward=res.reward) + + if rater: + score = rating['score'] + feedback = {k: v for k, v in rating.items() if k != 'score'} + else: + score = res.reward + feedback = ( + 'The agent successfully resolved all customer issues' + if score > 0 + else 'The agent failed to resolve all customer issues correctly' + ) + + info['feedback'] = feedback + return EnvRunResult( + task_id=idx, + reward=score, + info=info, + traj=res.messages, + trial=i, + ) + + def _run(idx: int) -> EnvRunResult: + try: + result = _run_with_retry(idx) + except Exception as e: + logging.warning('Inference error: %s', str(e)) + result = EnvRunResult( + task_id=idx, + reward=0.0, + info={ + 'error': str(e), + 'traceback': traceback.format_exc(), + 'metrics': dict(reward=0.0), + }, + traj=[], + trial=i, + ) + + if print_results: print( - f"Running tasks {config.start_index} to {end_index} " - f"(checkpoint path: {ckpt_path})" + '✅' if result.reward == 1 else '❌', + f'task_id={idx}', ) - for i in range(config.num_trials): - if config.task_ids: - idxs = config.task_ids - else: - idxs = list(range(config.start_index, end_index)) - if config.shuffle: - random.shuffle(idxs) - - @retry(tries=3, delay=10, backoff=2) - def _run_with_retry(idx: int) -> EnvRunResult: - isolated_env = get_env( - config.env, - user_strategy=config.user_strategy, - user_model=config.user_model, - task_split=config.task_split, - user_provider=config.user_model_provider, - task_index=idx, - ) - if print_results: - print(f"Running task {idx}") - res = agent.solve( - env=isolated_env, - task_index=idx, - ) - - rating = ( - rater(res.messages[1:] if len(res.messages) > 1 else res.messages) - if rater - else None - ) - info = dict(res.info) - info["metrics"] = dict(rating=rating, reward=res.reward) - - if rater: - score = rating["score"] - feedback = {k: v for k, v in rating.items() if k != "score"} - else: - score = res.reward - feedback = ( - "The agent successfully resolved all customer issues" - if score > 0 - else "The agent failed to resolve all customer issues correctly" - ) - - info["feedback"] = feedback - return EnvRunResult( - task_id=idx, - reward=score, - info=info, - traj=res.messages, - trial=i, - ) - - def _run(idx: int) -> EnvRunResult: - try: - result = _run_with_retry(idx) - except Exception as e: - logging.warning("Inference error: %s", str(e)) - result = EnvRunResult( - task_id=idx, - reward=0.0, - info={ - "error": str(e), - "traceback": traceback.format_exc(), - "metrics": dict(reward=0.0), - }, - traj=[], - trial=i, - ) - - if print_results: - print( - "✅" if result.reward == 1 else "❌", - f"task_id={idx}", - ) - print("-----") - with lock: - data = [] - if os.path.exists(ckpt_path): - with open(ckpt_path, "r") as f: - data = json.load(f) - with open(ckpt_path, "w") as f: - json.dump(data + [result.model_dump()], f, indent=2) - return result - - with ThreadPoolExecutor(max_workers=config.max_concurrency) as executor: - res = list(executor.map(_run, idxs)) - results.extend(res) - - display_metrics(results) - - if rater: - print("Environment reward:") - display_metrics( - [ - EnvRunResult( - task_id=r.task_id, - reward=r.info["metrics"]["reward"], - info={}, - traj=[], - trial=r.trial, - ) - for r in results - ] + print('-----') + with lock: + data = [] + if os.path.exists(ckpt_path): + with open(ckpt_path, 'r') as f: + data = json.load(f) + with open(ckpt_path, 'w') as f: + json.dump(data + [result.model_dump()], f, indent=2) + return result + + with ThreadPoolExecutor(max_workers=config.max_concurrency) as executor: + res = list(executor.map(_run, idxs)) + results.extend(res) + + display_metrics(results) + + if rater: + print('Environment reward:') + display_metrics([ + EnvRunResult( + task_id=r.task_id, + reward=r.info['metrics']['reward'], + info={}, + traj=[], + trial=r.trial, ) + for r in results + ]) - with open(ckpt_path, "w") as f: - json.dump([result.model_dump() for result in results], f, indent=2) - print(f"\n📄 Results saved to {ckpt_path}\n") - return results + with open(ckpt_path, 'w') as f: + json.dump([result.model_dump() for result in results], f, indent=2) + print(f'\n📄 Results saved to {ckpt_path}\n') + return results class TauBenchDataInst(TypedDict): - env: str - task_id: int - task_split: str + env: str + task_id: int + task_split: str class TauBenchTrajectory(TypedDict): - result_traj: list[dict[str, Any]] + result_traj: list[dict[str, Any]] class TauBenchRolloutOutput(TypedDict): - env: str - task_id: int - reward: float - task_info: dict[str, Any] + env: str + task_id: int + reward: float + task_info: dict[str, Any] class TauBenchAdapter( @@ -260,381 +258,383 @@ class TauBenchAdapter( TauBenchRolloutOutput, ] ): - """A GEPA adapter for evaluating agent performance on tau-bench benchmark.""" - - def __init__( - self, - env_name: str, - agent_model: str = "gemini-2.5-flash", - agent_model_provider: str = "vertex_ai", - user_model: str = "gemini-2.5-pro", - user_model_provider: str = "vertex_ai", - agent_strategy: str = "tool-calling", - user_strategy: str = "llm", - system_instruction_name: str = "system_instruction", - max_concurrency: int = 4, - rater: rater_lib.Rater | None = None, - log_dir: str | None = None, - ): - """Initializes the TauBenchAdapter. - - Args: - env_name: environment - agent_model: The model to use for the agent. - agent_model_provider: The provider for the agent model. - user_model: The model to use for simulating the user. - user_model_provider: The provider for the user model. - agent_strategy: The agent strategy to use (e.g., 'tool-calling'). - user_strategy: The user simulation strategy (e.g., 'llm'). - system_instruction_name: The key in the candidate dictionary that holds - the system instruction. - max_concurrency: The maximum number of tasks to run in parallel. - rater: An optional rater to evaluate the agent's performance. - log_dir: The directory to save traces and other logs. - """ - self._env_name = env_name - self._agent_model = agent_model - self._agent_model_provider = agent_model_provider - self._user_model = user_model - self._user_model_provider = user_model_provider - self._agent_strategy = agent_strategy - self._user_strategy = user_strategy - self._max_concurrency = max_concurrency - self._system_instruction_name = system_instruction_name - self._rater = rater - self._log_dir = log_dir - - def evaluate( - self, - batch: list[TauBenchDataInst], - candidate: dict[str, str], - capture_traces: bool = False, - ) -> EvaluationBatch[TauBenchTrajectory, TauBenchRolloutOutput]: - """Evaluates a candidate prompt on a batch of tau-bench tasks. - - This method is called by GEPA during the optimization loop. It takes a - candidate prompt, runs it against the specified tasks from tau-bench, and - returns the results. - - Args: - batch: A list of task instances to evaluate on. Each instance specifies - the environment and task ID. - candidate: A dictionary containing the components to be evaluated, - including the system instruction. - capture_traces: (Not used in this adapter) Whether to capture detailed - traces. - - Returns: - An EvaluationBatch object containing scores, outputs, and trajectories for - each task in the batch. - """ - del capture_traces # Not used. - env = batch[0]["env"] - task_ids = [inst["task_id"] for inst in batch] - tau_bench_run_config = RunConfig( - env=env, - model=self._agent_model, - model_provider=self._agent_model_provider, - user_model=self._user_model, - user_model_provider=self._user_model_provider, - agent_strategy=self._agent_strategy, - user_strategy=self._user_strategy, - max_concurrency=self._max_concurrency, - task_ids=task_ids, - log_dir=self._log_dir, - task_split=batch[0]["task_split"], - ) - tau_bench_results = run_tau_bench_rollouts( - tau_bench_run_config, - system_instruction=candidate.get(self._system_instruction_name), - rater=self._rater, - ) + """A GEPA adapter for evaluating agent performance on tau-bench benchmark.""" + + def __init__( + self, + env_name: str, + agent_model: str = 'gemini-2.5-flash', + agent_model_provider: str = 'vertex_ai', + user_model: str = 'gemini-2.5-pro', + user_model_provider: str = 'vertex_ai', + agent_strategy: str = 'tool-calling', + user_strategy: str = 'llm', + system_instruction_name: str = 'system_instruction', + max_concurrency: int = 4, + rater: rater_lib.Rater | None = None, + log_dir: str | None = None, + ): + """Initializes the TauBenchAdapter. - outputs = [] - trajectories = [] - scores = [] - for res in tau_bench_results: - outputs.append( - TauBenchRolloutOutput( - env=env, - task_id=res.task_id, - reward=res.reward, - task_info=res.info, - ) - ) - result_traj = res.traj - trajectories.append(TauBenchTrajectory(result_traj=result_traj)) - scores.append(res.reward) - - return EvaluationBatch( - scores=scores, outputs=outputs, trajectories=trajectories - ) + Args: + env_name: environment + agent_model: The model to use for the agent. + agent_model_provider: The provider for the agent model. + user_model: The model to use for simulating the user. + user_model_provider: The provider for the user model. + agent_strategy: The agent strategy to use (e.g., 'tool-calling'). + user_strategy: The user simulation strategy (e.g., 'llm'). + system_instruction_name: The key in the candidate dictionary that holds + the system instruction. + max_concurrency: The maximum number of tasks to run in parallel. + rater: An optional rater to evaluate the agent's performance. + log_dir: The directory to save traces and other logs. + """ + self._env_name = env_name + self._agent_model = agent_model + self._agent_model_provider = agent_model_provider + self._user_model = user_model + self._user_model_provider = user_model_provider + self._agent_strategy = agent_strategy + self._user_strategy = user_strategy + self._max_concurrency = max_concurrency + self._system_instruction_name = system_instruction_name + self._rater = rater + self._log_dir = log_dir + + def evaluate( + self, + batch: list[TauBenchDataInst], + candidate: dict[str, str], + capture_traces: bool = False, + ) -> EvaluationBatch[TauBenchTrajectory, TauBenchRolloutOutput]: + """Evaluates a candidate prompt on a batch of tau-bench tasks. + + This method is called by GEPA during the optimization loop. It takes a + candidate prompt, runs it against the specified tasks from tau-bench, and + returns the results. - def make_reflective_dataset( - self, - candidate: dict[str, str], - eval_batch: EvaluationBatch[TauBenchTrajectory, TauBenchRolloutOutput], - components_to_update: list[str], - ) -> dict[str, list[dict[str, Any]]]: - """Creates a dataset for reflection based on evaluation results. - - This method transforms the trajectories and scores from an evaluation run - into a structured format that a reflection model can use to generate - suggestions for improving the prompt. - - Args: - candidate: The candidate that was evaluated. - eval_batch: The results of the evaluation. - components_to_update: A list of component names that the reflection should - focus on improving. - - Returns: - A dictionary where keys are component names and values are lists of - data instances for reflection. - """ - system_instruction = candidate[self._system_instruction_name] - - env = get_env( - self._env_name, - user_strategy=self._user_strategy, - user_model=self._user_model, - user_provider=self._user_model_provider, - task_split="train", - ) + Args: + batch: A list of task instances to evaluate on. Each instance specifies + the environment and task ID. + candidate: A dictionary containing the components to be evaluated, + including the system instruction. + capture_traces: (Not used in this adapter) Whether to capture detailed + traces. - tool_definitions = json.dumps( - env.tools_info, - indent=2, - default=str, - ) - inputs = "\n\n".join( - [ - f"# System Instruction\n{system_instruction}", - f"# Tool Definitions\n{tool_definitions}", - ] - ) - ret_d: dict[str, list[dict[str, Any]]] = {} - for comp in components_to_update: - items: list[dict[str, Any]] = [] - trace_instances = list( - zip( - eval_batch.trajectories, - eval_batch.scores, - eval_batch.outputs, - strict=True, - ) - ) - for trace_instance in trace_instances: - traj, _, rollout = trace_instance - messages = traj["result_traj"] - # Remove instructions. - if len(messages) > 1: - messages = messages[1:] - d = { - "Inputs": inputs, - "Generated Outputs": json.dumps(messages, indent=2, default=str), - "Feedback": json.dumps( - rollout["task_info"]["feedback"], indent=2, default=str - ), - } - items.append(d) - if items: - ret_d[comp] = items - assert ret_d, ( - "empty reflective dataset for components " - f"{[comp for comp in components_to_update]}" - ) - return ret_d + Returns: + An EvaluationBatch object containing scores, outputs, and trajectories for + each task in the batch. + """ + del capture_traces # Not used. + env = batch[0]['env'] + task_ids = [inst['task_id'] for inst in batch] + tau_bench_run_config = RunConfig( + env=env, + model=self._agent_model, + model_provider=self._agent_model_provider, + user_model=self._user_model, + user_model_provider=self._user_model_provider, + agent_strategy=self._agent_strategy, + user_strategy=self._user_strategy, + max_concurrency=self._max_concurrency, + task_ids=task_ids, + log_dir=self._log_dir, + task_split=batch[0]['task_split'], + ) + tau_bench_results = run_tau_bench_rollouts( + tau_bench_run_config, + system_instruction=candidate.get(self._system_instruction_name), + rater=self._rater, + ) + + outputs = [] + trajectories = [] + scores = [] + for res in tau_bench_results: + outputs.append( + TauBenchRolloutOutput( + env=env, + task_id=res.task_id, + reward=res.reward, + task_info=res.info, + ) + ) + result_traj = res.traj + trajectories.append(TauBenchTrajectory(result_traj=result_traj)) + scores.append(res.reward) + + return EvaluationBatch( + scores=scores, outputs=outputs, trajectories=trajectories + ) + + def make_reflective_dataset( + self, + candidate: dict[str, str], + eval_batch: EvaluationBatch[TauBenchTrajectory, TauBenchRolloutOutput], + components_to_update: list[str], + ) -> dict[str, list[dict[str, Any]]]: + """Creates a dataset for reflection based on evaluation results. + + This method transforms the trajectories and scores from an evaluation run + into a structured format that a reflection model can use to generate + suggestions for improving the prompt. + + Args: + candidate: The candidate that was evaluated. + eval_batch: The results of the evaluation. + components_to_update: A list of component names that the reflection should + focus on improving. + + Returns: + A dictionary where keys are component names and values are lists of + data instances for reflection. + """ + system_instruction = candidate[self._system_instruction_name] + + env = get_env( + self._env_name, + user_strategy=self._user_strategy, + user_model=self._user_model, + user_provider=self._user_model_provider, + task_split='train', + ) + + tool_definitions = json.dumps( + env.tools_info, + indent=2, + default=str, + ) + inputs = '\n\n'.join([ + f'# System Instruction\n{system_instruction}', + f'# Tool Definitions\n{tool_definitions}', + ]) + ret_d: dict[str, list[dict[str, Any]]] = {} + for comp in components_to_update: + items: list[dict[str, Any]] = [] + trace_instances = list( + zip( + eval_batch.trajectories, + eval_batch.scores, + eval_batch.outputs, + strict=True, + ) + ) + for trace_instance in trace_instances: + traj, _, rollout = trace_instance + messages = traj['result_traj'] + # Remove instructions. + if len(messages) > 1: + messages = messages[1:] + d = { + 'Inputs': inputs, + 'Generated Outputs': json.dumps(messages, indent=2, default=str), + 'Feedback': json.dumps( + rollout['task_info']['feedback'], indent=2, default=str + ), + } + items.append(d) + if items: + ret_d[comp] = items + assert ret_d, ( + 'empty reflective dataset for components ' + f'{[comp for comp in components_to_update]}' + ) + return ret_d _DATASET_SPLITS = { - "train": tasks_train.TASKS_TRAIN, - "dev": tasks_dev.TASKS_DEV, - "test": tasks_test.TASKS_TEST, + 'train': tasks_train.TASKS_TRAIN, + 'dev': tasks_dev.TASKS_DEV, + 'test': tasks_test.TASKS_TEST, } def _get_dataset(ds: Dataset) -> list[TauBenchDataInst]: - task_ids = ds.indexes or list(range(len(_DATASET_SPLITS[ds.split]))) - if ds.max_size is not None: - task_ids = task_ids[: ds.max_size] - random.shuffle(task_ids) - return task_ids + task_ids = ds.indexes or list(range(len(_DATASET_SPLITS[ds.split]))) + if ds.max_size is not None: + task_ids = task_ids[: ds.max_size] + random.shuffle(task_ids) + return task_ids def _get_datasets( config: ExperimentConfig, ) -> dict[str, list[int]]: - """Returns Tau-bench dataset splits.""" - random.seed(config.rnd_seed) - train_task_ids = _get_dataset(config.feedback_dataset) - eval_task_ids = _get_dataset(config.pareto_dataset) - test_task_ids = _get_dataset(config.eval_dataset) - logging.info( - "Using datasets of size: train=%d, eval=%d, test=%d", - len(train_task_ids), - len(eval_task_ids), - len(test_task_ids), - ) - return dict( - train=train_task_ids, - dev=eval_task_ids, - test=test_task_ids, - ) + """Returns Tau-bench dataset splits.""" + random.seed(config.rnd_seed) + train_task_ids = _get_dataset(config.feedback_dataset) + eval_task_ids = _get_dataset(config.pareto_dataset) + test_task_ids = _get_dataset(config.eval_dataset) + logging.info( + 'Using datasets of size: train=%d, eval=%d, test=%d', + len(train_task_ids), + len(eval_task_ids), + len(test_task_ids), + ) + return dict( + train=train_task_ids, + dev=eval_task_ids, + test=test_task_ids, + ) SEED_SYSTEM_INSTRUCTION = ( - "you are a customer support agent helping customers resolve their " - "issues by using the right tools" + 'you are a customer support agent helping customers resolve their ' + 'issues by using the right tools' ) @dataclasses.dataclass(frozen=True) class Dataset: - split: str - indexes: list[int] | None = None - max_size: int = None + split: str + indexes: list[int] | None = None + max_size: int = None @dataclasses.dataclass class ExperimentConfig: - """Configures a GEPA experiment on Tau-bench.""" - - tau_bench_env: str - agent_model: str - agent_model_provider: str - user_model: str - user_model_provider: str - max_concurrency: int - num_eval_trials: int - rnd_seed: int - max_metric_calls: int - reflection_model: str - reflection_minibatch_size: int - use_rater: bool - feedback_dataset: Dataset - pareto_dataset: Dataset - eval_dataset: Dataset + """Configures a GEPA experiment on Tau-bench.""" + + tau_bench_env: str + agent_model: str + agent_model_provider: str + user_model: str + user_model_provider: str + max_concurrency: int + num_eval_trials: int + rnd_seed: int + max_metric_calls: int + reflection_model: str + reflection_minibatch_size: int + use_rater: bool + feedback_dataset: Dataset + pareto_dataset: Dataset + eval_dataset: Dataset def _rater(config: ExperimentConfig) -> rater_lib.Rater: - env = get_env( - config.tau_bench_env, - user_strategy="llm", - user_model=config.user_model, - user_provider=config.user_model_provider, - task_split="train", - ) - return rater_lib.Rater(json.dumps(env.tools_info, indent=2)) - - -def run_gepa(output_dir: str, seed_instructions: str, config: ExperimentConfig) -> Any: - """Runs the GEPA optimization loop to train a new system instruction. - - Args: - output_dir: The directory to save experiment results and artifacts. - seed_instructions: Agent instructions to initialize the agent with. - config: The experiment configuration. - - Returns: - The results of the GEPA optimization. - """ - # This section sets up and runs the GEPA optimization experiment. - # Here we define all the parameters for the tau-bench environment, the GEPA - # optimization loop, and the models to be used. - datasets = _get_datasets(config) - training_set = [ - TauBenchDataInst( - env=config.tau_bench_env, - task_id=task_id, - task_split=config.feedback_dataset.split, - ) - for task_id in datasets["train"] - ] - eval_set = [ - TauBenchDataInst( - env=config.tau_bench_env, - task_id=task_id, - task_split=config.pareto_dataset.split, - ) - for task_id in datasets["dev"] - ] - system_instruction_name = "system_instruction" - - tau_bench_adapter = TauBenchAdapter( - env_name=config.tau_bench_env, - agent_model=config.agent_model, - agent_model_provider=config.agent_model_provider, - user_model=config.user_model, - user_model_provider=config.user_model_provider, - agent_strategy="tool-calling", - user_strategy="llm", - system_instruction_name=system_instruction_name, - max_concurrency=config.max_concurrency, - rater=_rater(config) if config.use_rater else None, - log_dir=os.path.join(output_dir, "traces"), - ) - - gepa_results = gepa.optimize( - seed_candidate={ - system_instruction_name: seed_instructions, - }, - trainset=training_set, - valset=eval_set, - task_lm=None, # this must be None when a custom adapter is used - adapter=tau_bench_adapter, - max_metric_calls=config.max_metric_calls, - reflection_lm=utils.reflection_inference_fn(config.reflection_model), - reflection_minibatch_size=config.reflection_minibatch_size, - run_dir=output_dir, - ) - json.dump( - gepa_results.to_dict(), - open(os.path.join(output_dir, "results.json"), "w"), - ) - return gepa_results + env = get_env( + config.tau_bench_env, + user_strategy='llm', + user_model=config.user_model, + user_provider=config.user_model_provider, + task_split='train', + ) + return rater_lib.Rater(json.dumps(env.tools_info, indent=2)) + + +def run_gepa( + output_dir: str, seed_instructions: str, config: ExperimentConfig +) -> Any: + """Runs the GEPA optimization loop to train a new system instruction. + + Args: + output_dir: The directory to save experiment results and artifacts. + seed_instructions: Agent instructions to initialize the agent with. + config: The experiment configuration. + + Returns: + The results of the GEPA optimization. + """ + # This section sets up and runs the GEPA optimization experiment. + # Here we define all the parameters for the tau-bench environment, the GEPA + # optimization loop, and the models to be used. + datasets = _get_datasets(config) + training_set = [ + TauBenchDataInst( + env=config.tau_bench_env, + task_id=task_id, + task_split=config.feedback_dataset.split, + ) + for task_id in datasets['train'] + ] + eval_set = [ + TauBenchDataInst( + env=config.tau_bench_env, + task_id=task_id, + task_split=config.pareto_dataset.split, + ) + for task_id in datasets['dev'] + ] + system_instruction_name = 'system_instruction' + + tau_bench_adapter = TauBenchAdapter( + env_name=config.tau_bench_env, + agent_model=config.agent_model, + agent_model_provider=config.agent_model_provider, + user_model=config.user_model, + user_model_provider=config.user_model_provider, + agent_strategy='tool-calling', + user_strategy='llm', + system_instruction_name=system_instruction_name, + max_concurrency=config.max_concurrency, + rater=_rater(config) if config.use_rater else None, + log_dir=os.path.join(output_dir, 'traces'), + ) + + gepa_results = gepa.optimize( + seed_candidate={ + system_instruction_name: seed_instructions, + }, + trainset=training_set, + valset=eval_set, + task_lm=None, # this must be None when a custom adapter is used + adapter=tau_bench_adapter, + max_metric_calls=config.max_metric_calls, + reflection_lm=utils.reflection_inference_fn(config.reflection_model), + reflection_minibatch_size=config.reflection_minibatch_size, + run_dir=output_dir, + ) + json.dump( + gepa_results.to_dict(), + open(os.path.join(output_dir, 'results.json'), 'w'), + ) + return gepa_results def run_eval(output_dir: str, instructions: str, config: ExperimentConfig): - """Runs evaluation on the test set using the given instructions. - - Args: - output_dir: The directory to save evaluation results. - instructions: The system instructions to evaluate. - config: The experiment configuration. - """ - eval_dataset = _get_dataset(config.eval_dataset) - tau_bench_run_config = RunConfig( - env=config.tau_bench_env, - model=config.agent_model, - model_provider=config.agent_model_provider, - user_model=config.user_model, - user_model_provider=config.user_model_provider, - agent_strategy="tool-calling", - user_strategy="llm", - max_concurrency=config.max_concurrency, - num_trials=config.num_eval_trials, - task_ids=eval_dataset, - log_dir=output_dir, - task_split=config.eval_dataset.split, - ) - with open(os.path.join(output_dir, "prompt.txt"), "w") as f: - f.write(instructions) - - json.dump( - tau_bench_run_config.model_dump(), - open(os.path.join(output_dir, "run_config.json"), "w"), - ) - tau_bench_results = run_tau_bench_rollouts( - tau_bench_run_config, - system_instruction=instructions, - rater=_rater(config) if config.use_rater else None, - ) - total = len(tau_bench_results) - numerator = sum(1 for res in tau_bench_results if res.reward == 1) - print(f"average reward (total={total}): {numerator/total if total > 0 else 0}") - json.dump( - dict(results=[r.model_dump() for r in tau_bench_results]), - open(os.path.join(output_dir, "results.json"), "w"), - ) + """Runs evaluation on the test set using the given instructions. + + Args: + output_dir: The directory to save evaluation results. + instructions: The system instructions to evaluate. + config: The experiment configuration. + """ + eval_dataset = _get_dataset(config.eval_dataset) + tau_bench_run_config = RunConfig( + env=config.tau_bench_env, + model=config.agent_model, + model_provider=config.agent_model_provider, + user_model=config.user_model, + user_model_provider=config.user_model_provider, + agent_strategy='tool-calling', + user_strategy='llm', + max_concurrency=config.max_concurrency, + num_trials=config.num_eval_trials, + task_ids=eval_dataset, + log_dir=output_dir, + task_split=config.eval_dataset.split, + ) + with open(os.path.join(output_dir, 'prompt.txt'), 'w') as f: + f.write(instructions) + + json.dump( + tau_bench_run_config.model_dump(), + open(os.path.join(output_dir, 'run_config.json'), 'w'), + ) + tau_bench_results = run_tau_bench_rollouts( + tau_bench_run_config, + system_instruction=instructions, + rater=_rater(config) if config.use_rater else None, + ) + total = len(tau_bench_results) + numerator = sum(1 for res in tau_bench_results if res.reward == 1) + print( + f'average reward (total={total}): {numerator/total if total > 0 else 0}' + ) + json.dump( + dict(results=[r.model_dump() for r in tau_bench_results]), + open(os.path.join(output_dir, 'results.json'), 'w'), + ) diff --git a/contributing/samples/gepa/run_experiment.py b/contributing/samples/gepa/run_experiment.py index 642a7e9bd3..cfd850b3a3 100644 --- a/contributing/samples/gepa/run_experiment.py +++ b/contributing/samples/gepa/run_experiment.py @@ -29,133 +29,140 @@ import utils _OUTPUT_DIR = flags.DEFINE_string( - "output_dir", + 'output_dir', None, - "Directory to save experiment results and artifacts.", + 'Directory to save experiment results and artifacts.', required=True, ) _EVAL_SET_SIZE = flags.DEFINE_integer( - "eval_set_size", + 'eval_set_size', None, - "Size of the dev set to use for Pareto frontier evaluation in GEPA. If" - " None, uses all available dev tasks. A few tens of examples might" - " suffice more simpler tasks and up to a few hundreds for " - " more complex and variable tasks. Increase the size to mitigate effect of" - " variability at greater cost.", + 'Size of the dev set to use for Pareto frontier evaluation in GEPA. If' + ' None, uses all available dev tasks. A few tens of examples might' + ' suffice more simpler tasks and up to a few hundreds for ' + ' more complex and variable tasks. Increase the size to mitigate effect of' + ' variability at greater cost.', ) _MAX_METRIC_CALLS = flags.DEFINE_integer( - "max_metric_calls", + 'max_metric_calls', 500, - "Total budget for GEPA prompt evaluations. This is the main control for" - " runtime/cost. One could start with 100 and increase to 500+ for further" - " optimization.", + 'Total budget for GEPA prompt evaluations. This is the main control for' + ' runtime/cost. One could start with 100 and increase to 500+ for further' + ' optimization.', ) _NUM_TEST_RECORDS = flags.DEFINE_integer( - "num_test_records", + 'num_test_records', None, - "Size of the test set for final evaluation of the optimized prompt. If" - " None, uses all available test tasks.", + 'Size of the test set for final evaluation of the optimized prompt. If' + ' None, uses all available test tasks.', ) _NUM_EVAL_TRIALS = flags.DEFINE_integer( - "num_eval_trials", + 'num_eval_trials', 4, - "Number of times each task is run during evaluation. Higher values give" - " more stable evaluation metrics but increase runtime. Recommended: 4-8.", + 'Number of times each task is run during evaluation. Higher values give' + ' more stable evaluation metrics but increase runtime. Recommended: 4-8.', ) _MAX_CONCURRENCY = flags.DEFINE_integer( - "max_concurrency", + 'max_concurrency', 8, - "Maximum number of parallel agent-environment interactions. Increase if" - " you have sufficient API quota.", + 'Maximum number of parallel agent-environment interactions. Increase if' + ' you have sufficient API quota.', ) _EVAL_MODE = flags.DEFINE_bool( - "eval_mode", + 'eval_mode', False, - "If set, run evaluation only using the seed prompt, skipping GEPA" " optimization.", + 'If set, run evaluation only using the seed prompt, skipping GEPA' + ' optimization.', ) _USE_RATER = flags.DEFINE_bool( - "use_rater", + 'use_rater', False, - "If set, use an LLM rater to score trajectories.", + 'If set, use an LLM rater to score trajectories.', ) _TRAIN_BATCH_SIZE = flags.DEFINE_integer( - "train_batch_size", + 'train_batch_size', 3, - "Number of trajectories sampled from rollouts to be used by the" - " reflection model in each GEPA step to generate prompt improvements." - " Increasing the batch size may help provide a more stable signal and" - " estimate of a prompt quality but entails higher cost. One can start with" - " a low value and increase the size if significant variations are" - " observed.", + 'Number of trajectories sampled from rollouts to be used by the' + ' reflection model in each GEPA step to generate prompt improvements.' + ' Increasing the batch size may help provide a more stable signal and' + ' estimate of a prompt quality but entails higher cost. One can start with' + ' a low value and increase the size if significant variations are' + ' observed.', ) def main(argv: Sequence[str]) -> None: - if len(argv) > 1: - raise app.UsageError("Too many command-line arguments.") + if len(argv) > 1: + raise app.UsageError('Too many command-line arguments.') - # Get a list of all existing loggers - # logging.root.manager.loggerDict contains all named loggers - # logging.getLogger(name) retrieves the logger object - loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict] + # Get a list of all existing loggers + # logging.root.manager.loggerDict contains all named loggers + # logging.getLogger(name) retrieves the logger object + loggers = [ + logging.getLogger(name) for name in logging.root.manager.loggerDict + ] - # Iterate through the loggers and set their level to WARNING - for logger in loggers: - logger.setLevel(logging.WARNING) + # Iterate through the loggers and set their level to WARNING + for logger in loggers: + logger.setLevel(logging.WARNING) - types.logger.addFilter(utils.FilterInferenceWarnings()) - output_dir = os.path.join( - _OUTPUT_DIR.value, datetime.now().strftime("%Y%m%d%H%M%S%f") - ) - os.makedirs(output_dir) - logging.info("Writing to output_dir=%s", output_dir) - config = experiment.ExperimentConfig( - tau_bench_env="retail", - agent_model="gemini-2.5-flash", - agent_model_provider="vertex_ai", - user_model="gemini-2.5-flash", - user_model_provider="vertex_ai", - max_concurrency=_MAX_CONCURRENCY.value, - num_eval_trials=_NUM_EVAL_TRIALS.value, - rnd_seed=42, - max_metric_calls=_MAX_METRIC_CALLS.value, - reflection_model="gemini-2.5-pro", - reflection_minibatch_size=_TRAIN_BATCH_SIZE.value, - use_rater=_USE_RATER.value, - feedback_dataset=experiment.Dataset(split="train"), - pareto_dataset=experiment.Dataset(split="dev", max_size=_EVAL_SET_SIZE.value), - eval_dataset=experiment.Dataset(split="test", max_size=_NUM_TEST_RECORDS.value), - ) - json.dump( - dataclasses.asdict(config), - open(os.path.join(output_dir, "config.json"), "w"), - ) - logging.info("Using config=%s", config) + types.logger.addFilter(utils.FilterInferenceWarnings()) + output_dir = os.path.join( + _OUTPUT_DIR.value, datetime.now().strftime('%Y%m%d%H%M%S%f') + ) + os.makedirs(output_dir) + logging.info('Writing to output_dir=%s', output_dir) + config = experiment.ExperimentConfig( + tau_bench_env='retail', + agent_model='gemini-2.5-flash', + agent_model_provider='vertex_ai', + user_model='gemini-2.5-flash', + user_model_provider='vertex_ai', + max_concurrency=_MAX_CONCURRENCY.value, + num_eval_trials=_NUM_EVAL_TRIALS.value, + rnd_seed=42, + max_metric_calls=_MAX_METRIC_CALLS.value, + reflection_model='gemini-2.5-pro', + reflection_minibatch_size=_TRAIN_BATCH_SIZE.value, + use_rater=_USE_RATER.value, + feedback_dataset=experiment.Dataset(split='train'), + pareto_dataset=experiment.Dataset( + split='dev', max_size=_EVAL_SET_SIZE.value + ), + eval_dataset=experiment.Dataset( + split='test', max_size=_NUM_TEST_RECORDS.value + ), + ) + json.dump( + dataclasses.asdict(config), + open(os.path.join(output_dir, 'config.json'), 'w'), + ) + logging.info('Using config=%s', config) - if _EVAL_MODE.value: - return experiment.run_eval( - output_dir=output_dir, - instructions=experiment.SEED_SYSTEM_INSTRUCTION, - config=config, - ) - - results = experiment.run_gepa( - config=config, - seed_instructions=experiment.SEED_SYSTEM_INSTRUCTION, + if _EVAL_MODE.value: + return experiment.run_eval( output_dir=output_dir, - ) - print(list(enumerate(results.val_aggregate_scores))) - - eval_dir = os.path.join( - output_dir, "evals", datetime.now().strftime("%Y%m%d%H%M%S%f") - ) - os.makedirs(eval_dir) - experiment.run_eval( - output_dir=eval_dir, - instructions=results.best_candidate["system_instruction"], + instructions=experiment.SEED_SYSTEM_INSTRUCTION, config=config, ) + results = experiment.run_gepa( + config=config, + seed_instructions=experiment.SEED_SYSTEM_INSTRUCTION, + output_dir=output_dir, + ) + print(list(enumerate(results.val_aggregate_scores))) + + eval_dir = os.path.join( + output_dir, 'evals', datetime.now().strftime('%Y%m%d%H%M%S%f') + ) + os.makedirs(eval_dir) + experiment.run_eval( + output_dir=eval_dir, + instructions=results.best_candidate['system_instruction'], + config=config, + ) + -if __name__ == "__main__": - app.run(main) +if __name__ == '__main__': + app.run(main) From 276c5ba46afd965f0b6cd1dd8095d47110dfa82c Mon Sep 17 00:00:00 2001 From: mportdata Date: Sun, 28 Dec 2025 14:52:41 +0000 Subject: [PATCH 14/24] refactor(agents): reduce telemetry boilerplate Reuse callback runner with a nullcontext span and avoid quote-only churn. --- src/google/adk/agents/base_agent.py | 87 +++++++++++++++-------------- 1 file changed, 44 insertions(+), 43 deletions(-) diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index cd591e4a1a..23240bf6f7 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -52,7 +52,7 @@ if TYPE_CHECKING: from .invocation_context import InvocationContext -logger = logging.getLogger("google_adk." + __name__) +logger = logging.getLogger('google_adk.' + __name__) _SingleAgentCallback: TypeAlias = Callable[ [CallbackContext], @@ -69,7 +69,7 @@ list[_SingleAgentCallback], ] -SelfAgent = TypeVar("SelfAgent", bound="BaseAgent") +SelfAgent = TypeVar('SelfAgent', bound='BaseAgent') @experimental @@ -77,11 +77,11 @@ class BaseAgentState(BaseModel): """Base class for all agent states.""" model_config = ConfigDict( - extra="forbid", + extra='forbid', ) -AgentState = TypeVar("AgentState", bound=BaseAgentState) +AgentState = TypeVar('AgentState', bound=BaseAgentState) class BaseAgent(BaseModel): @@ -89,7 +89,7 @@ class BaseAgent(BaseModel): model_config = ConfigDict( arbitrary_types_allowed=True, - extra="forbid", + extra='forbid', ) """The pydantic model config.""" @@ -116,7 +116,7 @@ class MyAgent(BaseAgent): Agent name cannot be "user", since it's reserved for end-user's input. """ - description: str = "" + description: str = '' """Description about the agent's capability. The model uses this to determine whether to delegate control to the agent. @@ -136,6 +136,7 @@ class MyAgent(BaseAgent): """The sub-agents of this agent.""" disable_telemetry: bool = False """Whether to disable telemetry for this agent.""" + before_agent_callback: Optional[BeforeAgentCallback] = None """Callback or list of callbacks to be invoked before the agent run. @@ -223,10 +224,10 @@ def clone( A new agent instance with identical configuration as the original agent except for the fields specified in the update. """ - if update is not None and "parent_agent" in update: + if update is not None and 'parent_agent' in update: raise ValueError( - "Cannot update `parent_agent` field in clone. Parent agent is set" - " only when the parent agent is instantiated with the sub-agents." + 'Cannot update `parent_agent` field in clone. Parent agent is set' + ' only when the parent agent is instantiated with the sub-agents.' ) # Only allow updating fields that are defined in the agent class. @@ -235,8 +236,8 @@ def clone( invalid_fields = set(update) - allowed_fields if invalid_fields: raise ValueError( - f"Cannot update nonexistent fields in {self.__class__.__name__}:" - f" {invalid_fields}" + f'Cannot update nonexistent fields in {self.__class__.__name__}:' + f' {invalid_fields}' ) cloned_agent = self.model_copy(update=update) @@ -245,7 +246,7 @@ def clone( # shallow copy it for the cloned agent to avoid sharing the same list object # with the original agent. for field_name in cloned_agent.__class__.model_fields: - if field_name == "sub_agents": + if field_name == 'sub_agents': continue if update is not None and field_name in update: continue @@ -253,7 +254,7 @@ def clone( if isinstance(field, list): setattr(cloned_agent, field_name, field.copy()) - if update is None or "sub_agents" not in update: + if update is None or 'sub_agents' not in update: # If `sub_agents` is not provided in the update, need to recursively clone # the sub-agents to avoid sharing the sub-agents with the original agent. cloned_agent.sub_agents = [] @@ -288,13 +289,13 @@ async def run_async( ctx = self._create_invocation_context(parent_context) span_context = contextlib.nullcontext() if is_telemetry_enabled(self): - span_context = tracer.start_as_current_span(f"invoke_agent {self.name}") + span_context = tracer.start_as_current_span(f'invoke_agent {self.name}') with span_context as span: if span: tracing.trace_agent_invocation(span, self, ctx) async with Aclosing( - self._run_callbacks_and_impl(ctx, mode="async") + self._run_callbacks_and_impl(ctx, mode='async') ) as agen: async for event in agen: yield event @@ -317,12 +318,12 @@ async def run_live( ctx = self._create_invocation_context(parent_context) span_context = contextlib.nullcontext() if is_telemetry_enabled(self): - span_context = tracer.start_as_current_span(f"invoke_agent {self.name}") + span_context = tracer.start_as_current_span(f'invoke_agent {self.name}') with span_context as span: if span: tracing.trace_agent_invocation(span, self, ctx) - async for event in self._run_callbacks_and_impl(ctx, mode="live"): + async for event in self._run_callbacks_and_impl(ctx, mode='live'): yield event async def _run_async_impl( @@ -337,7 +338,7 @@ async def _run_async_impl( Event: the events generated by the agent. """ raise NotImplementedError( - f"_run_async_impl for {type(self)} is not implemented." + f'_run_async_impl for {type(self)} is not implemented.' ) yield # AsyncGenerator requires having at least one yield statement @@ -353,12 +354,12 @@ async def _run_live_impl( Event: the events generated by the agent. """ raise NotImplementedError( - f"_run_live_impl for {type(self)} is not implemented." + f'_run_live_impl for {type(self)} is not implemented.' ) yield # AsyncGenerator requires having at least one yield statement async def _run_callbacks_and_impl( - self, ctx: InvocationContext, mode: str = "async" + self, ctx: InvocationContext, mode: str = 'async' ) -> AsyncGenerator[Event, None]: """Runs the before and after agent callbacks around the core agent logic. Args: @@ -371,11 +372,11 @@ async def _run_callbacks_and_impl( yield event if ctx.end_invocation: return - if mode.lower() == "async": + if mode.lower() == 'async': async with Aclosing(self._run_async_impl(ctx)) as agen: async for event in agen: yield event - elif mode.lower() == "live": + elif mode.lower() == 'live': async with Aclosing(self._run_live_impl(ctx)) as agen: async for event in agen: yield event @@ -429,7 +430,7 @@ def _create_invocation_context( self, parent_context: InvocationContext ) -> InvocationContext: """Creates a new invocation context for this agent.""" - invocation_context = parent_context.model_copy(update={"agent": self}) + invocation_context = parent_context.model_copy(update={'agent': self}) return invocation_context @property @@ -577,24 +578,24 @@ async def _handle_after_agent_callback( def model_post_init(self, __context: Any) -> None: self.__set_parent_agent_for_sub_agents() - @field_validator("name", mode="after") + @field_validator('name', mode='after') @classmethod def validate_name(cls, value: str): if not value.isidentifier(): raise ValueError( - f"Found invalid agent name: `{value}`." - " Agent name must be a valid identifier. It should start with a" - " letter (a-z, A-Z) or an underscore (_), and can only contain" - " letters, digits (0-9), and underscores." + f'Found invalid agent name: `{value}`.' + ' Agent name must be a valid identifier. It should start with a' + ' letter (a-z, A-Z) or an underscore (_), and can only contain' + ' letters, digits (0-9), and underscores.' ) - if value == "user": + if value == 'user': raise ValueError( "Agent name cannot be `user`. `user` is reserved for end-user's" - " input." + ' input.' ) return value - @field_validator("sub_agents", mode="after") + @field_validator('sub_agents', mode='after') @classmethod def validate_sub_agents_unique_names( cls, value: list[BaseAgent] @@ -622,12 +623,12 @@ def validate_sub_agents_unique_names( seen_names.add(name) if duplicates: - duplicate_names_str = ", ".join( - f"`{name}`" for name in sorted(duplicates) + duplicate_names_str = ', '.join( + f'`{name}`' for name in sorted(duplicates) ) logger.warning( - "Found duplicate sub-agent names: %s. " - "All sub-agents must have unique names.", + 'Found duplicate sub-agent names: %s. ' + 'All sub-agents must have unique names.', duplicate_names_str, ) @@ -637,9 +638,9 @@ def __set_parent_agent_for_sub_agents(self) -> BaseAgent: for sub_agent in self.sub_agents: if sub_agent.parent_agent is not None: raise ValueError( - f"Agent `{sub_agent.name}` already has a parent agent, current" - f" parent: `{sub_agent.parent_agent.name}`, trying to add:" - f" `{self.name}`" + f'Agent `{sub_agent.name}` already has a parent agent, current' + f' parent: `{sub_agent.parent_agent.name}`, trying to add:' + f' `{self.name}`' ) sub_agent.parent_agent = self return self @@ -704,22 +705,22 @@ def __create_kwargs( from .config_agent_utils import resolve_callbacks kwargs: Dict[str, Any] = { - "name": config.name, - "description": config.description, + 'name': config.name, + 'description': config.description, } if config.sub_agents: sub_agents = [] for sub_agent_config in config.sub_agents: sub_agent = resolve_agent_reference(sub_agent_config, config_abs_path) sub_agents.append(sub_agent) - kwargs["sub_agents"] = sub_agents + kwargs['sub_agents'] = sub_agents if config.before_agent_callbacks: - kwargs["before_agent_callback"] = resolve_callbacks( + kwargs['before_agent_callback'] = resolve_callbacks( config.before_agent_callbacks ) if config.after_agent_callbacks: - kwargs["after_agent_callback"] = resolve_callbacks( + kwargs['after_agent_callback'] = resolve_callbacks( config.after_agent_callbacks ) return kwargs From ea1d30cc6666dc78766c90b58ee77864b4d72d74 Mon Sep 17 00:00:00 2001 From: mportdata Date: Sun, 28 Dec 2025 14:58:24 +0000 Subject: [PATCH 15/24] refactor(agents): keep telemetry flag in validator Reapply only the model telemetry toggle without quote-style churn. --- src/google/adk/agents/llm_agent.py | 114 ++++++++++++++--------------- 1 file changed, 57 insertions(+), 57 deletions(-) diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index 5862397f58..ac3156fd19 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -64,7 +64,7 @@ from .llm_agent_config import LlmAgentConfig from .readonly_context import ReadonlyContext -logger = logging.getLogger("google_adk." + __name__) +logger = logging.getLogger('google_adk.' + __name__) _SingleBeforeModelCallback: TypeAlias = Callable[ [CallbackContext, LlmRequest], @@ -184,7 +184,7 @@ async def _convert_tool_union_to_tools( class LlmAgent(BaseAgent): """LLM-based Agent.""" - model: Union[str, BaseLlm] = "" + model: Union[str, BaseLlm] = '' """The model to use for the agent. When not set, the agent will inherit the model from its ancestor. @@ -193,7 +193,7 @@ class LlmAgent(BaseAgent): config_type: ClassVar[Type[BaseAgentConfig]] = LlmAgentConfig """The config type for this agent.""" - instruction: Union[str, InstructionProvider] = "" + instruction: Union[str, InstructionProvider] = '' """Dynamic instructions for the LLM model, guiding the agent's behavior. These instructions can contain placeholders like {variable_name} that will be @@ -207,7 +207,7 @@ class LlmAgent(BaseAgent): comes first in the prompt, followed by dynamic content (instruction). """ - global_instruction: Union[str, InstructionProvider] = "" + global_instruction: Union[str, InstructionProvider] = '' """Instructions for all the agents in the entire agent tree. DEPRECATED: This field is deprecated and will be removed in a future version. @@ -297,7 +297,7 @@ class LlmAgent(BaseAgent): """Disallows LLM-controlled transferring to the peer agents.""" # LLM-based agent transfer configs - End - include_contents: Literal["default", "none"] = "default" + include_contents: Literal['default', 'none'] = 'default' """Controls content inclusion in model requests. Options: @@ -504,7 +504,7 @@ def canonical_model(self) -> BaseLlm: if isinstance(ancestor_agent, LlmAgent): return ancestor_agent.canonical_model ancestor_agent = ancestor_agent.parent_agent - raise ValueError(f"No model found for {self.name}.") + raise ValueError(f'No model found for {self.name}.') async def canonical_instruction( self, ctx: ReadonlyContext @@ -549,9 +549,9 @@ async def canonical_global_instruction( # Issue deprecation warning if global_instruction is being used if self.global_instruction: warnings.warn( - "global_instruction field is deprecated and will be removed in a" - " future version. Use GlobalInstructionPlugin instead for the same" - " functionality at the App level. See migration guide for details.", + 'global_instruction field is deprecated and will be removed in a' + ' future version. Use GlobalInstructionPlugin instead for the same' + ' functionality at the App level. See migration guide for details.', DeprecationWarning, stacklevel=2, ) @@ -702,12 +702,12 @@ def _get_subagent_to_resume( return self.__get_transfer_to_agent_or_none(last_event, self.name) # Last event is from user or another agent. - if last_event.author == "user": + if last_event.author == 'user': function_call_event = ctx._find_matching_function_call(last_event) if not function_call_event: raise ValueError( - "No agent to transfer to for resuming agent from function response" - f" {self.name}" + 'No agent to transfer to for resuming agent from function response' + f' {self.name}' ) if function_call_event.author == self.name: # User is responding to a tool call from the current agent. @@ -730,14 +730,14 @@ def __get_agent_to_run(self, agent_name: str) -> BaseAgent: error_msg = ( f"Agent '{agent_name}' not found.\n" f"Available agents: {', '.join(available)}\n\n" - "Possible causes:\n" - " 1. Agent not registered before being referenced\n" - " 2. Agent name mismatch (typo or case sensitivity)\n" - " 3. Timing issue (agent referenced before creation)\n\n" - "Suggested fixes:\n" - " - Verify agent is registered with root agent\n" - " - Check agent name spelling and case\n" - " - Ensure agents are created before being referenced" + 'Possible causes:\n' + ' 1. Agent not registered before being referenced\n' + ' 2. Agent name mismatch (typo or case sensitivity)\n' + ' 3. Timing issue (agent referenced before creation)\n\n' + 'Suggested fixes:\n' + ' - Verify agent is registered with root agent\n' + ' - Check agent name spelling and case\n' + ' - Ensure agents are created before being referenced' ) raise ValueError(error_msg) return agent_to_run @@ -756,7 +756,7 @@ def _get_available_agent_names(self) -> list[str]: def collect_agents(agent): agents.append(agent.name) - if hasattr(agent, "sub_agents") and agent.sub_agents: + if hasattr(agent, 'sub_agents') and agent.sub_agents: for sub_agent in agent.sub_agents: collect_agents(sub_agent) @@ -772,7 +772,7 @@ def __get_transfer_to_agent_or_none( return None for function_response in function_responses: if ( - function_response.name == "transfer_to_agent" + function_response.name == 'transfer_to_agent' and event.author == from_agent and event.actions.transfer_to_agent != from_agent ): @@ -785,7 +785,7 @@ def __maybe_save_output_to_state(self, event: Event): # transferred to another agent) if event.author != self.name: logger.debug( - "Skipping output save for agent %s: event authored by %s", + 'Skipping output save for agent %s: event authored by %s', self.name, event.author, ) @@ -797,7 +797,7 @@ def __maybe_save_output_to_state(self, event: Event): and event.content.parts ): - result = "".join( + result = ''.join( part.text for part in event.content.parts if part.text and not part.thought @@ -813,15 +813,15 @@ def __maybe_save_output_to_state(self, event: Event): ) event.actions.state_delta[self.output_key] = result - @model_validator(mode="after") + @model_validator(mode='after') def __model_validator_after(self) -> LlmAgent: - root_agent = getattr(self, "root_agent", None) or self + root_agent = getattr(self, 'root_agent', None) or self disable_telemetry: bool = not is_telemetry_enabled(root_agent) - if hasattr(self.model, "disable_telemetry"): + if hasattr(self.model, 'disable_telemetry'): self.model.disable_telemetry = disable_telemetry return self - @field_validator("generate_content_config", mode="after") + @field_validator('generate_content_config', mode='after') @classmethod def validate_generate_content_config( cls, generate_content_config: Optional[types.GenerateContentConfig] @@ -829,16 +829,16 @@ def validate_generate_content_config( if not generate_content_config: return types.GenerateContentConfig() if generate_content_config.thinking_config: - raise ValueError("Thinking config should be set via LlmAgent.planner.") + raise ValueError('Thinking config should be set via LlmAgent.planner.') if generate_content_config.tools: - raise ValueError("All tools must be set via LlmAgent.tools.") + raise ValueError('All tools must be set via LlmAgent.tools.') if generate_content_config.system_instruction: raise ValueError( - "System instruction must be set via LlmAgent.instruction." + 'System instruction must be set via LlmAgent.instruction.' ) if generate_content_config.response_schema: raise ValueError( - "Response schema must be set via LlmAgent.output_schema." + 'Response schema must be set via LlmAgent.output_schema.' ) return generate_content_config @@ -859,26 +859,26 @@ def _resolve_tools( resolved_tools = [] for tool_config in tool_configs: - if "." not in tool_config.name: + if '.' not in tool_config.name: # ADK built-in tools - module = importlib.import_module("google.adk.tools") + module = importlib.import_module('google.adk.tools') obj = getattr(module, tool_config.name) else: # User-defined tools - module_path, obj_name = tool_config.name.rsplit(".", 1) + module_path, obj_name = tool_config.name.rsplit('.', 1) module = importlib.import_module(module_path) obj = getattr(module, obj_name) if isinstance(obj, BaseTool) or isinstance(obj, BaseToolset): logger.debug( - "Tool %s is an instance of BaseTool/BaseToolset.", tool_config.name + 'Tool %s is an instance of BaseTool/BaseToolset.', tool_config.name ) resolved_tools.append(obj) elif inspect.isclass(obj) and ( issubclass(obj, BaseTool) or issubclass(obj, BaseToolset) ): logger.debug( - "Tool %s is a sub-class of BaseTool/BaseToolset.", tool_config.name + 'Tool %s is a sub-class of BaseTool/BaseToolset.', tool_config.name ) resolved_tools.append( obj.from_config(tool_config.args, config_abs_path) @@ -886,17 +886,17 @@ def _resolve_tools( elif callable(obj): if tool_config.args: logger.debug( - "Tool %s is a user-defined tool-generating function.", + 'Tool %s is a user-defined tool-generating function.', tool_config.name, ) resolved_tools.append(obj(tool_config.args)) else: logger.debug( - "Tool %s is a user-defined function tool.", tool_config.name + 'Tool %s is a user-defined function tool.', tool_config.name ) resolved_tools.append(obj) else: - raise ValueError(f"Invalid tool YAML config: {tool_config}.") + raise ValueError(f'Invalid tool YAML config: {tool_config}.') return resolved_tools @@ -913,45 +913,45 @@ def _parse_config( from .config_agent_utils import resolve_code_reference if config.model_code: - kwargs["model"] = resolve_code_reference(config.model_code) + kwargs['model'] = resolve_code_reference(config.model_code) elif config.model: - kwargs["model"] = config.model + kwargs['model'] = config.model if config.instruction: - kwargs["instruction"] = config.instruction + kwargs['instruction'] = config.instruction if config.static_instruction: - kwargs["static_instruction"] = config.static_instruction + kwargs['static_instruction'] = config.static_instruction if config.disallow_transfer_to_parent: - kwargs["disallow_transfer_to_parent"] = config.disallow_transfer_to_parent + kwargs['disallow_transfer_to_parent'] = config.disallow_transfer_to_parent if config.disallow_transfer_to_peers: - kwargs["disallow_transfer_to_peers"] = config.disallow_transfer_to_peers - if config.include_contents != "default": - kwargs["include_contents"] = config.include_contents + kwargs['disallow_transfer_to_peers'] = config.disallow_transfer_to_peers + if config.include_contents != 'default': + kwargs['include_contents'] = config.include_contents if config.input_schema: - kwargs["input_schema"] = resolve_code_reference(config.input_schema) + kwargs['input_schema'] = resolve_code_reference(config.input_schema) if config.output_schema: - kwargs["output_schema"] = resolve_code_reference(config.output_schema) + kwargs['output_schema'] = resolve_code_reference(config.output_schema) if config.output_key: - kwargs["output_key"] = config.output_key + kwargs['output_key'] = config.output_key if config.tools: - kwargs["tools"] = cls._resolve_tools(config.tools, config_abs_path) + kwargs['tools'] = cls._resolve_tools(config.tools, config_abs_path) if config.before_model_callbacks: - kwargs["before_model_callback"] = resolve_callbacks( + kwargs['before_model_callback'] = resolve_callbacks( config.before_model_callbacks ) if config.after_model_callbacks: - kwargs["after_model_callback"] = resolve_callbacks( + kwargs['after_model_callback'] = resolve_callbacks( config.after_model_callbacks ) if config.before_tool_callbacks: - kwargs["before_tool_callback"] = resolve_callbacks( + kwargs['before_tool_callback'] = resolve_callbacks( config.before_tool_callbacks ) if config.after_tool_callbacks: - kwargs["after_tool_callback"] = resolve_callbacks( + kwargs['after_tool_callback'] = resolve_callbacks( config.after_tool_callbacks ) if config.generate_content_config: - kwargs["generate_content_config"] = config.generate_content_config + kwargs['generate_content_config'] = config.generate_content_config return kwargs From cfca1f3c451672735023f13ca000b9bbf70dcc7c Mon Sep 17 00:00:00 2001 From: mportdata Date: Sun, 28 Dec 2025 15:05:46 +0000 Subject: [PATCH 16/24] refactor(flows): gate tracing on telemetry Use nullcontext for live send_data/call_llm spans and skip tool tracing when telemetry is disabled. Add a shared context_manager_with_span fixture for telemetry tests. --- .../adk/flows/llm_flows/base_llm_flow.py | 199 ++++++++---------- src/google/adk/flows/llm_flows/functions.py | 85 ++++---- tests/unittests/conftest.py | 20 +- 3 files changed, 144 insertions(+), 160 deletions(-) diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 962a718a83..65bed9d55a 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -60,9 +60,9 @@ from ._base_llm_processor import BaseLlmRequestProcessor from ._base_llm_processor import BaseLlmResponseProcessor -logger = logging.getLogger("google_adk." + __name__) +logger = logging.getLogger('google_adk.' + __name__) -_ADK_AGENT_NAME_LABEL_KEY = "adk_agent_name" +_ADK_AGENT_NAME_LABEL_KEY = 'adk_agent_name' # Timing configuration DEFAULT_REQUEST_QUEUE_TIMEOUT = 0.25 @@ -105,7 +105,7 @@ async def run_live( llm = self.__get_llm(invocation_context) logger.debug( - "Establishing live connection for agent: %s with llm request: %s", + 'Establishing live connection for agent: %s with llm request: %s', invocation_context.agent.name, llm_request, ) @@ -115,7 +115,7 @@ async def run_live( try: # On subsequent attempts, use the saved token to reconnect if invocation_context.live_session_resumption_handle: - logger.info("Attempting to reconnect (Attempt %s)...", attempt) + logger.info('Attempting to reconnect (Attempt %s)...', attempt) attempt += 1 if not llm_request.live_connect_config: llm_request.live_connect_config = types.LiveConnectConfig() @@ -125,16 +125,16 @@ async def run_live( llm_request.live_connect_config.session_resumption.transparent = True logger.info( - "Establishing live connection for agent: %s", + 'Establishing live connection for agent: %s', invocation_context.agent.name, ) async with llm.connect(llm_request) as llm_connection: if llm_request.contents: # Sends the conversation history to the model. - logger.debug("Sending history to model: %s", llm_request.contents) + logger.debug('Sending history to model: %s', llm_request.contents) span_context = contextlib.nullcontext() if is_telemetry_enabled(invocation_context.agent): - span_context = tracer.start_as_current_span("send_data") + span_context = tracer.start_as_current_span('send_data') with span_context as span: await llm_connection.send_history(llm_request.contents) if span: @@ -159,13 +159,12 @@ async def run_live( # Empty event means the queue is closed. if not event: break - logger.debug("Receive new event: %s", event) + logger.debug('Receive new event: %s', event) yield event # send back the function response to models if event.get_function_responses(): logger.debug( - "Sending back last function response event: %s", - event, + 'Sending back last function response event: %s', event ) invocation_context.live_request_queue.send_content( event.content @@ -185,21 +184,18 @@ async def run_live( and event.content.parts and event.content.parts[0].function_response and event.content.parts[0].function_response.name - == "transfer_to_agent" + == 'transfer_to_agent' ): await asyncio.sleep(DEFAULT_TRANSFER_AGENT_DELAY) # cancel the tasks that belongs to the closed connection. send_task.cancel() - logger.debug("Closing live connection") + logger.debug('Closing live connection') await llm_connection.close() - logger.debug("Live connection closed.") + logger.debug('Live connection closed.') # transfer to the sub agent. transfer_to_agent = event.actions.transfer_to_agent if transfer_to_agent: - logger.debug( - "Transferring to agent: %s", - transfer_to_agent, - ) + logger.debug('Transferring to agent: %s', transfer_to_agent) agent_to_run = self._get_agent_to_run( invocation_context, transfer_to_agent ) @@ -213,7 +209,7 @@ async def run_live( and event.content.parts and event.content.parts[0].function_response and event.content.parts[0].function_response.name - == "task_completed" + == 'task_completed' ): # this is used for sequential agent to signal the end of the agent. await asyncio.sleep(DEFAULT_TASK_COMPLETION_DELAY) @@ -231,11 +227,11 @@ async def run_live( except (ConnectionClosed, ConnectionClosedOK) as e: # when the session timeout, it will just close and not throw exception. # so this is for bad cases - logger.error("Connection closed: %s.", e) + logger.error('Connection closed: %s.', e) raise except Exception as e: logger.error( - "An unexpected error occurred in live flow: %s", e, exc_info=True + 'An unexpected error occurred in live flow: %s', e, exc_info=True ) raise @@ -257,7 +253,7 @@ async def _send_to_model( ) # duplicate the live_request to all the active streams logger.debug( - "Sending live request %s to active streams: %s", + 'Sending live request %s to active streams: %s', live_request, invocation_context.active_streaming_tools, ) @@ -281,7 +277,7 @@ async def _send_to_model( elif live_request.blob: # Cache input audio chunks before flushing self.audio_cache_manager.cache_audio( - invocation_context, live_request.blob, cache_type="input" + invocation_context, live_request.blob, cache_type='input' ) await llm_connection.send_realtime(live_request.blob) @@ -310,9 +306,9 @@ def get_author_for_event(llm_response): if ( llm_response and llm_response.content - and llm_response.content.role == "user" + and llm_response.content.role == 'user' ): - return "user" + return 'user' else: return invocation_context.agent.name @@ -323,8 +319,8 @@ def get_author_for_event(llm_response): async for llm_response in agen: if llm_response.live_session_resumption_update: logger.info( - "Update session resumption handle:" - f" {llm_response.live_session_resumption_update}." + 'Update session resumption handle:' + f' {llm_response.live_session_resumption_update}.' ) invocation_context.live_session_resumption_handle = ( llm_response.live_session_resumption_update.new_handle @@ -352,7 +348,7 @@ def get_author_for_event(llm_response): and event.content.parts and event.content.parts[0].inline_data and event.content.parts[0].inline_data.mime_type.startswith( - "audio/" + 'audio/' ) ): audio_blob = types.Blob( @@ -360,9 +356,7 @@ def get_author_for_event(llm_response): mime_type=event.content.parts[0].inline_data.mime_type, ) self.audio_cache_manager.cache_audio( - invocation_context, - audio_blob, - cache_type="output", + invocation_context, audio_blob, cache_type='output' ) yield event @@ -383,7 +377,7 @@ async def run_async( yield event if not last_event or last_event.is_final_response() or last_event.partial: if last_event and last_event.partial: - logger.warning("The last event is partial, which is not expected.") + logger.warning('The last event is partial, which is not expected.') break async def _run_one_step_async( @@ -479,7 +473,7 @@ async def _preprocess_async( agent = invocation_context.agent if not isinstance(agent, LlmAgent): raise TypeError( - f"Expected agent to be an LlmAgent, but got {type(agent)}" + f'Expected agent to be an LlmAgent, but got {type(agent)}' ) # Runs processors. @@ -733,7 +727,7 @@ def _get_agent_to_run( root_agent = invocation_context.agent.root_agent agent_to_run = root_agent.find_agent(agent_name) if not agent_to_run: - raise ValueError(f"Agent {agent_name} not found in the agent tree.") + raise ValueError(f'Agent {agent_name} not found in the agent tree.') return agent_to_run async def _call_llm_async( @@ -762,78 +756,71 @@ async def _call_llm_async( # Calls the LLM. llm = self.__get_llm(invocation_context) - async def _call_llm_body() -> AsyncGenerator[LlmResponse, None]: - if invocation_context.run_config.support_cfc: - invocation_context.live_request_queue = LiveRequestQueue() - responses_generator = self.run_live(invocation_context) - async with Aclosing( - self._run_and_handle_error( - responses_generator, - invocation_context, - llm_request, - model_response_event, - ) - ) as agen: - async for llm_response in agen: - # Runs after_model_callback if it exists. - if altered_llm_response := await self._handle_after_model_callback( - invocation_context, llm_response, model_response_event - ): - llm_response = altered_llm_response - # only yield partial response in SSE streaming mode - if ( - invocation_context.run_config.streaming_mode - == StreamingMode.SSE - or not llm_response.partial - ): - yield llm_response - if llm_response.turn_complete: - invocation_context.live_request_queue.close() - else: - # Check if we can make this llm call or not. If the current call - # pushes the counter beyond the max set value, then the execution is - # stopped right here, and exception is thrown. - invocation_context.increment_llm_call_count() - responses_generator = llm.generate_content_async( - llm_request, - stream=invocation_context.run_config.streaming_mode - == StreamingMode.SSE, - ) - async with Aclosing( - self._run_and_handle_error( - responses_generator, - invocation_context, - llm_request, - model_response_event, - ) - ) as agen: - async for llm_response in agen: - trace_call_llm( - invocation_context, - model_response_event.id, - llm_request, - llm_response, - ) - # Runs after_model_callback if it exists. - if altered_llm_response := await self._handle_after_model_callback( - invocation_context, llm_response, model_response_event - ): - llm_response = altered_llm_response - - yield llm_response - - async def _call_llm_with_optional_tracing() -> ( - AsyncGenerator[LlmResponse, None] - ): + async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]: span_context = contextlib.nullcontext() if is_telemetry_enabled(invocation_context.agent): - span_context = tracer.start_as_current_span("call_llm") + span_context = tracer.start_as_current_span('call_llm') with span_context: - async with Aclosing(_call_llm_body()) as agen: - async for r in agen: - yield r + if invocation_context.run_config.support_cfc: + invocation_context.live_request_queue = LiveRequestQueue() + responses_generator = self.run_live(invocation_context) + async with Aclosing( + self._run_and_handle_error( + responses_generator, + invocation_context, + llm_request, + model_response_event, + ) + ) as agen: + async for llm_response in agen: + # Runs after_model_callback if it exists. + if altered_llm_response := await self._handle_after_model_callback( + invocation_context, llm_response, model_response_event + ): + llm_response = altered_llm_response + # only yield partial response in SSE streaming mode + if ( + invocation_context.run_config.streaming_mode + == StreamingMode.SSE + or not llm_response.partial + ): + yield llm_response + if llm_response.turn_complete: + invocation_context.live_request_queue.close() + else: + # Check if we can make this llm call or not. If the current call + # pushes the counter beyond the max set value, then the execution is + # stopped right here, and exception is thrown. + invocation_context.increment_llm_call_count() + responses_generator = llm.generate_content_async( + llm_request, + stream=invocation_context.run_config.streaming_mode + == StreamingMode.SSE, + ) + async with Aclosing( + self._run_and_handle_error( + responses_generator, + invocation_context, + llm_request, + model_response_event, + ) + ) as agen: + async for llm_response in agen: + trace_call_llm( + invocation_context, + model_response_event.id, + llm_request, + llm_response, + ) + # Runs after_model_callback if it exists. + if altered_llm_response := await self._handle_after_model_callback( + invocation_context, llm_response, model_response_event + ): + llm_response = altered_llm_response + + yield llm_response - async with Aclosing(_call_llm_with_optional_tracing()) as agen: + async with Aclosing(_call_llm_with_tracing()) as agen: async for event in agen: yield event @@ -894,10 +881,10 @@ async def _maybe_add_grounding_metadata( tools = await agent.canonical_tools(readonly_context) invocation_context.canonical_tools_cache = tools - if not any(tool.name == "google_search_agent" for tool in tools): + if not any(tool.name == 'google_search_agent' for tool in tools): return response ground_metadata = invocation_context.session.state.get( - "temp:_adk_grounding_metadata", None + 'temp:_adk_grounding_metadata', None ) if not ground_metadata: return response @@ -974,7 +961,7 @@ async def _handle_control_event_flush( # Log cache statistics if enabled if DEFAULT_ENABLE_CACHE_STATISTICS: stats = self.audio_cache_manager.get_cache_stats(invocation_context) - logger.debug("Audio cache stats: %s", stats) + logger.debug('Audio cache stats: %s', stats) if llm_response.interrupted: # user interrupts so the model will stop. we can flush model audio here @@ -990,7 +977,7 @@ async def _handle_control_event_flush( flush_user_audio=True, flush_model_audio=True, ) - elif getattr(llm_response, "generation_complete", False): + elif getattr(llm_response, 'generation_complete', False): # model generation complete so we can flush model audio return await self.audio_cache_manager.flush_caches( invocation_context, @@ -1023,7 +1010,7 @@ async def _run_and_handle_error( agent = invocation_context.agent if not isinstance(agent, LlmAgent): raise TypeError( - f"Expected agent to be an LlmAgent, but got {type(agent)}" + f'Expected agent to be an LlmAgent, but got {type(agent)}' ) async def _run_on_model_error_callbacks( diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 10edc3b410..597e5c2236 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -47,15 +47,15 @@ if TYPE_CHECKING: from ...agents.llm_agent import LlmAgent -AF_FUNCTION_CALL_ID_PREFIX = "adk-" -REQUEST_EUC_FUNCTION_CALL_NAME = "adk_request_credential" -REQUEST_CONFIRMATION_FUNCTION_CALL_NAME = "adk_request_confirmation" +AF_FUNCTION_CALL_ID_PREFIX = 'adk-' +REQUEST_EUC_FUNCTION_CALL_NAME = 'adk_request_credential' +REQUEST_CONFIRMATION_FUNCTION_CALL_NAME = 'adk_request_confirmation' -logger = logging.getLogger("google_adk." + __name__) +logger = logging.getLogger('google_adk.' + __name__) def generate_client_function_call_id() -> str: - return f"{AF_FUNCTION_CALL_ID_PREFIX}{uuid.uuid4()}" + return f'{AF_FUNCTION_CALL_ID_PREFIX}{uuid.uuid4()}' def populate_client_function_call_id(model_response_event: Event) -> None: @@ -164,10 +164,10 @@ def generate_request_confirmation_event( request_confirmation_function_call = types.FunctionCall( name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME, args={ - "originalFunctionCall": original_function_call.model_dump( + 'originalFunctionCall': original_function_call.model_dump( exclude_none=True, by_alias=True ), - "toolConfirmation": tool_confirmation.model_dump( + 'toolConfirmation': tool_confirmation.model_dump( by_alias=True, exclude_none=True ), }, @@ -233,11 +233,9 @@ async def handle_function_call_list_async( function_call, tools_dict, agent, - ( - tool_confirmation_dict[function_call.id] - if tool_confirmation_dict - else None - ), + tool_confirmation_dict[function_call.id] + if tool_confirmation_dict + else None, ) ) for function_call in filtered_calls @@ -262,7 +260,7 @@ async def handle_function_call_list_async( # this is needed for debug traces of parallel calls # individual response with tool.name is traced in __build_response_event # (we drop tool.name from span name here as this is merged event) - with tracer.start_as_current_span("execute_tool (merged)"): + with tracer.start_as_current_span('execute_tool (merged)'): trace_merged_tool_calls( response_event_id=merged_event.id, function_response_event=merged_event, @@ -326,7 +324,7 @@ async def _run_on_tool_error_callbacks( try: tool = _get_tool(function_call, tools_dict) except ValueError as tool_error: - tool = BaseTool(name=function_call.name, description="Tool not found") + tool = BaseTool(name=function_call.name, description='Tool not found') error_response = await _run_on_tool_error_callbacks( tool=tool, tool_args=function_args, @@ -430,7 +428,7 @@ async def _run_with_trace(): if not is_telemetry_enabled(agent): return await _run_with_trace() - with tracer.start_as_current_span(f"execute_tool {tool.name}"): + with tracer.start_as_current_span(f'execute_tool {tool.name}'): try: function_response_event = await _run_with_trace() trace_tool_call( @@ -491,12 +489,11 @@ async def handle_function_calls_live( merged_event = merge_parallel_function_response_events( function_response_events ) - if len(function_response_events) > 1 and is_telemetry_enabled(agent): # this is needed for debug traces of parallel calls # individual response with tool.name is traced in __build_response_event # (we drop tool.name from span name here as this is merged event) - with tracer.start_as_current_span("execute_tool (merged)"): + with tracer.start_as_current_span('execute_tool (merged)'): trace_merged_tool_calls( response_event_id=merged_event.id, function_response_event=merged_event, @@ -583,7 +580,7 @@ async def _run_with_trace(): if not is_telemetry_enabled(agent): return await _run_with_trace() - with tracer.start_as_current_span(f"execute_tool {tool.name}"): + with tracer.start_as_current_span(f'execute_tool {tool.name}'): try: function_response_event = await _run_with_trace() trace_tool_call( @@ -610,10 +607,10 @@ async def _process_function_live_helper( function_response = None # Check if this is a stop_streaming function call if ( - function_call.name == "stop_streaming" - and "function_name" in function_args + function_call.name == 'stop_streaming' + and 'function_name' in function_args ): - function_name = function_args["function_name"] + function_name = function_args['function_name'] # Thread-safe access to active_streaming_tools async with streaming_lock: active_tasks = invocation_context.active_streaming_tools @@ -635,16 +632,16 @@ async def _process_function_live_helper( except (asyncio.CancelledError, asyncio.TimeoutError): # Log the specific condition if task.cancelled(): - logging.info("Task %s was cancelled successfully", function_name) + logging.info('Task %s was cancelled successfully', function_name) elif task.done(): - logging.info("Task %s completed during cancellation", function_name) + logging.info('Task %s completed during cancellation', function_name) else: logging.warning( - "Task %s might still be running after cancellation timeout", + 'Task %s might still be running after cancellation timeout', function_name, ) function_response = { - "status": f"The task is not cancelled yet for {function_name}." + 'status': f'The task is not cancelled yet for {function_name}.' } if not function_response: # Clean up the reference under lock @@ -656,13 +653,13 @@ async def _process_function_live_helper( invocation_context.active_streaming_tools[function_name].task = None function_response = { - "status": f"Successfully stopped streaming function {function_name}" + 'status': f'Successfully stopped streaming function {function_name}' } else: function_response = { - "status": f"No active streaming function named {function_name} found" + 'status': f'No active streaming function named {function_name} found' } - elif hasattr(tool, "func") and inspect.isasyncgenfunction(tool.func): + elif hasattr(tool, 'func') and inspect.isasyncgenfunction(tool.func): # for streaming tool use case # we require the function to be an async generator function async def run_tool_and_update_queue(tool, function_args, tool_context): @@ -677,10 +674,10 @@ async def run_tool_and_update_queue(tool, function_args, tool_context): ) as agen: async for result in agen: updated_content = types.Content( - role="user", + role='user', parts=[ types.Part.from_text( - text=f"Function {tool.name} returned: {result}" + text=f'Function {tool.name} returned: {result}' ) ], ) @@ -707,9 +704,9 @@ async def run_tool_and_update_queue(tool, function_args, tool_context): # Immediately return a pending response. # This is required by current live model. function_response = { - "status": ( - "The function is running asynchronously and the results are" - " pending." + 'status': ( + 'The function is running asynchronously and the results are' + ' pending.' ) } else: @@ -728,11 +725,11 @@ def _get_tool( error_msg = ( f"Tool '{function_call.name}' not found.\nAvailable tools:" f" {', '.join(available)}\n\nPossible causes:\n 1. LLM hallucinated" - " the function name - review agent instruction clarity\n 2. Tool not" - " registered - verify agent.tools list\n 3. Name mismatch - check for" - " typos\n\nSuggested fixes:\n - Review agent instruction to ensure" - " tool usage is clear\n - Verify tool is included in agent.tools" - " list\n - Check for typos in function name" + ' the function name - review agent instruction clarity\n 2. Tool not' + ' registered - verify agent.tools list\n 3. Name mismatch - check for' + ' typos\n\nSuggested fixes:\n - Review agent instruction to ensure' + ' tool usage is clear\n - Verify tool is included in agent.tools' + ' list\n - Check for typos in function name' ) raise ValueError(error_msg) @@ -804,7 +801,7 @@ def __build_response_event( ) -> Event: # Specs requires the result to be a dict. if not isinstance(function_result, dict): - function_result = {"result": function_result} + function_result = {'result': function_result} part_function_response = types.Part.from_function_response( name=tool.name, response=function_result @@ -812,7 +809,7 @@ def __build_response_event( part_function_response.function_response.id = tool_context.function_call_id content = types.Content( - role="user", + role='user', parts=[part_function_response], ) @@ -838,10 +835,10 @@ def deep_merge_dicts(d1: dict, d2: dict) -> dict: def merge_parallel_function_response_events( - function_response_events: list["Event"], -) -> "Event": + function_response_events: list['Event'], +) -> 'Event': if not function_response_events: - raise ValueError("No function response events provided.") + raise ValueError('No function response events provided.') if len(function_response_events) == 1: return function_response_events[0] @@ -871,7 +868,7 @@ def merge_parallel_function_response_events( invocation_id=base_event.invocation_id, author=base_event.author, branch=base_event.branch, - content=types.Content(role="user", parts=merged_parts), + content=types.Content(role='user', parts=merged_parts), actions=merged_actions, # Optionally merge actions if required ) diff --git a/tests/unittests/conftest.py b/tests/unittests/conftest.py index ab2ad70041..3d46b8a76c 100644 --- a/tests/unittests/conftest.py +++ b/tests/unittests/conftest.py @@ -22,19 +22,19 @@ from pytest import Metafunc _ENV_VARS = { - "GOOGLE_API_KEY": "fake_google_api_key", - "GOOGLE_CLOUD_PROJECT": "fake_google_cloud_project", - "GOOGLE_CLOUD_LOCATION": "fake_google_cloud_location", - "ADK_ALLOW_WIP_FEATURES": "true", + 'GOOGLE_API_KEY': 'fake_google_api_key', + 'GOOGLE_CLOUD_PROJECT': 'fake_google_cloud_project', + 'GOOGLE_CLOUD_LOCATION': 'fake_google_cloud_location', + 'ADK_ALLOW_WIP_FEATURES': 'true', } ENV_SETUPS = { - "GOOGLE_AI": { - "GOOGLE_GENAI_USE_VERTEXAI": "0", + 'GOOGLE_AI': { + 'GOOGLE_GENAI_USE_VERTEXAI': '0', **_ENV_VARS, }, - "VERTEX": { - "GOOGLE_GENAI_USE_VERTEXAI": "1", + 'VERTEX': { + 'GOOGLE_GENAI_USE_VERTEXAI': '1', **_ENV_VARS, }, } @@ -98,9 +98,9 @@ def pytest_generate_tests(metafunc: Metafunc): def _is_explicitly_marked(mark_name: str, metafunc: Metafunc) -> bool: - if hasattr(metafunc.function, "pytestmark"): + if hasattr(metafunc.function, 'pytestmark'): for mark in metafunc.function.pytestmark: - if mark.name == "parametrize" and mark.args[0] == mark_name: + if mark.name == 'parametrize' and mark.args[0] == mark_name: return True return False From 6e8c2facf324552443123887502086caaffd10d1 Mon Sep 17 00:00:00 2001 From: mportdata Date: Sun, 28 Dec 2025 15:19:15 +0000 Subject: [PATCH 17/24] refactor(core): gate telemetry spans Keep only telemetry logic in google_llm and runners without quote/indent churn. --- src/google/adk/models/google_llm.py | 137 +++++++++++------------ src/google/adk/runners.py | 164 ++++++++++++++-------------- 2 files changed, 151 insertions(+), 150 deletions(-) diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index bbb3b80dc2..4b090dc98e 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -44,10 +44,10 @@ from .llm_request import LlmRequest -logger = logging.getLogger("google_adk." + __name__) +logger = logging.getLogger('google_adk.' + __name__) -_NEW_LINE = "\n" -_EXCLUDED_PART_FIELD = {"inline_data": {"data"}} +_NEW_LINE = '\n' +_EXCLUDED_PART_FIELD = {'inline_data': {'data'}} _RESOURCE_EXHAUSTED_POSSIBLE_FIX_MESSAGE = """ @@ -76,7 +76,7 @@ def __str__(self): # stringified (for either publishing the exception on console or to logs) # we put in the required details for the developer. base_message = super().__str__() - return f"{_RESOURCE_EXHAUSTED_POSSIBLE_FIX_MESSAGE}\n\n{base_message}" + return f'{_RESOURCE_EXHAUSTED_POSSIBLE_FIX_MESSAGE}\n\n{base_message}' class Gemini(BaseLlm): @@ -88,48 +88,49 @@ class Gemini(BaseLlm): invocation. """ - model: str = "gemini-2.5-flash" + model: str = 'gemini-2.5-flash' speech_config: Optional[types.SpeechConfig] = None use_interactions_api: bool = False """Whether to use the interactions API for model invocation. - When enabled, uses the interactions API (client.aio.interactions.create()) - instead of the traditional generate_content API. The interactions API - provides stateful conversation capabilities, allowing you to chain - interactions using previous_interaction_id instead of sending full history. - The response format will be converted to match the existing LlmResponse - structure for compatibility. - - Sample: - ```python - agent = Agent( - model=Gemini(use_interactions_api=True) - ) - ``` - """ + When enabled, uses the interactions API (client.aio.interactions.create()) + instead of the traditional generate_content API. The interactions API + provides stateful conversation capabilities, allowing you to chain + interactions using previous_interaction_id instead of sending full history. + The response format will be converted to match the existing LlmResponse + structure for compatibility. + + Sample: + ```python + agent = Agent( + model=Gemini(use_interactions_api=True) + ) + ``` + """ retry_options: Optional[types.HttpRetryOptions] = None """Allow Gemini to retry failed responses. - Sample: - ```python - from google.genai import types + Sample: + ```python + from google.genai import types - # ... + # ... - agent = Agent( - model=Gemini( - retry_options=types.HttpRetryOptions(initial_delay=1, attempts=2), - ) + agent = Agent( + model=Gemini( + retry_options=types.HttpRetryOptions(initial_delay=1, attempts=2), ) - ``` - """ + ) + ``` + """ disable_telemetry: bool = False - """A bool to flag whether or not telemetry should be being disabled for Gemini LLM interactions. - """ + """A bool to flag whether or not telemetry should be being disabled for + Gemini LLM interactions. + """ @classmethod @override @@ -141,13 +142,13 @@ def supported_models(cls) -> list[str]: """ return [ - r"gemini-.*", + r'gemini-.*', # model optimizer pattern - r"model-optimizer-.*", + r'model-optimizer-.*', # fine-tuned vertex endpoint pattern - r"projects\/.+\/locations\/.+\/endpoints\/.+", + r'projects\/.+\/locations\/.+\/endpoints\/.+', # vertex gemini long name - r"projects\/.+\/locations\/.+\/publishers\/google\/models\/gemini.+", + r'projects\/.+\/locations\/.+\/publishers\/google\/models\/gemini.+', ] async def generate_content_async( @@ -174,7 +175,7 @@ async def generate_content_async( if not self.disable_telemetry: from ..telemetry.tracing import tracer - with tracer.start_as_current_span("handle_context_caching") as span: + with tracer.start_as_current_span('handle_context_caching') as span: cache_manager = GeminiContextCacheManager( self.api_client, disable_telemetry=self.disable_telemetry ) @@ -183,10 +184,10 @@ async def generate_content_async( ) if cache_metadata: if cache_metadata.cache_name: - span.set_attribute("cache_action", "active_cache") - span.set_attribute("cache_name", cache_metadata.cache_name) + span.set_attribute('cache_action', 'active_cache') + span.set_attribute('cache_name', cache_metadata.cache_name) else: - span.set_attribute("cache_action", "fingerprint_only") + span.set_attribute('cache_action', 'fingerprint_only') else: cache_manager = GeminiContextCacheManager( self.api_client, disable_telemetry=self.disable_telemetry @@ -194,7 +195,7 @@ async def generate_content_async( cache_metadata = await cache_manager.handle_context_caching(llm_request) logger.info( - "Sending out request, model: %s, backend: %s, stream: %s", + 'Sending out request, model: %s, backend: %s, stream: %s', llm_request.model, self._api_backend, stream, @@ -258,7 +259,7 @@ async def generate_content_async( contents=llm_request.contents, config=llm_request.config, ) - logger.info("Response received from the model.") + logger.info('Response received from the model.') logger.debug(_build_response_log(response)) llm_response = LlmResponse.create(response) @@ -332,10 +333,10 @@ def _api_backend(self) -> GoogleLLMVariant: def _tracking_headers(self) -> dict[str, str]: labels = get_client_labels() - header_value = " ".join(labels) + header_value = ' '.join(labels) tracking_headers = { - "x-goog-api-client": header_value, - "user-agent": header_value, + 'x-goog-api-client': header_value, + 'user-agent': header_value, } return tracking_headers @@ -343,10 +344,10 @@ def _tracking_headers(self) -> dict[str, str]: def _live_api_version(self) -> str: if self._api_backend == GoogleLLMVariant.VERTEX_AI: # use beta version for vertex api - return "v1beta1" + return 'v1beta1' else: # use v1alpha for using API KEY from Google AI Studio - return "v1alpha" + return 'v1alpha' @cached_property def _live_api_client(self) -> Client: @@ -388,7 +389,7 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection: llm_request.live_connect_config.speech_config = self.speech_config llm_request.live_connect_config.system_instruction = types.Content( - role="system", + role='system', parts=[ types.Part.from_text(text=llm_request.config.system_instruction) ], @@ -398,22 +399,22 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection: and llm_request.live_connect_config.session_resumption.transparent ): logger.debug( - "session resumption config: %s", + 'session resumption config: %s', llm_request.live_connect_config.session_resumption, ) logger.debug( - "self._api_backend: %s", + 'self._api_backend: %s', self._api_backend, ) if self._api_backend == GoogleLLMVariant.GEMINI_API: raise ValueError( - "Transparent session resumption is only supported for Vertex AI" - " backend. Please use Vertex AI backend." + 'Transparent session resumption is only supported for Vertex AI' + ' backend. Please use Vertex AI backend.' ) llm_request.live_connect_config.tools = llm_request.config.tools - logger.info("Connecting to live for model: %s", llm_request.model) - logger.debug("Connecting to live with llm_request:%s", llm_request) - logger.debug("Live connect config: %s", llm_request.live_connect_config) + logger.info('Connecting to live for model: %s', llm_request.model) + logger.debug('Connecting to live with llm_request:%s', llm_request) + logger.debug('Live connect config: %s', llm_request.live_connect_config) async with self._live_api_client.aio.live.connect( model=llm_request.model, config=llm_request.live_connect_config ) as live_session: @@ -431,7 +432,7 @@ async def wait_5_seconds(): return wait_5_seconds await ComputerUseToolset.adapt_computer_use_tool( - "wait", convert_wait_to_wait_5_seconds, llm_request + 'wait', convert_wait_to_wait_5_seconds, llm_request ) async def _preprocess_request(self, llm_request: LlmRequest) -> None: @@ -472,18 +473,18 @@ def _merge_tracking_headers(self, headers: dict[str, str]) -> dict[str, str]: continue # Merge tracking headers with existing headers and avoid duplicates. - value_parts = tracking_header_value.split(" ") - for custom_value_part in custom_value.split(" "): + value_parts = tracking_header_value.split(' ') + for custom_value_part in custom_value.split(' '): if custom_value_part not in value_parts: value_parts.append(custom_value_part) - headers[key] = " ".join(value_parts) + headers[key] = ' '.join(value_parts) return headers def _build_function_declaration_log( func_decl: types.FunctionDeclaration, ) -> str: - param_str = "{}" + param_str = '{}' if func_decl.parameters and func_decl.parameters.properties: param_str = str({ k: v.model_dump(exclude_none=True) @@ -492,13 +493,13 @@ def _build_function_declaration_log( elif func_decl.parameters_json_schema: param_str = str(func_decl.parameters_json_schema) - return_str = "" + return_str = '' if func_decl.response: - return_str = "-> " + str(func_decl.response.model_dump(exclude_none=True)) + return_str = '-> ' + str(func_decl.response.model_dump(exclude_none=True)) elif func_decl.response_json_schema: - return_str = "-> " + str(func_decl.response_json_schema) + return_str = '-> ' + str(func_decl.response_json_schema) - return f"{func_decl.name}: {param_str} {return_str}" + return f'{func_decl.name}: {param_str} {return_str}' def _build_request_log(req: LlmRequest) -> str: @@ -527,7 +528,7 @@ def _build_request_log(req: LlmRequest) -> str: content.model_dump_json( exclude_none=True, exclude={ - "parts": { + 'parts': { i: _EXCLUDED_PART_FIELD for i in range(len(content.parts)) } }, @@ -537,7 +538,7 @@ def _build_request_log(req: LlmRequest) -> str: # Build exclusion dict for config logging tools_exclusion = ( - {function_decl_tool_index: {"function_declarations"}} + {function_decl_tool_index: {'function_declarations'}} if function_decl_tool_index is not None else True ) @@ -547,8 +548,8 @@ def _build_request_log(req: LlmRequest) -> str: req.config.model_dump( exclude_none=True, exclude={ - "system_instruction": True, - "tools": tools_exclusion if req.config.tools else True, + 'system_instruction': True, + 'tools': tools_exclusion if req.config.tools else True, }, ) ) @@ -578,7 +579,7 @@ def _build_response_log(resp: types.GenerateContentResponse) -> str: if function_calls := resp.function_calls: for func_call in function_calls: function_calls_text.append( - f"name: {func_call.name}, args: {func_call.args}" + f'name: {func_call.name}, args: {func_call.args}' ) return f""" LLM Response: diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 3dd455cfc9..b0897eb358 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -65,7 +65,7 @@ from .utils.context_utils import Aclosing from .utils.telemetry_utils import is_telemetry_enabled -logger = logging.getLogger("google_adk." + __name__) +logger = logging.getLogger('google_adk.' + __name__) def _is_tool_call_or_response(event: Event) -> bool: @@ -212,22 +212,22 @@ def _validate_runner_params( """ if plugins is not None: warnings.warn( - "The `plugins` argument is deprecated. Please use the `app` argument" - " to provide plugins instead.", + 'The `plugins` argument is deprecated. Please use the `app` argument' + ' to provide plugins instead.', DeprecationWarning, ) if app: if app_name: raise ValueError( - "When app is provided, app_name should not be provided." + 'When app is provided, app_name should not be provided.' ) if agent: - raise ValueError("When app is provided, agent should not be provided.") + raise ValueError('When app is provided, agent should not be provided.') if plugins: raise ValueError( - "When app is provided, plugins should not be provided and should be" - " provided in the app instead." + 'When app is provided, plugins should not be provided and should be' + ' provided in the app instead.' ) app_name = app.name agent = app.root_agent @@ -236,7 +236,7 @@ def _validate_runner_params( resumability_config = app.resumability_config elif not app_name or not agent: raise ValueError( - "Either app or both app_name and agent must be provided." + 'Either app or both app_name and agent must be provided.' ) else: context_cache_config = None @@ -263,8 +263,8 @@ def _infer_agent_origin( """ # First, check for metadata set by AgentLoader (most reliable source). # AgentLoader sets these attributes when loading agents. - origin_app_name = getattr(agent, "_adk_origin_app_name", None) - origin_path = getattr(agent, "_adk_origin_path", None) + origin_app_name = getattr(agent, '_adk_origin_app_name', None) + origin_path = getattr(agent, '_adk_origin_path', None) if origin_app_name is not None and origin_path is not None: return origin_app_name, origin_path @@ -276,10 +276,10 @@ def _infer_agent_origin( # Skip ADK internal modules. When users instantiate LlmAgent directly # (not subclassed), inspect.getmodule() returns the ADK module. This # could falsely match 'agents' in 'google/adk/agents/' path. - if module.__name__.startswith("google.adk."): + if module.__name__.startswith('google.adk.'): return None, None - module_file = getattr(module, "__file__", None) + module_file = getattr(module, '__file__', None) if not module_file: return None, None module_path = Path(module_file).resolve() @@ -289,17 +289,17 @@ def _infer_agent_origin( except ValueError: return None, module_path.parent origin_dir = module_path.parent - if "agents" not in relative_path.parts: + if 'agents' not in relative_path.parts: return None, origin_dir origin_name = origin_dir.name - if origin_name.startswith("."): + if origin_name.startswith('.'): return None, origin_dir return origin_name, origin_dir def _enforce_app_name_alignment(self) -> None: origin_name = self._agent_origin_app_name origin_dir = self._agent_origin_dir - if not origin_name or origin_name.startswith("__"): + if not origin_name or origin_name.startswith('__'): self._app_name_alignment_hint = None return if origin_name == self.app_name: @@ -307,24 +307,24 @@ def _enforce_app_name_alignment(self) -> None: return origin_location = str(origin_dir) if origin_dir else origin_name mismatch_details = ( - "The runner is configured with app name " + 'The runner is configured with app name ' f'"{self.app_name}", but the root agent was loaded from ' f'"{origin_location}", which implies app name "{origin_name}".' ) resolution = ( - "Ensure the runner app_name matches that directory or pass app_name " - "explicitly when constructing the runner." + 'Ensure the runner app_name matches that directory or pass app_name ' + 'explicitly when constructing the runner.' ) - self._app_name_alignment_hint = f"{mismatch_details} {resolution}" - logger.warning("App name mismatch detected. %s", mismatch_details) + self._app_name_alignment_hint = f'{mismatch_details} {resolution}' + logger.warning('App name mismatch detected. %s', mismatch_details) def _format_session_not_found_message(self, session_id: str) -> str: - message = f"Session not found: {session_id}" + message = f'Session not found: {session_id}' if not self._app_name_alignment_hint: return message return ( - f"{message}. {self._app_name_alignment_hint} " - "The mismatch prevents the runner from locating the session." + f'{message}. {self._app_name_alignment_hint} ' + 'The mismatch prevents the runner from locating the session.' ) def run( @@ -429,7 +429,7 @@ async def run_async( run_config = run_config or RunConfig() if new_message and not new_message.role: - new_message.role = "user" + new_message.role = 'user' async def _run_body( new_message: Optional[types.Content] = None, @@ -443,9 +443,9 @@ async def _run_body( raise ValueError(message) if not invocation_id and not new_message: raise ValueError( - "Running an agent requires either a new_message or an " - "invocation_id to resume a previous invocation. " - f"Session: {session_id}, User: {user_id}" + 'Running an agent requires either a new_message or an ' + 'invocation_id to resume a previous invocation. ' + f'Session: {session_id}, User: {user_id}' ) if invocation_id: @@ -454,8 +454,8 @@ async def _run_body( or not self.resumability_config.is_resumable ): raise ValueError( - f"invocation_id: {invocation_id} is provided but the app is not" - " resumable." + f'invocation_id: {invocation_id} is provided but the app is not' + ' resumable.' ) invocation_context = await self._setup_context_for_resumed_invocation( session=session, @@ -495,7 +495,7 @@ async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]: # (We don't compact in the middle of an invocation, we only compact at # the end of an invocation.) if self.app and self.app.events_compaction_config: - logger.debug("Running event compactor.") + logger.debug('Running event compactor.') await _run_compaction_for_sliding_window( self.app, session, self.session_service ) @@ -506,7 +506,7 @@ async def _run_with_optional_trace( invocation_id: Optional[str] = None, ) -> AsyncGenerator[Event, None]: if is_telemetry_enabled(agent): - with tracer.start_as_current_span("invocation"): + with tracer.start_as_current_span('invocation'): async with Aclosing( _run_body(new_message=new_message, invocation_id=invocation_id) ) as agen: @@ -537,7 +537,7 @@ async def rewind_async( app_name=self.app_name, user_id=user_id, session_id=session_id ) if not session: - raise ValueError(f"Session not found: {session_id}") + raise ValueError(f'Session not found: {session_id}') rewind_event_index = -1 for i, event in enumerate(session.events): @@ -547,7 +547,7 @@ async def rewind_async( if rewind_event_index == -1: raise ValueError( - f"Invocation ID not found: {rewind_before_invocation_id}" + f'Invocation ID not found: {rewind_before_invocation_id}' ) # Compute state delta to reverse changes @@ -563,7 +563,7 @@ async def rewind_async( # Create rewind event rewind_event = Event( invocation_id=new_invocation_context_id(), - author="user", + author='user', actions=EventActions( rewind_before_invocation_id=rewind_before_invocation_id, state_delta=state_delta, @@ -571,7 +571,7 @@ async def rewind_async( ), ) - logger.info("Rewinding session to invocation: %s", rewind_event) + logger.info('Rewinding session to invocation: %s', rewind_event) await self.session_service.append_event(session=session, event=rewind_event) @@ -583,7 +583,7 @@ async def _compute_state_delta_for_rewind( for i in range(rewind_event_index): if session.events[i].actions.state_delta: for k, v in session.events[i].actions.state_delta.items(): - if k.startswith("app:") or k.startswith("user:"): + if k.startswith('app:') or k.startswith('user:'): continue if v is None: state_at_rewind_point.pop(k, None) @@ -602,7 +602,7 @@ async def _compute_state_delta_for_rewind( # but not in state_at_rewind_point. These keys were added after the # rewind point and need to be removed. for key in current_state: - if key.startswith("app:") or key.startswith("user:"): + if key.startswith('app:') or key.startswith('user:'): continue if key not in state_at_rewind_point: rewind_state_delta[key] = None @@ -629,7 +629,7 @@ async def _compute_artifact_delta_for_rewind( rewind_artifact_delta = {} for filename, vn in current_versions.items(): - if filename.startswith("user:"): + if filename.startswith('user:'): # User artifacts are not restored on rewind. continue vt = versions_at_rewind_point.get(filename) @@ -641,7 +641,7 @@ async def _compute_artifact_delta_for_rewind( # Artifact did not exist at rewind point. Mark it as inaccessible. artifact = types.Part( inline_data=types.Blob( - mime_type="application/octet-stream", data=b"" + mime_type='application/octet-stream', data=b'' ) ) else: @@ -711,7 +711,7 @@ async def _exec_with_plugin( if isinstance(early_exit_result, types.Content): early_exit_event = Event( invocation_id=invocation_context.invocation_id, - author="model", + author='model', content=early_exit_result, ) if self._should_append_event(early_exit_event, is_live_call): @@ -763,7 +763,7 @@ async def _exec_with_plugin( # transcription end signal, append buffered events is_transcribing = False logger.debug( - "Appending transcription finished event: %s", event + 'Appending transcription finished event: %s', event ) if self._should_append_event(event, is_live_call): await self.session_service.append_event( @@ -771,7 +771,7 @@ async def _exec_with_plugin( ) for buffered_event in buffered_events: - logger.debug("Appending buffered event: %s", buffered_event) + logger.debug('Appending buffered event: %s', buffered_event) await self.session_service.append_event( session=session, event=buffered_event ) @@ -780,7 +780,7 @@ async def _exec_with_plugin( # non-transcription event or empty transcription event, for # example, event that stores blob reference, should be appended. if self._should_append_event(event, is_live_call): - logger.debug("Appending non-buffered event: %s", event) + logger.debug('Appending non-buffered event: %s', event) await self.session_service.append_event( session=session, event=event ) @@ -822,15 +822,15 @@ async def _append_new_message_to_session( state_delta: Optional state changes to apply to the session. """ if not new_message.parts: - raise ValueError("No parts in the new_message.") + raise ValueError('No parts in the new_message.') if self.artifact_service and save_input_blobs_as_artifacts: # Issue deprecation warning warnings.warn( "The 'save_input_blobs_as_artifacts' parameter is deprecated. Use" - " SaveFilesAsArtifactsPlugin instead for better control and" - " flexibility. See google.adk.plugins.SaveFilesAsArtifactsPlugin for" - " migration guidance.", + ' SaveFilesAsArtifactsPlugin instead for better control and' + ' flexibility. See google.adk.plugins.SaveFilesAsArtifactsPlugin for' + ' migration guidance.', DeprecationWarning, stacklevel=3, ) @@ -840,7 +840,7 @@ async def _append_new_message_to_session( for i, part in enumerate(new_message.parts): if part.inline_data is None: continue - file_name = f"artifact_{invocation_context.invocation_id}_{i}" + file_name = f'artifact_{invocation_context.invocation_id}_{i}' await self.artifact_service.save_artifact( app_name=self.app_name, user_id=session.user_id, @@ -849,20 +849,20 @@ async def _append_new_message_to_session( artifact=part, ) new_message.parts[i] = types.Part( - text=f"Uploaded file: {file_name}. It is saved into artifacts" + text=f'Uploaded file: {file_name}. It is saved into artifacts' ) # Appends only. We do not yield the event because it's not from the model. if state_delta: event = Event( invocation_id=invocation_context.invocation_id, - author="user", + author='user', actions=EventActions(state_delta=state_delta), content=new_message, ) else: event = Event( invocation_id=invocation_context.invocation_id, - author="user", + author='user', content=new_message, ) # If new_message is a function response, find the matching function call @@ -938,15 +938,15 @@ async def run_live( # Some native audio models requires the modality to be set. So we set it to # AUDIO by default. if run_config.response_modalities is None: - run_config.response_modalities = ["AUDIO"] + run_config.response_modalities = ['AUDIO'] if session is None and (user_id is None or session_id is None): raise ValueError( - "Either session or user_id and session_id must be provided." + 'Either session or user_id and session_id must be provided.' ) if session is not None: warnings.warn( - "The `session` parameter is deprecated. Please use `user_id` and" - " `session_id` instead.", + 'The `session` parameter is deprecated. Please use `user_id` and' + ' `session_id` instead.', DeprecationWarning, stacklevel=2, ) @@ -955,7 +955,7 @@ async def run_live( app_name=self.app_name, user_id=user_id, session_id=session_id ) if not session: - raise ValueError(f"Session not found: {session_id}") + raise ValueError(f'Session not found: {session_id}') invocation_context = self._new_invocation_context_for_live( session, live_request_queue=live_request_queue, @@ -970,7 +970,7 @@ async def run_live( invocation_context.active_streaming_tools = {} # TODO(hangfei): switch to use canonical_tools. # for shell agents, there is no tools associated with it so we should skip. - if hasattr(invocation_context.agent, "tools"): + if hasattr(invocation_context.agent, 'tools'): import inspect for tool in invocation_context.agent.tools: @@ -989,7 +989,7 @@ async def run_live( # annotation object as it was defined on the function. This allows us to # perform a direct and reliable identity check (`param.annotation is LiveRequestQueue`) # without risking a `NameError`. - callable_to_inspect = tool.func if hasattr(tool, "func") else tool + callable_to_inspect = tool.func if hasattr(tool, 'func') else tool # Ensure the target is actually callable before inspecting to avoid errors. if not callable(callable_to_inspect): continue @@ -1051,7 +1051,7 @@ def _find_agent_to_run( def _event_filter(event: Event) -> bool: """Filters out user-authored events and agent state change events.""" - if event.author == "user": + if event.author == 'user': return False if event.actions.agent_state is not None or event.actions.end_of_agent: return False @@ -1064,7 +1064,7 @@ def _event_filter(event: Event) -> bool: if not (agent := root_agent.find_sub_agent(event.author)): # Agent not found, continue looking. logger.warning( - "Event from an unknown agent: %s, event id: %s", + 'Event from an unknown agent: %s, event id: %s', event.author, event.id, ) @@ -1100,8 +1100,8 @@ async def run_debug( self, user_messages: str | list[str], *, - user_id: str = "debug_user_id", - session_id: str = "debug_session_id", + user_id: str = 'debug_user_id', + session_id: str = 'debug_session_id', run_config: RunConfig | None = None, quiet: bool = False, verbose: bool = False, @@ -1172,9 +1172,9 @@ async def run_debug( app_name=self.app_name, user_id=user_id, session_id=session_id ) if not quiet: - print(f"\n ### Created new session: {session_id}") + print(f'\n ### Created new session: {session_id}') elif not quiet: - print(f"\n ### Continue session: {session_id}") + print(f'\n ### Continue session: {session_id}') collected_events: list[Event] = [] @@ -1183,7 +1183,7 @@ async def run_debug( for message in user_messages: if not quiet: - print(f"\nUser > {message}") + print(f'\nUser > {message}') async for event in self.run_async( user_id=user_id, @@ -1262,7 +1262,7 @@ async def _setup_context_for_resumed_invocation( available for resuming the invocation; Or if the app is not resumable. """ if not session.events: - raise ValueError(f"Session {session.id} has no events to resume.") + raise ValueError(f'Session {session.id} has no events to resume.') # Step 1: Maybe retrieve a previous user message for the invocation. user_message = new_message or self._find_user_message_for_invocation( @@ -1270,7 +1270,7 @@ async def _setup_context_for_resumed_invocation( ) if not user_message: raise ValueError( - f"No user message available for resuming invocation: {invocation_id}" + f'No user message available for resuming invocation: {invocation_id}' ) # Step 2: Create invocation context. invocation_context = self._new_invocation_context( @@ -1306,7 +1306,7 @@ def _find_user_message_for_invocation( for event in events: if ( event.invocation_id == invocation_id - and event.author == "user" + and event.author == 'user' and event.content and event.content.parts and event.content.parts[0].text @@ -1340,10 +1340,10 @@ def _new_invocation_context( if run_config.support_cfc and isinstance(self.agent, LlmAgent): model_name = self.agent.canonical_model.model - if not model_name.startswith("gemini-2"): + if not model_name.startswith('gemini-2'): raise ValueError( - f"CFC is not supported for model: {model_name} in agent:" - f" {self.agent.name}" + f'CFC is not supported for model: {model_name} in agent:' + f' {self.agent.name}' ) if not isinstance(self.agent.code_executor, BuiltInCodeExecutor): self.agent.code_executor = BuiltInCodeExecutor() @@ -1379,12 +1379,12 @@ def _new_invocation_context_for_live( if self.agent.sub_agents and live_request_queue: if not run_config.response_modalities: # default - run_config.response_modalities = ["AUDIO"] + run_config.response_modalities = ['AUDIO'] if not run_config.output_audio_transcription: run_config.output_audio_transcription = ( types.AudioTranscriptionConfig() ) - elif "TEXT" not in run_config.response_modalities: + elif 'TEXT' not in run_config.response_modalities: if not run_config.output_audio_transcription: run_config.output_audio_transcription = ( types.AudioTranscriptionConfig() @@ -1453,12 +1453,12 @@ async def _cleanup_toolsets(self, toolsets_to_close: set[BaseToolset]): # This maintains the same task context throughout cleanup for toolset in toolsets_to_close: try: - logger.info("Closing toolset: %s", type(toolset).__name__) + logger.info('Closing toolset: %s', type(toolset).__name__) # Use asyncio.wait_for to add timeout protection await asyncio.wait_for(toolset.close(), timeout=10.0) - logger.info("Successfully closed toolset: %s", type(toolset).__name__) + logger.info('Successfully closed toolset: %s', type(toolset).__name__) except asyncio.TimeoutError: - logger.warning("Toolset %s cleanup timed out", type(toolset).__name__) + logger.warning('Toolset %s cleanup timed out', type(toolset).__name__) except asyncio.CancelledError as e: # Handle cancel scope issues in Python 3.10 and 3.11 with anyio # @@ -1470,14 +1470,14 @@ async def _cleanup_toolsets(self, toolsets_to_close: set[BaseToolset]): # improved context propagation across task boundaries, and better cancellation # handling prevent the cross-task cancel scope violation. logger.warning( - "Toolset %s cleanup cancelled: %s", type(toolset).__name__, e + 'Toolset %s cleanup cancelled: %s', type(toolset).__name__, e ) except Exception as e: - logger.error("Error closing toolset %s: %s", type(toolset).__name__, e) + logger.error('Error closing toolset %s: %s', type(toolset).__name__, e) async def close(self): """Closes the runner.""" - logger.info("Closing runner...") + logger.info('Closing runner...') # Close Toolsets await self._cleanup_toolsets(self._collect_toolset(self.agent)) @@ -1485,10 +1485,10 @@ async def close(self): if self.plugin_manager: await self.plugin_manager.close() - logger.info("Runner closed.") + logger.info('Runner closed.') if sys.version_info < (3, 11): - Self = "Runner" # pylint: disable=invalid-name + Self = 'Runner' # pylint: disable=invalid-name else: from typing import Self # pylint: disable=g-import-not-at-top @@ -1535,7 +1535,7 @@ def __init__( plugin_close_timeout: The timeout in seconds for plugin close methods. """ if app is None and app_name is None: - app_name = "InMemoryRunner" + app_name = 'InMemoryRunner' super().__init__( app_name=app_name, agent=agent, From a097fcfadcc7ef545b6d2ce45a0d58119ecfa4c1 Mon Sep 17 00:00:00 2001 From: mportdata Date: Sun, 28 Dec 2025 15:25:16 +0000 Subject: [PATCH 18/24] refactor(models): rename cache helper Use _create_gemini_cache name consistently and update tests. Revert tracing.py quote-only churn. --- .../models/gemini_context_cache_manager.py | 4 +- src/google/adk/telemetry/tracing.py | 124 +++++++++--------- .../test_gemini_context_cache_manager.py | 2 +- 3 files changed, 65 insertions(+), 65 deletions(-) diff --git a/src/google/adk/models/gemini_context_cache_manager.py b/src/google/adk/models/gemini_context_cache_manager.py index 81da84a022..7ac6795c7c 100644 --- a/src/google/adk/models/gemini_context_cache_manager.py +++ b/src/google/adk/models/gemini_context_cache_manager.py @@ -309,7 +309,7 @@ async def _create_new_cache_with_contents( try: # Create cache using Gemini API directly - return await self._create_gemini_cache_with_optional_tracing( + return await self._create_gemini_cache( llm_request, cache_contents_count ) except Exception as e: @@ -349,7 +349,7 @@ def _estimate_request_tokens(self, llm_request: LlmRequest) -> int: # Rough estimate: 4 characters per token return total_chars // 4 - async def _create_gemini_cache_with_optional_tracing( + async def _create_gemini_cache( self, llm_request: LlmRequest, cache_contents_count: int ) -> CacheMetadata: """Create cache using Gemini API. diff --git a/src/google/adk/telemetry/tracing.py b/src/google/adk/telemetry/tracing.py index 9e67ac6c0e..f03cdc8010 100644 --- a/src/google/adk/telemetry/tracing.py +++ b/src/google/adk/telemetry/tracing.py @@ -37,16 +37,16 @@ # By default some ADK spans include attributes with potential PII data. # This env, when set to false, allows to disable populating those attributes. -ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS = "ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS" +ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS = 'ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS' # TODO: Replace with constant from opentelemetry.semconv when it reaches version 1.37 in g3. -GEN_AI_AGENT_DESCRIPTION = "gen_ai.agent.description" -GEN_AI_AGENT_NAME = "gen_ai.agent.name" -GEN_AI_CONVERSATION_ID = "gen_ai.conversation.id" -GEN_AI_OPERATION_NAME = "gen_ai.operation.name" -GEN_AI_TOOL_CALL_ID = "gen_ai.tool.call.id" -GEN_AI_TOOL_DESCRIPTION = "gen_ai.tool.description" -GEN_AI_TOOL_NAME = "gen_ai.tool.name" -GEN_AI_TOOL_TYPE = "gen_ai.tool.type" +GEN_AI_AGENT_DESCRIPTION = 'gen_ai.agent.description' +GEN_AI_AGENT_NAME = 'gen_ai.agent.name' +GEN_AI_CONVERSATION_ID = 'gen_ai.conversation.id' +GEN_AI_OPERATION_NAME = 'gen_ai.operation.name' +GEN_AI_TOOL_CALL_ID = 'gen_ai.tool.call.id' +GEN_AI_TOOL_DESCRIPTION = 'gen_ai.tool.description' +GEN_AI_TOOL_NAME = 'gen_ai.tool.name' +GEN_AI_TOOL_TYPE = 'gen_ai.tool.type' # Needed to avoid circular imports if TYPE_CHECKING: @@ -57,10 +57,10 @@ from ..tools.base_tool import BaseTool tracer = trace.get_tracer( - instrumenting_module_name="gcp.vertex.agent", + instrumenting_module_name='gcp.vertex.agent', instrumenting_library_version=version.__version__, # TODO: Replace with constant from opentelemetry.semconv when it reaches version 1.37 in g3. - schema_url="https://opentelemetry.io/schemas/1.37.0", + schema_url='https://opentelemetry.io/schemas/1.37.0', ) @@ -77,10 +77,10 @@ def _safe_json_serialize(obj) -> str: try: # Try direct JSON serialization first return json.dumps( - obj, ensure_ascii=False, default=lambda o: "" + obj, ensure_ascii=False, default=lambda o: '' ) except (TypeError, OverflowError): - return "" + return '' def trace_agent_invocation( @@ -107,7 +107,7 @@ def trace_agent_invocation( """ # Required - span.set_attribute(GEN_AI_OPERATION_NAME, "invoke_agent") + span.set_attribute(GEN_AI_OPERATION_NAME, 'invoke_agent') # Conditionally Required span.set_attribute(GEN_AI_AGENT_DESCRIPTION, agent.description) @@ -130,7 +130,7 @@ def trace_tool_call( """ span = trace.get_current_span() - span.set_attribute(GEN_AI_OPERATION_NAME, "execute_tool") + span.set_attribute(GEN_AI_OPERATION_NAME, 'execute_tool') span.set_attribute(GEN_AI_TOOL_DESCRIPTION, tool.description) span.set_attribute(GEN_AI_TOOL_NAME, tool.name) @@ -140,20 +140,20 @@ def trace_tool_call( # Setting empty llm request and response (as UI expect these) while not # applicable for tool_response. - span.set_attribute("gcp.vertex.agent.llm_request", "{}") - span.set_attribute("gcp.vertex.agent.llm_response", "{}") + span.set_attribute('gcp.vertex.agent.llm_request', '{}') + span.set_attribute('gcp.vertex.agent.llm_response', '{}') if _should_add_request_response_to_spans(): span.set_attribute( - "gcp.vertex.agent.tool_call_args", + 'gcp.vertex.agent.tool_call_args', _safe_json_serialize(args), ) else: - span.set_attribute("gcp.vertex.agent.tool_call_args", {}) + span.set_attribute('gcp.vertex.agent.tool_call_args', {}) # Tracing tool response - tool_call_id = "" - tool_response = "" + tool_call_id = '' + tool_response = '' if ( function_response_event is not None and function_response_event.content is not None @@ -170,16 +170,16 @@ def trace_tool_call( span.set_attribute(GEN_AI_TOOL_CALL_ID, tool_call_id) if not isinstance(tool_response, dict): - tool_response = {"result": tool_response} + tool_response = {'result': tool_response} if function_response_event is not None: - span.set_attribute("gcp.vertex.agent.event_id", function_response_event.id) + span.set_attribute('gcp.vertex.agent.event_id', function_response_event.id) if _should_add_request_response_to_spans(): span.set_attribute( - "gcp.vertex.agent.tool_response", + 'gcp.vertex.agent.tool_response', _safe_json_serialize(tool_response), ) else: - span.set_attribute("gcp.vertex.agent.tool_response", {}) + span.set_attribute('gcp.vertex.agent.tool_response', {}) def trace_merged_tool_calls( @@ -198,34 +198,34 @@ def trace_merged_tool_calls( span = trace.get_current_span() - span.set_attribute(GEN_AI_OPERATION_NAME, "execute_tool") - span.set_attribute(GEN_AI_TOOL_NAME, "(merged tools)") - span.set_attribute(GEN_AI_TOOL_DESCRIPTION, "(merged tools)") + span.set_attribute(GEN_AI_OPERATION_NAME, 'execute_tool') + span.set_attribute(GEN_AI_TOOL_NAME, '(merged tools)') + span.set_attribute(GEN_AI_TOOL_DESCRIPTION, '(merged tools)') span.set_attribute(GEN_AI_TOOL_CALL_ID, response_event_id) # TODO(b/441461932): See if these are still necessary - span.set_attribute("gcp.vertex.agent.tool_call_args", "N/A") - span.set_attribute("gcp.vertex.agent.event_id", response_event_id) + span.set_attribute('gcp.vertex.agent.tool_call_args', 'N/A') + span.set_attribute('gcp.vertex.agent.event_id', response_event_id) try: function_response_event_json = function_response_event.model_dumps_json( exclude_none=True ) except Exception: # pylint: disable=broad-exception-caught - function_response_event_json = "" + function_response_event_json = '' if _should_add_request_response_to_spans(): span.set_attribute( - "gcp.vertex.agent.tool_response", + 'gcp.vertex.agent.tool_response', function_response_event_json, ) else: - span.set_attribute("gcp.vertex.agent.tool_response", {}) + span.set_attribute('gcp.vertex.agent.tool_response', {}) # Setting empty llm request and response (as UI expect these) while not # applicable for tool_response. - span.set_attribute("gcp.vertex.agent.llm_request", "{}") + span.set_attribute('gcp.vertex.agent.llm_request', '{}') span.set_attribute( - "gcp.vertex.agent.llm_response", - "{}", + 'gcp.vertex.agent.llm_response', + '{}', ) @@ -249,57 +249,57 @@ def trace_call_llm( span = trace.get_current_span() # Special standard Open Telemetry GenaI attributes that indicate # that this is a span related to a Generative AI system. - span.set_attribute("gen_ai.system", "gcp.vertex.agent") - span.set_attribute("gen_ai.request.model", llm_request.model) + span.set_attribute('gen_ai.system', 'gcp.vertex.agent') + span.set_attribute('gen_ai.request.model', llm_request.model) span.set_attribute( - "gcp.vertex.agent.invocation_id", invocation_context.invocation_id + 'gcp.vertex.agent.invocation_id', invocation_context.invocation_id ) span.set_attribute( - "gcp.vertex.agent.session_id", invocation_context.session.id + 'gcp.vertex.agent.session_id', invocation_context.session.id ) - span.set_attribute("gcp.vertex.agent.event_id", event_id) + span.set_attribute('gcp.vertex.agent.event_id', event_id) # Consider removing once GenAI SDK provides a way to record this info. if _should_add_request_response_to_spans(): span.set_attribute( - "gcp.vertex.agent.llm_request", + 'gcp.vertex.agent.llm_request', _safe_json_serialize(_build_llm_request_for_trace(llm_request)), ) else: - span.set_attribute("gcp.vertex.agent.llm_request", {}) + span.set_attribute('gcp.vertex.agent.llm_request', {}) # Consider removing once GenAI SDK provides a way to record this info. if llm_request.config: if llm_request.config.top_p: span.set_attribute( - "gen_ai.request.top_p", + 'gen_ai.request.top_p', llm_request.config.top_p, ) if llm_request.config.max_output_tokens: span.set_attribute( - "gen_ai.request.max_tokens", + 'gen_ai.request.max_tokens', llm_request.config.max_output_tokens, ) try: llm_response_json = llm_response.model_dump_json(exclude_none=True) except Exception: # pylint: disable=broad-exception-caught - llm_response_json = "" + llm_response_json = '' if _should_add_request_response_to_spans(): span.set_attribute( - "gcp.vertex.agent.llm_response", + 'gcp.vertex.agent.llm_response', llm_response_json, ) else: - span.set_attribute("gcp.vertex.agent.llm_response", {}) + span.set_attribute('gcp.vertex.agent.llm_response', {}) if llm_response.usage_metadata is not None: span.set_attribute( - "gen_ai.usage.input_tokens", + 'gen_ai.usage.input_tokens', llm_response.usage_metadata.prompt_token_count, ) if llm_response.usage_metadata.candidates_token_count is not None: span.set_attribute( - "gen_ai.usage.output_tokens", + 'gen_ai.usage.output_tokens', llm_response.usage_metadata.candidates_token_count, ) if llm_response.finish_reason: @@ -308,7 +308,7 @@ def trace_call_llm( except AttributeError: finish_reason_str = str(llm_response.finish_reason).lower() span.set_attribute( - "gen_ai.response.finish_reasons", + 'gen_ai.response.finish_reasons', [finish_reason_str], ) @@ -330,14 +330,14 @@ def trace_send_data( """ span = trace.get_current_span() span.set_attribute( - "gcp.vertex.agent.invocation_id", invocation_context.invocation_id + 'gcp.vertex.agent.invocation_id', invocation_context.invocation_id ) - span.set_attribute("gcp.vertex.agent.event_id", event_id) + span.set_attribute('gcp.vertex.agent.event_id', event_id) # Once instrumentation is added to the GenAI SDK, consider whether this # information still needs to be recorded by the Agent Development Kit. if _should_add_request_response_to_spans(): span.set_attribute( - "gcp.vertex.agent.data", + 'gcp.vertex.agent.data', _safe_json_serialize([ types.Content(role=content.role, parts=content.parts).model_dump( exclude_none=True @@ -346,7 +346,7 @@ def trace_send_data( ]), ) else: - span.set_attribute("gcp.vertex.agent.data", {}) + span.set_attribute('gcp.vertex.agent.data', {}) def _build_llm_request_for_trace(llm_request: LlmRequest) -> dict[str, Any]: @@ -364,16 +364,16 @@ def _build_llm_request_for_trace(llm_request: LlmRequest) -> dict[str, Any]: """ # Some fields in LlmRequest are function pointers and cannot be serialized. result = { - "model": llm_request.model, - "config": llm_request.config.model_dump( - exclude_none=True, exclude="response_schema" + 'model': llm_request.model, + 'config': llm_request.config.model_dump( + exclude_none=True, exclude='response_schema' ), - "contents": [], + 'contents': [], } # We do not want to send bytes data to the trace. for content in llm_request.contents: parts = [part for part in content.parts if not part.inline_data] - result["contents"].append( + result['contents'].append( types.Content(role=content.role, parts=parts).model_dump( exclude_none=True ) @@ -387,6 +387,6 @@ def _build_llm_request_for_trace(llm_request: LlmRequest) -> dict[str, Any]: # to false. def _should_add_request_response_to_spans() -> bool: disabled_via_env_var = os.getenv( - ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS, "true" - ).lower() in ("false", "0") + ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS, 'true' + ).lower() in ('false', '0') return not disabled_via_env_var diff --git a/tests/unittests/agents/test_gemini_context_cache_manager.py b/tests/unittests/agents/test_gemini_context_cache_manager.py index 0575d4eaff..08f45ab921 100644 --- a/tests/unittests/agents/test_gemini_context_cache_manager.py +++ b/tests/unittests/agents/test_gemini_context_cache_manager.py @@ -479,7 +479,7 @@ async def test_create_new_cache_with_proper_ttl(self): with patch.object( self.manager, "_generate_cache_fingerprint", return_value="test_fp" ): - await self.manager._create_gemini_cache_with_optional_tracing( + await self.manager._create_gemini_cache( llm_request, cache_contents_count ) From de5ac88c65d7f06fac2e7f5a39faa7798e74f78a Mon Sep 17 00:00:00 2001 From: mportdata Date: Sun, 28 Dec 2025 15:48:39 +0000 Subject: [PATCH 19/24] removed unnecessary formatting change for easier code review --- src/google/adk/models/gemini_context_cache_manager.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/google/adk/models/gemini_context_cache_manager.py b/src/google/adk/models/gemini_context_cache_manager.py index 7ac6795c7c..516b7c0838 100644 --- a/src/google/adk/models/gemini_context_cache_manager.py +++ b/src/google/adk/models/gemini_context_cache_manager.py @@ -309,9 +309,7 @@ async def _create_new_cache_with_contents( try: # Create cache using Gemini API directly - return await self._create_gemini_cache( - llm_request, cache_contents_count - ) + return await self._create_gemini_cache(llm_request, cache_contents_count) except Exception as e: logger.warning("Failed to create cache: %s", e) return None From a4026e7acda824981e4028383cf5735b335c11f6 Mon Sep 17 00:00:00 2001 From: mportdata Date: Sun, 28 Dec 2025 15:56:48 +0000 Subject: [PATCH 20/24] removed unnecessary formatting change for easier code review --- tests/unittests/agents/test_gemini_context_cache_manager.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/unittests/agents/test_gemini_context_cache_manager.py b/tests/unittests/agents/test_gemini_context_cache_manager.py index 08f45ab921..0443843ae1 100644 --- a/tests/unittests/agents/test_gemini_context_cache_manager.py +++ b/tests/unittests/agents/test_gemini_context_cache_manager.py @@ -479,9 +479,7 @@ async def test_create_new_cache_with_proper_ttl(self): with patch.object( self.manager, "_generate_cache_fingerprint", return_value="test_fp" ): - await self.manager._create_gemini_cache( - llm_request, cache_contents_count - ) + await self.manager._create_gemini_cache(llm_request, cache_contents_count) # Verify cache creation call includes TTL create_call = self.manager.genai_client.aio.caches.create.call_args From cafe5e1f9146f40a1e5fab3aa69ebd8df8d683a4 Mon Sep 17 00:00:00 2001 From: mportdata Date: Sun, 28 Dec 2025 16:01:14 +0000 Subject: [PATCH 21/24] reordered the lines on changes to base_agent to minimise diffs for reviewing purposes --- src/google/adk/agents/base_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index 23240bf6f7..3d6cf2f107 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -315,12 +315,12 @@ async def run_live( Event: the events generated by the agent. """ - ctx = self._create_invocation_context(parent_context) span_context = contextlib.nullcontext() if is_telemetry_enabled(self): span_context = tracer.start_as_current_span(f'invoke_agent {self.name}') with span_context as span: + ctx = self._create_invocation_context(parent_context) if span: tracing.trace_agent_invocation(span, self, ctx) async for event in self._run_callbacks_and_impl(ctx, mode='live'): From 7905c89b62e01dd071a70a3fcc1e2439ac600c5b Mon Sep 17 00:00:00 2001 From: mportdata Date: Sun, 28 Dec 2025 16:08:27 +0000 Subject: [PATCH 22/24] reordered the lines on changes to base_agent to minimise diffs for reviewing purposes --- src/google/adk/agents/base_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index 3d6cf2f107..0eef85ad84 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -286,12 +286,12 @@ async def run_async( Event: the events generated by the agent. """ - ctx = self._create_invocation_context(parent_context) span_context = contextlib.nullcontext() if is_telemetry_enabled(self): span_context = tracer.start_as_current_span(f'invoke_agent {self.name}') with span_context as span: + ctx = self._create_invocation_context(parent_context) if span: tracing.trace_agent_invocation(span, self, ctx) async with Aclosing( From eee9b47395f5d4bb376e75edf9c7cd2542e8a79f Mon Sep 17 00:00:00 2001 From: mportdata Date: Sun, 28 Dec 2025 16:15:28 +0000 Subject: [PATCH 23/24] reordered lines to make review easier - now closer to original code --- src/google/adk/flows/llm_flows/base_llm_flow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 65bed9d55a..9730adc078 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -131,11 +131,11 @@ async def run_live( async with llm.connect(llm_request) as llm_connection: if llm_request.contents: # Sends the conversation history to the model. - logger.debug('Sending history to model: %s', llm_request.contents) span_context = contextlib.nullcontext() if is_telemetry_enabled(invocation_context.agent): span_context = tracer.start_as_current_span('send_data') with span_context as span: + logger.debug('Sending history to model: %s', llm_request.contents) await llm_connection.send_history(llm_request.contents) if span: trace_send_data( From 9fe906d595a360c9caddc8ceea8d74bc98596499 Mon Sep 17 00:00:00 2001 From: mportdata Date: Sun, 28 Dec 2025 16:23:10 +0000 Subject: [PATCH 24/24] removed abstraction _create_gemini_cache_body as this is only used once and was not needed --- .../models/gemini_context_cache_manager.py | 124 +++++++----------- 1 file changed, 51 insertions(+), 73 deletions(-) diff --git a/src/google/adk/models/gemini_context_cache_manager.py b/src/google/adk/models/gemini_context_cache_manager.py index 516b7c0838..c03a99f87c 100644 --- a/src/google/adk/models/gemini_context_cache_manager.py +++ b/src/google/adk/models/gemini_context_cache_manager.py @@ -366,87 +366,65 @@ async def _create_gemini_cache( span_context = tracer.start_as_current_span("create_cache") with span_context as span: - return await self._create_gemini_cache_body( - llm_request=llm_request, - cache_contents_count=cache_contents_count, - span=span, + # Prepare cache contents (first N contents + system instruction + tools) + cache_contents = llm_request.contents[:cache_contents_count] + + cache_config = types.CreateCachedContentConfig( + contents=cache_contents, + ttl=llm_request.cache_config.ttl_string, + display_name=( + f"adk-cache-{int(time.time())}-{cache_contents_count}contents" + ), ) - async def _create_gemini_cache_body( - self, - llm_request: LlmRequest, - cache_contents_count: int, - span: Optional[Span] = None, - ) -> CacheMetadata: - """Create cache using Gemini API. - - Args: - llm_request: Request to create cache for - cache_contents_count: Number of contents to cache + # Add system instruction if present + if llm_request.config and llm_request.config.system_instruction: + cache_config.system_instruction = llm_request.config.system_instruction + logger.debug( + "Added system instruction to cache config (length=%d)", + len(llm_request.config.system_instruction), + ) - Returns: - Cache metadata with precise creation timestamp - """ + # Add tools if present + if llm_request.config and llm_request.config.tools: + cache_config.tools = llm_request.config.tools - # Prepare cache contents (first N contents + system instruction + tools) - cache_contents = llm_request.contents[:cache_contents_count] + # Add tool config if present + if llm_request.config and llm_request.config.tool_config: + cache_config.tool_config = llm_request.config.tool_config - cache_config = types.CreateCachedContentConfig( - contents=cache_contents, - ttl=llm_request.cache_config.ttl_string, - display_name=( - f"adk-cache-{int(time.time())}-{cache_contents_count}contents" - ), - ) + if span is not None: + span.set_attribute("cache_contents_count", cache_contents_count) + span.set_attribute("model", llm_request.model) + span.set_attribute("ttl_seconds", llm_request.cache_config.ttl_seconds) - # Add system instruction if present - if llm_request.config and llm_request.config.system_instruction: - cache_config.system_instruction = llm_request.config.system_instruction logger.debug( - "Added system instruction to cache config (length=%d)", - len(llm_request.config.system_instruction), + "Creating cache with model %s and config: %s", + llm_request.model, + cache_config, + ) + cached_content = await self.genai_client.aio.caches.create( + model=llm_request.model, + config=cache_config, + ) + # Set precise creation timestamp right after cache creation + created_at = time.time() + logger.info("Cache created successfully: %s", cached_content.name) + + if span is not None: + span.set_attribute("cache_name", cached_content.name) + + # Return complete cache metadata with precise timing + return CacheMetadata( + cache_name=cached_content.name, + expire_time=created_at + llm_request.cache_config.ttl_seconds, + fingerprint=self._generate_cache_fingerprint( + llm_request, cache_contents_count + ), + invocations_used=1, + contents_count=cache_contents_count, + created_at=created_at, ) - - # Add tools if present - if llm_request.config and llm_request.config.tools: - cache_config.tools = llm_request.config.tools - - # Add tool config if present - if llm_request.config and llm_request.config.tool_config: - cache_config.tool_config = llm_request.config.tool_config - - if span is not None: - span.set_attribute("cache_contents_count", cache_contents_count) - span.set_attribute("model", llm_request.model) - span.set_attribute("ttl_seconds", llm_request.cache_config.ttl_seconds) - - logger.debug( - "Creating cache with model %s and config: %s", - llm_request.model, - cache_config, - ) - cached_content = await self.genai_client.aio.caches.create( - model=llm_request.model, - config=cache_config, - ) - # Set precise creation timestamp right after cache creation - created_at = time.time() - logger.info("Cache created successfully: %s", cached_content.name) - - if span is not None: - span.set_attribute("cache_name", cached_content.name) - - # Return complete cache metadata with precise timing - return CacheMetadata( - cache_name=cached_content.name, - expire_time=created_at + llm_request.cache_config.ttl_seconds, - fingerprint=self._generate_cache_fingerprint( - llm_request, cache_contents_count - ), - invocations_used=1, - contents_count=cache_contents_count, - created_at=created_at, - ) async def cleanup_cache(self, cache_name: str) -> None: """Clean up cache by deleting it.