Skip to content

Commit 8500b8d

Browse files
tmihalacTheodor Mihalache
authored andcommitted
Java transitive search
Signed-off-by: Theodor Mihalache <[email protected]>
1 parent 2943f18 commit 8500b8d

22 files changed

+7889
-441
lines changed

.tekton/on-pull-request.yaml

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,23 @@ spec:
178178
print_banner "INSTALLING DEPENDENCIES"
179179
uv sync
180180
181+
# Install Java
182+
JAVA_ARCH="x64"
183+
JDK_URL="https://github.com/adoptium/temurin22-binaries/releases/download/jdk-22.0.2%2B9/OpenJDK22U-jdk_${JAVA_ARCH}_linux_hotspot_22.0.2_9.tar.gz"
184+
JDK_DIR="jdk-22.0.2+9"
185+
186+
echo ">> Downloading $JDK_URL"
187+
mkdir -p /tekton/home/jdk
188+
curl -fsSL -o /tekton/home/jdk/jdk.tgz "$JDK_URL"
189+
tar -C /tekton/home/jdk -xzf /tekton/home/jdk/jdk.tgz
190+
rm -f /tekton/home/jdk/jdk.tgz
191+
192+
export JAVA_HOME="/tekton/home/jdk/${JDK_DIR}"
193+
export PATH="$JAVA_HOME/bin:$PATH"
194+
195+
echo "Java version:"
196+
java -version || true
197+
181198
# Install Go
182199
print_banner "Installing Go"
183200

@@ -190,6 +207,24 @@ spec:
190207
echo "Go version:"
191208
go version
192209

210+
# Install Maven
211+
print_banner "Installing Maven"
212+
213+
MAVEN_VERSION="3.9.11"
214+
ARCHIVE="apache-maven-${MAVEN_VERSION}-bin.tar.gz"
215+
URL="https://archive.apache.org/dist/maven/maven-3/${MAVEN_VERSION}/binaries/${ARCHIVE}"
216+
217+
curl -s -L -o "${ARCHIVE}" "${URL}"
218+
mkdir -p "$HOME/maven-sdk"
219+
tar -C "$HOME/maven-sdk" -xzf "${ARCHIVE}"
220+
221+
export MAVEN_HOME="$HOME/maven-sdk/apache-maven-${MAVEN_VERSION}"
222+
export M2_HOME="$MAVEN_HOME"
223+
export PATH="$MAVEN_HOME/bin:$PATH"
224+
225+
echo "Maven version:"
226+
mvn -v
227+
193228
print_banner "RUNNING LINTER"
194229
# Add the current directory to git's safe directories to avoid ownership errors.
195230
git config --global --add safe.directory /workspace/source

Dockerfile

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,27 @@ RUN curl -L -X GET https://go.dev/dl/go1.24.1.linux-amd64.tar.gz -o /tmp/go1.24.
4444
&& tar -C /usr/local -xzf /tmp/go1.24.1.linux-amd64.tar.gz \
4545
&& rm /tmp/go1.24.1.linux-amd64.tar.gz
4646

47+
# --- Temurin JDK 22 (amd64/x86_64) ---
48+
ARG JDK_URL="https://github.com/adoptium/temurin22-binaries/releases/download/jdk-22.0.2%2B9/OpenJDK22U-jdk_x64_linux_hotspot_22.0.2_9.tar.gz"
49+
ARG JDK_DIR="jdk-22.0.2+9"
50+
RUN mkdir -p /opt/jdk \
51+
&& curl -fsSL -o /tmp/jdk.tgz "${JDK_URL}" \
52+
&& tar -C /opt/jdk -xzf /tmp/jdk.tgz \
53+
&& rm -f /tmp/jdk.tgz
54+
ENV JAVA_HOME=/opt/jdk/${JDK_DIR}
55+
ENV PATH="${JAVA_HOME}/bin:${PATH}"
56+
57+
# --- Maven 3.9.11 (optional) ---
58+
ARG MVN_VER=3.9.11
59+
RUN curl -fsSL -o /tmp/maven.tgz \
60+
"https://archive.apache.org/dist/maven/maven-3/${MVN_VER}/binaries/apache-maven-${MVN_VER}-bin.tar.gz" \
61+
&& tar -C /opt -xzf /tmp/maven.tgz \
62+
&& rm -f /tmp/maven.tgz
63+
ENV PATH="/opt/apache-maven-${MVN_VER}/bin:${PATH}"
64+
65+
# Verify
66+
RUN java -version && mvn -v
67+
4768
# Set SSL environment variables
4869
ENV REQUESTS_CA_BUNDLE=/etc/ssl/certs/ca-certificates.crt
4970
ENV SSL_CERT_FILE=/etc/ssl/certs/ca-certificates.crt

