Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ packages = ["src/copra"]

[project]
name = "copra-theorem-prover"
version = "1.5.0"
version = "1.6.0"
authors = [
{ name="Amitayush Thakur", email="amitayush@utexas.edu" },
]
Expand Down
5 changes: 0 additions & 5 deletions src/copra/agent/dfs_policy_prompter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from copra.agent.rate_limiter import InvalidActionException
from copra.agent.simple_policy_prompter import SimplePolicyPrompter
from copra.agent.gpt_guided_tree_search_policy import PromptSummary, ProofQInfo, TreeSearchAction, TreeSearchActionType
from copra.gpts.llama_access import ServiceDownError
from copra.retrieval.coq_bm25_reranker import CoqBM25TrainingDataRetriever
from copra.prompt_generator.gpt_request_grammar import CoqGPTRequestGrammar, CoqGptRequest, CoqGptRequestActions
from copra.prompt_generator.dfs_agent_grammar import DfsAgentGrammar
Expand Down Expand Up @@ -327,10 +326,6 @@ def run_prompt(self, request: CoqGptResponse) -> list:
# don't change temperature for now

self._num_api_calls += 1
except ServiceDownError as e:
self.logger.info("Got a service down error. Will giveup until the docker container is restarted.")
self.logger.exception(e)
raise
except Exception as e:
self.logger.info("Got an unknown exception. Retrying.")
self.logger.exception(e)
Expand Down
11 changes: 2 additions & 9 deletions src/copra/agent/simple_policy_prompter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import logging
from copra.agent.rate_limiter import RateLimiter
from copra.gpts.gpt_access import GptAccess
from copra.gpts.llama_access import LlamaAccess, ServiceDownError
from copra.prompt_generator.prompter import PolicyPrompter
from copra.tools.misc import model_supports_openai_api, is_vllm_model

