Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ dependencies = [
"typer>=0.15.2",
"litellm==1.74.1",
"weave>=0.51.51",
"azure-identity>=1.24.0",
]

[project.optional-dependencies]
Expand Down
128 changes: 72 additions & 56 deletions src/art/mcp/generate_scenarios.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
import time
from typing import Any, Dict, List, Optional

import openai
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
import litellm

from art.mcp.types import GeneratedScenarioCollection, MCPResource, MCPTool
from art.utils.logging import _C, dim, err, info, ok, step
from art.utils.logging import _C, dim, err, info, ok, step, warn

# Enable automatic fallback to DefaultAzureCredential
litellm.enable_azure_ad_token_refresh = True

def preview_scenarios(scenarios: List[Dict[str, Any]], n: int = 5):
"""Preview generated scenarios."""
Expand All @@ -32,7 +35,7 @@ async def generate_scenarios(
custom_instructions: Optional[str] = None,
generator_model: str = "openai/gpt-4.1-mini",
generator_api_key: Optional[str] = None,
generator_base_url: str = "https://openrouter.ai/api/v1",
generator_base_url: str = "",
) -> GeneratedScenarioCollection:
"""
Generate scenarios for MCP tools.
Expand All @@ -43,9 +46,9 @@ async def generate_scenarios(
num_scenarios: Number of scenarios to generate (default: 24)
show_preview: Whether to show a preview of generated scenarios (default: True)
custom_instructions: Optional custom instructions for scenario generation
generator_model: Model to use for generation (default: "openai/gpt-4.1-mini")
generator_api_key: API key for the generator model. If None, will use OPENROUTER_API_KEY env var
generator_base_url: Base URL for the API (default: OpenRouter)
generator_model: LiteLLM model to use for generation (default: "openai/gpt-4.1-mini")
generator_api_key: API key for the generator model. If None, will use environment variable associated with the LiteLLM model.
generator_base_url: Base URL for the API

Returns:
GeneratedScenarioCollection containing the generated scenarios
Expand All @@ -54,13 +57,16 @@ async def generate_scenarios(

t0 = time.perf_counter()

# Handle API key
if generator_api_key is None:
generator_api_key = os.getenv("OPENROUTER_API_KEY")
if not generator_api_key:
raise ValueError(
"generator_api_key is required or OPENROUTER_API_KEY env var must be set"
)
# Handle API key if provided
if generator_api_key:
result = litellm.utils.check_valid_key(model=generator_model, api_key=generator_api_key)
if result:
litellm.api_key = generator_api_key
else:
raise ValueError("Invalid API key provided.")

if not generator_api_key:
warn("generator_api_key is not set. Will use environment variable associated with the LiteLLM model.")

# Validate that we have at least tools or resources
if not tools and not resources:
Expand Down Expand Up @@ -160,50 +166,60 @@ async def generate_scenarios(
}

step(f"Calling model: {_C.BOLD}{generator_model}{_C.RESET} &")
client_openai = openai.OpenAI(
api_key=generator_api_key,
base_url=generator_base_url,
)

t1 = time.perf_counter()
response = client_openai.chat.completions.create(
model=generator_model,
messages=[{"role": "user", "content": prompt}],
max_completion_tokens=8000,
response_format={
"type": "json_schema",
"json_schema": {"name": "scenario_list", "schema": response_schema},
},
)
dt = time.perf_counter() - t1
ok(f"Model responded in {dt:.2f}s.")

content = response.choices[0].message.content
if content is None:
err("Model response content is None.")
raise ValueError("Model response content is None")
info(f"Raw content length: {len(content)} chars.")

# Parse JSON
try:
result = json.loads(content)
except Exception as e:
err("Failed to parse JSON from model response.")
dim(f" Exception: {e}")
dim(" First 500 chars of response content:")
dim(content[:500] if content else "No content")
raise

# Extract scenarios
if "scenarios" in result:
scenarios = result["scenarios"]
else:
scenarios = result if isinstance(result, list) else list(result.values())[0]

# Validate count
if len(scenarios) != num_scenarios:
err(f"Expected {num_scenarios} scenarios, got {len(scenarios)}.")
raise ValueError(f"Expected {num_scenarios} scenarios, got {len(scenarios)}")
# If using Azure OpenAI, support managed identity authentication via DefaultAzureCredential
token_provider = None
if os.getenv("AZURE_API_BASE") and not os.getenv("AZURE_API_KEY"):
warn("AZURE_API_KEY environment variable not set for Azure OpenAI. Will fallback to DefaultAzureCredential().")
try:
token_provider = get_bearer_token_provider(DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default")
except Exception:
err(f"Failed to get bearer token provider for Azure AD authentication.")
raise

scenarios = []
while len(scenarios) < num_scenarios:
if len(scenarios) > 0:
warn(f"Expected {num_scenarios} scenarios, got {len(scenarios)}. Retrying to get more scenarios.")

t1 = time.perf_counter()
response = await litellm.acompletion(
model=generator_model,
messages=[{"role": "user", "content": prompt}],
azure_ad_token_provider=token_provider,
max_completion_tokens=8000,
response_format={
"type": "json_schema",
"json_schema": {"name": "scenario_list", "schema": response_schema},
},
)
dt = time.perf_counter() - t1
ok(f"Model responded in {dt:.2f}s.")

content = response.choices[0].message.content
if content is None:
err("Model response content is None.")
raise ValueError("Model response content is None")
info(f"Raw content length: {len(content)} chars.")

# Parse JSON
try:
result = json.loads(content)
except Exception as e:
err("Failed to parse JSON from model response.")
dim(f" Exception: {e}")
dim(" First 500 chars of response content:")
dim(content[:500] if content else "No content")
raise

# Extract scenarios
if "scenarios" in result:
scenarios_list = result["scenarios"]
else:
scenarios_list = result if isinstance(result, list) else list(result.values())[0]
scenarios.extend(scenarios_list)

scenarios = scenarios[:num_scenarios] # Trim to exact number if we got too many scenarios

ok(f"Parsed {len(scenarios)} scenario(s) successfully.")

Expand Down
42 changes: 40 additions & 2 deletions src/art/model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import os
from typing import TYPE_CHECKING, Generic, Iterable, Optional, TypeVar, cast, overload

from azure.identity import DefaultAzureCredential, get_bearer_token_provider
import httpx
from openai import AsyncOpenAI, DefaultAsyncHttpxClient
from openai import AsyncOpenAI, AsyncAzureOpenAI, DefaultAsyncHttpxClient
from pydantic import BaseModel
from typing_extensions import Never

from . import dev
from .trajectories import Trajectory, TrajectoryGroup
from .types import TrainConfig
from .utils.logging import warn

if TYPE_CHECKING:
from art.backend import Backend
Expand Down Expand Up @@ -71,7 +74,7 @@ class Model(
_backend: Optional["Backend"] = None
_s3_bucket: str | None = None
_s3_prefix: str | None = None
_openai_client: AsyncOpenAI | None = None
_openai_client: AsyncOpenAI | AsyncAzureOpenAI | None = None

def __init__(
self,
Expand Down Expand Up @@ -187,6 +190,41 @@ def openai_client(
)
return self._openai_client

def azure_openai_client(
self,
) -> AsyncAzureOpenAI:
if self._openai_client is not None:
return self._openai_client # type: ignore

if self.inference_base_url is None:
if self.trainable:
raise ValueError(
"AzureOpenAI client not yet available on this trainable model. You must call `model.register()` first."
)
else:
raise ValueError(
"In order to create an AzureOpenAI client you must provide an `inference_api_key` (optional) and `inference_base_url`."
)
if not os.getenv("AZURE_API_BASE"):
raise ValueError("AZURE_API_BASE environment variable must be set for Azure OpenAI")
token_provider = None
if not self.inference_api_key:
warn("Creating AzureOpenAI client without an inference_api_key. Will fall back to DefaultAzureCredential().")
token_provider = get_bearer_token_provider(DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default")
self._openai_client = AsyncAzureOpenAI(
base_url=os.environ["AZURE_API_BASE"],
api_key=self.inference_api_key,
azure_ad_token_provider=token_provider,
api_version=os.getenv("AZURE_API_VERSION", "2024-12-01-preview"),
http_client=DefaultAsyncHttpxClient(
timeout=httpx.Timeout(timeout=1200, connect=5.0),
limits=httpx.Limits(
max_connections=100_000, max_keepalive_connections=100_000
),
),
)
return self._openai_client

def litellm_completion_params(self) -> dict:
"""Return the parameters that should be sent to litellm.completion."""
model_name = self.inference_model_name
Expand Down
18 changes: 16 additions & 2 deletions src/art/rewards/ruler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,14 @@
"""

import json
import os
from textwrap import dedent
from typing import List

from litellm import acompletion
from art.utils.logging import err, warn
from azure.identity import DefaultAzureCredential, get_bearer_token_provider

import litellm
from litellm.types.utils import ModelResponse
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
from pydantic import BaseModel, Field
Expand Down Expand Up @@ -172,9 +176,19 @@ async def ruler(
{"role": "user", "content": user_text},
]

response = await acompletion(
token_provider = None
# If using Azure OpenAI, support managed identity authentication via DefaultAzureCredential
if os.getenv("AZURE_API_BASE") and not os.getenv("AZURE_API_KEY"):
warn("AZURE_API_KEY environment variable not set for Azure OpenAI. Will fallback to DefaultAzureCredential().")
try:
token_provider = get_bearer_token_provider(DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default")
except Exception:
err(f"Failed to get bearer token provider for Azure AD authentication.")
raise
response = await litellm.acompletion(
model=judge_model,
messages=messages,
azure_ad_token_provider=token_provider,
response_format=Response,
caching=False,
**extra_litellm_params if extra_litellm_params else {},
Expand Down
2 changes: 2 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.