1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16- import os
1716
1817from vuln_analysis .runtime_context import ctx_state
1918from vuln_analysis .utils .transitive_code_searcher_tool import TransitiveCodeSearcher
2625from langchain .docstore .document import Document
2726
2827from vuln_analysis .data_models .state import AgentMorpheusEngineState
29- from ..utils .chain_of_calls_retriever import ChainOfCallsRetriever
30- from vuln_analysis .utils .dep_tree import Ecosystem
3128from vuln_analysis .utils .document_embedding import DocumentEmbedding
29+ from ..data_models .input import SourceDocumentsInfo
30+ from ..utils .chain_of_calls_retriever_base import ChainOfCallsRetrieverBase
31+ from ..utils .chain_of_calls_retriever_factory import get_chain_of_calls_retriever
3232from ..utils .function_name_extractor import FunctionNameExtractor
3333from ..utils .function_name_locator import FunctionNameLocator
3434
3535from vuln_analysis .logging .loggers_factory import LoggingFactory
36+ from ..utils .java_chain_of_calls_retriever import JavaChainOfCallsRetriever
37+
3638logger = LoggingFactory .get_agent_logger (__name__ )
3739
3840
@@ -53,29 +55,34 @@ class PackageAndFunctionLocatorToolConfig(FunctionBaseConfig, name="package_and_
5355 Package and function locator tool used to validate package names and find function names using fuzzy matching.
5456 """
5557
56-
57- def get_call_of_chains_retriever (documents_embedder , si ):
58+ def get_call_of_chains_retriever (documents_embedder , si , query : str ):
5859 documents : list [Document ]
5960 git_repo = None
61+ code_source_info : SourceDocumentsInfo
6062 for source_info in si :
6163 if source_info .type == "code" :
64+ code_source_info = source_info
6265 git_repo = documents_embedder .get_repo_path (source_info )
6366 documents = documents_embedder .collect_documents (source_info )
6467 if git_repo is None :
6568 raise ValueError ("No code source info found" )
66- with open (os .path .join (git_repo , 'ecosystem_data.txt' ), 'r' , encoding = 'utf-8' ) as file :
67- ecosystem = file .read ()
68- ecosystem = Ecosystem [ecosystem ]
69- coc_retriever = ChainOfCallsRetriever (documents = documents , ecosystem = ecosystem , manifest_path = git_repo )
69+ ecosystem = TransitiveCodeSearcher .get_ecosystem (git_repo )
70+ coc_retriever = get_chain_of_calls_retriever (ecosystem ,
71+ documents ,
72+ git_repo ,
73+ query ,
74+ code_source_info )
7075 return coc_retriever
7176
72-
73- def get_transitive_code_searcher ():
77+ def get_transitive_code_searcher (query : str ):
7478 state : AgentMorpheusEngineState = ctx_state .get ()
75- if state .transitive_code_searcher is None :
79+ if state .transitive_code_searcher is None or isinstance ( state . transitive_code_searcher . chain_of_calls_retriever , JavaChainOfCallsRetriever ) :
7680 si = state .original_input .input .image .source_info
7781 documents_embedder = DocumentEmbedding (embedding = None )
78- coc_retriever = get_call_of_chains_retriever (documents_embedder , si )
82+ if state .transitive_code_searcher is not None and isinstance (state .transitive_code_searcher .chain_of_calls_retriever , JavaChainOfCallsRetriever ):
83+ coc_retriever = get_call_of_chains_retriever (documents_embedder , si , query )
84+ else :
85+ coc_retriever = get_call_of_chains_retriever (documents_embedder , si , query )
7986 transitive_code_searcher = TransitiveCodeSearcher (chain_of_calls_retriever = coc_retriever )
8087 state .transitive_code_searcher = transitive_code_searcher
8188 return state .transitive_code_searcher
@@ -87,7 +94,7 @@ async def transitive_search(config: TransitiveCodeSearchToolConfig,
8794
8895 async def _arun (query : str ) -> tuple :
8996 transitive_code_searcher : TransitiveCodeSearcher
90- transitive_code_searcher = get_transitive_code_searcher ()
97+ transitive_code_searcher = get_transitive_code_searcher (query )
9198 result = transitive_code_searcher .search (query )
9299 return result
93100
@@ -127,9 +134,9 @@ async def functions_usage_search(config: CallingFunctionNameExtractorToolConfig,
127134 builder : Builder ): # pylint: disable=unused-argument
128135
129136 async def _arun (query : str ) -> list :
130- coc_retriever : ChainOfCallsRetriever
137+ coc_retriever : ChainOfCallsRetrieverBase
131138 transitive_code_searcher : TransitiveCodeSearcher
132- transitive_code_searcher = get_transitive_code_searcher ()
139+ transitive_code_searcher = get_transitive_code_searcher (query )
133140 coc_retriever = transitive_code_searcher .chain_of_calls_retriever
134141 function_name_extractor = FunctionNameExtractor (coc_retriever )
135142 result = function_name_extractor .fetch_list (query )
@@ -175,20 +182,20 @@ async def package_and_function_locator(config: PackageAndFunctionLocatorToolConf
175182 builder : Builder ): # pylint: disable=unused-argument
176183
177184 async def _arun (query : str ) -> dict :
178- coc_retriever : ChainOfCallsRetriever
185+ coc_retriever : ChainOfCallsRetrieverBase
179186 transitive_code_searcher : TransitiveCodeSearcher
180- transitive_code_searcher = get_transitive_code_searcher ()
187+ transitive_code_searcher = get_transitive_code_searcher (query )
181188 coc_retriever = transitive_code_searcher .chain_of_calls_retriever
182189 locator = FunctionNameLocator (coc_retriever )
183190 result = await locator .locate_functions (query )
184191 pkg_msg = "Package is valid."
185- if not locator .is_package_valid and not locator .is_std_package :
186- pkg_msg = "Package is not valid."
187-
188-
192+ if not locator .is_package_valid and not locator .is_std_package :
193+ pkg_msg = "Package is not valid."
194+
195+
189196 return {
190197 "ecosystem" : coc_retriever .ecosystem .name ,
191- "package_msg" : pkg_msg ,
198+ "package_msg" : pkg_msg ,
192199 "result" : result
193200 }
194201
@@ -199,11 +206,12 @@ async def _arun(query: str) -> dict:
199206 FIRST STEP in code analysis. Validates packages, locates functions via fuzzy matching, provides ecosystem type.
200207 </tool_identity>
201208 <requires>
202- Input format: 'package_name,function_name' or 'package_name,class_name.method_name'
209+ Input format: 'package_name,function_name' or 'package_name,class_name.method_name' or maven_gav(groupId:artifactId:version),fully_qualified_class_name.method_name
203210 <examples>
204211 <example>libxml2,xmlParseDocument</example>
205212 <example>requests,Session.get</example>
206213 <example>numpy,array.reshape</example>
214+ <example>commons-beanutils:commons-beanutils:1.0.0,org.apache.commons.beanutils.PropertyUtilsBean.setSimpleProperty</example>
207215 </examples>
208216 </requires>
209217 <guide_lines>
@@ -218,6 +226,9 @@ async def _arun(query: str) -> dict:
218226 <input>requests,Session.get</input>
219227 <output>["Session.get", "Session.post", "Session.put", "Session.delete"]</output>
220228 <conclusion>Found class method "Session.get"</conclusion>
229+ <input>commons-beanutils:commons-beanutils:1.0.0,org.apache.commons.beanutils.PropertyUtilsBean.setSimpleProperty</input>
230+ <output>["PropertyUtilsBean.setSimpleProperty"]</output>
231+ <conclusion>Found class method "PropertyUtilsBean.setSimpleProperty"</conclusion>
221232 </examples>
222233 </guide_lines>
223234 <output>
0 commit comments