diff --git a/pyproject.toml b/pyproject.toml index 3428f333..8c18bcd6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ dependencies = [ "typer>=0.15.2", "litellm==1.74.1", "weave>=0.51.51", + "azure-identity>=1.24.0", ] [project.optional-dependencies] diff --git a/src/art/mcp/generate_scenarios.py b/src/art/mcp/generate_scenarios.py index df92ea3c..3cc3711b 100644 --- a/src/art/mcp/generate_scenarios.py +++ b/src/art/mcp/generate_scenarios.py @@ -4,11 +4,14 @@ import time from typing import Any, Dict, List, Optional -import openai +from azure.identity import DefaultAzureCredential, get_bearer_token_provider +import litellm from art.mcp.types import GeneratedScenarioCollection, MCPResource, MCPTool -from art.utils.logging import _C, dim, err, info, ok, step +from art.utils.logging import _C, dim, err, info, ok, step, warn +# Enable automatic fallback to DefaultAzureCredential +litellm.enable_azure_ad_token_refresh = True def preview_scenarios(scenarios: List[Dict[str, Any]], n: int = 5): """Preview generated scenarios.""" @@ -32,7 +35,7 @@ async def generate_scenarios( custom_instructions: Optional[str] = None, generator_model: str = "openai/gpt-4.1-mini", generator_api_key: Optional[str] = None, - generator_base_url: str = "https://openrouter.ai/api/v1", + generator_base_url: str = "", ) -> GeneratedScenarioCollection: """ Generate scenarios for MCP tools. @@ -43,9 +46,9 @@ async def generate_scenarios( num_scenarios: Number of scenarios to generate (default: 24) show_preview: Whether to show a preview of generated scenarios (default: True) custom_instructions: Optional custom instructions for scenario generation - generator_model: Model to use for generation (default: "openai/gpt-4.1-mini") - generator_api_key: API key for the generator model. If None, will use OPENROUTER_API_KEY env var - generator_base_url: Base URL for the API (default: OpenRouter) + generator_model: LiteLLM model to use for generation (default: "openai/gpt-4.1-mini") + generator_api_key: API key for the generator model. If None, will use environment variable associated with the LiteLLM model. + generator_base_url: Base URL for the API Returns: GeneratedScenarioCollection containing the generated scenarios @@ -54,13 +57,16 @@ async def generate_scenarios( t0 = time.perf_counter() - # Handle API key - if generator_api_key is None: - generator_api_key = os.getenv("OPENROUTER_API_KEY") - if not generator_api_key: - raise ValueError( - "generator_api_key is required or OPENROUTER_API_KEY env var must be set" - ) + # Handle API key if provided + if generator_api_key: + result = litellm.utils.check_valid_key(model=generator_model, api_key=generator_api_key) + if result: + litellm.api_key = generator_api_key + else: + raise ValueError("Invalid API key provided.") + + if not generator_api_key: + warn("generator_api_key is not set. Will use environment variable associated with the LiteLLM model.") # Validate that we have at least tools or resources if not tools and not resources: @@ -160,50 +166,60 @@ async def generate_scenarios( } step(f"Calling model: {_C.BOLD}{generator_model}{_C.RESET} &") - client_openai = openai.OpenAI( - api_key=generator_api_key, - base_url=generator_base_url, - ) - t1 = time.perf_counter() - response = client_openai.chat.completions.create( - model=generator_model, - messages=[{"role": "user", "content": prompt}], - max_completion_tokens=8000, - response_format={ - "type": "json_schema", - "json_schema": {"name": "scenario_list", "schema": response_schema}, - }, - ) - dt = time.perf_counter() - t1 - ok(f"Model responded in {dt:.2f}s.") - - content = response.choices[0].message.content - if content is None: - err("Model response content is None.") - raise ValueError("Model response content is None") - info(f"Raw content length: {len(content)} chars.") - - # Parse JSON - try: - result = json.loads(content) - except Exception as e: - err("Failed to parse JSON from model response.") - dim(f" Exception: {e}") - dim(" First 500 chars of response content:") - dim(content[:500] if content else "No content") - raise - - # Extract scenarios - if "scenarios" in result: - scenarios = result["scenarios"] - else: - scenarios = result if isinstance(result, list) else list(result.values())[0] - - # Validate count - if len(scenarios) != num_scenarios: - err(f"Expected {num_scenarios} scenarios, got {len(scenarios)}.") - raise ValueError(f"Expected {num_scenarios} scenarios, got {len(scenarios)}") + # If using Azure OpenAI, support managed identity authentication via DefaultAzureCredential + token_provider = None + if os.getenv("AZURE_API_BASE") and not os.getenv("AZURE_API_KEY"): + warn("AZURE_API_KEY environment variable not set for Azure OpenAI. Will fallback to DefaultAzureCredential().") + try: + token_provider = get_bearer_token_provider(DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default") + except Exception: + err(f"Failed to get bearer token provider for Azure AD authentication.") + raise + + scenarios = [] + while len(scenarios) < num_scenarios: + if len(scenarios) > 0: + warn(f"Expected {num_scenarios} scenarios, got {len(scenarios)}. Retrying to get more scenarios.") + + t1 = time.perf_counter() + response = await litellm.acompletion( + model=generator_model, + messages=[{"role": "user", "content": prompt}], + azure_ad_token_provider=token_provider, + max_completion_tokens=8000, + response_format={ + "type": "json_schema", + "json_schema": {"name": "scenario_list", "schema": response_schema}, + }, + ) + dt = time.perf_counter() - t1 + ok(f"Model responded in {dt:.2f}s.") + + content = response.choices[0].message.content + if content is None: + err("Model response content is None.") + raise ValueError("Model response content is None") + info(f"Raw content length: {len(content)} chars.") + + # Parse JSON + try: + result = json.loads(content) + except Exception as e: + err("Failed to parse JSON from model response.") + dim(f" Exception: {e}") + dim(" First 500 chars of response content:") + dim(content[:500] if content else "No content") + raise + + # Extract scenarios + if "scenarios" in result: + scenarios_list = result["scenarios"] + else: + scenarios_list = result if isinstance(result, list) else list(result.values())[0] + scenarios.extend(scenarios_list) + + scenarios = scenarios[:num_scenarios] # Trim to exact number if we got too many scenarios ok(f"Parsed {len(scenarios)} scenario(s) successfully.") diff --git a/src/art/model.py b/src/art/model.py index 43c519b2..8cffc303 100644 --- a/src/art/model.py +++ b/src/art/model.py @@ -1,13 +1,16 @@ +import os from typing import TYPE_CHECKING, Generic, Iterable, Optional, TypeVar, cast, overload +from azure.identity import DefaultAzureCredential, get_bearer_token_provider import httpx -from openai import AsyncOpenAI, DefaultAsyncHttpxClient +from openai import AsyncOpenAI, AsyncAzureOpenAI, DefaultAsyncHttpxClient from pydantic import BaseModel from typing_extensions import Never from . import dev from .trajectories import Trajectory, TrajectoryGroup from .types import TrainConfig +from .utils.logging import warn if TYPE_CHECKING: from art.backend import Backend @@ -71,7 +74,7 @@ class Model( _backend: Optional["Backend"] = None _s3_bucket: str | None = None _s3_prefix: str | None = None - _openai_client: AsyncOpenAI | None = None + _openai_client: AsyncOpenAI | AsyncAzureOpenAI | None = None def __init__( self, @@ -187,6 +190,41 @@ def openai_client( ) return self._openai_client + def azure_openai_client( + self, + ) -> AsyncAzureOpenAI: + if self._openai_client is not None: + return self._openai_client # type: ignore + + if self.inference_base_url is None: + if self.trainable: + raise ValueError( + "AzureOpenAI client not yet available on this trainable model. You must call `model.register()` first." + ) + else: + raise ValueError( + "In order to create an AzureOpenAI client you must provide an `inference_api_key` (optional) and `inference_base_url`." + ) + if not os.getenv("AZURE_API_BASE"): + raise ValueError("AZURE_API_BASE environment variable must be set for Azure OpenAI") + token_provider = None + if not self.inference_api_key: + warn("Creating AzureOpenAI client without an inference_api_key. Will fall back to DefaultAzureCredential().") + token_provider = get_bearer_token_provider(DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default") + self._openai_client = AsyncAzureOpenAI( + base_url=os.environ["AZURE_API_BASE"], + api_key=self.inference_api_key, + azure_ad_token_provider=token_provider, + api_version=os.getenv("AZURE_API_VERSION", "2024-12-01-preview"), + http_client=DefaultAsyncHttpxClient( + timeout=httpx.Timeout(timeout=1200, connect=5.0), + limits=httpx.Limits( + max_connections=100_000, max_keepalive_connections=100_000 + ), + ), + ) + return self._openai_client + def litellm_completion_params(self) -> dict: """Return the parameters that should be sent to litellm.completion.""" model_name = self.inference_model_name diff --git a/src/art/rewards/ruler.py b/src/art/rewards/ruler.py index 2ea33312..cd7303f1 100644 --- a/src/art/rewards/ruler.py +++ b/src/art/rewards/ruler.py @@ -10,10 +10,14 @@ """ import json +import os from textwrap import dedent from typing import List -from litellm import acompletion +from art.utils.logging import err, warn +from azure.identity import DefaultAzureCredential, get_bearer_token_provider + +import litellm from litellm.types.utils import ModelResponse from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam from pydantic import BaseModel, Field @@ -172,9 +176,19 @@ async def ruler( {"role": "user", "content": user_text}, ] - response = await acompletion( + token_provider = None + # If using Azure OpenAI, support managed identity authentication via DefaultAzureCredential + if os.getenv("AZURE_API_BASE") and not os.getenv("AZURE_API_KEY"): + warn("AZURE_API_KEY environment variable not set for Azure OpenAI. Will fallback to DefaultAzureCredential().") + try: + token_provider = get_bearer_token_provider(DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default") + except Exception: + err(f"Failed to get bearer token provider for Azure AD authentication.") + raise + response = await litellm.acompletion( model=judge_model, messages=messages, + azure_ad_token_provider=token_provider, response_format=Response, caching=False, **extra_litellm_params if extra_litellm_params else {}, diff --git a/uv.lock b/uv.lock index d2ba5a80..4b8a65b3 100644 --- a/uv.lock +++ b/uv.lock @@ -4124,6 +4124,7 @@ name = "openpipe-art" version = "0.5.2" source = { editable = "." } dependencies = [ + { name = "azure-identity" }, { name = "litellm" }, { name = "openai" }, { name = "typer" }, @@ -4187,6 +4188,7 @@ dev = [ requires-dist = [ { name = "accelerate", marker = "extra == 'backend'", specifier = "==1.7.0" }, { name = "awscli", marker = "extra == 'backend'", specifier = ">=1.38.1" }, + { name = "azure-identity", specifier = ">=1.24.0" }, { name = "bitsandbytes", marker = "extra == 'backend'", specifier = ">=0.45.2" }, { name = "gql", marker = "extra == 'backend'", specifier = "<4" }, { name = "hf-xet", marker = "extra == 'backend'", specifier = ">=1.1.0" },