Expand Down Expand Up @@ -74,10 +73,7 @@ def __init__(

# Initialize LLM access (GptAccess or LlamaAccess)
# Note: vLLM models (with "vllm:" prefix) are handled by GptAccess
if not model_supports_openai_api(gpt_model_name):
self._gpt_access = LlamaAccess(gpt_model_name)
else:
self._gpt_access = GptAccess(secret_filepath=secret_filepath, model_name=gpt_model_name)
self._gpt_access = GptAccess(secret_filepath=secret_filepath, model_name=gpt_model_name)

# Get model configuration
# For vLLM models, use the generic "vllm" key in model_info
Expand All @@ -104,14 +100,11 @@ def __init__(

def __enter__(self):
"""Context manager entry - initialize LLM service if needed."""
if isinstance(self._gpt_access, LlamaAccess):
self._gpt_access.__enter__()
return self

def __exit__(self, exc_type, exc_value, traceback):
"""Context manager exit - cleanup LLM service if needed."""
if isinstance(self._gpt_access, LlamaAccess):
self._gpt_access.__exit__(exc_type, exc_value, traceback)
pass

def add_to_history(self, message: typing.Dict[str, str]):
"""
Expand Down
13 changes: 3 additions & 10 deletions src/copra/baselines/gpt4/few_shot_policy_prompter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,10 @@
from copra.agent.rate_limiter import RateLimiter, InvalidActionException
from copra.agent.gpt_guided_tree_search_policy import TreeSearchAction
from copra.gpts.gpt_access import GptAccess
from copra.gpts.llama_access import LlamaAccess
from itp_interface.rl.proof_action import ProofAction
from copra.prompt_generator.prompter import PolicyPrompter
from copra.prompt_generator.dfs_agent_grammar import DfsAgentGrammar
from copra.baselines.gpt4.few_shot_grammar import FewShotGptRequest, FewShotGptRequestGrammar, FewShotGptResponse, FewShotGptResponseGrammar
from copra.tools.misc import model_supports_openai_api

class FewShotGptPolicyPrompter(PolicyPrompter):
_cache: typing.Dict[str, typing.Any] = {}
Expand Down Expand Up @@ -43,10 +41,7 @@ def __init__(self,
conv_messages = self.agent_grammar.get_openai_conv_messages(example_conv_prompt_path, "system")
main_message = self.agent_grammar.get_openai_main_message(main_sys_prompt_path, "system")
self.system_messages = [main_message] + conv_messages
if not model_supports_openai_api(gpt_model_name):
self._gpt_access = LlamaAccess(gpt_model_name)
else:
self._gpt_access = GptAccess(secret_filepath=secret_filepath, model_name=gpt_model_name)
self._gpt_access = GptAccess(secret_filepath=secret_filepath, model_name=gpt_model_name)

# For vLLM models, use the generic "vllm" key for model info
model_info_key = "vllm" if gpt_model_name.startswith("vllm:") else gpt_model_name
Expand Down Expand Up @@ -83,12 +78,10 @@ def __init__(self,
pass

def __enter__(self):
if isinstance(self._gpt_access, LlamaAccess):
self._gpt_access.__enter__()
pass

def __exit__(self, exc_type, exc_value, traceback):
if isinstance(self._gpt_access, LlamaAccess):
self._gpt_access.__exit__(exc_type, exc_value, traceback)
pass

def _init_retriever(self):
if FewShotGptPolicyPrompter._cache.get(self._training_data_path, None) is not None:
Expand Down
3 changes: 0 additions & 3 deletions src/copra/baselines/gpt4/hammer_policy_prompter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@
from copra.retrieval.coq_bm25_reranker import CoqBM25TrainingDataRetriever
from copra.prompt_generator.agent_grammar import CoqGPTResponseGrammar
from copra.prompt_generator.gpt_request_grammar import CoqGPTRequestGrammar, CoqGptRequestActions
from copra.agent.rate_limiter import RateLimiter
from copra.agent.gpt_guided_tree_search_policy import TreeSearchAction, TreeSearchActionType
from copra.gpts.gpt_access import GptAccess
from copra.gpts.llama_access import LlamaAccess
from itp_interface.rl.proof_action import ProofAction
from copra.prompt_generator.prompter import PolicyPrompter
from copra.prompt_generator.dfs_agent_grammar import DfsAgentGrammar
Expand Down
12 changes: 3 additions & 9 deletions src/copra/baselines/gpt4/informal_few_shot_policy_prompter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from copra.agent.rate_limiter import RateLimiter, InvalidActionException
from copra.agent.gpt_guided_tree_search_policy import TreeSearchAction
from copra.gpts.gpt_access import GptAccess
from copra.gpts.llama_access import LlamaAccess
from itp_interface.rl.proof_action import ProofAction
from copra.prompt_generator.prompter import PolicyPrompter
from copra.prompt_generator.dfs_agent_grammar import DfsAgentGrammar
Expand Down Expand Up @@ -44,10 +43,7 @@ def __init__(self,
conv_messages = self.agent_grammar.get_openai_conv_messages(example_conv_prompt_path, "system")
main_message = self.agent_grammar.get_openai_main_message(main_sys_prompt_path, "system")
self.system_messages = [main_message] + conv_messages
if not model_supports_openai_api(gpt_model_name):
self._gpt_access = LlamaAccess(gpt_model_name)
else:
self._gpt_access = GptAccess(secret_filepath=secret_filepath, model_name=gpt_model_name)
self._gpt_access = GptAccess(secret_filepath=secret_filepath, model_name=gpt_model_name)
self._token_limit_per_min = GptAccess.gpt_model_info[gpt_model_name]["token_limit_per_min"]
self._request_limit_per_min = GptAccess.gpt_model_info[gpt_model_name]["request_limit_per_min"]
self._max_token_per_prompt = GptAccess.gpt_model_info[gpt_model_name]["max_token_per_prompt"]
Expand Down Expand Up @@ -81,12 +77,10 @@ def __init__(self,
pass

def __enter__(self):
if isinstance(self._gpt_access, LlamaAccess):
self._gpt_access.__enter__()
pass

def __exit__(self, exc_type, exc_value, traceback):
if isinstance(self._gpt_access, LlamaAccess):
self._gpt_access.__exit__(exc_type, exc_value, traceback)
pass

def _init_retriever(self):
if InformalFewShotGptPolicyPrompter._cache.get(self._training_data_path, None) is not None:
Expand Down
227 changes: 0 additions & 227 deletions src/copra/gpts/llama2_chat_format.py

This file was deleted.

Loading