-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathfetch_models.py
More file actions
executable file
·177 lines (131 loc) · 5.98 KB
/
fetch_models.py
File metadata and controls
executable file
·177 lines (131 loc) · 5.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
#!/usr/bin/env python3
"""
Fetch available models from various API providers and display comprehensive information.
Currently supports Cirrascale API with ability to extend to other providers.
"""
import json
import os
from datetime import datetime
import httpx
from dotenv import load_dotenv
from rich.console import Console
from rich.json import JSON
from rich.table import Table
# Load environment variables
load_dotenv()
console = Console()
async def fetch_models():
"""Fetch all models from Cirrascale API."""
api_key = os.getenv("CIRRASCALE_API_KEY")
if not api_key:
console.print("[red]Error: CIRRASCALE_API_KEY not found in .env file[/red]")
return None
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
endpoint = "https://ai2endpoints.cirrascale.ai/api/models"
console.print(f"[cyan]Fetching models from: {endpoint}[/cyan]\n")
async with httpx.AsyncClient() as client:
try:
response = await client.get(endpoint, headers=headers, timeout=10.0)
response.raise_for_status()
data = response.json()
# Display raw response
console.print("[bold]Raw API Response:[/bold]")
console.print(JSON(json.dumps(data, indent=2)))
return data
except httpx.HTTPStatusError as e:
console.print(f"[red]HTTP Error {e.response.status_code}: {e.response.text}[/red]")
return None
except Exception as e:
console.print(f"[red]Error fetching models: {e}[/red]")
return None
async def display_model_info(models_data):
"""Display model information in a formatted table."""
if not models_data:
return
console.print("\n[bold cyan]Available Models Summary:[/bold cyan]\n")
# Create table for models
table = Table(title="Cirrascale AI2 Models")
table.add_column("Model ID", style="cyan")
table.add_column("Created", style="yellow")
table.add_column("Owner", style="green")
table.add_column("Permissions", style="magenta")
models = models_data.get("data", [])
for model in models:
model_id = model.get("id", "Unknown")
created = datetime.fromtimestamp(model.get("created", 0)).strftime("%Y-%m-%d %H:%M:%S")
owned_by = model.get("owned_by", "Unknown")
# Format permissions
perms = model.get("permission", [])
if perms:
perm_str = f"{len(perms)} permissions"
else:
perm_str = "No permissions"
table.add_row(model_id, created, owned_by, perm_str)
console.print(table)
# Count models
console.print(f"\n[green]Total models available: {len(models)}[/green]")
# List model IDs for easy copying
console.print("\n[bold]Model IDs for reference:[/bold]")
for model in models:
console.print(f" • {model.get('id')}")
return models
async def compare_with_routing():
"""Compare API models with our routing configuration."""
from provider_routing import FASTAPI_PROVIDERS
console.print("\n[bold cyan]Current Routing Configuration:[/bold cyan]\n")
# Get Cirrascale configuration
cirrascale_config = FASTAPI_PROVIDERS.get("cirrascale", {})
console.print("[bold]Cirrascale Provider Configuration:[/bold]")
console.print(f" Endpoint: {cirrascale_config.get('endpoint')}")
console.print(f" API Key Env: {cirrascale_config.get('api_key_env')}")
console.print(f" Model Prefix: {cirrascale_config.get('model_prefix')}")
console.print(f" Rate Limit: {cirrascale_config.get('rate_limit_rpm')} requests/min")
# Check if models are hardcoded
if "models" in cirrascale_config:
console.print("\n[yellow]Hardcoded models in configuration:[/yellow]")
for model in cirrascale_config["models"]:
console.print(f" • {model}")
else:
console.print("\n[green]No hardcoded models - using dynamic routing[/green]")
async def main():
"""Main execution function."""
console.print("[bold cyan]🔍 API Model Discovery[/bold cyan]\n")
# Fetch models from API
models_data = await fetch_models()
if models_data:
# Display model information
models = await display_model_info(models_data)
# Compare with routing configuration
await compare_with_routing()
# Provide recommendations
console.print("\n[bold cyan]📊 Analysis & Recommendations:[/bold cyan]\n")
if models:
model_ids = [m.get("id") for m in models]
# Check for different model sizes
sizes = {"1B": [], "7B": [], "13B": [], "other": []}
for model_id in model_ids:
if "1B" in model_id or "1b" in model_id or "0425" in model_id:
sizes["1B"].append(model_id)
elif "7B" in model_id or "7b" in model_id:
sizes["7B"].append(model_id)
elif "13B" in model_id or "13b" in model_id:
sizes["13B"].append(model_id)
else:
sizes["other"].append(model_id)
console.print("[bold]Model Size Distribution:[/bold]")
for size, models_list in sizes.items():
if models_list:
console.print(f" {size}: {len(models_list)} model(s)")
for model in models_list:
console.print(f" • {model}")
console.print("\n[bold]Recommended Next Steps:[/bold]")
console.print("1. Based on 12.2% accuracy with 1B model, test the 7B model next")
console.print("2. Expected 2-3x better accuracy with 7B (targeting 30-40%)")
console.print("3. Use 13B model only if 7B doesn't meet minimum requirements")
console.print(
"4. Consider adjusting temperature and top_p for better instruction following"
)
console.print("\n[dim]Timestamp:", datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "[/dim]")
if __name__ == "__main__":
import asyncio
asyncio.run(main())