src/vuln_analysis/tools/tests/test_transitive_code_search.py

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from vuln_analysis.tools.tests.mock_documents import (python_script_example, python_init_function_example,
1414
python_full_document_example, python_parse_function_example,
1515
python_mock_function_in_use, python_mock_file)
16+
from vuln_analysis.utils.dep_tree import Ecosystem
1617

1718
from vuln_analysis.utils.source_rpm_downloader import RPMDependencyManager
1819

@@ -99,8 +100,8 @@ async def test_transitive_search_golang_6():
99100
result = await transitive_code_search_runner_coroutine("net/http,ListenAndServe")
100101
(path_found, list_path) = result
101102
print(result)
102-
assert path_found is False
103-
assert len(list_path) is 0
103+
assert path_found is True
104+
assert len(list_path) is 2
104105

105106

106107
def set_input_for_next_run(git_repository: str, git_ref: str, included_extensions: list[str],
@@ -152,9 +153,7 @@ def mock_file_open(*args, **kwargs):
152153
mock_file = MagicMock()
153154
mock_file.__enter__ = MagicMock(return_value=mock_file)
154155
mock_file.__exit__ = MagicMock(return_value=None)
155-
if 'ecosystem_data.txt' in str(file_path):
156-
mock_file.read.return_value = "PYTHON"
157-
elif 'requirements.txt' in str(file_path):
156+
if 'requirements.txt' in str(file_path):
158157
mock_file.__iter__ = MagicMock(return_value=iter([
159158
"flask==2.0.1\n",
160159
"werkzeug==2.0.1\n",
@@ -190,8 +189,9 @@ def mock_file_open(*args, **kwargs):
190189
}
191190
])
192191
@patch('vuln_analysis.utils.dep_tree.run_command', return_value=python_dependency_tree_mock_output)
192+
@patch('vuln_analysis.utils.transitive_code_searcher_tool.TransitiveCodeSearcher.get_ecosystem', return_value=Ecosystem.PYTHON)
193193
@patch('builtins.open', side_effect=mock_file_open)
194-
async def test_transitive_search_python_parameterized(mock_open, mock_run_command,test_case):
194+
async def test_transitive_search_python_parameterized(mock_open, mock_run_command, mock_get_ecosystem, test_case):
195195
"""Parameterized test that runs all existing test cases with their respective configurations."""
196196
transitive_code_search_runner_coroutine = await get_transitive_code_runner_function()
197197
logging.basicConfig(level=logging.DEBUG)
@@ -299,4 +299,53 @@ async def test_c_transitive_search_2():
299299
print(f"DEBUG: list_path = {list_path}")
300300
print(f"DEBUG: len(list_path) = {len(list_path)}")
301301
assert len(list_path) == 1
302-
assert path_found == False
302+
assert path_found == False
303+
304+
@pytest.mark.asyncio
305+
async def test_transitive_search_java_1():
306+
transitive_code_search_runner_coroutine = await get_transitive_code_runner_function()
307+
set_input_for_next_run(git_repository="https://github.com/cryostatio/cryostat",
308+
git_ref="8f753753379e9381429b476aacbf6890ef101438",
309+
included_extensions=["**/*.java"],
310+
excluded_extensions=["target/**/*",
311+
"build/**/*",
312+
"*.class",
313+
".gradle/**/*",
314+
".mvn/**/*",
315+
".gitignore",
316+
"test/**/*",
317+
"tests/**/*",
318+
"src/test/**/*",
319+
"pom.xml",
320+
"build.gradle"])
321+
result = await transitive_code_search_runner_coroutine("commons-beanutils:commons-beanutils:1.9.4,org.apache.commons.beanutils.PropertyUtilsBean.getProperty")
322+
(path_found, list_path) = result
323+
print(result)
324+
assert path_found is False
325+
assert len(list_path) is 1
326+
327+
@pytest.mark.asyncio
328+
async def test_transitive_search_java_2():
329+
transitive_code_search_runner_coroutine = await get_transitive_code_runner_function()
330+
set_input_for_next_run(git_repository="https://github.com/cryostatio/cryostat",
331+
git_ref="8f753753379e9381429b476aacbf6890ef101438",
332+
included_extensions=["**/*.java"],
333+
excluded_extensions=["target/**/*",
334+
"build/**/*",
335+
"*.class",
336+
".gradle/**/*",
337+
".mvn/**/*",
338+
".gitignore",
339+
"test/**/*",
340+
"tests/**/*",
341+
"src/test/**/*",
342+
"pom.xml",
343+
"build.gradle"])
344+
result = await transitive_code_search_runner_coroutine("org.apache.commons:commons-lang3:3.14.0,org.apache.commons.lang3.StringUtils.isBlank")
345+
(path_found, list_path) = result
346+
print(result)
347+
assert path_found is True
348+
assert len(list_path) is 2
349+
document = list_path[1]
350+
assert 'src/main/java/io/cryostat' in document.metadata['source']
351+
assert 'StringUtils.isBlank(' in document.page_content

src/vuln_analysis/tools/transitive_code_search.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import os
1716

1817
from vuln_analysis.runtime_context import ctx_state
1918
from vuln_analysis.utils.transitive_code_searcher_tool import TransitiveCodeSearcher
@@ -26,13 +25,16 @@
2625
from langchain.docstore.document import Document
2726

2827
from vuln_analysis.data_models.state import AgentMorpheusEngineState
29-
from ..utils.chain_of_calls_retriever import ChainOfCallsRetriever
30-
from vuln_analysis.utils.dep_tree import Ecosystem
3128
from vuln_analysis.utils.document_embedding import DocumentEmbedding
29+
from ..data_models.input import SourceDocumentsInfo
30+
from ..utils.chain_of_calls_retriever_base import ChainOfCallsRetrieverBase
31+
from ..utils.chain_of_calls_retriever_factory import get_chain_of_calls_retriever
3232
from ..utils.function_name_extractor import FunctionNameExtractor
3333
from ..utils.function_name_locator import FunctionNameLocator
3434

3535
from vuln_analysis.logging.loggers_factory import LoggingFactory
36+
from ..utils.java_chain_of_calls_retriever import JavaChainOfCallsRetriever
37+
3638
logger = LoggingFactory.get_agent_logger(__name__)
3739

3840

@@ -53,29 +55,34 @@ class PackageAndFunctionLocatorToolConfig(FunctionBaseConfig, name="package_and_
5355
Package and function locator tool used to validate package names and find function names using fuzzy matching.
5456
"""
5557

56-
57-
def get_call_of_chains_retriever(documents_embedder, si):
58+
def get_call_of_chains_retriever(documents_embedder, si, query: str):
5859
documents: list[Document]
5960
git_repo = None
61+
code_source_info: SourceDocumentsInfo
6062
for source_info in si:
6163
if source_info.type == "code":
64+
code_source_info = source_info
6265
git_repo = documents_embedder.get_repo_path(source_info)
6366
documents = documents_embedder.collect_documents(source_info)
6467
if git_repo is None:
6568
raise ValueError("No code source info found")
66-
with open(os.path.join(git_repo, 'ecosystem_data.txt'), 'r', encoding='utf-8') as file:
67-
ecosystem = file.read()
68-
ecosystem = Ecosystem[ecosystem]
69-
coc_retriever = ChainOfCallsRetriever(documents=documents, ecosystem=ecosystem, manifest_path=git_repo)
69+
ecosystem = TransitiveCodeSearcher.get_ecosystem(git_repo)
70+
coc_retriever = get_chain_of_calls_retriever(ecosystem,
71+
documents,
72+
git_repo,
73+
query,
74+
code_source_info)
7075
return coc_retriever
7176

72-
73-
def get_transitive_code_searcher():
77+
def get_transitive_code_searcher(query: str):
7478
state: AgentMorpheusEngineState = ctx_state.get()
75-
if state.transitive_code_searcher is None:
79+
if state.transitive_code_searcher is None or isinstance(state.transitive_code_searcher.chain_of_calls_retriever, JavaChainOfCallsRetriever):
7680
si = state.original_input.input.image.source_info
7781
documents_embedder = DocumentEmbedding(embedding=None)
78-
coc_retriever = get_call_of_chains_retriever(documents_embedder, si)
82+
if state.transitive_code_searcher is not None and isinstance(state.transitive_code_searcher.chain_of_calls_retriever, JavaChainOfCallsRetriever):
83+
coc_retriever = get_call_of_chains_retriever(documents_embedder, si, query)
84+
else:
85+
coc_retriever = get_call_of_chains_retriever(documents_embedder, si, query)
7986
transitive_code_searcher = TransitiveCodeSearcher(chain_of_calls_retriever=coc_retriever)
8087
state.transitive_code_searcher = transitive_code_searcher
8188
return state.transitive_code_searcher
@@ -87,7 +94,7 @@ async def transitive_search(config: TransitiveCodeSearchToolConfig,
8794

8895
async def _arun(query: str) -> tuple:
8996
transitive_code_searcher: TransitiveCodeSearcher
90-
transitive_code_searcher = get_transitive_code_searcher()
97+
transitive_code_searcher = get_transitive_code_searcher(query)
9198
result = transitive_code_searcher.search(query)
9299
return result
93100

@@ -127,9 +134,9 @@ async def functions_usage_search(config: CallingFunctionNameExtractorToolConfig,
127134
builder: Builder): # pylint: disable=unused-argument
128135

129136
async def _arun(query: str) -> list:
130-
coc_retriever: ChainOfCallsRetriever
137+
coc_retriever: ChainOfCallsRetrieverBase
131138
transitive_code_searcher: TransitiveCodeSearcher
132-
transitive_code_searcher = get_transitive_code_searcher()
139+
transitive_code_searcher = get_transitive_code_searcher(query)
133140
coc_retriever = transitive_code_searcher.chain_of_calls_retriever
134141
function_name_extractor = FunctionNameExtractor(coc_retriever)
135142
result = function_name_extractor.fetch_list(query)
@@ -175,20 +182,20 @@ async def package_and_function_locator(config: PackageAndFunctionLocatorToolConf
175182
builder: Builder): # pylint: disable=unused-argument
176183

177184
async def _arun(query: str) -> dict:
178-
coc_retriever: ChainOfCallsRetriever
185+
coc_retriever: ChainOfCallsRetrieverBase
179186
transitive_code_searcher: TransitiveCodeSearcher
180-
transitive_code_searcher = get_transitive_code_searcher()
187+
transitive_code_searcher = get_transitive_code_searcher(query)
181188
coc_retriever = transitive_code_searcher.chain_of_calls_retriever
182189
locator = FunctionNameLocator(coc_retriever)
183190
result = await locator.locate_functions(query)
184191
pkg_msg = "Package is valid."
185-
if not locator.is_package_valid and not locator.is_std_package:
186-
pkg_msg = "Package is not valid."
187-
188-
192+
if not locator.is_package_valid and not locator.is_std_package:
193+
pkg_msg = "Package is not valid."
194+
195+
189196
return {
190197
"ecosystem": coc_retriever.ecosystem.name,
191-
"package_msg": pkg_msg,
198+
"package_msg": pkg_msg,
192199
"result": result
193200
}
194201

@@ -199,11 +206,12 @@ async def _arun(query: str) -> dict:
199206
FIRST STEP in code analysis. Validates packages, locates functions via fuzzy matching, provides ecosystem type.
200207
</tool_identity>
201208
<requires>
202-
Input format: 'package_name,function_name' or 'package_name,class_name.method_name'
209+
Input format: 'package_name,function_name' or 'package_name,class_name.method_name' or maven_gav(groupId:artifactId:version),fully_qualified_class_name.method_name
203210
<examples>
204211
<example>libxml2,xmlParseDocument</example>
205212
<example>requests,Session.get</example>
206213
<example>numpy,array.reshape</example>
214+
<example>commons-beanutils:commons-beanutils:1.0.0,org.apache.commons.beanutils.PropertyUtilsBean.setSimpleProperty</example>
207215
</examples>
208216
</requires>
209217
<guide_lines>
@@ -218,6 +226,9 @@ async def _arun(query: str) -> dict:
218226
<input>requests,Session.get</input>
219227
<output>["Session.get", "Session.post", "Session.put", "Session.delete"]</output>
220228
<conclusion>Found class method "Session.get"</conclusion>
229+
<input>commons-beanutils:commons-beanutils:1.0.0,org.apache.commons.beanutils.PropertyUtilsBean.setSimpleProperty</input>
230+
<output>["PropertyUtilsBean.setSimpleProperty"]</output>
231+
<conclusion>Found class method "PropertyUtilsBean.setSimpleProperty"</conclusion>
221232
</examples>
222233
</guide_lines>
223234
<output>

0 commit comments

Comments
 (0)