diff --git a/posthog/ai/gemini/__init__.py b/posthog/ai/gemini/__init__.py index eb17989d..fa62e073 100644 --- a/posthog/ai/gemini/__init__.py +++ b/posthog/ai/gemini/__init__.py @@ -1,4 +1,5 @@ from .gemini import Client +from .gemini_async import AsyncClient from .gemini_converter import ( format_gemini_input, format_gemini_response, @@ -15,6 +16,7 @@ class _GenAI: __all__ = [ "Client", + "AsyncClient", "genai", "format_gemini_input", "format_gemini_response", diff --git a/posthog/ai/gemini/gemini_async.py b/posthog/ai/gemini/gemini_async.py new file mode 100644 index 00000000..783f7dca --- /dev/null +++ b/posthog/ai/gemini/gemini_async.py @@ -0,0 +1,491 @@ +import os +import time +import uuid +from typing import Any, Dict, Optional + +from posthog.ai.types import TokenUsage + +try: + from google import genai +except ImportError: + raise ModuleNotFoundError( + "Please install the Google Gemini SDK to use this feature: 'pip install google-genai'" + ) + +from posthog import setup +from posthog.ai.utils import ( + call_llm_and_track_usage_async, + capture_streaming_event, + merge_usage_stats, +) +from posthog.ai.gemini.gemini_converter import ( + format_gemini_input, + extract_gemini_usage_from_chunk, + extract_gemini_content_from_chunk, + format_gemini_streaming_output, +) +from posthog.ai.sanitization import sanitize_gemini +from posthog.client import Client as PostHogClient + + +class AsyncClient: + """ + An async drop-in replacement for genai.Client that automatically sends LLM usage events to PostHog. + + Usage: + client = AsyncClient( + api_key="your_api_key", + posthog_client=posthog_client, + posthog_distinct_id="default_user", # Optional defaults + posthog_properties={"team": "ai"} # Optional defaults + ) + response = await client.aio.models.generate_content( + model="gemini-2.0-flash", + contents=["Hello world"], + posthog_distinct_id="specific_user" # Override default + ) + """ + + _ph_client: PostHogClient + + def __init__( + self, + api_key: Optional[str] = None, + vertexai: Optional[bool] = None, + credentials: Optional[Any] = None, + project: Optional[str] = None, + location: Optional[str] = None, + debug_config: Optional[Any] = None, + http_options: Optional[Any] = None, + posthog_client: Optional[PostHogClient] = None, + posthog_distinct_id: Optional[str] = None, + posthog_properties: Optional[Dict[str, Any]] = None, + posthog_privacy_mode: bool = False, + posthog_groups: Optional[Dict[str, Any]] = None, + **kwargs, + ): + """ + Args: + api_key: Google AI API key. If not provided, will use GOOGLE_API_KEY or API_KEY environment variable (not required for Vertex AI) + vertexai: Whether to use Vertex AI authentication + credentials: Vertex AI credentials object + project: GCP project ID for Vertex AI + location: GCP location for Vertex AI + debug_config: Debug configuration for the client + http_options: HTTP options for the client + posthog_client: PostHog client for tracking usage + posthog_distinct_id: Default distinct ID for all calls (can be overridden per call) + posthog_properties: Default properties for all calls (can be overridden per call) + posthog_privacy_mode: Default privacy mode for all calls (can be overridden per call) + posthog_groups: Default groups for all calls (can be overridden per call) + **kwargs: Additional arguments (for future compatibility) + """ + + self._ph_client = posthog_client or setup() + + if self._ph_client is None: + raise ValueError("posthog_client is required for PostHog tracking") + + self.aio = AsyncAio( + api_key=api_key, + vertexai=vertexai, + credentials=credentials, + project=project, + location=location, + debug_config=debug_config, + http_options=http_options, + posthog_client=self._ph_client, + posthog_distinct_id=posthog_distinct_id, + posthog_properties=posthog_properties, + posthog_privacy_mode=posthog_privacy_mode, + posthog_groups=posthog_groups, + **kwargs, + ) + + +class AsyncAio: + """ + Async interface that mimics genai.Client().aio with PostHog tracking. + """ + + _ph_client: PostHogClient # Not None after __init__ validation + + def __init__( + self, + api_key: Optional[str] = None, + vertexai: Optional[bool] = None, + credentials: Optional[Any] = None, + project: Optional[str] = None, + location: Optional[str] = None, + debug_config: Optional[Any] = None, + http_options: Optional[Any] = None, + posthog_client: Optional[PostHogClient] = None, + posthog_distinct_id: Optional[str] = None, + posthog_properties: Optional[Dict[str, Any]] = None, + posthog_privacy_mode: bool = False, + posthog_groups: Optional[Dict[str, Any]] = None, + **kwargs, + ): + """ + Args: + api_key: Google AI API key. If not provided, will use GOOGLE_API_KEY or API_KEY environment variable (not required for Vertex AI) + vertexai: Whether to use Vertex AI authentication + credentials: Vertex AI credentials object + project: GCP project ID for Vertex AI + location: GCP location for Vertex AI + debug_config: Debug configuration for the client + http_options: HTTP options for the client + posthog_client: PostHog client for tracking usage + posthog_distinct_id: Default distinct ID for all calls + posthog_properties: Default properties for all calls + posthog_privacy_mode: Default privacy mode for all calls + posthog_groups: Default groups for all calls + **kwargs: Additional arguments (for future compatibility) + """ + + self._ph_client = posthog_client or setup() + + if self._ph_client is None: + raise ValueError("posthog_client is required for PostHog tracking") + + self.models = AsyncModels( + api_key=api_key, + vertexai=vertexai, + credentials=credentials, + project=project, + location=location, + debug_config=debug_config, + http_options=http_options, + posthog_client=self._ph_client, + posthog_distinct_id=posthog_distinct_id, + posthog_properties=posthog_properties, + posthog_privacy_mode=posthog_privacy_mode, + posthog_groups=posthog_groups, + **kwargs, + ) + + +class AsyncModels: + """ + Async Models interface that mimics genai.Client().aio.models with PostHog tracking. + """ + + _ph_client: PostHogClient # Not None after __init__ validation + + def __init__( + self, + api_key: Optional[str] = None, + vertexai: Optional[bool] = None, + credentials: Optional[Any] = None, + project: Optional[str] = None, + location: Optional[str] = None, + debug_config: Optional[Any] = None, + http_options: Optional[Any] = None, + posthog_client: Optional[PostHogClient] = None, + posthog_distinct_id: Optional[str] = None, + posthog_properties: Optional[Dict[str, Any]] = None, + posthog_privacy_mode: bool = False, + posthog_groups: Optional[Dict[str, Any]] = None, + **kwargs, + ): + """ + Args: + api_key: Google AI API key. If not provided, will use GOOGLE_API_KEY or API_KEY environment variable (not required for Vertex AI) + vertexai: Whether to use Vertex AI authentication + credentials: Vertex AI credentials object + project: GCP project ID for Vertex AI + location: GCP location for Vertex AI + debug_config: Debug configuration for the client + http_options: HTTP options for the client + posthog_client: PostHog client for tracking usage + posthog_distinct_id: Default distinct ID for all calls + posthog_properties: Default properties for all calls + posthog_privacy_mode: Default privacy mode for all calls + posthog_groups: Default groups for all calls + **kwargs: Additional arguments (for future compatibility) + """ + + self._ph_client = posthog_client or setup() + + if self._ph_client is None: + raise ValueError("posthog_client is required for PostHog tracking") + + # Store default PostHog settings + self._default_distinct_id = posthog_distinct_id + self._default_properties = posthog_properties or {} + self._default_privacy_mode = posthog_privacy_mode + self._default_groups = posthog_groups + + # Build genai.Client arguments + client_args: Dict[str, Any] = {} + + # Add Vertex AI parameters if provided + if vertexai is not None: + client_args["vertexai"] = vertexai + + if credentials is not None: + client_args["credentials"] = credentials + + if project is not None: + client_args["project"] = project + + if location is not None: + client_args["location"] = location + + if debug_config is not None: + client_args["debug_config"] = debug_config + + if http_options is not None: + client_args["http_options"] = http_options + + # Handle API key authentication + if vertexai: + # For Vertex AI, api_key is optional + if api_key is not None: + client_args["api_key"] = api_key + else: + # For non-Vertex AI mode, api_key is required (backwards compatibility) + if api_key is None: + api_key = os.environ.get("GOOGLE_API_KEY") or os.environ.get("API_KEY") + + if api_key is None: + raise ValueError( + "API key must be provided either as parameter or via GOOGLE_API_KEY/API_KEY environment variable" + ) + + client_args["api_key"] = api_key + + self._client = genai.Client(**client_args) + self._base_url = "https://generativelanguage.googleapis.com" + + def _merge_posthog_params( + self, + call_distinct_id: Optional[str], + call_trace_id: Optional[str], + call_properties: Optional[Dict[str, Any]], + call_privacy_mode: Optional[bool], + call_groups: Optional[Dict[str, Any]], + ): + """Merge call-level PostHog parameters with client defaults.""" + + # Use call-level values if provided, otherwise fall back to defaults + distinct_id = ( + call_distinct_id + if call_distinct_id is not None + else self._default_distinct_id + ) + privacy_mode = ( + call_privacy_mode + if call_privacy_mode is not None + else self._default_privacy_mode + ) + groups = call_groups if call_groups is not None else self._default_groups + + # Merge properties: default properties + call properties (call properties override) + properties = dict(self._default_properties) + + if call_properties: + properties.update(call_properties) + + if call_trace_id is None: + call_trace_id = str(uuid.uuid4()) + + return distinct_id, call_trace_id, properties, privacy_mode, groups + + async def generate_content( + self, + model: str, + contents, + posthog_distinct_id: Optional[str] = None, + posthog_trace_id: Optional[str] = None, + posthog_properties: Optional[Dict[str, Any]] = None, + posthog_privacy_mode: Optional[bool] = None, + posthog_groups: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ): + """ + Generate content using Gemini's async API while tracking usage in PostHog. + + This method signature exactly matches genai.Client().aio.models.generate_content() + with additional PostHog tracking parameters. + + Args: + model: The model to use (e.g., 'gemini-2.0-flash') + contents: The input content for generation + posthog_distinct_id: ID to associate with the usage event (overrides client default) + posthog_trace_id: Trace UUID for linking events (auto-generated if not provided) + posthog_properties: Extra properties to include in the event (merged with client defaults) + posthog_privacy_mode: Whether to redact sensitive information (overrides client default) + posthog_groups: Group analytics properties (overrides client default) + **kwargs: Arguments passed to Gemini's generate_content + """ + + # Merge PostHog parameters + distinct_id, trace_id, properties, privacy_mode, groups = ( + self._merge_posthog_params( + posthog_distinct_id, + posthog_trace_id, + posthog_properties, + posthog_privacy_mode, + posthog_groups, + ) + ) + + kwargs_with_contents = {"model": model, "contents": contents, **kwargs} + + return await call_llm_and_track_usage_async( + distinct_id, + self._ph_client, + "gemini", + trace_id, + properties, + privacy_mode, + groups, + self._base_url, + self._client.aio.models.generate_content, + **kwargs_with_contents, + ) + + def _generate_content_streaming( + self, + model: str, + contents, + distinct_id: Optional[str], + trace_id: Optional[str], + properties: Optional[Dict[str, Any]], + privacy_mode: bool, + groups: Optional[Dict[str, Any]], + **kwargs: Any, + ): + """ + Factory function that returns an async generator for streaming content. + + Note: This method is intentionally NOT async - it returns an async generator + that callers can iterate with `async for`. + """ + start_time = time.time() + usage_stats: TokenUsage = TokenUsage(input_tokens=0, output_tokens=0) + accumulated_content = [] + + async def async_generator(): + nonlocal usage_stats + nonlocal accumulated_content # noqa: F824 + + kwargs_without_stream = {"model": model, "contents": contents, **kwargs} + response = await self._client.aio.models.generate_content_stream( + **kwargs_without_stream + ) + + try: + async for chunk in response: + # Extract usage stats from chunk + chunk_usage = extract_gemini_usage_from_chunk(chunk) + + if chunk_usage: + # Gemini reports cumulative totals, not incremental values + merge_usage_stats(usage_stats, chunk_usage, mode="cumulative") + + # Extract content from chunk (now returns content blocks) + content_block = extract_gemini_content_from_chunk(chunk) + + if content_block is not None: + accumulated_content.append(content_block) + + yield chunk + + finally: + end_time = time.time() + latency = end_time - start_time + + await self._capture_streaming_event( + model, + contents, + distinct_id, + trace_id, + properties, + privacy_mode, + groups, + kwargs, + usage_stats, + latency, + accumulated_content, + ) + + return async_generator() + + async def _capture_streaming_event( + self, + model: str, + contents, + distinct_id: Optional[str], + trace_id: Optional[str], + properties: Optional[Dict[str, Any]], + privacy_mode: bool, + groups: Optional[Dict[str, Any]], + kwargs: Dict[str, Any], + usage_stats: TokenUsage, + latency: float, + output: Any, + ): + from posthog.ai.types import StreamingEventData + + # Prepare standardized event data + formatted_input = self._format_input(contents) + sanitized_input = sanitize_gemini(formatted_input) + + event_data = StreamingEventData( + provider="gemini", + model=model, + base_url=self._base_url, + kwargs=kwargs, + formatted_input=sanitized_input, + formatted_output=format_gemini_streaming_output(output), + usage_stats=usage_stats, + latency=latency, + distinct_id=distinct_id, + trace_id=trace_id, + properties=properties, + privacy_mode=privacy_mode, + groups=groups, + ) + + # Use the common capture function + capture_streaming_event(self._ph_client, event_data) + + def _format_input(self, contents): + """Format input contents for PostHog tracking""" + + return format_gemini_input(contents) + + def generate_content_stream( + self, + model: str, + contents, + posthog_distinct_id: Optional[str] = None, + posthog_trace_id: Optional[str] = None, + posthog_properties: Optional[Dict[str, Any]] = None, + posthog_privacy_mode: Optional[bool] = None, + posthog_groups: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ): + # Merge PostHog parameters + distinct_id, trace_id, properties, privacy_mode, groups = ( + self._merge_posthog_params( + posthog_distinct_id, + posthog_trace_id, + posthog_properties, + posthog_privacy_mode, + posthog_groups, + ) + ) + + return self._generate_content_streaming( + model, + contents, + distinct_id, + trace_id, + properties, + privacy_mode, + groups, + **kwargs, + ) diff --git a/posthog/test/ai/gemini/test_gemini_async.py b/posthog/test/ai/gemini/test_gemini_async.py new file mode 100644 index 00000000..3e652cb1 --- /dev/null +++ b/posthog/test/ai/gemini/test_gemini_async.py @@ -0,0 +1,383 @@ +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from posthog.ai.gemini.gemini_async import AsyncClient, AsyncAio +from posthog.client import Client as PostHogClient + + +@pytest.fixture +def mock_posthog_client(): + """Mock PostHog client for testing.""" + client = MagicMock(spec=PostHogClient) + client.capture = MagicMock() + client.privacy_mode = False # Add privacy_mode attribute + return client + + +@pytest.fixture +def mock_genai_client(): + """Mock the underlying genai.Client.""" + with patch("posthog.ai.gemini.gemini_async.genai.Client") as mock_client_class: + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + # Mock the aio.models interface + mock_client.aio = MagicMock() + mock_client.aio.models = MagicMock() + mock_client.aio.models.generate_content = AsyncMock() + mock_client.aio.models.generate_content_stream = AsyncMock() + + yield mock_client + + +@pytest.fixture +def async_client(mock_posthog_client, mock_genai_client): + """Create an AsyncClient instance for testing.""" + return AsyncClient( + api_key="test-api-key", + posthog_client=mock_posthog_client, + posthog_distinct_id="test-user", + posthog_properties={"test": "property"}, + ) + + +@pytest.fixture +def mock_gemini_functions(): + """Mock all Gemini-related functions for streaming tests.""" + with ( + patch("posthog.ai.gemini.gemini_async.capture_streaming_event") as mock_capture, + patch("posthog.ai.gemini.gemini_async.sanitize_gemini") as mock_sanitize, + patch( + "posthog.ai.gemini.gemini_async.format_gemini_input" + ) as mock_format_input, + patch( + "posthog.ai.gemini.gemini_async.format_gemini_streaming_output" + ) as mock_format_output, + ): + mock_format_input.return_value = "formatted input" + mock_sanitize.return_value = "sanitized input" + mock_format_output.return_value = "formatted output" + + yield { + "capture": mock_capture, + "sanitize": mock_sanitize, + "format_input": mock_format_input, + "format_output": mock_format_output, + } + + +class TestAsyncClient: + """Test the AsyncClient class.""" + + def test_init_with_api_key(self, mock_posthog_client): + """Test AsyncClient initialization with API key.""" + with patch("posthog.ai.gemini.gemini_async.genai.Client") as mock_client_class: + client = AsyncClient(api_key="test-key", posthog_client=mock_posthog_client) + + assert client._ph_client == mock_posthog_client + assert isinstance(client.aio, AsyncAio) + mock_client_class.assert_called_once() + + def test_init_without_api_key_raises_error(self, mock_posthog_client): + """Test that AsyncClient raises error when no API key is provided.""" + with patch.dict("os.environ", {}, clear=True): + with pytest.raises(ValueError, match="API key must be provided"): + AsyncClient(posthog_client=mock_posthog_client) + + def test_init_with_vertex_ai(self, mock_posthog_client): + """Test AsyncClient initialization with Vertex AI.""" + with patch("posthog.ai.gemini.gemini_async.genai.Client") as mock_client_class: + client = AsyncClient( + vertexai=True, + project="test-project", + location="us-central1", + posthog_client=mock_posthog_client, + ) + + assert client._ph_client == mock_posthog_client + mock_client_class.assert_called_once_with( + vertexai=True, project="test-project", location="us-central1" + ) + + def test_init_without_posthog_client_raises_error(self): + """Test that AsyncClient raises error when PostHog client is None.""" + with patch("posthog.ai.gemini.gemini_async.setup", return_value=None): + with pytest.raises(ValueError, match="posthog_client is required"): + AsyncClient(api_key="test-key") + + +class TestAsyncModels: + """Test the AsyncModels class.""" + + @pytest.mark.asyncio + async def test_generate_content_basic(self, async_client, mock_genai_client): + """Test basic async content generation.""" + # Mock response + mock_response = MagicMock() + mock_response.text = "Generated content" + mock_genai_client.aio.models.generate_content.return_value = mock_response + + # Mock the async tracking function + with patch( + "posthog.ai.gemini.gemini_async.call_llm_and_track_usage_async" + ) as mock_track: + mock_track.return_value = mock_response + + response = await async_client.aio.models.generate_content( + model="gemini-2.0-flash", contents=["Hello world"] + ) + + assert response == mock_response + mock_track.assert_called_once() + + @pytest.mark.asyncio + async def test_generate_content_with_posthog_params( + self, async_client, mock_genai_client + ): + """Test async content generation with PostHog parameters.""" + mock_response = MagicMock() + mock_response.text = "Generated content" + + with patch( + "posthog.ai.gemini.gemini_async.call_llm_and_track_usage_async" + ) as mock_track: + mock_track.return_value = mock_response + + response = await async_client.aio.models.generate_content( + model="gemini-2.0-flash", + contents=["Hello world"], + posthog_distinct_id="custom-user", + posthog_trace_id="custom-trace", + posthog_properties={"custom": "property"}, + posthog_privacy_mode=True, + posthog_groups={"team": "ai"}, + ) + + assert response == mock_response + mock_track.assert_called_once() + + # Verify the call arguments + call_args = mock_track.call_args + assert call_args[0][0] == "custom-user" # distinct_id + assert call_args[0][2] == "gemini" # provider + assert call_args[0][3] == "custom-trace" # trace_id + assert call_args[0][4]["custom"] == "property" # properties + assert call_args[0][5] is True # privacy_mode + assert call_args[0][6] == {"team": "ai"} # groups + + @pytest.mark.asyncio + async def test_generate_content_stream_basic(self, async_client, mock_genai_client): + """Test basic async streaming content generation.""" + + # Create a proper async generator mock + class AsyncGeneratorMock: + def __init__(self, items): + self.items = items + self.index = 0 + + def __aiter__(self): + return self + + async def __anext__(self): + if self.index >= len(self.items): + raise StopAsyncIteration + item = self.items[self.index] + self.index += 1 + return item + + # Mock streaming response chunks with proper usage_metadata + chunk1 = MagicMock(text="chunk1") + chunk1.usage_metadata = MagicMock() + chunk1.usage_metadata.prompt_token_count = 5 + chunk1.usage_metadata.candidates_token_count = 10 + chunk1.usage_metadata.cached_content_token_count = 0 + chunk1.usage_metadata.thoughts_token_count = 0 + + chunk2 = MagicMock(text="chunk2") + chunk2.usage_metadata = MagicMock() + chunk2.usage_metadata.prompt_token_count = 5 + chunk2.usage_metadata.candidates_token_count = 15 + chunk2.usage_metadata.cached_content_token_count = 0 + chunk2.usage_metadata.thoughts_token_count = 0 + + mock_stream = AsyncGeneratorMock([chunk1, chunk2]) + + # Mock the underlying streaming call to return the async generator + mock_genai_client.aio.models.generate_content_stream.return_value = mock_stream + + # Mock the capture_streaming_event to avoid the privacy_mode issue + with patch("posthog.ai.gemini.gemini_async.capture_streaming_event"): + response_generator = async_client.aio.models.generate_content_stream( + model="gemini-2.0-flash", contents=["Hello world"] + ) + + # Collect all chunks + chunks = [] + async for chunk in response_generator: + chunks.append(chunk) + + assert len(chunks) == 2 + assert chunks[0].text == "chunk1" + assert chunks[1].text == "chunk2" + + @pytest.mark.asyncio + async def test_merge_posthog_params(self, async_client): + """Test parameter merging logic.""" + models = async_client.aio.models + + # Test with call-level parameters overriding defaults + distinct_id, trace_id, properties, privacy_mode, groups = ( + models._merge_posthog_params( + call_distinct_id="call-user", + call_trace_id="call-trace", + call_properties={"call": "prop"}, + call_privacy_mode=True, + call_groups={"call": "group"}, + ) + ) + + assert distinct_id == "call-user" + assert trace_id == "call-trace" + assert properties["test"] == "property" # Default property + assert properties["call"] == "prop" # Call property + assert privacy_mode is True + assert groups == {"call": "group"} + + @pytest.mark.asyncio + async def test_merge_posthog_params_defaults(self, async_client): + """Test parameter merging with defaults.""" + models = async_client.aio.models + + # Test with None values falling back to defaults + distinct_id, trace_id, properties, privacy_mode, groups = ( + models._merge_posthog_params( + call_distinct_id=None, + call_trace_id=None, + call_properties=None, + call_privacy_mode=None, + call_groups=None, + ) + ) + + assert distinct_id == "test-user" # Default from client + assert trace_id is not None # Auto-generated UUID + assert properties == {"test": "property"} # Default properties + assert privacy_mode is False # Default privacy mode + assert groups is None # Default groups + + def test_format_input(self, async_client): + """Test input formatting.""" + models = async_client.aio.models + + with patch("posthog.ai.gemini.gemini_async.format_gemini_input") as mock_format: + mock_format.return_value = "formatted input" + + result = models._format_input(["test content"]) + + assert result == "formatted input" + mock_format.assert_called_once_with(["test content"]) + + @pytest.mark.asyncio + async def test_capture_streaming_event( + self, async_client, mock_posthog_client, mock_gemini_functions + ): + """Test streaming event capture.""" + models = async_client.aio.models + + from posthog.ai.types import TokenUsage + + usage_stats = TokenUsage(input_tokens=10, output_tokens=20) + + await models._capture_streaming_event( + model="gemini-2.0-flash", + contents=["test"], + distinct_id="test-user", + trace_id="test-trace", + properties={"test": "prop"}, + privacy_mode=False, + groups={"team": "ai"}, + kwargs={"temperature": 0.7}, + usage_stats=usage_stats, + latency=1.5, + output=["output"], + ) + + # Verify the capture function was called + mock_gemini_functions["capture"].assert_called_once() + # Verify the call was made with the PostHog client and event data + call_args = mock_gemini_functions["capture"].call_args + assert call_args[0][0] == mock_posthog_client # First arg is the client + # The second arg is the event data object - just verify it exists + assert call_args[0][1] is not None + + +class TestAsyncIntegration: + """Integration tests for the async Gemini client.""" + + @pytest.mark.asyncio + async def test_full_async_workflow(self, mock_posthog_client): + """Test a complete async workflow.""" + with patch("posthog.ai.gemini.gemini_async.genai.Client") as mock_client_class: + # Setup mock client + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + # Mock async response + mock_response = MagicMock() + mock_response.text = "Hello! How can I help you?" + mock_client.aio.models.generate_content = AsyncMock( + return_value=mock_response + ) + + # Mock tracking function + with patch( + "posthog.ai.gemini.gemini_async.call_llm_and_track_usage_async" + ) as mock_track: + mock_track.return_value = mock_response + + # Create client and make request + client = AsyncClient( + api_key="test-key", + posthog_client=mock_posthog_client, + posthog_distinct_id="integration-test-user", + ) + + response = await client.aio.models.generate_content( + model="gemini-2.0-flash", + contents=["Hello world"], + posthog_properties={"test_type": "integration"}, + ) + + # Verify response + assert response.text == "Hello! How can I help you?" + + # Verify tracking was called + mock_track.assert_called_once() + call_args = mock_track.call_args + assert call_args[0][0] == "integration-test-user" # distinct_id + assert call_args[0][2] == "gemini" # provider + + @pytest.mark.asyncio + async def test_error_handling(self, mock_posthog_client): + """Test error handling in async operations.""" + with patch("posthog.ai.gemini.gemini_async.genai.Client") as mock_client_class: + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + # Mock an exception + mock_client.aio.models.generate_content = AsyncMock( + side_effect=Exception("API Error") + ) + + with patch( + "posthog.ai.gemini.gemini_async.call_llm_and_track_usage_async" + ) as mock_track: + mock_track.side_effect = Exception("API Error") + + client = AsyncClient( + api_key="test-key", posthog_client=mock_posthog_client + ) + + with pytest.raises(Exception, match="API Error"): + await client.aio.models.generate_content( + model="gemini-2.0-flash", contents=["Hello world"] + )