diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index e15f9af981..0eef85ad84 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -14,6 +14,7 @@ from __future__ import annotations +import contextlib import inspect import logging from typing import Any @@ -44,6 +45,7 @@ from ..telemetry.tracing import tracer from ..utils.context_utils import Aclosing from ..utils.feature_decorator import experimental +from ..utils.telemetry_utils import is_telemetry_enabled from .base_agent_config import BaseAgentConfig from .callback_context import CallbackContext @@ -132,6 +134,8 @@ class MyAgent(BaseAgent): """ sub_agents: list[BaseAgent] = Field(default_factory=list) """The sub-agents of this agent.""" + disable_telemetry: bool = False + """Whether to disable telemetry for this agent.""" before_agent_callback: Optional[BeforeAgentCallback] = None """Callback or list of callbacks to be invoked before the agent run. @@ -282,24 +286,20 @@ async def run_async( Event: the events generated by the agent. """ - with tracer.start_as_current_span(f'invoke_agent {self.name}') as span: - ctx = self._create_invocation_context(parent_context) - tracing.trace_agent_invocation(span, self, ctx) - if event := await self._handle_before_agent_callback(ctx): - yield event - if ctx.end_invocation: - return + span_context = contextlib.nullcontext() + if is_telemetry_enabled(self): + span_context = tracer.start_as_current_span(f'invoke_agent {self.name}') - async with Aclosing(self._run_async_impl(ctx)) as agen: + with span_context as span: + ctx = self._create_invocation_context(parent_context) + if span: + tracing.trace_agent_invocation(span, self, ctx) + async with Aclosing( + self._run_callbacks_and_impl(ctx, mode='async') + ) as agen: async for event in agen: yield event - if ctx.end_invocation: - return - - if event := await self._handle_after_agent_callback(ctx): - yield event - @final async def run_live( self, @@ -315,19 +315,15 @@ async def run_live( Event: the events generated by the agent. """ - with tracer.start_as_current_span(f'invoke_agent {self.name}') as span: - ctx = self._create_invocation_context(parent_context) - tracing.trace_agent_invocation(span, self, ctx) - if event := await self._handle_before_agent_callback(ctx): - yield event - if ctx.end_invocation: - return - - async with Aclosing(self._run_live_impl(ctx)) as agen: - async for event in agen: - yield event + span_context = contextlib.nullcontext() + if is_telemetry_enabled(self): + span_context = tracer.start_as_current_span(f'invoke_agent {self.name}') - if event := await self._handle_after_agent_callback(ctx): + with span_context as span: + ctx = self._create_invocation_context(parent_context) + if span: + tracing.trace_agent_invocation(span, self, ctx) + async for event in self._run_callbacks_and_impl(ctx, mode='live'): yield event async def _run_async_impl( @@ -362,6 +358,39 @@ async def _run_live_impl( ) yield # AsyncGenerator requires having at least one yield statement + async def _run_callbacks_and_impl( + self, ctx: InvocationContext, mode: str = 'async' + ) -> AsyncGenerator[Event, None]: + """Runs the before and after agent callbacks around the core agent logic. + Args: + ctx: InvocationContext, the invocation context for this agent. + mode: str, either 'async' or 'live', indicating which core agent logic to run. + Yields: + Event: the events generated by the agent. + """ + if event := await self._handle_before_agent_callback(ctx): + yield event + if ctx.end_invocation: + return + if mode.lower() == 'async': + async with Aclosing(self._run_async_impl(ctx)) as agen: + async for event in agen: + yield event + elif mode.lower() == 'live': + async with Aclosing(self._run_live_impl(ctx)) as agen: + async for event in agen: + yield event + else: + raise ValueError( + f"Invalid mode: {mode}. Must be either 'async' or 'live'." + ) + + if ctx.end_invocation: + return + + if event := await self._handle_after_agent_callback(ctx): + yield event + @property def root_agent(self) -> BaseAgent: """Gets the root agent of this agent.""" diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index 5abaef589f..ac3156fd19 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -55,6 +55,7 @@ from ..tools.tool_context import ToolContext from ..utils.context_utils import Aclosing from ..utils.feature_decorator import experimental +from ..utils.telemetry_utils import is_telemetry_enabled from .base_agent import BaseAgent from .base_agent import BaseAgentState from .base_agent_config import BaseAgentConfig @@ -814,6 +815,10 @@ def __maybe_save_output_to_state(self, event: Event): @model_validator(mode='after') def __model_validator_after(self) -> LlmAgent: + root_agent = getattr(self, 'root_agent', None) or self + disable_telemetry: bool = not is_telemetry_enabled(root_agent) + if hasattr(self.model, 'disable_telemetry'): + self.model.disable_telemetry = disable_telemetry return self @field_validator('generate_content_config', mode='after') diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 824cd26be1..9730adc078 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -16,6 +16,7 @@ from abc import ABC import asyncio +import contextlib import datetime import inspect import logging @@ -50,6 +51,7 @@ from ...tools.google_search_tool import google_search from ...tools.tool_context import ToolContext from ...utils.context_utils import Aclosing +from ...utils.telemetry_utils import is_telemetry_enabled from .audio_cache_manager import AudioCacheManager if TYPE_CHECKING: @@ -129,13 +131,16 @@ async def run_live( async with llm.connect(llm_request) as llm_connection: if llm_request.contents: # Sends the conversation history to the model. - with tracer.start_as_current_span('send_data'): - # Combine regular contents with audio/transcription from session + span_context = contextlib.nullcontext() + if is_telemetry_enabled(invocation_context.agent): + span_context = tracer.start_as_current_span('send_data') + with span_context as span: logger.debug('Sending history to model: %s', llm_request.contents) await llm_connection.send_history(llm_request.contents) - trace_send_data( - invocation_context, event_id, llm_request.contents - ) + if span: + trace_send_data( + invocation_context, event_id, llm_request.contents + ) send_task = asyncio.create_task( self._send_to_model(llm_connection, invocation_context) @@ -752,7 +757,10 @@ async def _call_llm_async( llm = self.__get_llm(invocation_context) async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]: - with tracer.start_as_current_span('call_llm'): + span_context = contextlib.nullcontext() + if is_telemetry_enabled(invocation_context.agent): + span_context = tracer.start_as_current_span('call_llm') + with span_context: if invocation_context.run_config.support_cfc: invocation_context.live_request_queue = LiveRequestQueue() responses_generator = self.run_live(invocation_context) diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index ffe1657be1..597e5c2236 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -42,6 +42,7 @@ from ...tools.tool_confirmation import ToolConfirmation from ...tools.tool_context import ToolContext from ...utils.context_utils import Aclosing +from ...utils.telemetry_utils import is_telemetry_enabled if TYPE_CHECKING: from ...agents.llm_agent import LlmAgent @@ -255,7 +256,7 @@ async def handle_function_call_list_async( function_response_events ) - if len(function_response_events) > 1: + if len(function_response_events) > 1 and is_telemetry_enabled(agent): # this is needed for debug traces of parallel calls # individual response with tool.name is traced in __build_response_event # (we drop tool.name from span name here as this is merged event) @@ -425,6 +426,8 @@ async def _run_with_trace(): ) return function_response_event + if not is_telemetry_enabled(agent): + return await _run_with_trace() with tracer.start_as_current_span(f'execute_tool {tool.name}'): try: function_response_event = await _run_with_trace() @@ -486,7 +489,7 @@ async def handle_function_calls_live( merged_event = merge_parallel_function_response_events( function_response_events ) - if len(function_response_events) > 1: + if len(function_response_events) > 1 and is_telemetry_enabled(agent): # this is needed for debug traces of parallel calls # individual response with tool.name is traced in __build_response_event # (we drop tool.name from span name here as this is merged event) @@ -575,6 +578,8 @@ async def _run_with_trace(): ) return function_response_event + if not is_telemetry_enabled(agent): + return await _run_with_trace() with tracer.start_as_current_span(f'execute_tool {tool.name}'): try: function_response_event = await _run_with_trace() diff --git a/src/google/adk/models/gemini_context_cache_manager.py b/src/google/adk/models/gemini_context_cache_manager.py index cd842cf494..c03a99f87c 100644 --- a/src/google/adk/models/gemini_context_cache_manager.py +++ b/src/google/adk/models/gemini_context_cache_manager.py @@ -15,7 +15,7 @@ """Manages context cache lifecycle for Gemini models.""" from __future__ import annotations - +import contextlib import hashlib import json import logging @@ -24,6 +24,7 @@ from typing import TYPE_CHECKING from google.genai import types +from opentelemetry.trace import Span from ..utils.feature_decorator import experimental from .cache_metadata import CacheMetadata @@ -45,13 +46,15 @@ class GeminiContextCacheManager: cache compatibility and implements efficient caching strategies. """ - def __init__(self, genai_client: Client): + def __init__(self, genai_client: Client, disable_telemetry: bool = False): """Initialize cache manager with shared client. Args: genai_client: The GenAI client to use for cache operations. + disable_telemetry: A bool to flag whether or not telemetry should be disabled. """ self.genai_client = genai_client + self.disable_telemetry = disable_telemetry async def handle_context_caching( self, llm_request: LlmRequest @@ -356,9 +359,13 @@ async def _create_gemini_cache( Returns: Cache metadata with precise creation timestamp """ - from ..telemetry.tracing import tracer - with tracer.start_as_current_span("create_cache") as span: + span_context = contextlib.nullcontext() + if not self.disable_telemetry: + from ..telemetry.tracing import tracer + span_context = tracer.start_as_current_span("create_cache") + + with span_context as span: # Prepare cache contents (first N contents + system instruction + tools) cache_contents = llm_request.contents[:cache_contents_count] @@ -386,9 +393,10 @@ async def _create_gemini_cache( if llm_request.config and llm_request.config.tool_config: cache_config.tool_config = llm_request.config.tool_config - span.set_attribute("cache_contents_count", cache_contents_count) - span.set_attribute("model", llm_request.model) - span.set_attribute("ttl_seconds", llm_request.cache_config.ttl_seconds) + if span is not None: + span.set_attribute("cache_contents_count", cache_contents_count) + span.set_attribute("model", llm_request.model) + span.set_attribute("ttl_seconds", llm_request.cache_config.ttl_seconds) logger.debug( "Creating cache with model %s and config: %s", @@ -403,7 +411,8 @@ async def _create_gemini_cache( created_at = time.time() logger.info("Cache created successfully: %s", cached_content.name) - span.set_attribute("cache_name", cached_content.name) + if span is not None: + span.set_attribute("cache_name", cached_content.name) # Return complete cache metadata with precise timing return CacheMetadata( diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index 9261fada39..4b090dc98e 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -127,6 +127,11 @@ class Gemini(BaseLlm): ``` """ + disable_telemetry: bool = False + """A bool to flag whether or not telemetry should be being disabled for + Gemini LLM interactions. + """ + @classmethod @override def supported_models(cls) -> list[str]: @@ -165,18 +170,29 @@ async def generate_content_async( cache_metadata = None cache_manager = None if llm_request.cache_config: - from ..telemetry.tracing import tracer from .gemini_context_cache_manager import GeminiContextCacheManager - with tracer.start_as_current_span('handle_context_caching') as span: - cache_manager = GeminiContextCacheManager(self.api_client) + if not self.disable_telemetry: + from ..telemetry.tracing import tracer + + with tracer.start_as_current_span('handle_context_caching') as span: + cache_manager = GeminiContextCacheManager( + self.api_client, disable_telemetry=self.disable_telemetry + ) + cache_metadata = await cache_manager.handle_context_caching( + llm_request + ) + if cache_metadata: + if cache_metadata.cache_name: + span.set_attribute('cache_action', 'active_cache') + span.set_attribute('cache_name', cache_metadata.cache_name) + else: + span.set_attribute('cache_action', 'fingerprint_only') + else: + cache_manager = GeminiContextCacheManager( + self.api_client, disable_telemetry=self.disable_telemetry + ) cache_metadata = await cache_manager.handle_context_caching(llm_request) - if cache_metadata: - if cache_metadata.cache_name: - span.set_attribute('cache_action', 'active_cache') - span.set_attribute('cache_name', cache_metadata.cache_name) - else: - span.set_attribute('cache_action', 'fingerprint_only') logger.info( 'Sending out request, model: %s, backend: %s, stream: %s', diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 1773729719..b0897eb358 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -63,6 +63,7 @@ from .tools.base_toolset import BaseToolset from .utils._debug_output import print_event from .utils.context_utils import Aclosing +from .utils.telemetry_utils import is_telemetry_enabled logger = logging.getLogger('google_adk.' + __name__) @@ -430,79 +431,97 @@ async def run_async( if new_message and not new_message.role: new_message.role = 'user' - async def _run_with_trace( + async def _run_body( new_message: Optional[types.Content] = None, invocation_id: Optional[str] = None, ) -> AsyncGenerator[Event, None]: - with tracer.start_as_current_span('invocation'): - session = await self.session_service.get_session( - app_name=self.app_name, user_id=user_id, session_id=session_id + session = await self.session_service.get_session( + app_name=self.app_name, user_id=user_id, session_id=session_id + ) + if not session: + message = self._format_session_not_found_message(session_id) + raise ValueError(message) + if not invocation_id and not new_message: + raise ValueError( + 'Running an agent requires either a new_message or an ' + 'invocation_id to resume a previous invocation. ' + f'Session: {session_id}, User: {user_id}' ) - if not session: - message = self._format_session_not_found_message(session_id) - raise ValueError(message) - if not invocation_id and not new_message: + + if invocation_id: + if ( + not self.resumability_config + or not self.resumability_config.is_resumable + ): raise ValueError( - 'Running an agent requires either a new_message or an ' - 'invocation_id to resume a previous invocation. ' - f'Session: {session_id}, User: {user_id}' + f'invocation_id: {invocation_id} is provided but the app is not' + ' resumable.' ) + invocation_context = await self._setup_context_for_resumed_invocation( + session=session, + new_message=new_message, + invocation_id=invocation_id, + run_config=run_config, + state_delta=state_delta, + ) + if invocation_context.end_of_agents.get(invocation_context.agent.name): + # Directly return if the current agent in invocation context is + # already final. + return + else: + invocation_context = await self._setup_context_for_new_invocation( + session=session, + new_message=new_message, # new_message is not None. + run_config=run_config, + state_delta=state_delta, + ) - if invocation_id: - if ( - not self.resumability_config - or not self.resumability_config.is_resumable - ): - raise ValueError( - f'invocation_id: {invocation_id} is provided but the app is not' - ' resumable.' - ) - invocation_context = await self._setup_context_for_resumed_invocation( - session=session, - new_message=new_message, - invocation_id=invocation_id, - run_config=run_config, - state_delta=state_delta, - ) - if invocation_context.end_of_agents.get( - invocation_context.agent.name - ): - # Directly return if the current agent in invocation context is - # already final. - return - else: - invocation_context = await self._setup_context_for_new_invocation( + async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]: + async with Aclosing(ctx.agent.run_async(ctx)) as agen: + async for event in agen: + yield event + + async with Aclosing( + self._exec_with_plugin( + invocation_context=invocation_context, session=session, - new_message=new_message, # new_message is not None. - run_config=run_config, - state_delta=state_delta, + execute_fn=execute, + is_live_call=False, ) + ) as agen: + async for event in agen: + yield event + # Run compaction after all events are yielded from the agent. + # (We don't compact in the middle of an invocation, we only compact at + # the end of an invocation.) + if self.app and self.app.events_compaction_config: + logger.debug('Running event compactor.') + await _run_compaction_for_sliding_window( + self.app, session, self.session_service + ) - async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]: - async with Aclosing(ctx.agent.run_async(ctx)) as agen: - async for event in agen: - yield event - + async def _run_with_optional_trace( + agent: BaseAgent, + new_message: Optional[types.Content] = None, + invocation_id: Optional[str] = None, + ) -> AsyncGenerator[Event, None]: + if is_telemetry_enabled(agent): + with tracer.start_as_current_span('invocation'): + async with Aclosing( + _run_body(new_message=new_message, invocation_id=invocation_id) + ) as agen: + async for e in agen: + yield e + else: async with Aclosing( - self._exec_with_plugin( - invocation_context=invocation_context, - session=session, - execute_fn=execute, - is_live_call=False, - ) + _run_body(new_message=new_message, invocation_id=invocation_id) ) as agen: - async for event in agen: - yield event - # Run compaction after all events are yielded from the agent. - # (We don't compact in the middle of an invocation, we only compact at - # the end of an invocation.) - if self.app and self.app.events_compaction_config: - logger.debug('Running event compactor.') - await _run_compaction_for_sliding_window( - self.app, session, self.session_service - ) + async for e in agen: + yield e - async with Aclosing(_run_with_trace(new_message, invocation_id)) as agen: + async with Aclosing( + _run_with_optional_trace(self.agent, new_message, invocation_id) + ) as agen: async for event in agen: yield event diff --git a/src/google/adk/utils/telemetry_utils.py b/src/google/adk/utils/telemetry_utils.py new file mode 100644 index 0000000000..fb5b3e8594 --- /dev/null +++ b/src/google/adk/utils/telemetry_utils.py @@ -0,0 +1,64 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for telemetry. + +This module is for ADK internal use only. +Please do not rely on the implementation details. +""" + +from typing import TYPE_CHECKING + +from .env_utils import is_env_enabled + +if TYPE_CHECKING: + from ..agents.base_agent import BaseAgent + + +def is_telemetry_enabled(agent: "BaseAgent") -> bool: + """Check if telemetry is enabled for the given agent. + + By default telemetry is enabled for an agent unless any of the variables to disable telemetry are set to true. + + Args: + agent: The agent to check if telemetry is enabled for. + + Returns: + False if any of the environment variables or attributes to disable telemetryare set to True, 'true' or 1, False otherwise. + + Examples: + >>> os.environ['OTEL_SDK_DISABLED'] = 'true' + >>> is_telemetry_enabled(my_agent) + True + + >>> os.environ['ADK_TELEMETRY_DISABLED'] = 1 + >>> is_telemetry_enabled(my_agent) + True + + >>> my_agent.disable_telemetry = True + >>> is_telemetry_enabled(my_agent) + True + + >>> os.environ['OTEL_SDK_DISABLED'] = 1 + >>> os.environ['ADK_TELEMETRY_DISABLED'] = 'false' + >>> my_agent.disable_telemetry = False + >>> is_telemetry_enabled(my_agent) + False + """ + telemetry_disabled = ( + is_env_enabled("OTEL_SDK_DISABLED") + or is_env_enabled("ADK_TELEMETRY_DISABLED") + or getattr(agent, "disable_telemetry", False) + ) + return not telemetry_disabled diff --git a/tests/integration/telemetry/__init__.py b/tests/integration/telemetry/__init__.py new file mode 100644 index 0000000000..0a2669d7a2 --- /dev/null +++ b/tests/integration/telemetry/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/integration/telemetry/test_telemetry_disable.py b/tests/integration/telemetry/test_telemetry_disable.py new file mode 100644 index 0000000000..d4ed5704dc --- /dev/null +++ b/tests/integration/telemetry/test_telemetry_disable.py @@ -0,0 +1,122 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.adk.agents.llm_agent import Agent +from google.adk.telemetry import tracing +from google.adk.utils.context_utils import Aclosing +from google.genai.types import Part +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter +import pytest + +from tests.unittests.testing_utils import MockModel +from tests.unittests.testing_utils import TestInMemoryRunner + + +@pytest.fixture +def span_exporter(monkeypatch: pytest.MonkeyPatch) -> InMemorySpanExporter: + tracer_provider = TracerProvider() + exporter = InMemorySpanExporter() + tracer_provider.add_span_processor(SimpleSpanProcessor(exporter)) + real_tracer = tracer_provider.get_tracer(__name__) + + monkeypatch.setattr( + tracing.tracer, + "start_as_current_span", + real_tracer.start_as_current_span, + ) + return exporter + + +@pytest.mark.asyncio +async def test_telemetry_enabled_records_spans(monkeypatch, span_exporter): + monkeypatch.delenv("OTEL_SDK_DISABLED", raising=False) + monkeypatch.delenv("ADK_TELEMETRY_DISABLED", raising=False) + + agent = Agent( + name="test_agent", + model=MockModel.create(responses=[Part.from_text(text="ok")]), + disable_telemetry=False, + ) + runner = TestInMemoryRunner(agent) + + async with Aclosing(runner.run_async_with_new_session_agen("")) as agen: + async for _ in agen: + pass + + spans = span_exporter.get_finished_spans() + assert spans + + +@pytest.mark.asyncio +async def test_adk_telemetry_disabled_env_var_disables( + monkeypatch, span_exporter +): + monkeypatch.setenv("ADK_TELEMETRY_DISABLED", "true") + monkeypatch.delenv("OTEL_SDK_DISABLED", raising=False) + + agent = Agent( + name="test_agent", + model=MockModel.create(responses=[Part.from_text(text="ok")]), + disable_telemetry=False, + ) + runner = TestInMemoryRunner(agent) + + async with Aclosing(runner.run_async_with_new_session_agen("")) as agen: + async for _ in agen: + pass + + spans = span_exporter.get_finished_spans() + assert not spans + + +@pytest.mark.asyncio +async def test_otel_sdk_env_var_disables_telemetry(monkeypatch, span_exporter): + monkeypatch.setenv("OTEL_SDK_DISABLED", "true") + monkeypatch.delenv("ADK_TELEMETRY_DISABLED", raising=False) + + agent = Agent( + name="test_agent", + model=MockModel.create(responses=[Part.from_text(text="ok")]), + disable_telemetry=False, + ) + runner = TestInMemoryRunner(agent) + + async with Aclosing(runner.run_async_with_new_session_agen("")) as agen: + async for _ in agen: + pass + + spans = span_exporter.get_finished_spans() + assert not spans + + +@pytest.mark.asyncio +async def test_agent_flag_disables_telemetry(monkeypatch, span_exporter): + monkeypatch.delenv("OTEL_SDK_DISABLED", raising=False) + monkeypatch.delenv("ADK_TELEMETRY_DISABLED", raising=False) + + agent = Agent( + name="test_agent", + model=MockModel.create(responses=[Part.from_text(text="ok")]), + disable_telemetry=True, + ) + runner = TestInMemoryRunner(agent) + + async with Aclosing(runner.run_async_with_new_session_agen("")) as agen: + async for _ in agen: + pass + + spans = span_exporter.get_finished_spans() + assert not spans diff --git a/tests/unittests/conftest.py b/tests/unittests/conftest.py index 59b66bd622..3d46b8a76c 100644 --- a/tests/unittests/conftest.py +++ b/tests/unittests/conftest.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import os +from unittest import mock from pytest import fixture from pytest import FixtureRequest @@ -101,3 +103,14 @@ def _is_explicitly_marked(mark_name: str, metafunc: Metafunc) -> bool: if mark.name == 'parametrize' and mark.args[0] == mark_name: return True return False + + +@fixture +def context_manager_with_span(span=None): + span = span or mock.MagicMock() + + @contextlib.contextmanager + def _cm(): + yield span + + return _cm() diff --git a/tests/unittests/telemetry/test_telemetry_disable_agent.py b/tests/unittests/telemetry/test_telemetry_disable_agent.py new file mode 100644 index 0000000000..5e8e87b89f --- /dev/null +++ b/tests/unittests/telemetry/test_telemetry_disable_agent.py @@ -0,0 +1,119 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +from google.adk.agents.llm_agent import Agent +from google.adk.telemetry import tracing +from google.adk.utils.context_utils import Aclosing +from google.genai.types import Part +import pytest + +from ..testing_utils import MockModel +from ..testing_utils import TestInMemoryRunner + + +@pytest.mark.asyncio +async def test_disable_telemetry_prevents_span_creation(monkeypatch): + monkeypatch.delenv("OTEL_SDK_DISABLED", raising=False) + monkeypatch.delenv("ADK_TELEMETRY_DISABLED", raising=False) + span = mock.MagicMock() + context_manager = mock.MagicMock() + context_manager.__enter__.return_value = span + context_manager.__exit__.return_value = False + + mock_start = mock.Mock(return_value=context_manager) + monkeypatch.setattr(tracing.tracer, "start_as_current_span", mock_start) + + agent = Agent( + name="agent", + model=MockModel.create(responses=[Part.from_text(text="ok")]), + disable_telemetry=True, + ) + + runner = TestInMemoryRunner(agent) + + async with Aclosing(runner.run_async_with_new_session_agen("")) as agen: + async for _ in agen: + pass + + assert mock_start.call_count == 0 + + +@pytest.mark.asyncio +async def test_enabled_telemetry_causes_span_creation(monkeypatch): + monkeypatch.setenv("OTEL_SDK_DISABLED", "false") + monkeypatch.setenv("ADK_TELEMETRY_DISABLED", "false") + span = mock.MagicMock() + context_manager = mock.MagicMock() + context_manager.__enter__.return_value = span + context_manager.__exit__.return_value = False + + mock_start = mock.Mock(return_value=context_manager) + monkeypatch.setattr(tracing.tracer, "start_as_current_span", mock_start) + + agent = Agent( + name="agent", + model=MockModel.create(responses=[Part.from_text(text="ok")]), + disable_telemetry=False, + ) + + runner = TestInMemoryRunner(agent) + + async with Aclosing(runner.run_async_with_new_session_agen("")) as agen: + async for _ in agen: + pass + + assert mock_start.call_count > 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "env_var,env_value", + [ + ("OTEL_SDK_DISABLED", "true"), + ("OTEL_SDK_DISABLED", "1"), + ("ADK_TELEMETRY_DISABLED", "true"), + ("ADK_TELEMETRY_DISABLED", "1"), + ], +) +async def test_env_flag_disables_telemetry(monkeypatch, env_var, env_value): + monkeypatch.setenv(env_var, env_value) + monkeypatch.delenv( + "ADK_TELEMETRY_DISABLED" + if env_var == "OTEL_SDK_DISABLED" + else "OTEL_SDK_DISABLED", + raising=False, + ) + span = mock.MagicMock() + context_manager = mock.MagicMock() + context_manager.__enter__.return_value = span + context_manager.__exit__.return_value = False + + mock_start = mock.Mock(return_value=context_manager) + monkeypatch.setattr(tracing.tracer, "start_as_current_span", mock_start) + + agent = Agent( + name="agent", + model=MockModel.create(responses=[Part.from_text(text="ok")]), + disable_telemetry=False, + ) + + runner = TestInMemoryRunner(agent) + + async with Aclosing(runner.run_async_with_new_session_agen("")) as agen: + async for _ in agen: + pass + + assert mock_start.call_count == 0 diff --git a/tests/unittests/telemetry/test_telemetry_disable_google_llm.py b/tests/unittests/telemetry/test_telemetry_disable_google_llm.py new file mode 100644 index 0000000000..f63cbdaaf4 --- /dev/null +++ b/tests/unittests/telemetry/test_telemetry_disable_google_llm.py @@ -0,0 +1,122 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +from google.adk.models import gemini_context_cache_manager as cache_mod +from google.adk.models import llm_response as llm_response_mod +from google.adk.models.google_llm import Gemini +import pytest + + +@pytest.mark.asyncio +async def test_disable_google_llm_telemetry( + monkeypatch, context_manager_with_span +): + monkeypatch.setenv("OTEL_SDK_DISABLED", "false") + monkeypatch.setenv("ADK_TELEMETRY_DISABLED", "false") + start_span = mock.Mock(return_value=context_manager_with_span) + monkeypatch.setattr( + "google.adk.telemetry.tracing.tracer.start_as_current_span", + start_span, + ) + + gemini = Gemini(disable_telemetry=True) + + # Avoid real Client construction + fake_client = mock.MagicMock() + fake_client.vertexai = False + fake_client.aio.models.generate_content = mock.AsyncMock( + return_value=mock.MagicMock() + ) + gemini.__dict__["api_client"] = fake_client + + # Prevent cache validation code running (the bit that touches expire_time) + monkeypatch.setattr( + cache_mod.GeminiContextCacheManager, + "handle_context_caching", + mock.AsyncMock(return_value=None), + ) + + req = mock.MagicMock() + req.cache_config = object() # force the cache path + req.model = "gemini-2.5-flash" + req.contents = [] + req.config = mock.MagicMock() + req.config.tools = None + req.config.system_instruction = "" + req.config.model_dump = mock.Mock(return_value={}) + req.config.http_options = None + + monkeypatch.setattr( + llm_response_mod.LlmResponse, + "create", + mock.Mock(return_value=mock.MagicMock()), + ) + + async for _ in gemini.generate_content_async(req, stream=False): + break + + assert start_span.call_count == 0 + + +@pytest.mark.asyncio +async def test_enable_google_llm_telemetry( + monkeypatch, context_manager_with_span +): + monkeypatch.setenv("OTEL_SDK_DISABLED", "false") + monkeypatch.setenv("ADK_TELEMETRY_DISABLED", "false") + start_span = mock.Mock(return_value=context_manager_with_span) + monkeypatch.setattr( + "google.adk.telemetry.tracing.tracer.start_as_current_span", + start_span, + ) + + gemini = Gemini(disable_telemetry=False) + + # Avoid real Client construction + fake_client = mock.MagicMock() + fake_client.vertexai = False + fake_client.aio.models.generate_content = mock.AsyncMock( + return_value=mock.MagicMock() + ) + gemini.__dict__["api_client"] = fake_client + + # Prevent cache validation code running (the bit that touches expire_time) + monkeypatch.setattr( + cache_mod.GeminiContextCacheManager, + "handle_context_caching", + mock.AsyncMock(return_value=None), + ) + + req = mock.MagicMock() + req.cache_config = object() # force the cache path + req.model = "gemini-2.5-flash" + req.contents = [] + req.config = mock.MagicMock() + req.config.tools = None + req.config.system_instruction = "" + req.config.model_dump = mock.Mock(return_value={}) + req.config.http_options = None + + monkeypatch.setattr( + llm_response_mod.LlmResponse, + "create", + mock.Mock(return_value=mock.MagicMock()), + ) + + async for _ in gemini.generate_content_async(req, stream=False): + break + + assert start_span.call_count > 0