From 5a756be1238787e1a9bd625b3edfea99ccc745c0 Mon Sep 17 00:00:00 2001 From: Vladimir Belousov Date: Mon, 27 Oct 2025 14:46:44 +0200 Subject: [PATCH 1/3] refactor(go-segmenter): replace custom GoSegmenter with native Tree-Sitter implementation Signed-off-by: Vladimir Belousov --- .../tools/tests/test_go_segmenter.py | 72 ++++++++ .../tests/test_transitive_code_search.py | 4 +- .../utils/chain_of_calls_retriever.py | 7 +- src/vuln_analysis/utils/document_embedding.py | 4 +- .../golang_functions_parsers.py | 163 ++++++++++++------ .../utils/go_segmenter_extended.py | 31 ++++ 6 files changed, 218 insertions(+), 63 deletions(-) create mode 100644 src/vuln_analysis/tools/tests/test_go_segmenter.py create mode 100644 src/vuln_analysis/utils/go_segmenter_extended.py diff --git a/src/vuln_analysis/tools/tests/test_go_segmenter.py b/src/vuln_analysis/tools/tests/test_go_segmenter.py new file mode 100644 index 00000000..49656d2d --- /dev/null +++ b/src/vuln_analysis/tools/tests/test_go_segmenter.py @@ -0,0 +1,72 @@ +from vuln_analysis.utils.go_segmenter_extended import GoSegmenterExtended + + +def _extract(code: str): + seg = GoSegmenterExtended(code) + return [s.strip() for s in seg.extract_functions_classes()] + + +def test_generic_method_basic(): + code = """ + type Box[T any] struct { value T } + func (b *Box[T]) Set(v T) { b.value = v } + """ + chunks = _extract(code) + assert any("Set" in c for c in chunks), "generic method not extracted" + + +def test_generic_multiple_type_params(): + code = """ + func MapKeys[K comparable, V any](m map[K]V) []K { + keys := make([]K, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + return keys + } + """ + chunks = _extract(code) + assert any("MapKeys" in c for c in chunks), "multiple generics not parsed" + + +def test_function_returning_func(): + code = """ + func makeAdder(x int) func(int) int { + return func(y int) int { return x + y } + } + """ + chunks = _extract(code) + assert any("makeAdder" in c for c in chunks), "failed to parse func returning func" + + +def test_inline_anonymous_func(): + code = """ + func Worker() { + defer func() { cleanup() }() + go func() { runTask() }() + } + """ + chunks = _extract(code) + assert any("Worker" in c for c in chunks), "missed inline anonymous func" + + +def test_double_pointer_receiver(): + code = """ + type Conn struct{} + func (c **Conn) Reset() {} + """ + chunks = _extract(code) + assert any("Reset" in c for c in chunks), "failed to detect pointer receiver" + + +def test_multiline_generic_method(): + code = """ + func (r *Repo[ + T any, + E error, + ]) Save(v T) (E, error) { + return nil, nil + } + """ + chunks = _extract(code) + assert any("Save" in c for c in chunks), "multiline generic method not parsed" diff --git a/src/vuln_analysis/tools/tests/test_transitive_code_search.py b/src/vuln_analysis/tools/tests/test_transitive_code_search.py index 5622ed20..55bd19cd 100644 --- a/src/vuln_analysis/tools/tests/test_transitive_code_search.py +++ b/src/vuln_analysis/tools/tests/test_transitive_code_search.py @@ -86,7 +86,7 @@ async def test_transitive_search_golang_5(): (path_found, list_path) = result assert path_found is False - assert len(list_path) is 1 + assert len(list_path) == 1 # Test fix of https://issues.redhat.com/browse/APPENG-3435 @pytest.mark.asyncio @@ -100,7 +100,7 @@ async def test_transitive_search_golang_6(): (path_found, list_path) = result print(result) assert path_found is True - assert len(list_path) is 2 + assert len(list_path) == 2 def set_input_for_next_run(git_repository: str, git_ref: str, included_extensions: list[str], diff --git a/src/vuln_analysis/utils/chain_of_calls_retriever.py b/src/vuln_analysis/utils/chain_of_calls_retriever.py index a9957957..f700add1 100644 --- a/src/vuln_analysis/utils/chain_of_calls_retriever.py +++ b/src/vuln_analysis/utils/chain_of_calls_retriever.py @@ -191,14 +191,15 @@ def __init__(self, documents: List[Document], ecosystem: Ecosystem, manifest_pat logger.debug(f"no_macro_documents len : {len(no_macro_documents)}") if not self.language_parser.is_script_language(): - # filter out types and full code documents, retaining only functions/methods documents in this attribute. + # filter out types and full code documents, retaining only functions/methods documents in this attribute. self.documents = [doc for doc in no_macro_documents if self.language_parser.is_function(doc)] self.documents_of_functions = self.documents else: self.documents = filtered_documents self.documents_of_functions = [doc for doc in self.documents if doc.page_content.startswith(self.language_parser.get_function_reserved_word())] - + # sort documents to ensure deterministic behavior + self.documents.sort(key=lambda doc: doc.metadata.get('source', '')) logger.debug(f"self.documents len : {len(self.documents)}") logger.debug("Chain of Calls Retriever - retaining only types/classes docs " "documents_of_types len %d", len(self.documents_of_types)) @@ -249,6 +250,8 @@ def __find_caller_function_dfs(self, document_function: Document, function_packa # Add same package itself to search path. # direct_parents.extend([function_package]) # gets list of documents to search in only from parents of function' package. + # Fixes non-deterministic behavior in chain-of-calls resolution where identical inputs produced different call-chain lengths + direct_parents.sort() function_name_to_search = self.language_parser.get_function_name(document_function) if function_name_to_search == self.language_parser.get_constructor_method_name(): function_name_to_search = self.language_parser.get_class_name_from_class_function(document_function) diff --git a/src/vuln_analysis/utils/document_embedding.py b/src/vuln_analysis/utils/document_embedding.py index e01a600a..dae2e639 100644 --- a/src/vuln_analysis/utils/document_embedding.py +++ b/src/vuln_analysis/utils/document_embedding.py @@ -37,7 +37,7 @@ from langchain_core.document_loaders.blob_loaders import Blob from vuln_analysis.data_models.input import SourceDocumentsInfo -from vuln_analysis.utils.go_segmenters_with_methods import GoSegmenterWithMethods +from vuln_analysis.utils.go_segmenter_extended import GoSegmenterExtended from vuln_analysis.utils.python_segmenters_with_classes_methods import PythonSegmenterWithClassesMethods from vuln_analysis.utils.js_extended_parser import ExtendedJavaScriptSegmenter from vuln_analysis.utils.source_code_git_loader import SourceCodeGitLoader @@ -144,7 +144,7 @@ class ExtendedLanguageParser(LanguageParser): "javascript": ExtendedJavaScriptSegmenter, "js": ExtendedJavaScriptSegmenter, } - additional_segmenters["go"] = GoSegmenterWithMethods + additional_segmenters["go"] = GoSegmenterExtended additional_segmenters["python"] = PythonSegmenterWithClassesMethods additional_segmenters["c"] = CSegmenterExtended diff --git a/src/vuln_analysis/utils/functions_parsers/golang_functions_parsers.py b/src/vuln_analysis/utils/functions_parsers/golang_functions_parsers.py index 94a677b0..f3981fdc 100644 --- a/src/vuln_analysis/utils/functions_parsers/golang_functions_parsers.py +++ b/src/vuln_analysis/utils/functions_parsers/golang_functions_parsers.py @@ -1,11 +1,10 @@ +import hashlib import os import re from langchain_core.documents import Document from .lang_functions_parsers import LanguageFunctionsParser -from ..dep_tree import Ecosystem -from ..standard_library_cache import StandardLibraryCache EMBEDDED_TYPE = "embedded_type" @@ -45,7 +44,7 @@ def check_types_from_callee_package(params: list[tuple], type_documents: list[Do # Only type without package if len(parts) == 1: for the_type in type_documents: - if the_type.page_content.startwith(f"type {parts[0]}"): + if the_type.page_content.startswith(f"type {parts[0]}"): code_with_type_file = code_documents.get(the_type.metadata['source']) type_file_package_name = get_package_name_file(code_with_type_file) if type_file_package_name == callee_function_file_package_name: @@ -53,7 +52,7 @@ def check_types_from_callee_package(params: list[tuple], type_documents: list[Do # type with package qualifier else: for the_type in type_documents: - if the_type.page_content.startwith(f"type {parts[1]}"): + if the_type.page_content.startswith(f"type {parts[1]}"): code_with_type_file = code_documents.get(the_type.metadata['source']) package_match = handle_imports(code_with_type_file, parts[0], callee_package) return package_match @@ -171,7 +170,9 @@ def __prepare_package_lookup(self, parts, variables_mappings, the_part: int): if var_properties is not None: resolved_type = var_properties.get("type") value = var_properties.get("value") - struct_initializer_expression = re.search(r"(&|\\*)?\w+\s*{", value) + struct_initializer_expression = None + if value: + struct_initializer_expression = re.search(r"(&|\\*)?\w+\s*{", value) resolved_type = str(resolved_type).replace("&", "").replace("*", "") return resolved_type, struct_initializer_expression, value, var_properties else: @@ -202,10 +203,15 @@ def __get_type_docs_matched_with_callee_package(self, callee_package, checked_ty (self.get_type_name(a_type) == checked_type or self.get_type_name(a_type) in checked_type)] def create_map_of_local_vars(self, functions_methods_documents: list[Document]) -> dict[str, dict]: + """ + Builds a mapping of function identifiers to their local variables and parameters. + Includes support for anonymous functions with deterministic names. + """ mappings = dict() for func_method in functions_methods_documents: - func_key = f"{self.get_function_name(func_method)}@{func_method.metadata['source']}" all_vars = dict() + func_name = self.get_function_name(func_method) + func_key = f"{func_name}@{func_method.metadata['source']}" for row in func_method.page_content.splitlines(): if not self.is_comment_line(row): # Extract arguments and receiver argument of type as parameters @@ -262,7 +268,7 @@ def create_map_of_local_vars(self, functions_methods_documents: list[Document]) right_side = func_method.page_content[index_of_start_if + 2: index_of_start_if + end_of_assignment - 1].strip() - all_vars[(left_side.strip())] = {" value": right_side.strip().replace + all_vars[(left_side.strip())] = {"value": right_side.strip().replace ("\n\t", "").replace("\t", ""), "type": LOCAL_IMPLICIT } @@ -284,7 +290,8 @@ def create_map_of_local_vars(self, functions_methods_documents: list[Document]) else: pass - mappings[func_key] = all_vars + if func_name != "": + mappings[func_key] = all_vars return mappings @@ -327,20 +334,26 @@ def parse_all_type_struct_class_to_fields(self, types: list[Document]) -> dict[t next_struct = current_line_stripped.find("struct") if next_eol > - 1 and (next_struct == -1 or next_struct > next_eol): if not self.is_comment_line(current_line_stripped[:next_eol + 1]): - declaration_parts = current_line_stripped[:next_eol + 1].split() + # remove inline comments + line = current_line_stripped[:next_eol + 1] + line = line.split("//", 1)[0].split("/*", 1)[0].strip() # If row inside block contains func, then it's a function type and need to parse it in a # special way. - if current_line_stripped[:next_eol + 1].__contains__("func"): - declaration_parts = current_line_stripped[:next_eol + 1].split("func") + if line.__contains__("func"): + declaration_parts = line.split("func") declaration_parts = [part.strip() for part in declaration_parts] if len(declaration_parts) == 2: declaration_parts[1] = f"func {declaration_parts[1]}" + else: + # For regular types, split by whitespace + declaration_parts = line.split() # ignore alias' "equals" notation - if len(declaration_parts) == 3: - [name, _, type_name] = declaration_parts - elif len(declaration_parts) == 2: - [name, type_name] = declaration_parts - if len(declaration_parts) == (2 or 3): + if len(declaration_parts) in [2, 3]: + if len(declaration_parts) == 3: + [name, _, type_name] = declaration_parts + elif len(declaration_parts) == 2: + [name, type_name] = declaration_parts + self.parse_one_type(Document(page_content=f"type {name} {type_name}", metadata={"source": the_type.metadata['source']}), types_mapping) @@ -428,38 +441,71 @@ def dir_name_for_3rd_party_packages(self) -> str: def is_exported_function(self, function: Document) -> bool: function_name = self.get_function_name(function) - return re.search("[A-Z][a-z0-9-]*", function_name) + return bool(re.search("[A-Z][a-z0-9-]*", function_name)) def get_function_name(self, function: Document) -> str: - try: - index_of_function_opening = function.page_content.index("{") - except ValueError as e: - function_line = function.page_content.find(os.linesep) - # print(f"function {function.page_content[:function_line]} => contains no body ") - return function.page_content[:function_line] - - function_header = function.page_content[:index_of_function_opening] - # function is a method of a type + """ + Extracts the function name from the Go function definition. + If the function is anonymous or the name cannot be determined, + returns a deterministic fallback name based on content hash. + """ + if not function or not getattr(function, "page_content", None): + return "" + content = function.page_content + index_of_function_opening = content.find("{") + # function without body is valid according to the Go specification + # https://go.dev/ref/spec#Function_declarations + if index_of_function_opening == -1: + # print("Function without body") + function_header = content.splitlines()[0] + else: + # print("Function WITH body") + function_header = content[:index_of_function_opening] + + # method with receiver if function_header.startswith("func ("): - index_of_first_right_bracket = function_header.index(")") - skip_receiver_arg = function_header[index_of_first_right_bracket + 1:] - index_of_first_left_bracket = skip_receiver_arg.index("(") + index_of_first_right_bracket = function_header.find(")") + if index_of_first_right_bracket == -1: + return "" + + skip_receiver_arg = function_header[index_of_first_right_bracket + 1:].strip() + index_of_first_left_bracket = skip_receiver_arg.find("(") + if index_of_first_left_bracket == -1: + parts = skip_receiver_arg.split() + return parts[0] if parts else "" return skip_receiver_arg[:index_of_first_left_bracket].strip() - # regular function not tied to a certain type + # regular or generic function else: + if "(" in function_header: + index_of_first_left_bracket = function_header.find("(") + else: + index_of_first_left_bracket = function_header.find("[") + + if index_of_first_left_bracket == -1: + name = "" + else: + func_with_name = function_header[:index_of_first_left_bracket].strip() + parts = func_with_name.split() + if len(parts) > 1: + name = parts[1] + elif len(parts) == 1 and parts[0] != "func": + name = parts[0] + else: + name = "" + + # Fallback for anonymous or malformed functions + if not name or name in ("", "unknown", "func"): try: - index_of_first_left_bracket = function_header.index("(") - # Go Generic function - except ValueError: - try: - index_of_first_left_bracket = function_header.index("[") - except ValueError: - raise ValueError(f"Invalid function header - {function_header}") - func_with_name = function_header[:index_of_first_left_bracket] - if len(func_with_name.split(" ")) > 1: - return func_with_name.split(" ")[1] - # TODO Try to extract anonymous function var - # else: + content_bytes = function.page_content.encode("utf-8") + short_hash = hashlib.sha256(content_bytes).hexdigest()[:8] + src = function.metadata.get("source", "unknown") + prefix = src.split("/")[-1].split(".")[0] + fallback_name = f"anon_{prefix}_{short_hash}" + return fallback_name + except Exception: + return "" + + return name def search_for_called_function(self, caller_function: Document, callee_function_name: str, callee_function: Document, @@ -508,7 +554,7 @@ def __check_identifier_resolved_to_callee_function_package(self, function: Docum try: callee_function = code_documents[callee_function_file_name] callee_function_package = get_package_name_file(callee_function).strip() - except KeyError as e: + except KeyError: # Standard library function , there is no function code, thus the source name is the package name in # this case callee_function_package = callee_function_file_name @@ -579,31 +625,34 @@ def __check_identifier_resolved_to_callee_function_package(self, function: Docum return False def get_package_names(self, function: Document) -> list[str]: - package_names = list() - full_doc_path = str(function.metadata['source']) - parts = full_doc_path.split("/") + package_names = [] + full_doc_path = str(function.metadata.get("source") or "").strip() + if not full_doc_path: + return [""] + + parts = [p for p in full_doc_path.split("/") if p] version = "" + if len(parts) > 4: match = re.search(r"[vV][0-9]{1,2}", parts[4]) - if match and match.group(0): + if match: version = f"/{match.group(0)}" - if parts[0].startswith(self.dir_name_for_3rd_party_packages()) and len(parts) > 3: + if parts and parts[0].startswith(self.dir_name_for_3rd_party_packages()) and len(parts) > 3: package_names.append(f"{parts[1]}/{parts[2]}{version}") package_names.append(f"{parts[1]}/{parts[2]}/{parts[3]}{version}") - else: - try: - package_names.append(f"{parts[0]}/{parts[1]}{version}") + elif len(parts) >= 2: + package_names.append(f"{parts[0]}/{parts[1]}{version}") + if len(parts) >= 3: package_names.append(f"{parts[0]}/{parts[1]}/{parts[2]}{version}") - # Standard library package - except IndexError as index_excp: - if len(parts) > 1: - package_names.append(f"{parts[0]}/{parts[1]}{version}") - else: - package_names.append(f"{parts[0]}{version}") + elif len(parts) == 1: + package_names.append(f"{parts[0]}{version}") + else: + package_names.append("") return package_names + def is_root_package(self, function: Document) -> bool: return not function.metadata['source'].startswith(self.dir_name_for_3rd_party_packages()) diff --git a/src/vuln_analysis/utils/go_segmenter_extended.py b/src/vuln_analysis/utils/go_segmenter_extended.py new file mode 100644 index 00000000..2a274ec0 --- /dev/null +++ b/src/vuln_analysis/utils/go_segmenter_extended.py @@ -0,0 +1,31 @@ +from langchain_community.document_loaders.parsers.language.go import GoSegmenter + +CHUNK_QUERY_EXT = """ +[ + (function_declaration) @function + (method_declaration) @method + (func_literal) @anon + (type_declaration) @type +] +""".strip() + + +class GoSegmenterExtended(GoSegmenter): + def get_chunk_query(self) -> str: + return CHUNK_QUERY_EXT + + def extract_functions_classes(self): + """Extract all functions, methods, anonymous functions, and types without filtering overlaps.""" + language = self.get_language() + query = language.query(self.get_chunk_query()) + + parser = self.get_parser() + tree = parser.parse(bytes(self.code, encoding="UTF-8")) + captures = query.captures(tree.root_node) + + chunks = [] + for node, _ in captures: + chunk_text = node.text.decode("UTF-8") + chunks.append(chunk_text) + + return chunks From f1f582dbb907d8242e70727aee50cd9a01757cc3 Mon Sep 17 00:00:00 2001 From: Vladimir Belousov Date: Mon, 3 Nov 2025 11:42:25 +0200 Subject: [PATCH 2/3] refactor(go-parser): decompose get_function_name into Strategy pattern Signed-off-by: Vladimir Belousov --- .../golang_functions_parsers.py | 127 +++++++++--------- .../utils/go_segmenter_extended.py | 29 ++-- 2 files changed, 84 insertions(+), 72 deletions(-) diff --git a/src/vuln_analysis/utils/functions_parsers/golang_functions_parsers.py b/src/vuln_analysis/utils/functions_parsers/golang_functions_parsers.py index f3981fdc..b17db9d2 100644 --- a/src/vuln_analysis/utils/functions_parsers/golang_functions_parsers.py +++ b/src/vuln_analysis/utils/functions_parsers/golang_functions_parsers.py @@ -1,6 +1,7 @@ import hashlib import os import re +from typing import Callable, final from langchain_core.documents import Document @@ -114,11 +115,56 @@ def handle_imports(code_content: str, identifier: str, callee_package: str) -> b return False +@final +class GoConstants: + UNKNOWN_FUNCTION = "" + ANON_FUNCTION_PREFIX = "anon" + FUNC_KEYWORD = "func" + class GoLanguageFunctionsParser(LanguageFunctionsParser): def get_dummy_function(self, function_name): return f"{self.get_function_reserved_word()} {function_name}() {{}}" + def _generate_fallback_name(self, document: Document) -> str: + """ + Generates a deterministic name for anonymous or unparsable functions. + """ + try: + source = document.metadata.get("source", "unknown_source") + prefix = source.split("/")[-1].split(".")[0] + content_bytes = document.page_content.encode("utf-8") + short_hash = hashlib.sha256(content_bytes).hexdigest()[:8] + return f"{GoConstants.ANON_FUNCTION_PREFIX}_{prefix}_{short_hash}" + except Exception: + return GoConstants.UNKNOWN_FUNCTION + def _try_parse_method(self, header: str) -> str | None: + """ + Tries to parse a function name assuming it's a method with a receiver. + Returns the name or None if it doesn't match the pattern. + """ + if not header.startswith("func ("): + return None + receiver_end_idx = header.find(")") + if receiver_end_idx == -1: + return None + after_receiver = header[receiver_end_idx + 1 :].lstrip() + + if "(" in after_receiver and ")" not in after_receiver: + return None + + name_match = re.match(r"([a-zA-Z0-9_]+)", after_receiver) + + return name_match.group(1) if name_match else None + def _try_parse_regular_function(self, header: str) -> str | None: + """ + Tries to parse a regular or generic function name. + Returns the name or None if it's an anonymous function or doesn't match. + """ + match = re.search(r"^func\s+([a-zA-Z0-9_]+)\s*[\[\(](?=.*[\]\)])", header) + if match: + return match.group(1) + return None def __trace_down_package(self, expression: str, code_documents: dict[str, Document], type_documents: list[Document], callee_package: str, fields_of_types: dict[tuple, list[tuple]], functions_local_variables_index: dict[str, dict], @@ -445,67 +491,28 @@ def is_exported_function(self, function: Document) -> bool: def get_function_name(self, function: Document) -> str: """ - Extracts the function name from the Go function definition. - If the function is anonymous or the name cannot be determined, - returns a deterministic fallback name based on content hash. + Extracts the function name from its Go definition. """ - if not function or not getattr(function, "page_content", None): - return "" - content = function.page_content - index_of_function_opening = content.find("{") - # function without body is valid according to the Go specification - # https://go.dev/ref/spec#Function_declarations - if index_of_function_opening == -1: - # print("Function without body") - function_header = content.splitlines()[0] - else: - # print("Function WITH body") - function_header = content[:index_of_function_opening] - - # method with receiver - if function_header.startswith("func ("): - index_of_first_right_bracket = function_header.find(")") - if index_of_first_right_bracket == -1: - return "" - - skip_receiver_arg = function_header[index_of_first_right_bracket + 1:].strip() - index_of_first_left_bracket = skip_receiver_arg.find("(") - if index_of_first_left_bracket == -1: - parts = skip_receiver_arg.split() - return parts[0] if parts else "" - return skip_receiver_arg[:index_of_first_left_bracket].strip() - # regular or generic function - else: - if "(" in function_header: - index_of_first_left_bracket = function_header.find("(") - else: - index_of_first_left_bracket = function_header.find("[") - - if index_of_first_left_bracket == -1: - name = "" - else: - func_with_name = function_header[:index_of_first_left_bracket].strip() - parts = func_with_name.split() - if len(parts) > 1: - name = parts[1] - elif len(parts) == 1 and parts[0] != "func": - name = parts[0] - else: - name = "" - - # Fallback for anonymous or malformed functions - if not name or name in ("", "unknown", "func"): - try: - content_bytes = function.page_content.encode("utf-8") - short_hash = hashlib.sha256(content_bytes).hexdigest()[:8] - src = function.metadata.get("source", "unknown") - prefix = src.split("/")[-1].split(".")[0] - fallback_name = f"anon_{prefix}_{short_hash}" - return fallback_name - except Exception: - return "" - - return name + if function is None or getattr(function, "page_content", None) is None: + return GoConstants.UNKNOWN_FUNCTION + + content = function.page_content.strip() + if not content: + return self._generate_fallback_name(function) + + body_start_idx = content.find("{") + header = content if body_start_idx == -1 else content[:body_start_idx] + + parsing_strategies: list[Callable[[str], str | None]] = [ + self._try_parse_method, + self._try_parse_regular_function, + ] + + for strategy in parsing_strategies: + name = strategy(header) + if name: + return name + return self._generate_fallback_name(function) def search_for_called_function(self, caller_function: Document, callee_function_name: str, callee_function: Document, diff --git a/src/vuln_analysis/utils/go_segmenter_extended.py b/src/vuln_analysis/utils/go_segmenter_extended.py index 2a274ec0..e449ff31 100644 --- a/src/vuln_analysis/utils/go_segmenter_extended.py +++ b/src/vuln_analysis/utils/go_segmenter_extended.py @@ -1,31 +1,36 @@ from langchain_community.document_loaders.parsers.language.go import GoSegmenter CHUNK_QUERY_EXT = """ -[ - (function_declaration) @function - (method_declaration) @method - (func_literal) @anon - (type_declaration) @type -] +(source_file + [ + (function_declaration) @function + (method_declaration) @method + (type_declaration) @type + ] +) """.strip() class GoSegmenterExtended(GoSegmenter): def get_chunk_query(self) -> str: return CHUNK_QUERY_EXT - + def extract_functions_classes(self): - """Extract all functions, methods, anonymous functions, and types without filtering overlaps.""" + """ + Extracts all TOP-LEVEL functions, methods, and types. + Nested anonymous functions are kept inside their parent functions. + """ language = self.get_language() query = language.query(self.get_chunk_query()) - + parser = self.get_parser() - tree = parser.parse(bytes(self.code, encoding="UTF-8")) + tree = parser.parse(bytes(self.code, "UTF-8")) + captures = query.captures(tree.root_node) - + chunks = [] for node, _ in captures: chunk_text = node.text.decode("UTF-8") chunks.append(chunk_text) - + return chunks From 9261a8c2eef8962943ac241a831105fb6f0e3c05 Mon Sep 17 00:00:00 2001 From: Vladimir Belousov Date: Mon, 3 Nov 2025 11:48:11 +0200 Subject: [PATCH 3/3] refactor(go-parser): add unit tests for Go function naming Signed-off-by: Vladimir Belousov --- src/vuln_analysis/tools/tests/conftest.py | 14 ++ .../tools/tests/test_go_segmenter.py | 80 ++++++----- .../tests/test_golang_functions_parsers.py | 132 ++++++++++++++++++ 3 files changed, 189 insertions(+), 37 deletions(-) create mode 100644 src/vuln_analysis/tools/tests/conftest.py create mode 100644 src/vuln_analysis/tools/tests/test_golang_functions_parsers.py diff --git a/src/vuln_analysis/tools/tests/conftest.py b/src/vuln_analysis/tools/tests/conftest.py new file mode 100644 index 00000000..648bf7bc --- /dev/null +++ b/src/vuln_analysis/tools/tests/conftest.py @@ -0,0 +1,14 @@ +import pytest + +from vuln_analysis.utils.functions_parsers.golang_functions_parsers import ( + GoLanguageFunctionsParser, +) + + +@pytest.fixture(scope="module") +def go_parser() -> GoLanguageFunctionsParser: + """ + Provides a single instance of the GoLanguageFunctionsParser + for all tests in a module. + """ + return GoLanguageFunctionsParser() diff --git a/src/vuln_analysis/tools/tests/test_go_segmenter.py b/src/vuln_analysis/tools/tests/test_go_segmenter.py index 49656d2d..1f86cedf 100644 --- a/src/vuln_analysis/tools/tests/test_go_segmenter.py +++ b/src/vuln_analysis/tools/tests/test_go_segmenter.py @@ -1,65 +1,60 @@ +import textwrap + from vuln_analysis.utils.go_segmenter_extended import GoSegmenterExtended def _extract(code: str): - seg = GoSegmenterExtended(code) - return [s.strip() for s in seg.extract_functions_classes()] + seg = GoSegmenterExtended(textwrap.dedent(code)) + return seg.extract_functions_classes() -def test_generic_method_basic(): +def test_segmenter_extracts_type_and_generic_method(): code = """ type Box[T any] struct { value T } func (b *Box[T]) Set(v T) { b.value = v } """ - chunks = _extract(code) - assert any("Set" in c for c in chunks), "generic method not extracted" + expected_chunks = [ + "type Box[T any] struct { value T }", + "func (b *Box[T]) Set(v T) { b.value = v }", + ] - -def test_generic_multiple_type_params(): - code = """ - func MapKeys[K comparable, V any](m map[K]V) []K { - keys := make([]K, 0, len(m)) - for k := range m { - keys = append(keys, k) - } - return keys - } - """ - chunks = _extract(code) - assert any("MapKeys" in c for c in chunks), "multiple generics not parsed" + actual_chunks = [c.strip() for c in _extract(code)] + assert actual_chunks == expected_chunks -def test_function_returning_func(): +def test_segmenter_extracts_toplevel_function_only_and_ignores_nested(): code = """ func makeAdder(x int) func(int) int { return func(y int) int { return x + y } } """ - chunks = _extract(code) - assert any("makeAdder" in c for c in chunks), "failed to parse func returning func" - + expected_chunks = [ + textwrap.dedent(""" + func makeAdder(x int) func(int) int { + return func(y int) int { return x + y } + } + """).strip() + ] -def test_inline_anonymous_func(): - code = """ - func Worker() { - defer func() { cleanup() }() - go func() { runTask() }() - } - """ - chunks = _extract(code) - assert any("Worker" in c for c in chunks), "missed inline anonymous func" + actual_chunks = [c.strip() for c in _extract(code)] + assert actual_chunks == expected_chunks -def test_double_pointer_receiver(): +def test_segmenter_handles_double_pointer_receiver(): code = """ type Conn struct{} func (c **Conn) Reset() {} """ - chunks = _extract(code) - assert any("Reset" in c for c in chunks), "failed to detect pointer receiver" + expected_chunks = [ + "type Conn struct{}", + "func (c **Conn) Reset() {}", + ] + actual_chunks = [c.strip() for c in _extract(code)] + assert actual_chunks == expected_chunks -def test_multiline_generic_method(): + +def test_segmenter_handles_multiline_generic_method(): code = """ func (r *Repo[ T any, @@ -68,5 +63,16 @@ def test_multiline_generic_method(): return nil, nil } """ - chunks = _extract(code) - assert any("Save" in c for c in chunks), "multiline generic method not parsed" + expected_chunks = [ + textwrap.dedent(""" + func (r *Repo[ + T any, + E error, + ]) Save(v T) (E, error) { + return nil, nil + } + """).strip() + ] + + actual_chunks = [c.strip() for c in _extract(code)] + assert actual_chunks == expected_chunks diff --git a/src/vuln_analysis/tools/tests/test_golang_functions_parsers.py b/src/vuln_analysis/tools/tests/test_golang_functions_parsers.py new file mode 100644 index 00000000..ef024078 --- /dev/null +++ b/src/vuln_analysis/tools/tests/test_golang_functions_parsers.py @@ -0,0 +1,132 @@ + +import textwrap + +import pytest +from langchain_core.documents import Document + +HAPPY_PATH_CASES = [ + ("simple_function", "func DoSomething() {}", "DoSomething"), + ("with_parameters", "func DoSomething(p1 string, p2 int) {}", "DoSomething"), + ("with_return_value", "func DoSomething(v int) string {}", "DoSomething"), + ( + "with_named_return", + "func DoSomething(a, b float64) (q float64, e error) {}", + "DoSomething", + ), + ( + "method_with_receiver", + "func (p *Point) DoSomething() float64 {}", + "DoSomething", + ), +] + + +EDGE_CASES_TEST = [ + ("generic_function", "func DoSomething[T any](s []T) {}", "DoSomething"), + ( + "letter_or_underscores", + "func _internal_calculate_v2() {}", + "_internal_calculate_v2", + ), + ( + "receivers_double_pointer_function", + "func (c **Connection) Close() error {}", + "Close", + ), + ( + "receivers_without_the_name_function", + "func (*Point) IsOrigin() bool {}", + "IsOrigin", + ), + ( + "multiline_function", + """ + func (r *Repository[ + T Model, + K KeyType, + ]) FindByID(id K) (*T, error) {} + """, + "FindByID", + ), +] + +NEGATIVE_ANONYMOUS_CASES = [ + ( + "assigned_to_variable", + "var greeter = func(name string) { fmt.Println('Hello,', name) }", + ), + ( + "assigned_to_variable2", + textwrap.dedent( + """ + greet := func() { // Assigning anonymous function to a variable 'greet' + fmt.Println("Greetings from a variable-assigned anonymous function!") + } + """ + ), + ), + ( + "go_routine", + "go func() { fmt.Println('Running in background') }()", + ), + ( + "defer_statement", + "defer func() { file.Close() }()", + ), + ( + "callback_argument", + "http.HandleFunc('/', func(w http.ResponseWriter, r *http.Request) {})", + ), +] + +MALFORMED_INPUT_CASES = [ + ("empty_string", ""), + ("whitespace_only", " \n\t "), + ("just_the_keyword", "func"), + ("incomplete_header", "func myFunc("), + ("garbage_input", "a = b + c;"), +] + +@pytest.mark.parametrize("test_id, code_snippet, expected_name", HAPPY_PATH_CASES) +def test_happy_path_function_names(go_parser, test_id,code_snippet, expected_name): + doc = Document(page_content=code_snippet.strip(), metadata={"source": "test.go"}) + actual_name = go_parser.get_function_name(doc) + assert actual_name == expected_name, f"Test case '{test_id}' failed" + +@pytest.mark.parametrize("test_id, code_snippet, expected_name", EDGE_CASES_TEST) +def test_edge_cases_function_names(go_parser, test_id, code_snippet, expected_name): + doc = Document(page_content=code_snippet.strip(), metadata={"source": "test.go"}) + actual_name = go_parser.get_function_name(doc) + assert actual_name == expected_name, f"Test case '{test_id}' failed" + +@pytest.mark.parametrize("test_id, code_snippet", NEGATIVE_ANONYMOUS_CASES) +def test_negative_cases_anonymous_functions(go_parser, test_id, code_snippet): + doc = Document(page_content=code_snippet.strip(), metadata={"source": "proxy.go"}) + name = go_parser.get_function_name(doc) + assert name.startswith("anon_"), ( + f"[{test_id}] Expected name to start with 'anon_', but got '{name}'" + ) + parts = name.split("_") + assert len(parts) == 3, ( + f"[{test_id}] Expected name format 'anon__', but got '{name}'" + ) + + assert parts[1] == "proxy", ( + f"[{test_id}] Expected file prefix 'proxy', but got '{parts[1]}'" + ) + hash_part = parts[2] + assert len(hash_part) == 8, ( + f"[{test_id}] Hash part should be 8 characters, but got '{hash_part}'" + ) + assert all(c in "0123456789abcdef" for c in hash_part), ( + f"[{test_id}] Hash part should be hex, but got '{hash_part}'" + ) + +@pytest.mark.parametrize("test_id, code_snippet", MALFORMED_INPUT_CASES) +def test_malformed_input_graceful_failure(go_parser, test_id, code_snippet): + doc = Document(page_content=code_snippet, metadata={"source": "malformed.go"}) + name = go_parser.get_function_name(doc) + + assert name.startswith("anon_"), ( + f"[{test_id}] Failed to handle malformed input gracefully. Got: {name}" + )