-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Feature/eas llm setup #9268
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
alex3267006
wants to merge
4
commits into
Azure:main
Choose a base branch
from
alex3267006:feature/eas-llm-setup
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+961
−36
Open
Feature/eas llm setup #9268
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# -------------------------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. See License.txt in the project root for license information. | ||
# -------------------------------------------------------------------------------------------- | ||
|
||
|
||
import os | ||
from typing import List, Dict, Optional | ||
import yaml | ||
|
||
from azure.cli.core.api import get_config_dir | ||
from azext_aks_agent._consts import CONST_AGENT_CONFIG_FILE_NAME | ||
|
||
|
||
class LLMConfigManager: | ||
"""Manages loading and saving LLM configuration from/to a YAML file.""" | ||
|
||
def __init__(self, config_path=None): | ||
if config_path is None: | ||
config_path = os.path.join( | ||
get_config_dir(), CONST_AGENT_CONFIG_FILE_NAME) | ||
self.config_path = os.path.expanduser(config_path) | ||
|
||
def save(self, provider_name: str, params: dict): | ||
configs = self.load() | ||
if not isinstance(configs, Dict): | ||
configs = {} | ||
|
||
models = configs.get("llms", []) | ||
model_name = params.get("MODEL_NAME") | ||
if not model_name: | ||
raise ValueError("MODEL_NAME is required to save configuration.") | ||
|
||
# Check if model already exists, update it and move it to the last; | ||
# otherwise, append new | ||
models = [ | ||
cfg for cfg in models if not ( | ||
cfg.get("provider") == provider_name and cfg.get("MODEL_NAME") == model_name)] | ||
models.append({"provider": provider_name, **params}) | ||
|
||
configs["llms"] = models | ||
|
||
with open(self.config_path, "w") as f: | ||
yaml.safe_dump(configs, f, sort_keys=False) | ||
|
||
def load(self): | ||
"""Load configurations from the YAML file.""" | ||
if not os.path.exists(self.config_path): | ||
return {} | ||
with open(self.config_path, "r") as f: | ||
configs = yaml.safe_load(f) | ||
return configs if isinstance(configs, Dict) else {} | ||
|
||
def get_list(self) -> List[Dict]: | ||
"""Get the list of all model configurations""" | ||
return self.load()["llms"] if self.load( | ||
) and "llms" in self.load() else [] | ||
|
||
def get_latest(self) -> Optional[Dict]: | ||
"""Get the last model configuration""" | ||
model_configs = self.get_list() | ||
if model_configs: | ||
return model_configs[-1] | ||
raise ValueError( | ||
"No configurations found. Please run `az aks agent-init`") | ||
|
||
def get_specific( | ||
self, | ||
provider_name: str, | ||
model_name: str) -> Optional[Dict]: | ||
""" | ||
Get specific model configuration by provider and model name during Q&A with --model provider/model | ||
""" | ||
model_configs = self.get_list() | ||
for cfg in model_configs: | ||
if cfg.get("provider") == provider_name and cfg.get( | ||
"MODEL_NAME") == model_name: | ||
return cfg | ||
raise ValueError( | ||
f"No configuration found for provider '{provider_name}' with model '{model_name}'. " | ||
f"Please run `az aks agent-init`") | ||
|
||
def is_config_complete(self, config, provider_schema): | ||
""" | ||
Check if the given config has all required keys and valid values as per the provider schema. | ||
""" | ||
for key, meta in provider_schema.items(): | ||
if meta.get("validator") and not meta["validator"]( | ||
config.get(key)): | ||
return False | ||
return True |
77 changes: 77 additions & 0 deletions
77
src/aks-agent/azext_aks_agent/agent/llm_providers/__init__.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# -------------------------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. See License.txt in the project root for license information. | ||
# -------------------------------------------------------------------------------------------- | ||
|
||
from typing import List, Tuple | ||
from .base import LLMProvider | ||
from .azure_provider import AzureProvider | ||
from .openai_provider import OpenAIProvider | ||
from .anthropic_provider import AnthropicProvider | ||
from .gemini_provider import GeminiProvider | ||
from .openai_compatiable_provider import OpenAICompatiableProvider | ||
|
||
|
||
_PROVIDER_CLASSES: List[LLMProvider] = [ | ||
AzureProvider, | ||
OpenAIProvider, | ||
AnthropicProvider, | ||
GeminiProvider, | ||
OpenAICompatiableProvider, | ||
# Add new providers here | ||
] | ||
|
||
PROVIDER_REGISTRY = {} | ||
for cls in _PROVIDER_CLASSES: | ||
key = cls.name.lower() | ||
if key not in PROVIDER_REGISTRY: | ||
PROVIDER_REGISTRY[key] = cls | ||
|
||
|
||
def _available_providers() -> List[str]: | ||
"""Return a list of registered provider names (lowercase): ["azure", "openai", ...]""" | ||
return list(PROVIDER_REGISTRY.keys()) | ||
|
||
|
||
def _provider_choices_numbered() -> List[Tuple[int, str]]: | ||
"""Return numbered choices: [(1, "azure"), (2, "openai"), ...].""" | ||
return [(i + 1, name) for i, name in enumerate(_available_providers())] | ||
|
||
|
||
def _get_provider_by_index(idx: int) -> LLMProvider: | ||
""" | ||
Return provider instance by numeric index (1-based). | ||
Raises ValueError if index is out of range. | ||
""" | ||
if 1 <= idx <= len(_PROVIDER_CLASSES): | ||
print("You selected provider:", _PROVIDER_CLASSES[idx - 1].name) | ||
return _PROVIDER_CLASSES[idx - 1]() | ||
raise ValueError(f"Invalid provider index: {idx}") | ||
|
||
|
||
def prompt_provider_choice() -> LLMProvider: | ||
""" | ||
Show a numbered menu and return the chosen provider instance. | ||
Keeps prompting until a valid selection is made. | ||
""" | ||
choices = _provider_choices_numbered() | ||
if not choices: | ||
raise ValueError("No providers are registered.") | ||
while True: | ||
for idx, name in choices: | ||
print(f" {idx}. {name}") | ||
sel_idx = input("Enter the number of your choice: ").strip().lower() | ||
|
||
if sel_idx == "/exit": | ||
raise SystemExit(0) | ||
try: | ||
return _get_provider_by_index(int(sel_idx)) | ||
except ValueError as e: | ||
print( | ||
f"Invalid input: {e}. Please enter a valid number, or type /exit to quit.") | ||
|
||
|
||
__all__ = [ | ||
"PROVIDER_REGISTRY", | ||
"prompt_provider_choice", | ||
] |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Corrected spelling of 'compatiable' to 'compatible'.
Copilot uses AI. Check for mistakes.