diff --git a/.github/workflows/dec-tests.yml b/.github/workflows/dec-tests.yml index 4355e3a2..b0bd2ef4 100644 --- a/.github/workflows/dec-tests.yml +++ b/.github/workflows/dec-tests.yml @@ -41,19 +41,18 @@ jobs: python ./.github/binja/download_headless.py --serial ${{ env.BN_SERIAL }} --output .github/binja/BinaryNinja-headless.zip unzip .github/binja/BinaryNinja-headless.zip -d .github/binja/ python .github/binja/binaryninja/scripts/install_api.py --install-on-root --silent - - name: Set up Java 17 + - name: Set up Java 21 uses: actions/setup-java@v4 with: distribution: "oracle" - java-version: "17" + java-version: "21" - name: Install Ghidra uses: antoniovazquezblanco/setup-ghidra@v2.0.12 with: - version: "11.1" + version: "12.0" auth_token: ${{ secrets.GITHUB_TOKEN }} - name: Pytest run: | # these two test must be run in separate python environments, due to the way ghidra bridge works # you also must run these tests in the exact order shown here - TEST_BINARIES_DIR=/tmp/bs-artifacts/binaries pytest ./tests/test_remote_ghidra.py -s - TEST_BINARIES_DIR=/tmp/bs-artifacts/binaries pytest ./tests/test_decompilers.py -s \ No newline at end of file + TEST_BINARIES_DIR=/tmp/bs-artifacts/binaries pytest tests/test_decompilers.py tests/test_client_server.py -sv \ No newline at end of file diff --git a/README.md b/README.md index c84d8ecd..e450a011 100644 --- a/README.md +++ b/README.md @@ -13,14 +13,12 @@ pip install libbs ``` The minimum Python version is **3.10**. -**If you plan on using libbs alone (without installing some other plugin), -you must do `libbs --install` after pip install**. This will copy the appropriate files to your decompiler. ## Supported Decompilers - IDA Pro: **>= 8.4** (if you have an older version, use `v1.26.0`) - Binary Ninja: **>= 2.4** - angr-management: **>= 9.0** -- Ghidra: **>= 11.2** +- Ghidra: **>= 12.0** (started in PyGhidra mode) ## Usage LibBS exposes all decompiler API through the abstract class `DecompilerInterface`. The `DecompilerInterface` @@ -45,6 +43,9 @@ for addr in deci.functions: deci.functions[function.addr] = function ``` +Note that for Ghidra in UI mode you must first start it in PyGhidra mode. You can do this by going to your install dir +and running `./support/pyghidraRun`. + ### Headless Mode To use headless mode you must specify a decompiler to use. You can get the traditional interface using the following: @@ -54,8 +55,8 @@ from libbs.api import DecompilerInterface deci = DecompilerInterface.discover(force_decompiler="ghidra", headless=True) ``` -In the case of Ghidra, you must have the environment variable `GHIDRA_HEADLESS_PATH` set to the path of the Ghidra -headless binary. This is usually `ghidraRun` or `ghidraHeadlessAnalyzer`. +In the case of Ghidra, you must have the environment variable `GHIDRA_INSTALL_DIR` set to the path of the Ghidra +installation (the place the `ghidraRun` script is located). ### Artifact Access Caveats In designing the dictionaries that contain all Artifacts in a decompiler, we had a clash between ease-of-use and speed. @@ -85,7 +86,7 @@ loaded_func = Function.loads(json_str, fmt="json") ``` ## Sponsors -BinSync and it's associated projects would not be possible without sponsorship. +BinSync and its associated projects would not be possible without sponsorship. In no particular order, we'd like to thank all the organizations that have previously or are currently sponsoring one of the many BinSync projects. diff --git a/examples/decompiler_client_example.py b/examples/decompiler_client_example.py new file mode 100644 index 00000000..bc5a9cb9 --- /dev/null +++ b/examples/decompiler_client_example.py @@ -0,0 +1,213 @@ +#!/usr/bin/env python3 +""" +Example demonstrating the RPyC-based DecompilerClient. + +This script shows how to use the new RPyC DecompilerClient which provides +identical API to DecompilerInterface but connects to a remote server. +""" + +import logging +import time +import sys +from typing import Optional + +# Set up logging +logging.basicConfig(level=logging.INFO) + +def example_with_local_decompiler(): + """Example using local DecompilerInterface""" + try: + from libbs.api import DecompilerInterface + + print("=== Using Local DecompilerInterface ===") + deci = DecompilerInterface.discover() + if deci is None: + print("No local decompiler found") + return + + demo_decompiler_operations(deci) + + except Exception as e: + print(f"Local decompiler error: {e}") + + +def example_with_remote_decompiler(server_url: str = "rpyc://localhost:18861"): + """Example using remote DecompilerClient""" + try: + from libbs.api.decompiler_client import DecompilerClient + + print(f"\n=== Using Remote DecompilerClient ({server_url}) ===") + with DecompilerClient.discover(server_url=server_url) as deci: + demo_decompiler_operations(deci) + + except Exception as e: + print(f"Remote decompiler error: {e}") + print("Make sure to start the server first with: libbs --server") + + +def demo_decompiler_operations(deci): + """ + Demo function that works identically with both DecompilerInterface and DecompilerClient. + + This shows the power of the unified API - the same code works regardless of whether + the decompiler is local or remote. + """ + print(f"Decompiler: {deci.name}") + print(f"Binary path: {deci.binary_path}") + print(f"Binary hash: {deci.binary_hash}") + print(f"Base address: 0x{deci.binary_base_addr:x}" if deci.binary_base_addr else "None") + print(f"Decompiler available: {deci.decompiler_available}") + + # Test fast collection operations (this is where RPyC shines) + print(f"\n=== Testing Fast Collection Operations ===") + + # This should be fast - single bulk request for all light artifacts + start_time = time.time() + functions = list(deci.functions.items()) + end_time = time.time() + print(f"Retrieved {len(functions)} functions in {end_time - start_time:.3f}s") + + # Test other collections + collections = [ + ("comments", deci.comments), + ("patches", deci.patches), + ("global_vars", deci.global_vars), + ("structs", deci.structs), + ("enums", deci.enums), + ("typedefs", deci.typedefs) + ] + + for name, collection in collections: + start_time = time.time() + items = list(collection.keys()) + end_time = time.time() + print(f" {name}: {len(items)} items in {end_time - start_time:.3f}s") + + # Test function access (if any functions exist) + if len(deci.functions) > 0: + print(f"\n=== Testing Individual Access ===") + first_addr = functions[0][0] + + # Test full artifact access via __getitem__ (standard behavior) + start_time = time.time() + full_func = deci.functions[first_addr] # This gets the full artifact + end_time = time.time() + print(f"Full artifact access via []: {end_time - start_time:.3f}s") + print(f"Function: {full_func.name} at 0x{full_func.addr:x} (size: {full_func.size})") + + # Test light artifact access (fast, cached) + if hasattr(deci.functions, 'get_light'): + start_time = time.time() + light_func = deci.functions.get_light(first_addr) + end_time = time.time() + print(f"Light artifact access via get_light(): {end_time - start_time:.6f}s") + + # Show first few functions + print("\nFirst 5 functions:") + for addr, func in functions[:5]: + print(f" 0x{addr:x}: {func.name} (size: {func.size})") + + # Test method calls + try: + print(f"\n=== Testing Method Calls ===") + if len(deci.functions) > 0: + first_addr = list(deci.functions.keys())[0] + light_func = deci.fast_get_function(first_addr) + if light_func: + print(f" fast_get_function(0x{first_addr:x}): {light_func.name}") + + func_size = deci.get_func_size(first_addr) + print(f" get_func_size(0x{first_addr:x}): {func_size}") + + # Test decompilation if available + if deci.decompiler_available: + start_time = time.time() + decomp = deci.decompile(first_addr) + end_time = time.time() + if decomp: + lines = decomp.text.split('\n') + print(f" decompile(0x{first_addr:x}): {len(lines)} lines in {end_time - start_time:.3f}s") + print(f" First line: {lines[0][:80]}...") + else: + print(" No functions available for testing") + + except Exception as e: + print(f" Method call error: {e}") + + +def discover_decompiler(prefer_remote: bool = False, server_url: str = "rpyc://localhost:18861"): + """ + Smart discovery function that tries remote first if preferred, then falls back to local. + + This demonstrates how you can write code that seamlessly works with either + local or remote decompilers based on availability. + """ + if prefer_remote: + # Try remote first + try: + from libbs.api.decompiler_client import DecompilerClient + return DecompilerClient.discover(server_url=server_url) + except Exception: + pass + + # Fall back to local + try: + from libbs.api import DecompilerInterface + return DecompilerInterface.discover() + except Exception: + return None + else: + # Try local first + try: + from libbs.api import DecompilerInterface + return DecompilerInterface.discover() + except Exception: + pass + + # Fall back to remote + try: + from libbs.api.decompiler_client import DecompilerClient + return DecompilerClient.discover(server_url=server_url) + except Exception: + return None + + +def main(): + if len(sys.argv) > 1: + server_url = sys.argv[1] + else: + server_url = "rpyc://localhost:18861" + + print("LibBS DecompilerClient Example") + print("==============================") + + # Demo 1: Try local decompiler + example_with_local_decompiler() + + # Demo 2: Try remote decompiler + example_with_remote_decompiler(server_url) + + # Demo 3: Smart discovery + print(f"\n=== Smart Discovery (prefer remote) ===") + deci = discover_decompiler(prefer_remote=True, server_url=server_url) + if deci: + print(f"Discovered: {type(deci).__name__}") + demo_decompiler_operations(deci) + if hasattr(deci, 'shutdown'): + deci.shutdown() + else: + print("No decompiler available (local or remote)") + + print(f"\n=== Smart Discovery (prefer local) ===") + deci = discover_decompiler(prefer_remote=False, server_url=server_url) + if deci: + print(f"Discovered: {type(deci).__name__}") + demo_decompiler_operations(deci) + if hasattr(deci, 'shutdown'): + deci.shutdown() + else: + print("No decompiler available (local or remote)") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/libbs/__init__.py b/libbs/__init__.py index 0ce16f58..a85f86d1 100644 --- a/libbs/__init__.py +++ b/libbs/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.16.5" +__version__ = "3.0.0" import logging diff --git a/libbs/__main__.py b/libbs/__main__.py index 69221177..a85769df 100644 --- a/libbs/__main__.py +++ b/libbs/__main__.py @@ -1,9 +1,6 @@ import argparse import sys import logging -from pathlib import Path -import importlib -import importlib.resources from libbs.plugin_installer import LibBSPluginInstaller @@ -14,15 +11,100 @@ def install(): LibBSPluginInstaller().install() +def start_server(socket_path=None, decompiler=None, binary_path=None, headless=False): + """Start the DecompilerServer (AF_UNIX socket-based)""" + try: + from libbs.api.decompiler_server import DecompilerServer + from libbs.api.decompiler_interface import DecompilerInterface + + # Configure logging + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') + + # Prepare interface kwargs + interface_kwargs = {} + if decompiler: + interface_kwargs['force_decompiler'] = decompiler + if binary_path: + interface_kwargs['binary_path'] = binary_path + if headless: + interface_kwargs['headless'] = headless + + # Create and start server + if socket_path: + _l.info(f"Starting AF_UNIX DecompilerServer on {socket_path}") + else: + _l.info("Starting AF_UNIX DecompilerServer with auto-generated socket path") + if interface_kwargs: + _l.info(f"Interface options: {interface_kwargs}") + + with DecompilerServer(socket_path=socket_path, **interface_kwargs) as server: + _l.info("Server started successfully. Press Ctrl+C to stop.") + _l.info("Connect with: DecompilerClient.discover('unix://{}')".format(server.socket_path)) + try: + server.wait_for_shutdown() + except KeyboardInterrupt: + _l.info("Shutting down server...") + + except ImportError as e: + _l.error(f"Failed to import required modules: {e}") + sys.exit(1) + except Exception as e: + _l.error(f"Failed to start server: {e}") + sys.exit(1) + + +def test_client(server_url=None): + """Test the DecompilerClient connection""" + try: + from libbs.api.decompiler_client import DecompilerClient + + # Configure logging + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') + + if server_url: + _l.info(f"Testing connection to DecompilerServer at {server_url}") + else: + _l.info("Testing connection to auto-discovered DecompilerServer") + + with DecompilerClient.discover(server_url=server_url) as client: + _l.info(f"Successfully connected to {client.name} decompiler") + _l.info(f"Binary path: {client.binary_path}") + _l.info(f"Binary hash: {client.binary_hash}") + _l.info(f"Decompiler available: {client.decompiler_available}") + + # Test fast artifact collections (benchmark performance) + import time + start_time = time.time() + functions = list(client.functions.items()) + end_time = time.time() + _l.info(f"Retrieved {len(functions)} functions in {end_time - start_time:.3f}s") + + start_time = time.time() + comments = list(client.comments.keys()) + end_time = time.time() + _l.info(f"Retrieved {len(comments)} comment keys in {end_time - start_time:.3f}s") + + _l.info("✅ Client test completed successfully!") + + except ImportError as e: + _l.error(f"Failed to import required modules: {e}") + sys.exit(1) + except Exception as e: + _l.error(f"Client test failed: {e}") + sys.exit(1) + + def main(): parser = argparse.ArgumentParser( description=""" The LibBS Command Line Util. This is the script interface to LibBS that allows you to install and run - the Ghidra UI for running plugins. + the Ghidra UI for running plugins, and start the DecompilerServer. """, epilog=""" Examples: - libbs --install + libbs --install | + libbs --server --socket-path /tmp/my_server.sock | + libbs --server --decompiler ghidra --binary-path /path/to/binary --headless """ ) parser.add_argument( @@ -33,6 +115,37 @@ def main(): parser.add_argument( "--single-decompiler-install", nargs=2, metavar=('decompiler', 'path'), help="Install DAILA into a single decompiler. Decompiler must be one of: ida, ghidra, binja, angr." ) + parser.add_argument( + "--server", action="store_true", help=""" + Start the DecompilerServer to expose DecompilerInterface APIs over AF_UNIX sockets. + """ + ) + parser.add_argument( + "--server-url", help=""" + URL of the DecompilerServer to connect to (e.g., unix:///tmp/server.sock). + If not specified, will auto-discover running servers. + """ + ) + parser.add_argument( + "--socket-path", help=""" + Path for the AF_UNIX socket (default: auto-generated in temp directory). + """ + ) + parser.add_argument( + "--decompiler", choices=["ida", "ghidra", "binja", "angr"], help=""" + Force a specific decompiler for the server. If not specified, auto-detection will be used. + """ + ) + parser.add_argument( + "--binary-path", help=""" + Path to the binary file to analyze (required for headless mode). + """ + ) + parser.add_argument( + "--headless", action="store_true", help=""" + Run the decompiler in headless mode (no GUI). Requires --binary-path. + """ + ) args = parser.parse_args() if args.single_decompiler_install: @@ -40,6 +153,17 @@ def main(): LibBSPluginInstaller().install(interactive=False, paths_by_target={decompiler: path}) elif args.install: install() + elif args.server: + if args.headless and not args.binary_path: + parser.error("--headless requires --binary-path") + start_server( + socket_path=args.socket_path, + decompiler=args.decompiler, + binary_path=args.binary_path, + headless=args.headless + ) + else: + parser.print_help() if __name__ == "__main__": diff --git a/libbs/api/decompiler_client.py b/libbs/api/decompiler_client.py new file mode 100644 index 00000000..85eeec7a --- /dev/null +++ b/libbs/api/decompiler_client.py @@ -0,0 +1,1128 @@ +# Note to reader: most of this code was generated by Claude 4.5. It may contain errors and was designed +# in tandem with decompiler_server.py and the tests/test_client_server.py file. This comment will be +# removed when the majority of the file is owned by a human author. + +import logging +import socket +import time +import os +import glob +import tempfile +from typing import Dict, Any, Optional, List, Union, Callable +from collections import defaultdict +import threading + +from libbs.artifacts import ( + Artifact, Function, Comment, Patch, GlobalVariable, + Struct, Enum, Typedef, Context, Decompilation +) +from libbs.api.decompiler_server import SocketProtocol +from libbs.api.type_parser import CTypeParser +from libbs.configuration import LibbsConfig + +_l = logging.getLogger(__name__) + + +class ArtLifterProxy: + """ + A proxy for the ArtifactLifter that delegates all operations to the remote server. + This maintains API compatibility with the local ArtifactLifter while sending + requests to the remote decompiler server. + """ + SCOPE_DELIMITER = "::" + + def __init__(self, client: 'DecompilerClient'): + self.client = client + + def lift(self, artifact: Artifact): + """Lift an artifact using the remote decompiler""" + return self.client._send_request({ + "type": "method_call", + "method_name": "art_lifter.lift", + "args": [artifact] + }) + + def lower(self, artifact: Artifact): + """Lower an artifact using the remote decompiler""" + return self.client._send_request({ + "type": "method_call", + "method_name": "art_lifter.lower", + "args": [artifact] + }) + + def lift_addr(self, addr: int) -> int: + """Lift an address using the remote decompiler""" + return self.client._send_request({ + "type": "method_call", + "method_name": "art_lifter.lift_addr", + "args": [addr] + }) + + def lower_addr(self, addr: int) -> int: + """Lower an address using the remote decompiler""" + return self.client._send_request({ + "type": "method_call", + "method_name": "art_lifter.lower_addr", + "args": [addr] + }) + + def lift_type(self, type_str: str) -> str: + """Lift a type string using the remote decompiler""" + return self.client._send_request({ + "type": "method_call", + "method_name": "art_lifter.lift_type", + "args": [type_str] + }) + + def lower_type(self, type_str: str) -> str: + """Lower a type string using the remote decompiler""" + return self.client._send_request({ + "type": "method_call", + "method_name": "art_lifter.lower_type", + "args": [type_str] + }) + + def lift_stack_offset(self, offset: int, func_addr: int) -> int: + """Lift a stack offset using the remote decompiler""" + return self.client._send_request({ + "type": "method_call", + "method_name": "art_lifter.lift_stack_offset", + "args": [offset, func_addr] + }) + + def lower_stack_offset(self, offset: int, func_addr: int) -> int: + """Lower a stack offset using the remote decompiler""" + return self.client._send_request({ + "type": "method_call", + "method_name": "art_lifter.lower_stack_offset", + "args": [offset, func_addr] + }) + + @staticmethod + def parse_scoped_type(type_str: str) -> tuple[str, str | None]: + """ + Parse a scoped type string into its base type and scope. + This is a static method that doesn't need remote decompiler access. + """ + if not type_str: + return "", None + + # check if the type is scoped + scope = None + deli = ArtLifterProxy.SCOPE_DELIMITER + if deli in type_str: + scope_parts = type_str.split(deli) + base_type = scope_parts[-1] + scope = deli.join(scope_parts[:-1]) + else: + base_type = type_str + + return base_type, scope + + @staticmethod + def scoped_type_to_str(name: str, scope: str | None = None) -> str: + """ + Convert a name and scope into a scoped type string. + This is a static method that doesn't need remote decompiler access. + """ + return name if not scope else f"{scope}::{name}" + + +class FastClientArtifactDict(dict): + """ + A fast client-side proxy for ArtifactDict that communicates with DecompilerServer via AF_UNIX sockets. + + This class mimics the behavior of ArtifactDict but uses sockets for bulk operations + and maintains the same performance characteristics as the local version by using + the _lifted_art_lister pattern. + """ + + def __init__(self, collection_name: str, artifact_class, client: 'DecompilerClient'): + super().__init__() + self.collection_name = collection_name + self.artifact_class = artifact_class + self.client = client + self._light_cache = {} + self._light_cache_timestamp = 0 + self._cache_ttl = 10.0 # Cache for 10 seconds + + def _get_light_artifacts(self) -> Dict: + """Get all light artifacts using the server's fast bulk operation""" + current_time = time.time() + if current_time - self._light_cache_timestamp > self._cache_ttl: + # Cache expired, fetch from server using bulk endpoint + try: + _l.debug(f"Fetching light artifacts for {self.collection_name}") + request = { + "type": "get_light_artifacts", + "collection_name": self.collection_name + } + serialized_artifacts = self.client._send_request(request) + + # Reconstruct artifacts from serialized format + reconstructed_artifacts = {} + for addr, artifact_info in serialized_artifacts.items(): + try: + # Import the artifact class dynamically + module_name = artifact_info['module'] + class_name = artifact_info['type'] + serialized_data = artifact_info['data'] + + # Import the module and get the class + module = __import__(module_name, fromlist=[class_name]) + artifact_class = getattr(module, class_name) + + # Reconstruct the artifact using its loads method + artifact = artifact_class.loads(serialized_data) + reconstructed_artifacts[addr] = artifact + + except Exception as e: + _l.warning(f"Failed to reconstruct artifact at 0x{addr:x}: {e}") + # Skip problematic artifacts rather than failing completely + continue + + self._light_cache = reconstructed_artifacts + self._light_cache_timestamp = current_time + except Exception as e: + _l.warning(f"Failed to fetch light artifacts for {self.collection_name}: {e}") + + return self._light_cache + + def _invalidate_cache(self): + """Invalidate the light artifact cache""" + self._light_cache.clear() + self._light_cache_timestamp = 0 + + def __len__(self): + """Return the number of items in the collection""" + light_items = self._get_light_artifacts() + return len(light_items) + + def __iter__(self): + """Iterate over keys in the collection""" + light_items = self._get_light_artifacts() + return iter(light_items.keys()) + + def keys(self): + """Return an iterator over the keys (fast bulk operation)""" + light_items = self._get_light_artifacts() + return light_items.keys() + + def values(self): + """Return an iterator over the values (light artifacts, fast bulk operation)""" + light_items = self._get_light_artifacts() + return light_items.values() + + def items(self): + """Return an iterator over (key, value) pairs (fast bulk operation)""" + light_items = self._get_light_artifacts() + return light_items.items() + + def __getitem__(self, key): + """Get a full artifact by key (same behavior as local ArtifactDict)""" + # First, check if the key exists by looking at light artifacts + light_items = self._get_light_artifacts() + if key not in light_items: + raise KeyError(f"Key {key} not found in {self.collection_name}") + + # Key exists, get the full artifact from server + try: + request = { + "type": "get_full_artifact", + "collection_name": self.collection_name, + "key": key + } + return self.client._send_request(request) + except Exception as e: + if "not found" in str(e).lower(): + raise KeyError(f"Key {key} not found in {self.collection_name}") + else: + raise + + def get_light(self, key): + """Get a light artifact by key (fast, cached access)""" + light_items = self._get_light_artifacts() + if key not in light_items: + raise KeyError(f"Key {key} not found in {self.collection_name}") + return light_items[key] + + def get_full(self, key): + """Explicitly get a full artifact (with expensive operations like decompilation)""" + try: + request = { + "type": "get_full_artifact", + "collection_name": self.collection_name, + "key": key + } + return self.client._send_request(request) + except Exception as e: + if "not found" in str(e).lower(): + raise KeyError(f"Key {key} not found in {self.collection_name}") + else: + raise + + def __setitem__(self, key, value): + """Set an artifact by key on the server""" + if not isinstance(value, self.artifact_class): + raise ValueError(f"Expected {self.artifact_class.__name__}, got {type(value).__name__}") + + # Use the direct decompiler interface for setting artifacts + success = self.client.set_artifact(value) + + # Invalidate cache since we modified the collection + self._invalidate_cache() + + if not success: + raise RuntimeError(f"Failed to set artifact") + + def __delitem__(self, key): + """Delete an artifact by key (not implemented in decompiler interfaces)""" + raise NotImplementedError("Deletion not supported by DecompilerInterface") + + def __contains__(self, key): + """Check if a key exists in the collection""" + light_items = self._get_light_artifacts() + return key in light_items + + def get(self, key, default=None): + """Get a full artifact with a default value""" + try: + return self[key] # Use __getitem__ which returns full artifact + except KeyError: + return default + + +class DecompilerClient: + """ + A client that connects to DecompilerServer via AF_UNIX sockets and provides the same interface as DecompilerInterface. + + This class acts as a transparent proxy to a remote DecompilerInterface, allowing users to + write code that works identically whether using a local or remote decompiler. + """ + + def __init__(self, + socket_path: str, + timeout: float = 30.0): + """ + Initialize the DecompilerClient. + + Args: + socket_path: Path to the AF_UNIX socket + timeout: Connection timeout in seconds + """ + self.socket_path = socket_path + self.timeout = timeout + + # Connection state + self._socket = None + self._connected = False + self._server_info = None + self._socket_lock = threading.Lock() + + # Try to connect + self._connect() + + # Initialize fast artifact collections + self.functions = FastClientArtifactDict("functions", Function, self) + self.comments = FastClientArtifactDict("comments", Comment, self) + self.patches = FastClientArtifactDict("patches", Patch, self) + self.global_vars = FastClientArtifactDict("global_vars", GlobalVariable, self) + self.structs = FastClientArtifactDict("structs", Struct, self) + self.enums = FastClientArtifactDict("enums", Enum, self) + self.typedefs = FastClientArtifactDict("typedefs", Typedef, self) + + # Initialize callback attributes to match DecompilerInterface + self.artifact_change_callbacks = defaultdict(list) + self.decompiler_closed_callbacks = [] + self.decompiler_opened_callbacks = [] + self.undo_event_callbacks = [] + self._thread_artifact_callbacks = True + + # Create a proxy art_lifter that delegates to server + # art_lifter is typically used for address lifting operations + self.art_lifter = ArtLifterProxy(self) + + # Additional public attributes to match DecompilerInterface + self.type_parser = CTypeParser() # Local type parser + self.artifact_write_lock = threading.Lock() # Thread safety lock + self.config = LibbsConfig.update_or_make() # Configuration object + self.gui_plugin = None # GUI plugin reference + self.artifact_watchers_started = False # Watcher state + + # Event listener state for receiving callbacks from server + self._event_listener_running = False + self._subscribed_to_events = False + self._event_listener_thread = None + self._event_socket = None + self._event_socket_lock = threading.Lock() + + # These attributes will be fetched from server on first access + self._supports_undo = None + self._supports_type_scopes = None + self._qt_version = None + self._default_func_prefix = None + self._headless = None + self._force_click_recording = None + self._track_mouse_moves = None + + _l.info(f"DecompilerClient connected to {socket_path}") + + def _connect(self): + """Establish connection to the server""" + try: + _l.debug(f"Attempting to connect to AF_UNIX socket at {self.socket_path}") + + self._socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self._socket.settimeout(self.timeout) + self._socket.connect(self.socket_path) + + _l.debug("Socket connection established") + + # Test the connection by getting server info first + self._server_info = self._send_request({"type": "server_info"}) + _l.debug(f"Got server info: {self._server_info}") + + self._connected = True + + _l.info(f"Connected to {self._server_info.get('name', 'DecompilerServer')} " + f"using {self._server_info.get('decompiler', 'unknown')} decompiler") + except Exception as e: + _l.error(f"Failed to connect to DecompilerServer at {self.socket_path}: {e}") + + # Provide helpful error messages for common issues + if "No such file or directory" in str(e): + raise ConnectionError(f"Cannot connect to DecompilerServer at {self.socket_path}. " + f"Make sure the server is running with: libbs --server") + elif "Connection refused" in str(e): + raise ConnectionError(f"Cannot connect to DecompilerServer at {self.socket_path}. " + f"Make sure the server is running.") + else: + raise ConnectionError(f"Cannot connect to DecompilerServer: {e}") + + def _send_request(self, request: Dict[str, Any]) -> Any: + """Send a request to the server and return the response""" + with self._socket_lock: + try: + SocketProtocol.send_message(self._socket, request) + response = SocketProtocol.recv_message(self._socket) + + # Check if response is an error + if isinstance(response, dict) and "error" in response: + error_type = response.get("type", "Exception") + error_msg = response.get("error", "Unknown error") + + # Try to reconstruct the original exception type + if error_type == "KeyError": + raise KeyError(error_msg) + elif error_type == "ValueError": + raise ValueError(error_msg) + elif error_type == "AttributeError": + raise AttributeError(error_msg) + else: + raise RuntimeError(f"{error_type}: {error_msg}") + + # Check if response is a serialized artifact + if isinstance(response, dict) and response.get("is_artifact"): + try: + # Reconstruct the artifact + module_name = response['module'] + class_name = response['type'] + serialized_data = response['data'] + + # Import the module and get the class + module = __import__(module_name, fromlist=[class_name]) + artifact_class = getattr(module, class_name) + + # Reconstruct the artifact using its loads method + artifact = artifact_class.loads(serialized_data) + return artifact + + except Exception as e: + _l.warning(f"Failed to reconstruct artifact response: {e}") + # Fall back to returning the raw response + return response + + return response + except Exception as e: + _l.error(f"Request failed: {e} for {request}") + raise + + # Properties - mirror DecompilerInterface properties + @property + def name(self) -> str: + """Name of the decompiler""" + return self._server_info.get('decompiler', 'remote') + + @property + def binary_base_addr(self) -> int: + """Base address of the binary""" + return self._send_request({"type": "property_get", "property_name": "binary_base_addr"}) + + @property + def binary_hash(self) -> str: + """Hash of the binary""" + return self._send_request({"type": "property_get", "property_name": "binary_hash"}) + + @property + def binary_path(self) -> Optional[str]: + """Path to the binary""" + return self._send_request({"type": "property_get", "property_name": "binary_path"}) + + @property + def decompiler_available(self) -> bool: + """Whether decompiler is available""" + return self._send_request({"type": "property_get", "property_name": "decompiler_available"}) + + @property + def default_pointer_size(self) -> int: + """Default pointer size""" + return self._send_request({"type": "property_get", "property_name": "default_pointer_size"}) + + # GUI API methods - delegate to remote decompiler + def gui_active_context(self) -> Optional[Context]: + """Get the active context from the GUI""" + return self._send_request({"type": "method_call", "method_name": "gui_active_context"}) + + def gui_goto(self, func_addr) -> None: + """Go to an address in the GUI""" + return self._send_request({"type": "method_call", "method_name": "gui_goto", "args": [func_addr]}) + + def gui_show_type(self, type_name: str) -> None: + """Show a type in the GUI""" + return self._send_request({"type": "method_call", "method_name": "gui_show_type", "args": [type_name]}) + + def gui_ask_for_string(self, question: str, title: str = "Plugin Question") -> str: + """Ask for a string input""" + return self._send_request({"type": "method_call", "method_name": "gui_ask_for_string", "args": [question, title]}) + + def gui_ask_for_choice(self, question: str, choices: list, title: str = "Plugin Question") -> str: + """Ask for a choice from a list""" + return self._send_request({"type": "method_call", "method_name": "gui_ask_for_choice", "args": [question, choices, title]}) + + def gui_popup_text(self, text: str, title: str = "Plugin Message") -> bool: + """Show a popup message""" + return self._send_request({"type": "method_call", "method_name": "gui_popup_text", "args": [text, title]}) + + # Core decompiler API methods - delegate to remote decompiler + def fast_get_function(self, func_addr) -> Optional[Function]: + """Get a light version of a function""" + return self._send_request({"type": "method_call", "method_name": "fast_get_function", "args": [func_addr]}) + + def get_func_size(self, func_addr) -> int: + """Get the size of a function""" + return self._send_request({"type": "method_call", "method_name": "get_func_size", "args": [func_addr]}) + + def decompile(self, addr: int, map_lines=False, **kwargs) -> Optional[Decompilation]: + """Decompile a function""" + return self._send_request({"type": "method_call", "method_name": "decompile", "args": [addr], "kwargs": {"map_lines": map_lines, **kwargs}}) + + def xrefs_to(self, artifact: Artifact, decompile=False, only_code=False) -> List[Artifact]: + """Get cross-references to an artifact""" + return self._send_request({"type": "method_call", "method_name": "xrefs_to", "args": [artifact], "kwargs": {"decompile": decompile, "only_code": only_code}}) + + def get_callgraph(self, only_names=False): + """Get the call graph""" + return self._send_request({"type": "method_call", "method_name": "get_callgraph", "kwargs": {"only_names": only_names}}) + + def get_dependencies(self, artifact: Artifact, decompile=True, max_resolves=50, **kwargs) -> List[Artifact]: + """Get dependencies for an artifact""" + return self._send_request({"type": "method_call", "method_name": "get_dependencies", "args": [artifact], "kwargs": {"decompile": decompile, "max_resolves": max_resolves, **kwargs}}) + + def get_func_containing(self, addr: int) -> Optional[Function]: + """Get the function containing an address""" + return self._send_request({"type": "method_call", "method_name": "get_func_containing", "args": [addr]}) + + def get_decompilation_object(self, function: Function, **kwargs): + """Get the decompilation object for a function""" + return self._send_request({"type": "method_call", "method_name": "get_decompilation_object", "args": [function], "kwargs": kwargs}) + + def set_artifact(self, artifact: Artifact, lower=True, **kwargs) -> bool: + """Set an artifact in the decompiler""" + return self._send_request({"type": "method_call", "method_name": "set_artifact", "args": [artifact], "kwargs": {"lower": lower, **kwargs}}) + + def get_defined_type(self, type_str: str): + """Get a defined type by string""" + return self._send_request({"type": "method_call", "method_name": "get_defined_type", "args": [type_str]}) + + # Optional API methods - delegate to remote decompiler + def undo(self) -> None: + """Undo the last operation""" + return self._send_request({"type": "method_call", "method_name": "undo"}) + + def local_variable_names(self, func: Function) -> List[str]: + """Get local variable names for a function""" + return self._send_request({"type": "method_call", "method_name": "local_variable_names", "args": [func]}) + + def rename_local_variables_by_names(self, func: Function, name_map: Dict[str, str], **kwargs) -> bool: + """Rename local variables by name map""" + return self._send_request({"type": "method_call", "method_name": "rename_local_variables_by_names", "args": [func, name_map], "kwargs": kwargs}) + + # Logging methods - delegate to remote decompiler + def print(self, msg: str, **kwargs) -> None: + """Print a message""" + return self._send_request({"type": "method_call", "method_name": "print", "args": [msg], "kwargs": kwargs}) + + def info(self, msg: str, **kwargs) -> None: + """Log an info message""" + return self._send_request({"type": "method_call", "method_name": "info", "args": [msg], "kwargs": kwargs}) + + def debug(self, msg: str, **kwargs) -> None: + """Log a debug message""" + return self._send_request({"type": "method_call", "method_name": "debug", "args": [msg], "kwargs": kwargs}) + + def warning(self, msg: str, **kwargs) -> None: + """Log a warning message""" + return self._send_request({"type": "method_call", "method_name": "warning", "args": [msg], "kwargs": kwargs}) + + def error(self, msg: str, **kwargs) -> None: + """Log an error message""" + return self._send_request({"type": "method_call", "method_name": "error", "args": [msg], "kwargs": kwargs}) + + def _start_event_listener(self) -> None: + """Start the event listener thread to receive callbacks from server""" + if self._event_listener_running: + _l.debug("Event listener already running") + return + + _l.debug("Starting event listener") + + # Create a separate socket connection for receiving events + try: + self._event_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self._event_socket.settimeout(self.timeout) + self._event_socket.connect(self.socket_path) + + # Send subscription request to server + SocketProtocol.send_message(self._event_socket, {"type": "subscribe_events"}) + response = SocketProtocol.recv_message(self._event_socket) + + if response.get("status") == "subscribed": + self._subscribed_to_events = True + _l.debug("Successfully subscribed to events") + + # Start event listener thread + self._event_listener_running = True + self._event_listener_thread = threading.Thread( + target=self._event_listener_loop, + daemon=True + ) + self._event_listener_thread.start() + _l.info("Event listener started") + else: + _l.error(f"Failed to subscribe to events: {response}") + self._event_socket.close() + self._event_socket = None + + except Exception as e: + _l.error(f"Failed to start event listener: {e}") + if self._event_socket: + self._event_socket.close() + self._event_socket = None + + def _stop_event_listener(self) -> None: + """Stop the event listener thread""" + if not self._event_listener_running: + _l.debug("Event listener not running") + return + + _l.debug("Stopping event listener") + self._event_listener_running = False + + # Send unsubscribe request + if self._event_socket and self._subscribed_to_events: + try: + SocketProtocol.send_message(self._event_socket, {"type": "unsubscribe_events"}) + except: + pass + + # Close event socket + if self._event_socket: + try: + self._event_socket.close() + except: + pass + self._event_socket = None + + # Wait for thread to finish + if self._event_listener_thread and self._event_listener_thread.is_alive(): + self._event_listener_thread.join(timeout=2.0) + + self._subscribed_to_events = False + _l.info("Event listener stopped") + + def _event_listener_loop(self) -> None: + """Event listener thread loop that receives events from server""" + _l.debug("Event listener loop started") + + try: + while self._event_listener_running: + try: + # Set a timeout so we can periodically check if we should stop + self._event_socket.settimeout(1.0) + event = SocketProtocol.recv_message(self._event_socket) + + # Process the event + self._process_event(event) + + except socket.timeout: + # Normal timeout, continue loop + continue + except ConnectionError as e: + _l.warning(f"Event listener connection error: {e}") + break + except Exception as e: + _l.error(f"Error in event listener loop: {e}") + break + + except Exception as e: + _l.error(f"Fatal error in event listener loop: {e}") + finally: + _l.debug("Event listener loop ended") + self._event_listener_running = False + + def _process_event(self, event: Dict[str, Any]) -> None: + """Process an event received from the server""" + try: + event_type = event.get("event_type") + artifact_data = event.get("artifact") + + if not event_type or not artifact_data: + _l.warning(f"Invalid event received: {event}") + return + + # Reconstruct the artifact from serialized data + if isinstance(artifact_data, dict) and artifact_data.get("is_artifact"): + module_name = artifact_data['module'] + class_name = artifact_data['type'] + serialized_data = artifact_data['data'] + + # Import the module and get the class + module = __import__(module_name, fromlist=[class_name]) + artifact_class = getattr(module, class_name) + + # Reconstruct the artifact + artifact = artifact_class.loads(serialized_data) + + # Extract additional kwargs + kwargs = event.get("kwargs", {}) + + # Dispatch to appropriate handler based on event type + if event_type == "comment_changed": + self.comment_changed(artifact, **kwargs) + elif event_type == "function_header_changed": + self.function_header_changed(artifact, **kwargs) + elif event_type == "stack_variable_changed": + self.stack_variable_changed(artifact, **kwargs) + elif event_type == "struct_changed": + self.struct_changed(artifact, **kwargs) + elif event_type == "enum_changed": + self.enum_changed(artifact, **kwargs) + elif event_type == "typedef_changed": + self.typedef_changed(artifact, **kwargs) + elif event_type == "global_variable_changed": + self.global_variable_changed(artifact, **kwargs) + else: + _l.warning(f"Unknown event type: {event_type}") + + except Exception as e: + _l.error(f"Error processing event: {e}") + + # Lifecycle methods + def shutdown(self) -> None: + """Shutdown the client""" + _l.info("DecompilerClient shutting down") + + # Stop event listener first + if self._event_listener_running: + self._stop_event_listener() + + if self._socket: + try: + # Send shutdown request to server + self._send_request({"type": "shutdown_deci"}) + except: + pass + self._socket.close() + self._connected = False + _l.info("DecompilerClient shut down complete") + + def is_connected(self) -> bool: + """Check if connected to the server""" + return self._connected and self._socket + + def reconnect(self) -> None: + """Reconnect to the server""" + if self._socket: + self._socket.close() + self._connect() + + def ping(self) -> bool: + """Ping the server to check connectivity""" + try: + self._send_request({"type": "server_info"}) + return True + except Exception: + return False + + # Context manager support + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.shutdown() + + # Static methods for compatibility + @staticmethod + def discover(server_url: str = None, binary_hash: str = None, **kwargs) -> 'DecompilerClient': + """ + Discover and connect to a DecompilerServer. + + This method provides a similar interface to DecompilerInterface.discover() + but connects to a remote server instead. It intelligently handles: + - Stale socket files from previous server instances + - Multiple running servers + - Binary hash matching to connect to the correct server + + Args: + server_url: URL of the server (e.g., "unix:///tmp/libbs_server_abc123/decompiler.sock") + binary_hash: Optional binary hash to match against server's binary_hash + **kwargs: Additional arguments for DecompilerClient constructor + + Returns: + Connected DecompilerClient instance + + Raises: + ConnectionError: If no suitable server is found or connection fails + """ + if server_url: + # Parse server URL + if "://" in server_url: + protocol, path = server_url.split("://", 1) + if protocol != "unix": + _l.warning(f"Expected unix:// protocol, got {protocol}://") + socket_path = path + else: + # Assume it's a direct path + socket_path = server_url + + # If binary_hash is provided, validate it matches + if binary_hash: + try: + client = DecompilerClient(socket_path=socket_path, **kwargs) + server_hash = client.binary_hash + if server_hash != binary_hash: + client.shutdown() + raise ConnectionError( + f"Server at {socket_path} has binary_hash={server_hash}, " + f"but expected {binary_hash}" + ) + return client + except Exception as e: + raise ConnectionError(f"Failed to connect to server at {socket_path}: {e}") + else: + return DecompilerClient(socket_path=socket_path, **kwargs) + else: + # Auto-discovery: find all socket files and try to connect to each + temp_dir = tempfile.gettempdir() + pattern = os.path.join(temp_dir, "libbs_server_*/decompiler.sock") + matches = glob.glob(pattern) + + if not matches: + raise ConnectionError("No DecompilerServer found. Start one with: libbs --server") + + # Sort by modification time (newest first) to prefer recently started servers + matches.sort(key=lambda p: os.path.getmtime(p), reverse=True) + + _l.debug(f"Found {len(matches)} potential server socket(s)") + + # Try each socket, filtering by binary_hash if provided + successful_connections = [] + for socket_path in matches: + try: + _l.debug(f"Attempting connection to {socket_path}") + test_client = DecompilerClient(socket_path=socket_path, **kwargs) + + # Successfully connected, now check binary_hash if needed + if binary_hash: + try: + server_hash = test_client.binary_hash + if server_hash == binary_hash: + _l.info(f"Auto-discovered server at {socket_path} with matching binary_hash") + return test_client + else: + _l.debug(f"Server at {socket_path} has binary_hash={server_hash}, skipping") + test_client.shutdown() + except Exception as e: + _l.debug(f"Failed to get binary_hash from {socket_path}: {e}") + test_client.shutdown() + else: + # No binary_hash filter, use the first working server + _l.info(f"Auto-discovered server at {socket_path}") + return test_client + + except ConnectionError as e: + # This socket is defunct (server stopped), skip it + _l.debug(f"Failed to connect to {socket_path}: {e}") + continue + except Exception as e: + _l.debug(f"Unexpected error connecting to {socket_path}: {e}") + continue + + # No suitable server found + if binary_hash: + raise ConnectionError( + f"No DecompilerServer found with binary_hash={binary_hash}. " + f"Found {len(matches)} socket(s) but none matched." + ) + else: + raise ConnectionError( + f"No working DecompilerServer found. Found {len(matches)} socket(s) " + f"but all connections failed. Start a new server with: libbs --server" + ) + + # Properties that fetch values from server on first access + @property + def supports_undo(self) -> bool: + """Check if the decompiler supports undo operations""" + if self._supports_undo is None: + self._supports_undo = self._send_request({"type": "property_get", "property_name": "supports_undo"}) + return self._supports_undo + + @property + def supports_type_scopes(self) -> bool: + """Check if the decompiler supports type scopes""" + if self._supports_type_scopes is None: + self._supports_type_scopes = self._send_request({"type": "property_get", "property_name": "supports_type_scopes"}) + return self._supports_type_scopes + + @property + def qt_version(self) -> str: + """Get the Qt version used by the decompiler""" + if self._qt_version is None: + self._qt_version = self._send_request({"type": "property_get", "property_name": "qt_version"}) + return self._qt_version + + @property + def default_func_prefix(self) -> str: + """Get the default function prefix used by the decompiler""" + if self._default_func_prefix is None: + self._default_func_prefix = self._send_request({"type": "property_get", "property_name": "default_func_prefix"}) + return self._default_func_prefix + + @property + def headless(self) -> bool: + """Check if the decompiler is running in headless mode""" + if self._headless is None: + self._headless = self._send_request({"type": "property_get", "property_name": "headless"}) + return self._headless + + @property + def force_click_recording(self) -> bool: + """Check if click recording is forced""" + if self._force_click_recording is None: + self._force_click_recording = self._send_request({"type": "property_get", "property_name": "force_click_recording"}) + return self._force_click_recording + + @property + def track_mouse_moves(self) -> bool: + """Check if mouse moves are tracked""" + if self._track_mouse_moves is None: + self._track_mouse_moves = self._send_request({"type": "property_get", "property_name": "track_mouse_moves"}) + return self._track_mouse_moves + + @property + def default_pointer_size(self) -> int: + """Get default pointer size""" + return self._send_request({"type": "property_get", "property_name": "default_pointer_size"}) + + # Artifact watcher methods + def start_artifact_watchers(self) -> None: + """Start artifact watchers on the remote decompiler""" + result = self._send_request({"type": "method_call", "method_name": "start_artifact_watchers"}) + self.artifact_watchers_started = True + + # Start event listener to receive callbacks from server + self._start_event_listener() + + return result + + def stop_artifact_watchers(self) -> None: + """Stop artifact watchers on the remote decompiler""" + # Stop event listener first + self._stop_event_listener() + + result = self._send_request({"type": "method_call", "method_name": "stop_artifact_watchers"}) + self.artifact_watchers_started = False + return result + + def should_watch_artifacts(self) -> bool: + """Check if artifacts should be watched""" + return self._send_request({"type": "method_call", "method_name": "should_watch_artifacts"}) + + # GUI registration methods (stubs since we can't proxy GUI operations) + def gui_register_ctx_menu(self, name: str, action_string: str, callback_func: Callable, category=None) -> bool: + """Register a context menu item (not supported in remote mode)""" + _l.warning("GUI context menu registration is not supported in remote decompiler mode") + return False + + def gui_register_ctx_menu_many(self, actions: dict) -> None: + """Register multiple context menu items (not supported in remote mode)""" + _l.warning("GUI context menu registration is not supported in remote decompiler mode") + + def gui_run_on_main_thread(self, func: Callable, *args, **kwargs): + """Run function on main thread (not supported in remote mode)""" + _l.warning("GUI main thread operations are not supported in remote decompiler mode") + raise NotImplementedError("GUI main thread operations not supported in remote mode") + + def gui_attach_qt_window(self, qt_window, title: str, target_window=None, position=None, *args, **kwargs) -> bool: + """Attach Qt window (not supported in remote mode)""" + _l.warning("GUI window attachment is not supported in remote decompiler mode") + return False + + # Event callback methods (these trigger callbacks locally but don't send to server) + def decompiler_opened_event(self, **kwargs): + """Handle decompiler opened event""" + for callback in self.decompiler_opened_callbacks: + try: + if self._thread_artifact_callbacks: + import threading + thread = threading.Thread(target=callback, kwargs=kwargs) + thread.start() + else: + callback(**kwargs) + except Exception as e: + _l.error(f"Error in decompiler opened callback: {e}") + + def decompiler_closed_event(self, **kwargs): + """Handle decompiler closed event""" + for callback in self.decompiler_closed_callbacks: + try: + if self._thread_artifact_callbacks: + import threading + thread = threading.Thread(target=callback, kwargs=kwargs) + thread.start() + else: + callback(**kwargs) + except Exception as e: + _l.error(f"Error in decompiler closed callback: {e}") + + def gui_undo_event(self, **kwargs): + """Handle GUI undo event""" + for callback in self.undo_event_callbacks: + try: + if self._thread_artifact_callbacks: + import threading + thread = threading.Thread(target=callback, kwargs=kwargs) + thread.start() + else: + callback(**kwargs) + except Exception as e: + _l.error(f"Error in undo event callback: {e}") + + def gui_context_changed(self, ctx: Context, **kwargs) -> Context: + """Handle GUI context changed event""" + # This would typically be handled by GUI callbacks locally + return ctx + + # Artifact change event methods (these handle local callbacks) + def function_header_changed(self, fheader, **kwargs): + """Handle function header changed event""" + for callback in self.artifact_change_callbacks.get(type(fheader), []): + try: + if self._thread_artifact_callbacks: + import threading + thread = threading.Thread(target=callback, args=(fheader,), kwargs=kwargs) + thread.start() + else: + callback(fheader, **kwargs) + except Exception as e: + _l.error(f"Error in function header change callback: {e}") + return fheader + + def stack_variable_changed(self, svar, **kwargs): + """Handle stack variable changed event""" + for callback in self.artifact_change_callbacks.get(type(svar), []): + try: + if self._thread_artifact_callbacks: + import threading + thread = threading.Thread(target=callback, args=(svar,), kwargs=kwargs) + thread.start() + else: + callback(svar, **kwargs) + except Exception as e: + _l.error(f"Error in stack variable change callback: {e}") + return svar + + def comment_changed(self, comment: Comment, deleted=False, **kwargs) -> Comment: + """Handle comment changed event""" + kwargs["deleted"] = deleted + for callback in self.artifact_change_callbacks.get(Comment, []): + try: + if self._thread_artifact_callbacks: + import threading + thread = threading.Thread(target=callback, args=(comment,), kwargs=kwargs) + thread.start() + else: + callback(comment, **kwargs) + except Exception as e: + _l.error(f"Error in comment change callback: {e}") + return comment + + def struct_changed(self, struct: Struct, deleted=False, **kwargs) -> Struct: + """Handle struct changed event""" + kwargs["deleted"] = deleted + for callback in self.artifact_change_callbacks.get(Struct, []): + try: + if self._thread_artifact_callbacks: + import threading + thread = threading.Thread(target=callback, args=(struct,), kwargs=kwargs) + thread.start() + else: + callback(struct, **kwargs) + except Exception as e: + _l.error(f"Error in struct change callback: {e}") + return struct + + def enum_changed(self, enum: Enum, deleted=False, **kwargs) -> Enum: + """Handle enum changed event""" + kwargs["deleted"] = deleted + for callback in self.artifact_change_callbacks.get(Enum, []): + try: + if self._thread_artifact_callbacks: + import threading + thread = threading.Thread(target=callback, args=(enum,), kwargs=kwargs) + thread.start() + else: + callback(enum, **kwargs) + except Exception as e: + _l.error(f"Error in enum change callback: {e}") + return enum + + def typedef_changed(self, typedef: Typedef, deleted=False, **kwargs) -> Typedef: + """Handle typedef changed event""" + kwargs["deleted"] = deleted + for callback in self.artifact_change_callbacks.get(Typedef, []): + try: + if self._thread_artifact_callbacks: + import threading + thread = threading.Thread(target=callback, args=(typedef,), kwargs=kwargs) + thread.start() + else: + callback(typedef, **kwargs) + except Exception as e: + _l.error(f"Error in typedef change callback: {e}") + return typedef + + def global_variable_changed(self, gvar: GlobalVariable, **kwargs) -> GlobalVariable: + """Handle global variable changed event""" + for callback in self.artifact_change_callbacks.get(GlobalVariable, []): + try: + if self._thread_artifact_callbacks: + import threading + thread = threading.Thread(target=callback, args=(gvar,), kwargs=kwargs) + thread.start() + else: + callback(gvar, **kwargs) + except Exception as e: + _l.error(f"Error in global variable change callback: {e}") + return gvar \ No newline at end of file diff --git a/libbs/api/decompiler_interface.py b/libbs/api/decompiler_interface.py index daa5ab62..a592b3f8 100644 --- a/libbs/api/decompiler_interface.py +++ b/libbs/api/decompiler_interface.py @@ -991,6 +991,13 @@ def find_current_decompiler(force: str = None) -> Optional[str]: if "License is not valid" in str(e): _l.warning("Binary Ninja license is invalid, skipping...") + # Ghidra + this_obj = DecompilerInterface._find_global_in_call_frames("__this__") + if (this_obj is not None) and (hasattr(this_obj, "currentProgram")): + available.add(GHIDRA_DECOMPILER) + if not force: + return GHIDRA_DECOMPILER + # angr-management try: import angr @@ -1015,25 +1022,6 @@ def find_current_decompiler(force: str = None) -> Optional[str]: except Exception: pass - # Ghidra - # It is always available, and we don't have an import check, because when started in headless mode we create - # the interface by which ghidra can now be imported. - available.add(GHIDRA_DECOMPILER) - import socket - from libbs.decompiler_stubs.ghidra_libbs.libbs_vendored.ghidra_bridge_port import DEFAULT_SERVER_PORT - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(2) # 2 Second Timeout - try: - if sock.connect_ex(('127.0.0.1', DEFAULT_SERVER_PORT)) == 0: - if not force: - return GHIDRA_DECOMPILER - except ConnectionError: - pass - this_obj = DecompilerInterface._find_global_in_call_frames("__this__") - if (this_obj is not None) and (hasattr(this_obj, "currentProgram")): - if not force: - return GHIDRA_DECOMPILER - if not available: _l.critical("LibBS was unable to find the current decompiler you are running in or any headless instances!") return None diff --git a/libbs/api/decompiler_server.py b/libbs/api/decompiler_server.py new file mode 100644 index 00000000..9d1482f4 --- /dev/null +++ b/libbs/api/decompiler_server.py @@ -0,0 +1,546 @@ +# Note to reader: most of this code was generated by Claude 4.5. It may contain errors and was designed +# in tandem with decompiler_client.py and the tests/test_client_server.py file. This comment will be +# removed when the majority of the file is owned by a human author. + +import logging +import pickle +import socket +import struct +import threading +import time +import tempfile +import os +from typing import Optional, Dict, Any, List + +from libbs.api.decompiler_interface import DecompilerInterface + +_l = logging.getLogger(__name__) + + +class SocketProtocol: + """Helper class for socket protocol message framing""" + + @staticmethod + def send_message(sock: socket.socket, data: Any) -> None: + """Send a pickled message with length prefix""" + try: + pickled_data = pickle.dumps(data) + msg_len = len(pickled_data) + + # Send 4-byte length prefix + sock.sendall(struct.pack('!I', msg_len)) + # Send pickled data + sock.sendall(pickled_data) + except (ConnectionError, BrokenPipeError, OSError) as e: + # Expected during shutdown when socket is closed, just re-raise + raise + except Exception as e: + # Unexpected error - log it + _l.error(f"Failed to send message (pickle.dumps): {e}") + _l.error(f"Data type: {type(data)}") + if hasattr(data, '__dict__'): + _l.error(f"Data dict: {data.__dict__}") + raise + + @staticmethod + def recv_message(sock: socket.socket) -> Any: + """Receive a pickled message with length prefix""" + pickled_data = b'' + try: + # Receive 4-byte length prefix + len_data = sock.recv(4) + if len(len_data) != 4: + raise ConnectionError("Failed to receive message length") + + msg_len = struct.unpack('!I', len_data)[0] + + # Receive the pickled data + while len(pickled_data) < msg_len: + chunk = sock.recv(msg_len - len(pickled_data)) + if not chunk: + raise ConnectionError("Connection closed while receiving message") + pickled_data += chunk + + return pickle.loads(pickled_data) + except (ConnectionError, socket.timeout): + # Expected during shutdown or normal timeout, just re-raise without logging + raise + except Exception as e: + # Unexpected error - log it + _l.error(f"Failed to receive message (pickle.loads): {e}") + if pickled_data: + _l.error(f"Received {len(pickled_data)} bytes of pickle data") + raise + + +class SocketServerHandler: + """Handler for individual client connections""" + + def __init__(self, deci: DecompilerInterface, server: 'DecompilerServer' = None): + self.deci = deci + self.server = server + self._light_caches = {} + self._cache_lock = threading.Lock() + self._cache_ttl = 10.0 + + def handle_client(self, client_socket: socket.socket, addr: str): + """Handle a client connection""" + _l.info(f"Client connected: {addr}") + + try: + while True: + try: + request = SocketProtocol.recv_message(client_socket) + response = self._process_request(request, client_socket=client_socket) + SocketProtocol.send_message(client_socket, response) + except ConnectionError: + # Client disconnected + break + except Exception as e: + # Send error response + error_response = {"error": str(e), "type": type(e).__name__} + try: + SocketProtocol.send_message(client_socket, error_response) + except: + break + finally: + # Remove from event subscribers if subscribed + if self.server: + with self.server._event_subscribers_lock: + if client_socket in self.server._event_subscribers: + self.server._event_subscribers.remove(client_socket) + _l.debug("Removed client from event subscribers") + + client_socket.close() + _l.info(f"Client disconnected: {addr}") + + def _process_request(self, request: Dict[str, Any], client_socket: socket.socket = None) -> Any: + """Process a client request and return response""" + request_type = request.get("type") + + if request_type == "subscribe_events": + # Client wants to subscribe to artifact change events + if self.server and client_socket: + with self.server._event_subscribers_lock: + if client_socket not in self.server._event_subscribers: + self.server._event_subscribers.append(client_socket) + _l.info(f"Client subscribed to events (total subscribers: {len(self.server._event_subscribers)})") + return {"status": "subscribed"} + else: + return {"status": "error", "message": "Server not available"} + + elif request_type == "unsubscribe_events": + # Client wants to unsubscribe from events + if self.server and client_socket: + with self.server._event_subscribers_lock: + if client_socket in self.server._event_subscribers: + self.server._event_subscribers.remove(client_socket) + _l.info(f"Client unsubscribed from events (total subscribers: {len(self.server._event_subscribers)})") + return {"status": "unsubscribed"} + else: + return {"status": "error", "message": "Server not available"} + + elif request_type == "server_info": + return { + "name": "LibBS DecompilerServer (AF_UNIX)", + "version": "3.0.0", + "decompiler": self.deci.name if self.deci else "unknown", + "protocol": "unix_socket", + "binary_hash": self.deci.binary_hash if self.deci else None + } + + elif request_type == "get_light_artifacts": + collection_name = request.get("collection_name") + return self._get_light_artifacts(collection_name) + + elif request_type == "get_full_artifact": + collection_name = request.get("collection_name") + key = request.get("key") + collection = getattr(self.deci, collection_name) + artifact = collection[key] + + # Serialize the full artifact safely + if hasattr(artifact, 'dumps') and hasattr(artifact, '__class__'): + try: + return { + 'type': artifact.__class__.__name__, + 'module': artifact.__class__.__module__, + 'data': artifact.dumps(), + 'is_artifact': True + } + except Exception as e: + _l.warning(f"Failed to serialize full artifact: {e}") + # Fall back to direct return, which might fail with pickle + return artifact + else: + return artifact + + elif request_type == "method_call": + method_name = request.get("method_name") + args = request.get("args", []) + kwargs = request.get("kwargs", {}) + + # Handle dotted method names like "art_lifter.lift" + if "." in method_name: + obj = self.deci + for attr in method_name.split("."): + obj = getattr(obj, attr) + method = obj + else: + # Get the method from the decompiler interface + method = getattr(self.deci, method_name) + result = method(*args, **kwargs) + + # Check if result is an artifact and serialize it properly + if hasattr(result, 'dumps') and hasattr(result, '__class__'): + # This looks like an artifact, serialize it safely + try: + return { + 'type': result.__class__.__name__, + 'module': result.__class__.__module__, + 'data': result.dumps(), + 'is_artifact': True + } + except Exception as e: + _l.warning(f"Failed to serialize result artifact: {e}") + # Fall back to direct return, which might fail with pickle + return result + else: + # Not an artifact, return as-is + return result + + elif request_type == "property_get": + property_name = request.get("property_name") + return getattr(self.deci, property_name) + + elif request_type == "shutdown_deci": + if self.deci: + self.deci.shutdown() + return {"status": "shutdown"} + + else: + raise ValueError(f"Unknown request type: {request_type}") + + def _get_light_artifacts(self, collection_name: str) -> Dict: + """Get light artifacts for a collection, computing and caching on first request""" + with self._cache_lock: + cache_entry = self._light_caches.get(collection_name) + + # Check if we have a valid cache entry + if cache_entry and time.time() - cache_entry["timestamp"] < self._cache_ttl: + return cache_entry["items"] + + # Cache miss or stale - compute light artifacts on-demand + _l.debug(f"Computing light artifacts for {collection_name} on-demand") + try: + collection = getattr(self.deci, collection_name) + if hasattr(collection, '_lifted_art_lister'): + start_time = time.time() + light_items = collection._lifted_art_lister() + end_time = time.time() + + # Convert artifacts to serializable format using their own serialization + serializable_items = {} + for addr, artifact in light_items.items(): + try: + # Use the artifact's built-in serialization which handles complex objects + serialized = artifact.dumps() + # Store as a tuple of (type_name, serialized_data) for reconstruction + serializable_items[addr] = { + 'type': artifact.__class__.__name__, + 'module': artifact.__class__.__module__, + 'data': serialized + } + except Exception as e: + _l.warning(f"Failed to serialize {artifact.__class__.__name__} at 0x{addr:x}: {e}") + # Skip problematic artifacts rather than failing completely + continue + + # Cache the serializable artifacts + self._light_caches[collection_name] = { + "items": serializable_items, + "timestamp": time.time() + } + + _l.info(f"Computed {len(serializable_items)} light {collection_name} in {end_time - start_time:.3f}s") + return serializable_items + else: + _l.warning(f"Collection {collection_name} does not support light artifacts") + return {} + + except Exception as e: + _l.warning(f"Failed to compute light artifacts for {collection_name}: {e}") + # Return stale cache if available, otherwise empty dict + if cache_entry: + _l.debug(f"Returning stale cache for {collection_name} due to error") + return cache_entry["items"] + return {} + + +class DecompilerServer: + """ + A server that exposes DecompilerInterface APIs over AF_UNIX sockets. + + This server wraps a DecompilerInterface instance and provides network access + to all its public methods and artifact collections through AF_UNIX sockets. + """ + + def __init__(self, + decompiler_interface: Optional[DecompilerInterface] = None, + socket_path: Optional[str] = None, + **interface_kwargs): + """ + Initialize the DecompilerServer. + + Args: + decompiler_interface: An existing DecompilerInterface instance. If None, + one will be created using DecompilerInterface.discover() + socket_path: Path for the AF_UNIX socket. If None, a temporary path will be used + **interface_kwargs: Arguments passed to DecompilerInterface.discover() if + decompiler_interface is None + """ + + self.socket_path = socket_path + self._server_socket = None + self._server_thread = None + self._running = False + self._clients = [] + self._client_threads = [] + + # Event subscription tracking + self._event_subscribers = [] # List of sockets subscribed to events + self._event_subscribers_lock = threading.Lock() + + # Initialize the decompiler interface + if decompiler_interface is not None: + self.deci = decompiler_interface + else: + if interface_kwargs and interface_kwargs.get("headless", False): + forced_decompiler = interface_kwargs.get("force_decompiler", None) + if forced_decompiler is None: + _l.warning(f"Using a headless interface without setting a decompiler has unpredictable behavior!") + _l.info(f"Using headless interface utilizing %s", forced_decompiler) + else: + _l.info("Discovering decompiler interface...") + + self.deci = DecompilerInterface.discover(**interface_kwargs) + if self.deci is None: + raise RuntimeError("Failed to discover decompiler interface") + + # Create socket handler + self.handler = SocketServerHandler(self.deci, server=self) + + # Register artifact change callbacks to broadcast events + self._register_artifact_callbacks() + + # Generate socket path if not provided + if self.socket_path is None: + temp_dir = tempfile.mkdtemp(prefix="libbs_server_") + self.socket_path = os.path.join(temp_dir, "decompiler.sock") + self._temp_dir = temp_dir + else: + self._temp_dir = None + + _l.info(f"DecompilerServer initialized with {self.deci.name} interface") + _l.info(f"Socket path: {self.socket_path}") + + def _register_artifact_callbacks(self): + """Register callbacks to broadcast artifact changes to subscribed clients""" + from libbs.artifacts import Comment, Struct, Enum, Typedef, GlobalVariable, FunctionHeader, StackVariable + + # Register callbacks for different artifact types + self.deci.artifact_change_callbacks[Comment].append( + lambda artifact, **kwargs: self._broadcast_event("comment_changed", artifact, **kwargs) + ) + self.deci.artifact_change_callbacks[Struct].append( + lambda artifact, **kwargs: self._broadcast_event("struct_changed", artifact, **kwargs) + ) + self.deci.artifact_change_callbacks[Enum].append( + lambda artifact, **kwargs: self._broadcast_event("enum_changed", artifact, **kwargs) + ) + self.deci.artifact_change_callbacks[Typedef].append( + lambda artifact, **kwargs: self._broadcast_event("typedef_changed", artifact, **kwargs) + ) + self.deci.artifact_change_callbacks[GlobalVariable].append( + lambda artifact, **kwargs: self._broadcast_event("global_variable_changed", artifact, **kwargs) + ) + self.deci.artifact_change_callbacks[FunctionHeader].append( + lambda artifact, **kwargs: self._broadcast_event("function_header_changed", artifact, **kwargs) + ) + self.deci.artifact_change_callbacks[StackVariable].append( + lambda artifact, **kwargs: self._broadcast_event("stack_variable_changed", artifact, **kwargs) + ) + + def _broadcast_event(self, event_type: str, artifact, **kwargs): + """Broadcast an artifact change event to all subscribed clients""" + with self._event_subscribers_lock: + if not self._event_subscribers: + _l.debug(f"No subscribers for event: {event_type}") + return + + # Serialize the artifact + try: + serialized_artifact = { + 'type': artifact.__class__.__name__, + 'module': artifact.__class__.__module__, + 'data': artifact.dumps(), + 'is_artifact': True + } + + event_message = { + "event_type": event_type, + "artifact": serialized_artifact, + "kwargs": kwargs + } + + # Send to all subscribers + dead_subscribers = [] + for subscriber_socket in self._event_subscribers: + try: + SocketProtocol.send_message(subscriber_socket, event_message) + _l.debug(f"Broadcasted {event_type} to subscriber") + except Exception as e: + _l.warning(f"Failed to send event to subscriber: {e}") + dead_subscribers.append(subscriber_socket) + + # Remove dead subscribers + for dead_socket in dead_subscribers: + self._event_subscribers.remove(dead_socket) + _l.debug("Removed dead subscriber") + + except Exception as e: + _l.error(f"Failed to broadcast event {event_type}: {e}") + + def start(self): + """Start the server in a separate thread""" + if self._running: + _l.warning("Server is already running") + return + + _l.info(f"Starting DecompilerServer on {self.socket_path}") + + # Create AF_UNIX socket + self._server_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + + # Set timeout so accept() doesn't block forever + self._server_socket.settimeout(1.0) + + # Remove socket file if it exists + if os.path.exists(self.socket_path): + os.unlink(self.socket_path) + + # Bind and listen + self._server_socket.bind(self.socket_path) + self._server_socket.listen(5) + + # Set running flag before starting thread + self._running = True + + # Start server in a separate thread + self._server_thread = threading.Thread(target=self._server_loop, daemon=True) + self._server_thread.start() + + _l.info(f"DecompilerServer started successfully on unix://{self.socket_path}") + _l.info("Connect with: DecompilerClient.discover('unix://{}')".format(self.socket_path)) + + def _server_loop(self): + """Main server loop""" + try: + while self._running: + try: + client_socket, addr = self._server_socket.accept() + self._clients.append(client_socket) + + # Handle client in separate thread + client_thread = threading.Thread( + target=self.handler.handle_client, + args=(client_socket, str(addr)), + daemon=True + ) + self._client_threads.append(client_thread) + client_thread.start() + + except socket.timeout: + # Normal timeout, continue loop to check if we should stop + continue + except OSError: + # Socket was closed + break + except Exception as e: + _l.error(f"Error accepting client: {e}") + + except Exception as e: + _l.error(f"Server loop error: {e}") + finally: + _l.info("Server loop ended") + + def stop(self): + """Stop the server""" + if not self._running: + _l.warning("Server is not running") + return + + _l.info("Stopping DecompilerServer...") + self._running = False + + # Close all client connections + for client in self._clients: + try: + client.close() + except: + pass + + # Close server socket + if self._server_socket: + self._server_socket.close() + + # Wait for threads to finish (short timeout since we use daemon threads) + if self._server_thread and self._server_thread.is_alive(): + self._server_thread.join(timeout=2.0) + + for thread in self._client_threads: + if thread.is_alive(): + thread.join(timeout=0.5) + + # Clean up socket file and temp directory + if os.path.exists(self.socket_path): + os.unlink(self.socket_path) + + if self._temp_dir and os.path.exists(self._temp_dir): + try: + os.rmdir(self._temp_dir) + except: + pass + + # Shutdown the decompiler interface + if self.deci: + try: + self.deci.shutdown() + except Exception as e: + _l.warning(f"Error shutting down decompiler: {e}") + + _l.info("DecompilerServer stopped") + + def is_running(self) -> bool: + """Check if the server is currently running""" + return self._running + + def wait_for_shutdown(self): + """Wait for the server to be shut down (blocking)""" + if self._server_thread and self._server_thread.is_alive(): + try: + self._server_thread.join() + except KeyboardInterrupt: + _l.info("Received interrupt signal, stopping server...") + self.stop() + + def __enter__(self): + """Context manager entry""" + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit""" + self.stop() + if self.deci: + self.deci.shutdown() \ No newline at end of file diff --git a/libbs/artifacts/artifact.py b/libbs/artifacts/artifact.py index debfafe7..dca48ac3 100644 --- a/libbs/artifacts/artifact.py +++ b/libbs/artifacts/artifact.py @@ -6,6 +6,7 @@ from .formatting import ArtifactFormat, TomlHexEncoder +from toml.tz import TomlTz class Artifact: """ @@ -30,12 +31,42 @@ def __init__(self, last_change: Optional[datetime.datetime] = None, scope: Optio self.scope = scope self._attr_ignore_set = set() + @staticmethod + def _normalize_datetime(dt): + """ + Convert TomlTz datetime objects to standard Python datetime objects. + TomlTz objects from TOML deserialization don't pickle correctly. + """ + if not isinstance(dt, datetime.datetime): + return dt + + # If the datetime has a TomlTz tzinfo, convert it to standard timezone + if dt.tzinfo is not None and isinstance(dt.tzinfo, TomlTz): + # Get the offset and convert to standard timezone + offset = dt.utcoffset() + if offset is not None: + std_tz = datetime.timezone(offset) + # Replace the TomlTz with standard timezone + return dt.replace(tzinfo=std_tz) + + return dt + def __getstate__(self) -> Dict: - return dict( - (k, getattr(self, k)) for k in self.slots - ) + state = {} + for k in self.slots: + value = getattr(self, k) + # Normalize datetime objects to ensure they pickle correctly + if isinstance(value, datetime.datetime): + value = self._normalize_datetime(value) + state[k] = value + return state def __setstate__(self, state): + # When pickle calls __setstate__, __init__ is never called, so we need to + # initialize _attr_ignore_set before accessing self.slots (which uses it) + if not hasattr(self, '_attr_ignore_set'): + self._attr_ignore_set = set() + for k in self.slots: if k in state: setattr(self, k, state[k]) diff --git a/libbs/artifacts/func.py b/libbs/artifacts/func.py index d26fc9dc..b68b1393 100644 --- a/libbs/artifacts/func.py +++ b/libbs/artifacts/func.py @@ -70,6 +70,7 @@ def __getstate__(self): return data_dict def __setstate__(self, state): + # Pop nested object data and reconstruct in local variable args_dict = state.pop("args", {}) new_args_dict = {} for k, v in args_dict.items(): @@ -77,7 +78,10 @@ def __setstate__(self, state): fa.__setstate__(v) new_args_dict[int(k, 0)] = fa - self.args = new_args_dict + # Put reconstructed objects back in state + state["args"] = new_args_dict + + # Let super set all attributes at once super().__setstate__(state) def diff(self, other, **kwargs) -> Dict: @@ -239,38 +243,50 @@ def __getstate__(self): return state def __setstate__(self, state): + # When pickle calls __setstate__, __init__ is never called + # Initialize _attr_ignore_set and add dec_obj to it (as done in __init__) + if not hasattr(self, '_attr_ignore_set'): + self._attr_ignore_set = set() + self._attr_ignore_set.add("dec_obj") + # XXX: this is a backport of the old state format. Remove this after a few releases. if "metadata" in state: metadata: Dict = state.pop("metadata") metadata.update(state) state = metadata + # Pop nested object data and reconstruct in local variables header_dat = state.pop("header", None) if header_dat: header = FunctionHeader() header.__setstate__(header_dat) else: header = None - self.header = header - # alias for name overrides header if it exists - if "name" in state: - self.name = state.pop("name") - # alias for type overrides header if it exists - if "type" in state: - self.type = state.pop("type") + # Handle name/type aliases that override header values + # We modify the header object directly instead of using property setters + # to avoid accessing self.header and self.addr before they're initialized + name_override = state.pop("name", None) + type_override = state.pop("type", None) + + if name_override is not None and header is not None: + header.name = name_override + if type_override is not None and header is not None: + header.type = type_override stack_vars_dat = state.pop("stack_vars", {}) + stack_vars = {} if stack_vars_dat: - stack_vars = {} for off, stack_var in stack_vars_dat.items(): sv = StackVariable() sv.__setstate__(stack_var) stack_vars[int(off, 0)] = sv - else: - stack_vars = None - self.stack_vars = stack_vars or {} + # Put reconstructed objects back in state + state["header"] = header + state["stack_vars"] = stack_vars + + # Let super set all attributes at once super().__setstate__(state) def diff(self, other, **kwargs) -> Dict: diff --git a/libbs/artifacts/patch.py b/libbs/artifacts/patch.py index ad39a9eb..3c8f8d95 100644 --- a/libbs/artifacts/patch.py +++ b/libbs/artifacts/patch.py @@ -36,7 +36,14 @@ def __getstate__(self): return data_dict def __setstate__(self, state): + # Pop and decode bytes data bytes_dat = state.pop("bytes", None) + decoded_bytes = None if bytes_dat: - self.bytes = codecs.decode(bytes_dat, "hex") + decoded_bytes = codecs.decode(bytes_dat, "hex") + + # Put decoded bytes back in state + state["bytes"] = decoded_bytes + + # Let super set all attributes at once super().__setstate__(state) diff --git a/libbs/artifacts/struct.py b/libbs/artifacts/struct.py index 13aabd02..236a11d3 100644 --- a/libbs/artifacts/struct.py +++ b/libbs/artifacts/struct.py @@ -83,14 +83,19 @@ def __setstate__(self, state): metadata.update(state) state = metadata + # Pop nested object data and reconstruct in local variable members_dat = state.pop("members", None) + members = {} if members_dat: for off, member in members_dat.items(): sm = StructMember() sm.__setstate__(member) - self.members[int(off, 0)] = sm - else: - self.members = {} + members[int(off, 0)] = sm + + # Put reconstructed objects back in state + state["members"] = members + + # Let super set all attributes at once super().__setstate__(state) def add_struct_member(self, mname, moff, mtype, size): diff --git a/libbs/decompiler_stubs/ghidra_libbs/__init__.py b/libbs/decompiler_stubs/ghidra_libbs/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/libbs/decompiler_stubs/ghidra_libbs/ghidra_libbs.py b/libbs/decompiler_stubs/ghidra_libbs/ghidra_libbs.py deleted file mode 100644 index 02675bdb..00000000 --- a/libbs/decompiler_stubs/ghidra_libbs/ghidra_libbs.py +++ /dev/null @@ -1,10 +0,0 @@ -# Starts the LibBS backend for Ghidra scripts. -# @author LibBS -# @category LibBS -# @menupath Tools.LibBS.Start LibBS Backend - -import subprocess -from libbs_vendored.ghidra_bridge_server import GhidraBridgeServer - -if __name__ == "__main__": - GhidraBridgeServer.run_server(background=True) diff --git a/libbs/decompiler_stubs/ghidra_libbs/ghidra_libbs_mainthread_server.py b/libbs/decompiler_stubs/ghidra_libbs/ghidra_libbs_mainthread_server.py deleted file mode 100644 index f184583c..00000000 --- a/libbs/decompiler_stubs/ghidra_libbs/ghidra_libbs_mainthread_server.py +++ /dev/null @@ -1,8 +0,0 @@ -# Starts the LibBS backend for Ghidra scripts. -# @author LibBS -# @category LibBS - -from libbs_vendored.ghidra_bridge_server import GhidraBridgeServer - -if __name__ == "__main__": - GhidraBridgeServer.run_server(background=False) diff --git a/libbs/decompiler_stubs/ghidra_libbs/ghidra_libbs_shutdown.py b/libbs/decompiler_stubs/ghidra_libbs/ghidra_libbs_shutdown.py deleted file mode 100644 index 093fdf85..00000000 --- a/libbs/decompiler_stubs/ghidra_libbs/ghidra_libbs_shutdown.py +++ /dev/null @@ -1,15 +0,0 @@ -# Shutdown the LibBS backend server. -# @author LibBS -# @category LibBS -# @menupath Tools.LibBS.Shutdown LibBS Backend - -from libbs_vendored.jfx_bridge import bridge -from libbs_vendored.ghidra_bridge_port import DEFAULT_SERVER_PORT - -if __name__ == "__main__": - print("Requesting server shutdown...") - b = bridge.BridgeClient( - connect_to_host="127.0.0.1", connect_to_port=DEFAULT_SERVER_PORT - ) - - print(b.remote_shutdown()) diff --git a/libbs/decompiler_stubs/ghidra_libbs/libbs_vendored/__init__.py b/libbs/decompiler_stubs/ghidra_libbs/libbs_vendored/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/libbs/decompiler_stubs/ghidra_libbs/libbs_vendored/ghidra_bridge_port.py b/libbs/decompiler_stubs/ghidra_libbs/libbs_vendored/ghidra_bridge_port.py deleted file mode 100644 index bd3d4f30..00000000 --- a/libbs/decompiler_stubs/ghidra_libbs/libbs_vendored/ghidra_bridge_port.py +++ /dev/null @@ -1 +0,0 @@ -DEFAULT_SERVER_PORT = 4768 \ No newline at end of file diff --git a/libbs/decompiler_stubs/ghidra_libbs/libbs_vendored/ghidra_bridge_server.py b/libbs/decompiler_stubs/ghidra_libbs/libbs_vendored/ghidra_bridge_server.py deleted file mode 100644 index df21e6e8..00000000 --- a/libbs/decompiler_stubs/ghidra_libbs/libbs_vendored/ghidra_bridge_server.py +++ /dev/null @@ -1,203 +0,0 @@ -import logging -import subprocess -import sys -from .jfx_bridge import bridge -from .ghidra_bridge_port import DEFAULT_SERVER_PORT - -# NOTE: we definitely DON'T want to exclude ghidra from ghidra_bridge :P -import ghidra - - -class GhidraBridgeServer(object): - """ Class mostly used to collect together functions and variables that we don't want contaminating the global namespace - variables set in remote clients - - NOTE: this class needs to be excluded from ghidra_bridge - it doesn't need to be in the globals, if people want it and - know what they're doing, they can get it from the BridgedObject for the main module - """ - - class PrintAccumulator(object): - """ Class to handle capturing print output so we can send it across the bridge, by hooking sys.stdout.write(). - Not multithreading aware, it'll just capture whatever is printed from the moment it hooks to the moment - it stops. - """ - - output = None - old_stdout = None - - def __init__(self): - self.output = "" - - def write(self, output): - self.output += output - - def get_output(self): - return self.output - - def hook(self): - self.old_stdout = sys.stdout - sys.stdout = self - - def unhook(self): - if self.old_stdout is not None: - sys.stdout = self.old_stdout - - def __enter__(self): - self.hook() - - return self - - def __exit__(self, type, value, traceback): - self.unhook() - - @staticmethod - def ghidra_help(param=None): - """ call the ghidra help method, capturing the print output with PrintAccumulator, and return it as a string """ - with GhidraBridgeServer.PrintAccumulator() as help_output: - help(param) - - return help_output.get_output() - - class InteractiveListener(ghidra.framework.model.ToolListener): - """ Class to handle registering for plugin events associated with the GUI - environment, and sending them back to clients running in interactive mode - so they can update their variables - - We define the interactive listener on the server end, so it can - cleanly recover from bridge failures when trying to send messages back. If we - let it propagate exceptions up into Ghidra, the GUI gets unhappy and can stop - sending tool events out - """ - - tool = None - callback_fn = None - - def __init__(self, tool, callback_fn): - """ Create with the tool to listen to (from state.getTool() - won't change during execution) - and the callback function to notify on the client end (should be the update_vars function) """ - self.tool = tool - self.callback_fn = callback_fn - - # register the listener against the remote tool - tool.addToolListener(self) - - def stop_listening(self): - # we're done, make sure we remove the tool listener - self.tool.removeToolListener(self) - - def processToolEvent(self, plugin_event): - """ Called by the ToolListener interface """ - try: - self.callback_fn._bridge_conn.logger.debug( - "InteractiveListener got event: " + str(plugin_event) - ) - - event_name = plugin_event.getEventName() - if "Location" in event_name: - self.callback_fn( - currentProgram=plugin_event.getProgram(), - currentLocation=plugin_event.getLocation(), - ) - elif "Selection" in event_name: - self.callback_fn( - currentProgram=plugin_event.getProgram(), - currentSelection=plugin_event.getSelection(), - ) - elif "Highlight" in event_name: - self.callback_fn( - currentProgram=plugin_event.getProgram(), - currentHighlight=plugin_event.getHighlight(), - ) - except Exception as e: - # any exception, we just want to bail and shut down the listener. - # most likely case is the bridge connection has gone down. - self.stop_listening() - self.callback_fn._bridge_conn.logger.error( - "InteractiveListener failed trying to callback client: " + str(e) - ) - - @staticmethod - def run_server( - server_host=bridge.DEFAULT_HOST, - server_port=DEFAULT_SERVER_PORT, - response_timeout=bridge.DEFAULT_RESPONSE_TIMEOUT, - background=True, - ): - """ Run a ghidra_bridge_server (forever) - server_host - what address the server should listen on - server_port - what port the server should listen on - response_timeout - default timeout in seconds before a response is treated as "failed" - background - false to run the server in this thread (script popup will stay), true for a new thread (script popup disappears) - """ - server = bridge.BridgeServer( - server_host=server_host, - server_port=server_port, - loglevel=logging.INFO, - response_timeout=response_timeout, - ) - - if background: - server.start() - server.logger.info( - "Server launching in background - will continue to run after launch script finishes..." - ) - else: - server.run() - - @staticmethod - def run_script_across_ghidra_bridge(script_file, python="python", argstring=""): - """ Spin up a ghidra_bridge_server and spawn the script in external python to connect back to it. Useful in scripts being triggered from - inside ghidra that need to use python3 or packages that don't work in jython - - The called script needs to handle the --connect_to_host and --connect_to_port command-line arguments and use them to start - a ghidra_bridge client to talk back to the server. - - Specify python to control what the script gets run with. Defaults to whatever python is in the shell - if changing, specify a path - or name the shell can find. - Specify argstring to pass further arguments to the script when it starts up. - """ - - # spawn a ghidra bridge server - use server port 0 to pick a random port - server = bridge.BridgeServer( - server_host="127.0.0.1", server_port=0, loglevel=logging.INFO - ) - # start it running in a background thread - server.start() - - try: - # work out where we're running the server - server_host, server_port = server.server.bridge.get_server_info() - - print("Running " + script_file) - - # spawn an external python process to run against it - - try: - output = subprocess.check_output( - "{python} {script} --connect_to_host={host} --connect_to_port={port} {argstring}".format( - python=python, - script=script_file, - host=server_host, - port=server_port, - argstring=argstring, - ), - stderr=subprocess.STDOUT, - shell=True, - ) - print(output) - except subprocess.CalledProcessError as exc: - print("Failed ({}):{}".format(exc.returncode, exc.output)) - - print(script_file + " completed") - - finally: - # when we're done with the script, shut down the server - server.shutdown() - - -if __name__ == "__main__": - # legacy version - run the server in the foreground, so we don't break people's expectations - GhidraBridgeServer.run_server( - response_timeout=bridge.DEFAULT_RESPONSE_TIMEOUT, background=False - ) - diff --git a/libbs/decompiler_stubs/ghidra_libbs/libbs_vendored/jfx_bridge/__init__.py b/libbs/decompiler_stubs/ghidra_libbs/libbs_vendored/jfx_bridge/__init__.py deleted file mode 100644 index 2ed416f4..00000000 --- a/libbs/decompiler_stubs/ghidra_libbs/libbs_vendored/jfx_bridge/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .bridge import __version__ diff --git a/libbs/decompiler_stubs/ghidra_libbs/libbs_vendored/jfx_bridge/bridge.py b/libbs/decompiler_stubs/ghidra_libbs/libbs_vendored/jfx_bridge/bridge.py deleted file mode 100644 index 83d7aef8..00000000 --- a/libbs/decompiler_stubs/ghidra_libbs/libbs_vendored/jfx_bridge/bridge.py +++ /dev/null @@ -1,2234 +0,0 @@ -""" Handles converting artifacts back and forward between 2 and 3 """ - -from __future__ import unicode_literals # string literals are all unicode - -try: - import SocketServer as socketserver # py2 -except Exception: - import socketserver # py3 - -import logging -import json -import base64 -import uuid -import threading -import importlib -import socket -import struct -import sys -import time -import traceback -import weakref -import functools -import operator -import warnings -import inspect -import random -import textwrap -import types - -__version__ = "1.0.0" # automatically patched by setup.py when packaging - -# from six.py's strategy -INTEGER_TYPES = None -try: - INTEGER_TYPES = (int, long) -except NameError: # py3 has no long - INTEGER_TYPES = (int,) - -STRING_TYPES = None -try: - STRING_TYPES = (str, unicode) -except NameError: # py3 has no unicode - STRING_TYPES = (str,) - -# need to pick up java.lang.Throwable as an exception type if we're in a jython context -EXCEPTION_TYPES = None -try: - import java - - EXCEPTION_TYPES = (Exception, java.lang.Throwable) -except ImportError: - # Nope, just normal python here - EXCEPTION_TYPES = (Exception,) - -ENUM_TYPE = () -try: - from enum import Enum - - ENUM_TYPE = (Enum,) -except ImportError: # py2 has no enum - pass - -if sys.version_info[0] == 2: - from socket import ( - error as ConnectionError, - ) # ConnectionError not defined in python2, this is next closest thing - from socket import error as ConnectionResetError # as above - - -class ThreadingTCPServer(socketserver.ThreadingMixIn, socketserver.TCPServer): - # prevent server threads hanging around and stopping python from closing - daemon_threads = True - - -DEFAULT_HOST = "127.0.0.1" -DEFAULT_SERVER_PORT = 27238 # 0x6a66 = "jf" - -VERSION = "v" -MAX_VERSION = "max_v" -MIN_VERSION = "min_v" -COMMS_VERSION_5 = 5 -TYPE = "type" -VALUE = "value" -KEY = "key" -TUPLE = "tuple" -LIST = "list" -DICT = "dict" -INT = "int" -FLOAT = "float" -BOOL = "bool" -STR = "str" -BYTES = "bytes" -NONE = "none" -PARTIAL = "partial" -SLICE = "slice" -NOTIMPLEMENTED = "notimp" -BRIDGED = "bridged" -EXCEPTION = "exception" -OBJ = "obj" -CALLABLE_OBJ = "callable_obj" -BASES = "bases" -REPR = "repr" - -MESSAGE = "message" -CMD = "cmd" -ID = "ID" -ARGS = "args" -GET = "get" -GET_ALL = "get_all" -CREATE_TYPE = "create_type" -SET = "set" -ISINSTANCE = "isinstance" -CALL = "call" -IMPORT = "import" -DEL = "del" -EVAL = "eval" -EXEC = "exec" -EXPR = "expr" -RESULT = "result" -ERROR = "error" -SHUTDOWN = "shutdown" -RESPOND = "respond" - -HANDLE = "handle" -NAME = "name" -ATTRS = "attrs" - -KWARGS = "kwargs" - -BRIDGE_PREFIX = "_bridge" - -# Comms v5 (alpha) adds slices to the serialization - one day, I'll support backwards compatibility -MIN_SUPPORTED_COMMS_VERSION = COMMS_VERSION_5 -MAX_SUPPORTED_COMMS_VERSION = COMMS_VERSION_5 - -DEFAULT_RESPONSE_TIMEOUT = 2 # seconds - -GLOBAL_BRIDGE_SHUTDOWN = False - -# BridgedObjects have a little trouble with class methods (e.g., where the method of accessing is not instance.doThing(), but more like -# type(instance).doThing(instance) - such as __lt__, len(), str(). -# To handle this, we define a list of class methods that we want to expose - this is a little gross, I'd like to dynamically do this based on the methods in the -# bridged object's type, but need to come up with a blacklist of things like __class__, __new__, etc which will interfere with the local objects first -BRIDGED_CLASS_METHODS = ["__str__", "__len__", "__iter__", "__hash__"] -# extract methods from operator, so I don't have to type out all the different options -for operator_name in dir(operator): - # only do the methods that start and end with __, and exclude __new__ - if ( - operator_name.startswith("__") - and operator_name.endswith("__") - and operator_name != "__new__" - and "builtin_function_or_method" in str(type(getattr(operator, operator_name))) - ): - BRIDGED_CLASS_METHODS.append(operator_name) - - -class BridgeException(Exception): - """An exception happened on the other side of the bridge and has been proxied back here - The bridge is fine, but the remote code you ran might have had an issue. - """ - - pass - - -class BridgeOperationException(Exception): - """Some issue happened with the operation of the bridge itself. The bridge may not be in a good state""" - - pass - - -class BridgeClosedException(Exception): - """The bridge has closed""" - - pass - - -class BridgeTimeoutException(Exception): - """A command we tried to run across the bridge took too long. You might need to increase the response timeout, check the command isn't - causing a deadlock, or make sure the network connection to the other end of the bridge is okay. - """ - - pass - - -def stats_hit(func): - """Decorate a function to record how many times it gets hit. Assumes the function is in a class with a stats attribute (can be set to None to - disable stats recording - """ - - @functools.wraps(func) - def wrapper(self, *args, **kwargs): - if self.stats is not None: - self.stats.add_hit(func.__name__) - return func(self, *args, **kwargs) - - return wrapper - - -def stats_time(func): - """Decorate a function to record how long it takes to execute. Assumes the function is in a class with a stats attribute (can be set to None to - disable stats recording - """ - - @functools.wraps(func) - def wrapper(self, *args, **kwargs): - start_time = time.time() - return_val = func(self, *args, **kwargs) - total_time = time.time() - start_time - - if self.stats is not None: - self.stats.add_time(func.__name__, total_time) - - return return_val - - return wrapper - - -class Stats: - """Class to record the number of hits of particular points (e.g., function calls) and - times (e.g., execution times) for gathering statistics. - """ - - def __init__(self): - self.lock = threading.Lock() - self.hits = dict() # name -> hit count - self.times = dict() # name -> (hit count, cumulative_time) - - def add_hit(self, hit_name): - with self.lock: - hit_count = self.hits.get(hit_name, 0) - self.hits[hit_name] = hit_count + 1 - - def add_time(self, time_name, time): - with self.lock: - hit_count, cumulative_time = self.times.get(time_name, (0, 0)) - self.times[time_name] = (hit_count + 1, cumulative_time + time) - - def total_hits(self): - total = 0 - with self.lock: - for value in self.hits.values(): - total += value - - return total - - def total_time(self): - total_time_hits = 0 - total_time = 0 - with self.lock: - for hits, cumulative_time in self.times.values(): - total_time_hits += hits - total_time += cumulative_time - - return (total_time_hits, total_time) - - def __str__(self): - return "Stats(total_hits={},hits={},total_time={},times={})".format( - self.total_hits(), self.hits, self.total_time(), self.times - ) - - def copy(self): - """Take a copy of the stats at the current time""" - copy_stats = Stats() - with self.lock: - copy_stats.hits = self.hits.copy() - copy_stats.times = self.times.copy() - - return copy_stats - - def __sub__(self, other): - if not isinstance(other, Stats): - raise Exception("Can't subtract non-Stats object from a Stats object") - - # take a copy of this stats, then subtract the other from the copy - new_stats = self.copy() - - # subtract the value of each key in other hits from the corresponding key in new_stats - # if new_stats doesn't have the key, treat it as 0 - # nuke any values which end up as 0 - for key, value in other.hits.items(): - new_stats.hits[key] = new_stats.hits.get(key, 0) - value - if new_stats.hits[key] == 0: - del new_stats.hits[key] - - # as above, but for times - for key, value in other.times.items(): - hit_count, cumulative_time = new_stats.times.get(key, (0, 0)) - new_stats.times[key] = (hit_count - value[0], cumulative_time - value[1]) - if new_stats.times[key][0] == 0: - del new_stats.times[key] - - return new_stats - - -SIZE_FORMAT = "!I" - - -def write_size_and_data_to_socket(sock, data): - """Utility function to pack the size in front of artifacts and send it off - - Note: not thread safe - sock.send can return before all the artifacts is sent, python can swap active threads, and another thread can start sending its artifacts halfway through - the first one's. Call from BridgeConn.send_data() - """ - - # pack the size as network-endian - data_size = len(data) - size_bytes = struct.pack(SIZE_FORMAT, len(data)) - package = size_bytes + data - total_size = len(size_bytes) + data_size - - sent = 0 - # noted errors sending large blobs of artifacts with sendall, so we'll send as much as send() allows and keep trying - while sent < total_size: - # send it all off - bytes_sent = sock.send(package[sent:]) - sent = sent + bytes_sent - - -def read_exactly(sock, num_bytes): - """Utility function to keep reading from the socket until we get the desired number of bytes""" - data = b"" - while num_bytes > 0: - new_data = sock.recv(num_bytes) - if new_data is None: - # most likely reason for a none here is the socket being closed on the remote end - raise BridgeClosedException() - num_bytes = num_bytes - len(new_data) - data += new_data - - return data - - -def read_size_and_data_from_socket(sock): - """Utility function to read the size of a artifacts block, followed by all of that artifacts""" - - size_bytes = read_exactly(sock, struct.calcsize(SIZE_FORMAT)) - size = struct.unpack(SIZE_FORMAT, size_bytes)[0] - - data = read_exactly(sock, size) - data = data.strip() - - return data - - -def can_handle_version(message_dict): - """Utility function for checking we know about this version""" - return (message_dict[VERSION] <= MAX_SUPPORTED_COMMS_VERSION) and ( - message_dict[VERSION] >= MIN_SUPPORTED_COMMS_VERSION - ) - - -class BridgeCommandHandlerThread(threading.Thread): - """Thread that checks for commands to handle and serves them""" - - bridge_conn = None - threadpool = None - - def __init__(self, threadpool): - super(BridgeCommandHandlerThread, self).__init__() - - self.bridge_conn = threadpool.bridge_conn - # make sure this thread doesn't keep the threadpool alive - self.threadpool = weakref.proxy(threadpool) - - # don't let the command handlers keep us alive - self.daemon = True - - def run(self): - try: - cmd = self.threadpool.get_command() # block, waiting for first command - while cmd is not None: # get_command returns none if we should shut down - # handle a command and write back the response - # TODO make this return an error tied to the cmd_id, so it goes in the response mgr - result = None - - # see if the command wants a response - want_response = cmd.get(RESPOND, True) - - try: - result = self.bridge_conn.handle_command( - cmd, want_response=want_response - ) - except EXCEPTION_TYPES as e: - self.bridge_conn.logger.error( - "Unexpected exception for {}: {}\n{}".format( - cmd, e, traceback.format_exc() - ) - ) - # pack a minimal error, so the other end doesn't have to wait for a timeout - result = json.dumps( - { - VERSION: COMMS_VERSION_5, - TYPE: ERROR, - ID: cmd[ID], - } - ).encode("utf-8") - - # only reply if the command wants a response - if want_response: - try: - self.bridge_conn.send_data(result) - except socket.error: - # Other end has closed the socket before we can respond. That's fine, just ask me to do something then ignore me. Jerk. Don't bother staying around, they're probably dead - break - - cmd = self.threadpool.get_command() # block, waiting for next command - except ReferenceError: - # expected, means the connection has been closed and the threadpool cleaned up - pass - - -class BridgeCommandHandlerThreadPool(object): - """Takes commands and handles spinning up threads to run them. Will keep the threads that are started and reuse them before creating new ones""" - - bridge_conn = None - # semaphore indicating how many threads are ready right now to grab a command - ready_threads = None - command_list = None # store the commands that need to be handled - command_list_read_lock = None # just for reading the list - command_list_write_lock = None # for writing the list - shutdown_flag = False - - def __init__(self, bridge_conn): - self.thread_count = 0 - self.bridge_conn = bridge_conn - self.ready_threads = threading.Semaphore(0) # start the ready threads at 0 - self.command_list = list() - self.command_list_read_lock = threading.Lock() - self.command_list_write_lock = threading.Lock() - - def handle_command(self, msg_dict): - """Give the threadpool a command to handle""" - # test if there are ready_threads waiting - if not self.ready_threads.acquire(blocking=False): - # no ready threads waiting - create a new one - self.thread_count += 1 - self.bridge_conn.logger.debug( - "Creating thread - now {} threads".format(self.thread_count) - ) - new_handler = BridgeCommandHandlerThread(self) - new_handler.start() - else: - self.ready_threads.release() - - # take out the write lock, we're adding to the list - with self.command_list_write_lock: - self.command_list.append(msg_dict) - # the next ready thread will grab the command - - def get_command(self): - """Threads ask for commands to handle - a thread stuck waiting here is counted in the ready threads""" - # release increments the ready threads count - self.ready_threads.release() - - try: - while not self.shutdown_flag and not GLOBAL_BRIDGE_SHUTDOWN: - # get the read lock, so we can see if there's anything to do - with self.command_list_read_lock: - if len(self.command_list) > 0: - # yes! grab the write lock (only thing that can have the write lock without the read lock is commands being added, so we won't deadlock/have to wait long) - with self.command_list_write_lock: - # yes! give back the first command - return self.command_list.pop() - # wait a little before we try again - time.sleep(0.01) - finally: - # make sure the thread "acquires" the semaphore (decrements the ready_threads count) - self.ready_threads.acquire(blocking=False) - - # if we make it here, we're shutting down. return none and the thread will pack it in - return None - - def __del__(self): - """We're done with this threadpool, tell the threads to start packing it in""" - self.shutdown_flag = True - - -class BridgeReceiverThread(threading.Thread): - """class to handle running a thread to receive bridge commands/responses and direct accordingly""" - - # If we don't know how to handle the version, reply back with an error and the highest version we do support - ERROR_UNSUPPORTED_VERSION = json.dumps( - { - ERROR: True, - MAX_VERSION: MAX_SUPPORTED_COMMS_VERSION, - MIN_VERSION: MIN_SUPPORTED_COMMS_VERSION, - } - ) - - def __init__(self, bridge_conn): - super(BridgeReceiverThread, self).__init__() - - self.bridge_conn = bridge_conn - - # don't let the recv loop keep us alive - self.daemon = True - - def run(self): - # threadpool to handle creating/running threads to handle commands - threadpool = BridgeCommandHandlerThreadPool(self.bridge_conn) - - while not GLOBAL_BRIDGE_SHUTDOWN: - try: - data = read_size_and_data_from_socket(self.bridge_conn.get_socket()) - except socket.timeout: - # client didn't have anything to say - just wait some more - time.sleep(0.1) - continue - - try: - msg_dict = json.loads(data.decode("utf-8")) - self.bridge_conn.logger.debug("Recv loop received {}".format(msg_dict)) - - if can_handle_version(msg_dict): - if msg_dict[TYPE] in [RESULT, ERROR]: - # handle a response or error - self.bridge_conn.add_response(msg_dict) - else: - # queue this and hand off to a worker threadpool - threadpool.handle_command(msg_dict) - else: - # bad version - self.bridge_conn.send_data( - BridgeReceiverThread.ERROR_UNSUPPORTED_VERSION - ) - except EXCEPTION_TYPES as e: - # eat exceptions and continue, don't want a bad message killing the recv loop - self.bridge_conn.logger.exception(e) - - self.bridge_conn.logger.debug("Receiver thread shutdown") - - -class BridgeCommandHandler(socketserver.BaseRequestHandler): - def handle(self): - """handle a new client connection coming in - continue trying to read/service requests in a loop until we fail to send/recv""" - self.server.bridge.logger.warn( - "Handling connection from {}".format(self.request.getpeername()) - ) - try: - # run the recv loop directly - BridgeReceiverThread( - BridgeConn( - self.server.bridge, - self.request, - response_timeout=self.server.bridge.response_timeout, - ) - ).run() - - # only get here if the client has requested we shutdown the bridge - self.server.bridge.logger.debug( - "Receiver thread exited - bridge shutdown requested" - ) - self.server.bridge.shutdown() - except (BridgeClosedException, ConnectionResetError): - pass # expected - the client has closed the connection - except EXCEPTION_TYPES as e: - # something weird went wrong? - self.server.bridge.logger.exception(e) - finally: - self.server.bridge.logger.warn( - "Closing connection from {}".format(self.request.getpeername()) - ) - # we're out of the loop now, so the connection object will get told to delete itself, which will remove its references to any objects its holding onto - - -class BridgeHandle(object): - def __init__(self, local_obj): - self.handle = str(uuid.uuid4()) - self.local_obj = local_obj - self.attrs = dir(local_obj) - - def to_dict(self): - # extract the type name from the repr for the type - type_repr = repr(type(self.local_obj)) - # expect it to be something like or - if "'" in type_repr: - type_name = type_repr.split("'")[1] - else: - # just use the repr straight up - type_name = type_repr - return { - HANDLE: self.handle, - TYPE: type_name, - ATTRS: self.attrs, - REPR: repr(self.local_obj), - } - - def __str__(self): - return "BridgeHandle({}: {})".format(self.handle, self.local_obj) - - -class BridgeResponse(object): - """Utility class for waiting for and receiving responses""" - - event = None # used to flag whether the response is ready - response = None - - def __init__(self, response_id): - self.event = threading.Event() - self.response_id = response_id # just for tracking, so we can report it in timeout exception if needed - - def set(self, response): - """store response artifacts, and let anyone waiting know it's ready""" - self.response = response - # trigger the event - self.event.set() - - def get(self, timeout=None): - """wait for the response""" - if timeout is not None and timeout < 0: - # can't pass in None higher up reliably, as it gets used to indicate "default timeout". - # Instead, treat a negative timeout as "wait forever", and set timeout to None, so event.wait - # will wait forever. - timeout = None - - # patch: make the timeout much longer - if not self.event.wait(60 * 3): - raise BridgeTimeoutException( - "Didn't receive response {} before timeout".format(self.response_id) - ) - - return self.response - - -class BridgeResponseManager(object): - """Handles waiting for and receiving responses""" - - response_dict = None # maps response ids to a BridgeResponse - response_lock = None - - def __init__(self): - self.response_dict = dict() - self.response_lock = threading.Lock() - - def add_response(self, response_dict): - """response received - register it, then set the event for it""" - with self.response_lock: - response_id = response_dict[ID] - if response_id not in self.response_dict: - # response hasn't been waited for yet. create the entry - self.response_dict[response_id] = BridgeResponse(response_id) - - # set the artifacts and trigger the event - self.response_dict[response_id].set(response_dict) - - def get_response(self, response_id, timeout=None): - """Register for a response and wait until received""" - with self.response_lock: - if response_id not in self.response_dict: - # response hasn't been waited for yet. create the entry - self.response_dict[response_id] = BridgeResponse(response_id) - response = self.response_dict[response_id] - - # wait for the artifacts - will throw a BridgeTimeoutException if doesn't get it by timeout - data = response.get(timeout) - - if TYPE in data: - if data[TYPE] == ERROR: - # problem with the bridge itself, raise an exception - raise BridgeOperationException(data) - - with self.response_lock: - # delete the entry, we're done here - del self.response_dict[response_id] - - return data - - -class BridgeConn(object): - """Internal class, representing a connection to a remote bridge that serves our requests""" - - stats = None - - def __init__( - self, - bridge, - sock=None, - connect_to_host=None, - connect_to_port=None, - response_timeout=DEFAULT_RESPONSE_TIMEOUT, - record_stats=False, - ): - """Set up the bridge connection - only instantiates a connection as needed""" - self.host = connect_to_host - self.port = connect_to_port - - # get a reference to the bridge's logger for the connection - self.logger = bridge.logger - - self.handle_dict = {} - # list of tuples of (handle, time) that have been marked for deletion and the time they were marked at - # list will always be in order of earliest marked to latest - self.delay_delete_handles = [] - - self.sock = sock - self.comms_lock = threading.RLock() - self.handle_lock = threading.RLock() - - self.response_mgr = BridgeResponseManager() - self.response_timeout = response_timeout - - # keep a cache of types of objects we've created - # we'll keep all the types forever (including handles to bridgedcallables in them) because types are super-likely - # to be reused regularly, and we don't want to keep deleting them and then having to recreate them all the time. - self.cached_bridge_types = dict() - - # if the bridge has requested a local_call_hook/local_eval_hook, record that - self.local_call_hook = bridge.local_call_hook - self.local_eval_hook = bridge.local_eval_hook - self.local_exec_hook = bridge.local_exec_hook - - if record_stats: - self.stats = Stats() - - def __del__(self): - """On teardown, make sure we close our socket to the remote bridge""" - with self.comms_lock: - if self.sock is not None: - self.sock.close() - - def create_handle(self, obj): - bridge_handle = BridgeHandle(obj) - - with self.handle_lock: - self.handle_dict[bridge_handle.handle] = bridge_handle - - self.logger.debug("Handle created {} for {}".format(bridge_handle.handle, obj)) - - return bridge_handle - - def get_object_by_handle(self, handle): - with self.handle_lock: - if handle not in self.handle_dict: - raise Exception("Old/unknown handle {}".format(handle)) - - return self.handle_dict[handle].local_obj - - def release_handle(self, handle): - with self.handle_lock: - if handle in self.handle_dict: - # don't release the handle just yet - put it in the delay list - # this is because some remote_evals end up with objects being released remotely (causing - # a delete command to be sent) before they're sent back in a response. The delete command - # beats the response back, and the handle is removed before it can be used in the response - # causing an error. - # To avoid this, we'll delay for a response_timeout period to make sure that we got our - # response back post-delete. - self.delay_delete_handles.append((handle, time.time())) - - # use this as a good time to purge delayed handles - self.purge_delay_delete_handles() - - def purge_delay_delete_handles(self): - """Actually remove deleted handles from the handle dict once they've exceeded the timeout""" - with self.handle_lock: - # work out the cutoff time for when we'd delete delayed handles - delay_exceeded_time = time.time() - self.response_timeout - # run over delay_delete_handles until it's empty or the times are later than the delay_exceeded_time - while ( - len(self.delay_delete_handles) > 0 - and self.delay_delete_handles[0][1] <= delay_exceeded_time - ): - handle = self.delay_delete_handles[0][0] - # actually remove the handle - del self.handle_dict[handle] - # remove this entry from the list - self.delay_delete_handles.pop(0) - - def serialize_to_dict(self, data): - serialized_dict = None - - # note: this needs to come before int, because apparently bools are instances of int (but not vice versa) - if isinstance(data, bool): - serialized_dict = {TYPE: BOOL, VALUE: str(data)} - # don't treat py3 enums as ints - pass them as objects - elif isinstance(data, INTEGER_TYPES) and not isinstance(data, ENUM_TYPE): - serialized_dict = {TYPE: INT, VALUE: str(data)} - elif isinstance(data, float): - serialized_dict = {TYPE: FLOAT, VALUE: str(data)} - elif isinstance( - data, STRING_TYPES - ): # all strings are coerced to unicode when serialized - serialized_dict = { - TYPE: STR, - VALUE: base64.b64encode(data.encode("utf-8")).decode("utf-8"), - } - elif isinstance(data, bytes): # py3 only, bytestring in 2 is str - serialized_dict = { - TYPE: BYTES, - VALUE: base64.b64encode(data).decode("utf-8"), - } - elif isinstance(data, list): - serialized_dict = { - TYPE: LIST, - VALUE: [self.serialize_to_dict(v) for v in data], - } - elif isinstance(data, tuple): - serialized_dict = { - TYPE: TUPLE, - VALUE: [self.serialize_to_dict(v) for v in data], - } - elif isinstance(data, dict): - serialized_dict = { - TYPE: DICT, - VALUE: [ - {KEY: self.serialize_to_dict(k), VALUE: self.serialize_to_dict(v)} - for k, v in data.items() - ], - } - elif isinstance(data, slice): - serialized_dict = { - TYPE: SLICE, - VALUE: [ - self.serialize_to_dict(data.start), - self.serialize_to_dict(data.stop), - self.serialize_to_dict(data.step), - ], - } - elif isinstance( - data, EXCEPTION_TYPES - ): # will also catch java.lang.Throwable in jython context - # treat the exception object as an object - value = self.create_handle(data).to_dict() - # then wrap the exception specifics around it - serialized_dict = { - TYPE: EXCEPTION, - VALUE: value, - MESSAGE: self.serialize_to_dict(getattr(data, "message", "")), - } - elif isinstance(data, BridgedObject): - # passing back a reference to an object on the other side - # e.g., bridge_obj1.do_thing(bridge_obj2) - serialized_dict = {TYPE: BRIDGED, VALUE: data._bridge_handle} - elif isinstance(data, type(None)): - serialized_dict = {TYPE: NONE} - elif isinstance(data, type(NotImplemented)): - serialized_dict = {TYPE: NOTIMPLEMENTED} - elif isinstance(data, functools.partial) and isinstance( - data.func, BridgedCallable - ): - # if it's a partial, possible that it's against a remote function - in that case, instead of sending it back as a BridgedCallable - # to get remote called back here where we'll issue a call to the original function, we'll send it with the partial's details so - # it can be reconstructed on the other side (0 round-trips instead of 2 round-trips) - # TODO do we have to worry about artifacts.func being from a different bridge connection? - serialized_dict = { - TYPE: PARTIAL, - VALUE: self.serialize_to_dict(data.func), - ARGS: self.serialize_to_dict(data.args), - KWARGS: self.serialize_to_dict(data.keywords), - } - else: - # it's an object. assign a reference - obj_type = CALLABLE_OBJ if callable(data) else OBJ - serialized_dict = { - TYPE: obj_type, - VALUE: self.create_handle(data).to_dict(), - } - - return serialized_dict - - def deserialize_from_dict(self, serial_dict): - if serial_dict[TYPE] == INT: # int, long - return int(serial_dict[VALUE]) - elif serial_dict[TYPE] == FLOAT: - return float(serial_dict[VALUE]) - elif serial_dict[TYPE] == BOOL: - return serial_dict[VALUE] == "True" - elif serial_dict[TYPE] == STR: - result = base64.b64decode(serial_dict[VALUE]).decode("utf-8") - # if we're in python 2, result is now a unicode string. - if sys.version_info[0] == 2: - try: - # We'll try and force it down to a plain string, because there are plenty of cases where plain strings - # are expected instead of unicode (e.g., type/module names). If that fails, we'll keep it as unicode. - result = str(result) - except UnicodeEncodeError: - # couldn't make it ascii, keep as unicode - pass - return result - - elif serial_dict[TYPE] == BYTES: - return base64.b64decode(serial_dict[VALUE]) - elif serial_dict[TYPE] == LIST: - return [self.deserialize_from_dict(v) for v in serial_dict[VALUE]] - elif serial_dict[TYPE] == TUPLE: - return tuple(self.deserialize_from_dict(v) for v in serial_dict[VALUE]) - elif serial_dict[TYPE] == DICT: - result = dict() - for kv in serial_dict[VALUE]: - key = self.deserialize_from_dict(kv[KEY]) - value = self.deserialize_from_dict(kv[VALUE]) - result[key] = value - - return result - elif ( - serial_dict[TYPE] == SLICE - ): # we create local slice objects so isinstance(slice) in __getitem__/etc works - start, stop, step = [ - self.deserialize_from_dict(v) for v in serial_dict[VALUE] - ] - result = slice(start, stop, step) - return result - elif serial_dict[TYPE] == EXCEPTION: - raise BridgeException( - self.deserialize_from_dict(serial_dict[MESSAGE]), - self.build_bridged_object(serial_dict[VALUE]), - ) - elif serial_dict[TYPE] == BRIDGED: - return self.get_object_by_handle(serial_dict[VALUE]) - elif serial_dict[TYPE] == NONE: - return None - elif serial_dict[TYPE] == NOTIMPLEMENTED: - return NotImplemented - elif serial_dict[TYPE] == PARTIAL: - func = self.deserialize_from_dict(serial_dict[VALUE]) - args = self.deserialize_from_dict(serial_dict[ARGS]) - if args is None: - args = () - keywords = self.deserialize_from_dict(serial_dict[KWARGS]) - if keywords is None: - keywords = {} - return functools.partial(func, *args, **keywords) - elif serial_dict[TYPE] == OBJ or serial_dict[TYPE] == CALLABLE_OBJ: - return self.build_bridged_object( - serial_dict[VALUE], callable=(serial_dict[TYPE] == CALLABLE_OBJ) - ) - - raise Exception("Unhandled artifacts {}".format(serial_dict)) - - def get_socket(self): - with self.comms_lock: - if self.sock is None: - self.logger.debug( - "Creating socket to {}:{}".format(self.host, self.port) - ) - # Create a socket (SOCK_STREAM means a TCP socket) - self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.sock.settimeout(10) - self.sock.connect((self.host, self.port)) - # spin up the recv loop thread in the background - BridgeReceiverThread(self).start() - - return self.sock - - def send_data(self, data): - """Handle shipping the artifacts across the bridge. Locked to prevent multiple sends - interleaving with each other (e.g., one is halfway through sending it artifacts when - it returns, GIL gives it up and the other begins sending - causing decode errors - on the other side""" - with self.comms_lock: - sock = self.get_socket() - # send the artifacts - write_size_and_data_to_socket(sock, data) - - @stats_time - def send_cmd(self, command_dict, get_response=True, timeout_override=None): - """Package and send a command off. If get_response set, wait for the response and return it. Else return none. - If timeout override set, wait that many seconds, else wait for default response timeout - """ - cmd_id = str(uuid.uuid4()) # used to link commands and responses - envelope_dict = { - VERSION: COMMS_VERSION_5, - ID: cmd_id, - TYPE: CMD, - CMD: command_dict, - RESPOND: get_response, - } - self.logger.debug("Sending {}".format(envelope_dict)) - data = json.dumps(envelope_dict).encode("utf-8") - - self.send_data(data) - - if get_response: - result = {} - # wait for the response - response_dict = self.response_mgr.get_response( - cmd_id, - timeout=timeout_override - if timeout_override is not None - else self.response_timeout, - ) - - if response_dict is not None: - if RESULT in response_dict: - result = response_dict[RESULT] - return result - else: - return None - - @stats_hit - def remote_get(self, handle, name): - self.logger.debug("remote_get: {}.{}".format(handle, name)) - command_dict = {CMD: GET, ARGS: {HANDLE: handle, NAME: name}} - return self.deserialize_from_dict(self.send_cmd(command_dict)) - - @stats_hit - def local_get(self, args_dict): - handle = args_dict[HANDLE] - name = args_dict[NAME] - self.logger.debug("local_get: {}.{}".format(handle, name)) - - target = self.get_object_by_handle(handle) - try: - result = getattr(target, name) - except EXCEPTION_TYPES as e: - result = e - traceback.print_exc() - - return result - - @stats_hit - def remote_set(self, handle, name, value): - self.logger.debug("remote_set: {}.{} = {}".format(handle, name, value)) - command_dict = { - CMD: SET, - ARGS: {HANDLE: handle, NAME: name, VALUE: self.serialize_to_dict(value)}, - } - self.deserialize_from_dict(self.send_cmd(command_dict)) - - @stats_hit - def local_set(self, args_dict): - handle = args_dict[HANDLE] - name = args_dict[NAME] - value = self.deserialize_from_dict(args_dict[VALUE]) - - if self.logger.getEffectiveLevel() <= logging.DEBUG: - try: - # we want to get log the deserialized values, because they're useful. - # but this also means a bad repr can break things. So we get ready to - # catch that and fallback to undeserialized values - self.logger.debug("local_set: {}.{} = {}".format(handle, name, value)) - except EXCEPTION_TYPES as e: - self.logger.debug( - "Failed to log deserialized arguments: {}\n{}".format( - e, traceback.format_exc() - ) - ) - self.logger.debug( - "Falling back:\n\tlocal_set: {}.{} = {}".format( - handle, name, args_dict[VALUE] - ) - ) - - target = self.get_object_by_handle(handle) - result = None - try: - result = setattr(target, name, value) - except EXCEPTION_TYPES as e: - result = e - traceback.print_exc() # TODO - this and other tracebacks, log with info about what's happening - - return result - - @stats_hit - def remote_call(self, handle, *args, **kwargs): - self.logger.debug("remote_call: {}({},{})".format(handle, args, kwargs)) - - serial_args = self.serialize_to_dict(args) - serial_kwargs = self.serialize_to_dict(kwargs) - command_dict = { - CMD: CALL, - ARGS: {HANDLE: handle, ARGS: serial_args, KWARGS: serial_kwargs}, - } - - return self.deserialize_from_dict(self.send_cmd(command_dict)) - - @stats_hit - def remote_call_nonreturn(self, handle, *args, **kwargs): - """As per remote_call, but without expecting a response""" - self.logger.debug( - "remote_call_nonreturn: {}({},{})".format(handle, args, kwargs) - ) - - serial_args = self.serialize_to_dict(args) - serial_kwargs = self.serialize_to_dict(kwargs) - command_dict = { - CMD: CALL, - ARGS: {HANDLE: handle, ARGS: serial_args, KWARGS: serial_kwargs}, - } - - self.send_cmd(command_dict, get_response=False) - - @stats_hit - def local_call(self, args_dict): - handle = args_dict[HANDLE] - - args = self.deserialize_from_dict(args_dict[ARGS]) - kwargs = self.deserialize_from_dict(args_dict[KWARGS]) - - if self.logger.getEffectiveLevel() <= logging.DEBUG: - try: - # we want to get log the deserialized values, because they're useful. - # but this also means a bad repr can break things. So we get ready to - # catch that and fallback to undeserialized values - self.logger.debug("local_call: {}({},{})".format(handle, args, kwargs)) - except EXCEPTION_TYPES as e: - self.logger.debug( - "Failed to log deserialized arguments: {}\n{}".format( - e, traceback.format_exc() - ) - ) - self.logger.debug( - "Falling back:\n\tlocal_call: {}({},{})".format( - handle, args_dict[ARGS], args_dict[KWARGS] - ) - ) - - result = None - try: - target_callable = self.get_object_by_handle(handle) - # call the target function, or the hook if we've registered one - if self.local_call_hook is None: - result = target_callable(*args, **kwargs) - else: - result = self.local_call_hook(self, target_callable, *args, **kwargs) - except EXCEPTION_TYPES as e: - result = e - if not isinstance(e, Exception): - # not an exception type, so it'll be a java throwable - # just output the string representation at the moment - # if you want the stack trace, here's where you'd get it from. - self.logger.warning("Got java.lang.Throwable: {}".format(e)) - # also, don't display StopIteration exceptions, they're totally normal - elif not isinstance(e, StopIteration): - traceback.print_exc() - - return result - - @stats_hit - def remote_del(self, handle): - self.logger.debug("remote_del {}".format(handle)) - command_dict = {CMD: DEL, ARGS: {HANDLE: handle}} - try: - self.send_cmd(command_dict, get_response=False) - except (ConnectionError, OSError): - # get a lot of these when shutting down if the bridge connection has already been torn down before the bridged objects are deleted - # just ignore - we want to know if the other operations fail, but deleting failing we can probably get away with - pass - - @stats_hit - def local_del(self, args_dict): - handle = args_dict[HANDLE] - self.logger.debug("local_del {}".format(handle)) - self.release_handle(handle) - - @stats_hit - def remote_import(self, module_name): - self.logger.debug("remote_import {}".format(module_name)) - command_dict = {CMD: IMPORT, ARGS: {NAME: module_name}} - return self.deserialize_from_dict(self.send_cmd(command_dict)) - - @stats_hit - def local_import(self, args_dict): - name = args_dict[NAME] - - self.logger.debug("local_import {}".format(name)) - result = None - try: - result = importlib.import_module(name) - except EXCEPTION_TYPES as e: - result = e - traceback.print_exc() - - return result - - @stats_hit - def remote_get_type(self, handle): - self.logger.debug("remote_get_type {}".format(handle)) - command_dict = {CMD: TYPE, ARGS: {HANDLE: handle}} - return self.deserialize_from_dict(self.send_cmd(command_dict)) - - @stats_hit - def local_get_type(self, args_dict): - handle = args_dict[HANDLE] - self.logger.debug("local_get_type {}".format(handle)) - - target_obj = self.get_object_by_handle(handle) - - try: - result = type(target_obj) - except EXCEPTION_TYPES as e: - result = e - traceback.print_exc() - - return result - - @stats_hit - def remote_create_type(self, name, bases, dct): - self.logger.debug("remote_create_type {}, {}, {}".format(name, bases, dct)) - command_dict = { - CMD: CREATE_TYPE, - ARGS: { - NAME: name, - BASES: self.serialize_to_dict(bases), - DICT: self.serialize_to_dict(dct), - }, - } - return self.deserialize_from_dict(self.send_cmd(command_dict)) - - @stats_hit - def local_create_type(self, args_dict): - name = str( - args_dict[NAME] - ) # type name can't be unicode string in python2 - force to string - bases = self.deserialize_from_dict(args_dict[BASES]) - dct = self.deserialize_from_dict(args_dict[DICT]) - - if self.logger.getEffectiveLevel() <= logging.DEBUG: - try: - # we want to get log the deserialized values, because they're useful. - # but this also means a bad repr can break things. So we get ready to - # catch that and fallback to undeserialized values - self.logger.debug( - "local_create_type {}, {}, {}".format(name, bases, dct) - ) - except EXCEPTION_TYPES as e: - self.logger.debug( - "Failed to log deserialized arguments: {}\n{}".format( - e, traceback.format_exc() - ) - ) - self.logger.debug( - "Falling back:\n\tlocal_create_type {}, {}, {}".format( - name, args_dict[BASES], args_dict[DICT] - ) - ) - - result = None - - try: - result = type(name, bases, dct) - except EXCEPTION_TYPES as e: - result = e - traceback.print_exc() - - return result - - @stats_hit - def remote_get_all(self, handle): - self.logger.debug("remote_get_all {}".format(handle)) - command_dict = {CMD: GET_ALL, ARGS: {HANDLE: handle}} - return self.deserialize_from_dict(self.send_cmd(command_dict)) - - @stats_hit - def local_get_all(self, args_dict): - handle = args_dict[HANDLE] - self.logger.debug("local_get_all {}".format(handle)) - - target_obj = self.get_object_by_handle(handle) - result = {name: getattr(target_obj, name) for name in dir(target_obj)} - - return result - - @stats_hit - def remote_isinstance(self, test_object, class_or_tuple): - self.logger.debug( - "remote_isinstance({}, {})".format(test_object, class_or_tuple) - ) - - check_class_tuple = None - # if we're not checking against a tuple, force it into one - if not _is_bridged_object(class_or_tuple): - # local - probably a tuple already - if not isinstance(class_or_tuple, tuple): - # it's not :X - raise Exception( - "Can't use remote_isinstance on a non-bridged class: {}".format( - class_or_tuple - ) - ) - else: - check_class_tuple = class_or_tuple - else: - # single bridged, just wrap in a tuple - check_class_tuple = (class_or_tuple,) - - command_dict = { - CMD: ISINSTANCE, - ARGS: self.serialize_to_dict({OBJ: test_object, TUPLE: check_class_tuple}), - } - return self.deserialize_from_dict(self.send_cmd(command_dict)) - - @stats_hit - def local_isinstance(self, args_dict): - args = self.deserialize_from_dict(args_dict) - test_object = args[OBJ] - check_class_tuple = args[TUPLE] - - if self.logger.getEffectiveLevel() <= logging.DEBUG: - try: - # we want to get log the deserialized values, because they're useful. - # but this also means a bad repr can break things. So we get ready to - # catch that and fallback to undeserialized values - self.logger.debug( - "local_isinstance({},{})".format(test_object, check_class_tuple) - ) - except EXCEPTION_TYPES as e: - self.logger.debug( - "Failed to log deserialized arguments: {}\n{}".format( - e, traceback.format_exc() - ) - ) - self.logger.debug( - "Falling back:\n\tlocal_isinstance({})".format(args_dict) - ) - - # make sure every element is a local object on this side - if _is_bridged_object(test_object): - raise Exception( - "Can't use local_isinstance on a bridged object: {}".format(test_object) - ) - - for clazz in check_class_tuple: - if _is_bridged_object(clazz): - raise Exception( - "Can't use local_isinstance on a bridged class: {}".format(clazz) - ) - - return isinstance(test_object, check_class_tuple) - - @stats_hit - def remote_eval(self, eval_string, timeout_override=None, **kwargs): - self.logger.debug("remote_eval({}, {})".format(eval_string, kwargs)) - - command_dict = { - CMD: EVAL, - ARGS: self.serialize_to_dict({EXPR: eval_string, KWARGS: kwargs}), - } - # Remote eval commands might take a while, so override the timeout value, factor 100 is arbitrary unless an override specified by caller - if timeout_override is None: - timeout_override = self.response_timeout * 100 - result = self.send_cmd(command_dict, timeout_override=timeout_override) - - return self.deserialize_from_dict(result) - - @stats_hit - def local_eval(self, args_dict): - args = self.deserialize_from_dict(args_dict) - - result = None - - if self.logger.getEffectiveLevel() <= logging.DEBUG: - try: - # we want to get log the deserialized values, because they're useful. - # but this also means a bad repr can break things. So we get ready to - # catch that and fallback to undeserialized values - self.logger.debug("local_eval({},{})".format(args[EXPR], args[KWARGS])) - except EXCEPTION_TYPES as e: - self.logger.debug( - "Failed to log deserialized arguments: {}\n{}".format( - e, traceback.format_exc() - ) - ) - self.logger.debug("Falling back:\nlocal_eval {}".format(args_dict)) - - try: - """the import __main__ trick allows accessing all the variables that the bridge imports, - so evals will run within the global context of what started the bridge, and the arguments - supplied as kwargs will override that""" - eval_expr = args[EXPR] - eval_globals = importlib.import_module("__main__").__dict__ - eval_locals = args[KWARGS] - # do the eval, or defer to the hook if we've registered one - if self.local_eval_hook is None: - result = eval(eval_expr, eval_globals, eval_locals) - else: - result = self.local_eval_hook( - self, eval_expr, eval_globals, eval_locals - ) - self.logger.debug("local_eval: Finished evaluating") - except EXCEPTION_TYPES as e: - result = e - traceback.print_exc() - - return result - - @stats_hit - def remote_exec(self, exec_string, timeout_override=None, **kwargs): - self.logger.debug("remote_exec({}, {})".format(exec_string, kwargs)) - - command_dict = { - CMD: EXEC, - ARGS: self.serialize_to_dict({EXPR: exec_string, KWARGS: kwargs}), - } - # Remote exec commands might take a while, so override the timeout value, factor 100 is arbitrary unless an override specified by caller - if timeout_override is None: - timeout_override = self.response_timeout * 100 - result = self.send_cmd(command_dict, timeout_override=timeout_override) - - return self.deserialize_from_dict(result) - - @stats_hit - def local_exec(self, args_dict): - args = self.deserialize_from_dict(args_dict) - - result = None - - if self.logger.getEffectiveLevel() <= logging.DEBUG: - try: - # we want to get log the deserialized values, because they're useful. - # but this also means a bad repr can break things. So we get ready to - # catch that and fallback to undeserialized values - self.logger.debug("local_exec({},{})".format(args[EXPR], args[KWARGS])) - except EXCEPTION_TYPES as e: - self.logger.debug( - "Failed to log deserialized arguments: {}\n{}".format( - e, traceback.format_exc() - ) - ) - self.logger.debug("Falling back:\nlocal_exec {}".format(args_dict)) - - try: - """the import __main__ trick allows accessing all the variables that the bridge imports, - so execs will run within the global context of what started the bridge, and the arguments - supplied as kwargs will override that""" - exec_expr = args[EXPR] - exec_globals = importlib.import_module("__main__").__dict__ - # unlike remote_eval, we add the kwargs to the globals, because the most common use of remote_exec is to define a function/class, and locals aren't accessible in those definitions - exec_globals.update(args[KWARGS]) - # do the exec, or defer to the hook if we've registered one - if self.local_exec_hook is None: - exec(exec_expr, exec_globals) - else: - self.local_exec_hook(self, exec_expr, exec_globals) - self.logger.debug("local_exec: Finished executing") - except EXCEPTION_TYPES as e: - result = e - traceback.print_exc() - - return result - - def remoteify(self, module_class_or_function, **kwargs): - """Push a module, class or function definition into the remote python interpreter, and return a handle to it. - - Notes: - * requires that the class or function code is able to be understood by the remote interpreter (e.g., if it's running python2, the source must be python2 compatible) - * If remoteify-ing a class, the class can't be defined in a REPL (a limitation of inspect.getsource). You need to define it in a file somewhere. - * If remoteify-ing a module, it can't do relative imports - they require a package structure which won't exist - * If remoteify-ing a module, you only get the handle back - it's not installed into the remote or local sys.modules, you need to do that yourself. - * You can't remoteify a decorated function/class - it'll only get the source for the decorator wrapper, not the original. - """ - source_string = inspect.getsource(module_class_or_function) - name = module_class_or_function.__name__ - - # random name that'll appear in the __main__ globals to retrieve the remote definition. - # Used to avoid colliding with other uses of the name that might be there, or other clients - temp_name = "_bridge_remoteify_temp_result" + "".join( - [random.choice("0123456789ABCDEF") for _ in range(0, 8)] - ) - - if isinstance(module_class_or_function, types.ModuleType): - """Modules need a bit of extra love and care.""" - # We'll use the temp_name to store the source of the module (makes it easier than patching it into the format string below and escaping everything), - # and pass it as a global to the exec - kwargs[temp_name] = source_string - - # We create a new module context to execute the module code in, then run a second exec from - # the first exec inside the new module's __dict__, so imports are set correctly as globals of the module (not globals of the exec) - # Note that we need to force the module name to be a string - python2 doesn't support unicode module names - source_string = "import types\nnew_mod = types.ModuleType(str('{name}'))\nexec({temp_name}, new_mod.__dict__)\n".format( - name=name, temp_name=temp_name - ) - # update name to capture the new module object we've created - name = "new_mod" - - elif ( - source_string[0] in " \t" - ): # modules won't be indented, only a class/function issue - # source is indented to some level, so dedent it to avoid an indentation error - source_string = textwrap.dedent(source_string) - - retrieval_string = "\nglobals()['{temp_name}'] = {name}".format( - temp_name=temp_name, name=name - ) - - # run the exec - self.remote_exec(source_string + retrieval_string, **kwargs) - - # retrieve from __main__ with remote_eval - result = self.remote_eval(temp_name) - - # nuke the temp name - the remote handle will keep the module/class/function around - self.remote_exec( - "global {temp_name}\ndel {temp_name}".format(temp_name=temp_name) - ) - - return result - - @stats_hit - def remote_shutdown(self): - self.logger.debug("remote_shutdown") - result = self.deserialize_from_dict(self.send_cmd({CMD: SHUTDOWN})) - print(result) - if SHUTDOWN in result and result[SHUTDOWN]: - # shutdown received - as a gross hack, send a followup that we don't expect to return, to unblock some loops and actually let things shutdown - self.send_cmd({CMD: SHUTDOWN}, get_response=False) - - return result - - @stats_hit - def local_shutdown(self): - global GLOBAL_BRIDGE_SHUTDOWN - - self.logger.debug("local_shutdown") - - GLOBAL_BRIDGE_SHUTDOWN = True - - return {SHUTDOWN: True} - - def handle_command(self, message_dict, want_response=True): - response_dict = { - VERSION: COMMS_VERSION_5, - ID: message_dict[ID], - TYPE: RESULT, - RESULT: {}, - } - - command_dict = message_dict[CMD] - - if command_dict[CMD] == DEL: - self.local_del(command_dict[ARGS]) # no result required - else: - result = None - if command_dict[CMD] == GET: - result = self.local_get(command_dict[ARGS]) - elif command_dict[CMD] == SET: - result = self.local_set(command_dict[ARGS]) - elif command_dict[CMD] == CALL: - result = self.local_call(command_dict[ARGS]) - elif command_dict[CMD] == IMPORT: - result = self.local_import(command_dict[ARGS]) - elif command_dict[CMD] == TYPE: - result = self.local_get_type(command_dict[ARGS]) - elif command_dict[CMD] == CREATE_TYPE: - result = self.local_create_type(command_dict[ARGS]) - elif command_dict[CMD] == GET_ALL: - result = self.local_get_all(command_dict[ARGS]) - elif command_dict[CMD] == ISINSTANCE: - result = self.local_isinstance(command_dict[ARGS]) - elif command_dict[CMD] == EVAL: - result = self.local_eval(command_dict[ARGS]) - elif command_dict[CMD] == EXEC: - result = self.local_exec(command_dict[ARGS]) - elif command_dict[CMD] == SHUTDOWN: - result = self.local_shutdown() - - if want_response: # only serialize if we want a response - response_dict[RESULT] = self.serialize_to_dict(result) - - if want_response: - self.logger.debug("Responding with {}".format(response_dict)) - return json.dumps(response_dict).encode("utf-8") - else: - return None - - def get_bridge_type(self, bridged_obj_dict, callable=False): - # Get a dynamic bridging type from the cache based on the type name, or create it based on the type recovered from the instance bridge handle - bridge_handle = bridged_obj_dict[HANDLE] - type_name = bridged_obj_dict[TYPE] - - # short circuit - any function-like thing, as well as any type (or java.lang.Class) becomes a BridgedCallable (need to invoke types/classes, so they're callable) - if type_name in [ - "type", - "java.lang.Class", - "function", - "builtin_function_or_method", - "instancemethod", - "method_descriptor", - "wrapper_descriptor", - "reflectedfunction", # jython - e.g. jarray.zeros() - ]: - return BridgedCallable - elif type_name in ["module", "javapackage"]: - return BridgedModule - - # if we've already handled this type, use the old one - if type_name in self.cached_bridge_types: - return self.cached_bridge_types[type_name] - - self.logger.debug("Creating type " + type_name) - # need to create a type - # grab the remote type for the instance. - remote_type = self.remote_get_type(bridge_handle) - - # create the class dict by getting any of the methods we're interested in - class_dict = {} - for method_name in BRIDGED_CLASS_METHODS: - if method_name in remote_type._bridge_attrs: - class_dict[method_name] = remote_type._bridged_get(method_name) - - # handle a python2/3 compatibility issue - 3 uses truediv for /, 2 uses div unless you've imported - # __future__.division. Allow falling back to __div__ if __truediv__ requested but not present - if ( - "__div__" in remote_type._bridge_attrs - and "__truediv__" not in remote_type._bridge_attrs - ): - class_dict["__truediv__"] = remote_type._bridged_get("__div__") - - # create the bases - any class level method which requires special implementation needs to add the relevant type - bases = (BridgedObject,) - - if callable: - bases = (BridgedCallable,) - elif ( - "__next__" in remote_type._bridge_attrs - or "next" in remote_type._bridge_attrs - ): - bases = (BridgedIterator,) - - local_type = type( - str("_bridged_" + type_name), bases, class_dict - ) # str to force it to non-unicode in py2 (is unicode thanks to unicode_literals) - self.cached_bridge_types[type_name] = local_type - - return local_type - - def build_bridged_object(self, obj_dict, callable=False): - # construct a bridgedobject, including getting/creating a local dynamic type for its type - bridge_type = self.get_bridge_type(obj_dict, callable=callable) - - return bridge_type(self, obj_dict) - - def get_stats(self): - """Get a copy of the statistics accumulated in the run of this connection so far. Requires that __init__ was called with - record_stats=True - """ - stats = None - if self.stats is not None: - stats = self.stats.copy() - - return stats - - @stats_hit - def add_response(self, msg_dict): - # Just a wrapper to allow us to record this stat - self.response_mgr.add_response(msg_dict) - - -class BridgeServer( - threading.Thread -): # TODO - have BridgeServer and BridgeClient share a class - """Python2Python RPC bridge server - - Like a thread, so call run() to run directly, or start() to run on a background thread - """ - - is_serving = False - local_call_hook = None - local_eval_hook = None - local_exec_hook = None - - def __init__( - self, - server_host=DEFAULT_HOST, - server_port=0, - loglevel=None, - response_timeout=DEFAULT_RESPONSE_TIMEOUT, - local_call_hook=None, - local_eval_hook=None, - local_exec_hook=None, - ): - """Set up the bridge. - - server_host/port: host/port to listen on to serve requests. If not specified, defaults to 127.0.0.1:0 (random port - use get_server_info() to find out where it's serving) - loglevel - what messages to log - response_timeout - how long to wait for a response before throwing an exception, in seconds - """ - global GLOBAL_BRIDGE_SHUTDOWN - - super(BridgeServer, self).__init__() - - # init the server - self.server = ThreadingTCPServer( - (server_host, server_port), BridgeCommandHandler - ) - # the server needs to be able to get back to the bridge to handle commands, but we don't want that reference keeping the bridge alive - self.server.bridge = weakref.proxy(self) - self.server.timeout = 1 - self.daemon = True - - logging.basicConfig() - self.logger = logging.getLogger(__name__) - if loglevel is None: # we don't want any logging - ignore everything - loglevel = logging.CRITICAL + 1 - - self.logger.setLevel(loglevel) - self.response_timeout = response_timeout - - # if we're starting the server, we need to make sure the flag is set to false - GLOBAL_BRIDGE_SHUTDOWN = False - - # specify a callable to local_call_hook(bridge_conn, target_callable, *args, **kwargs) or - # local_eval_hook(bridge_conn, eval_expression, eval_globals_dict, eval_locals_dict) to - # hook local_call/local_eval to allow inspection/modification of calls/evals (e.g., forcing them onto a particular thread) - self.local_call_hook = local_call_hook - self.local_eval_hook = local_eval_hook - self.local_exec_hook = local_exec_hook - - def get_server_info(self): - """return where the server is serving on""" - return self.server.socket.getsockname() - - def run(self): - self.logger.info( - "serving! (jfx_bridge v{}, Python {}.{}.{})".format( - __version__, - sys.version_info.major, - sys.version_info.minor, - sys.version_info.micro, - ) - ) - self.is_serving = True - self.server.serve_forever() - self.logger.info("stopped serving") - - def __del__(self): - self.shutdown() - - def shutdown(self): - if self.is_serving: - self.logger.info("Shutting down bridge") - self.is_serving = False - self.server.shutdown() - self.server.server_close() - - -class BridgeClient(object): - """Python2Python RPC bridge client""" - - local_call_hook = None - local_eval_hook = None - local_exec_hook = None - _bridge = None - - def __init__( - self, - connect_to_host=DEFAULT_HOST, - connect_to_port=DEFAULT_SERVER_PORT, - loglevel=None, - response_timeout=DEFAULT_RESPONSE_TIMEOUT, - hook_import=False, - record_stats=False, - ): - """Set up the bridge client - connect_to_host/port - host/port to connect to run commands. - loglevel - what messages to log (e.g., logging.INFO, logging.DEBUG) - response_timeout - how long to wait for a response before throwing an error, in seconds - hook_import - set to True to add a hook to the import system to allowing importing remote modules - """ - logging.basicConfig() - self.logger = logging.getLogger(__name__) - if loglevel is None: # we don't want any logging - ignore everything - loglevel = logging.CRITICAL + 1 - - self.logger.setLevel(loglevel) - - self.client = BridgeConn( - self, - sock=None, - connect_to_host=connect_to_host, - connect_to_port=connect_to_port, - response_timeout=response_timeout, - record_stats=record_stats, - ) - - if hook_import: - # add a path_hook for this bridge - sys.path_hooks.append(BridgedModuleFinderLoader(self).path_hook_fn) - # add an entry for this bridge client's bridge connection to the paths. - # We add it at the end, so we only catch imports that no one else wants to handle - sys.path.append(repr(self.client)) - # TODO make sure we remove the finder when the client is torn down? - - self._bridge = self - - @property - def bridge(self): - """for backwards compatibility with old examples using external_bridge.bridge.remote_import/etc, - before the external bridges just inherited from BridgeClient - Allow access, but warn about it - """ - warnings.warn( - "Using .bridge to get to remote_import/eval/shutdown is deprecated - just do .remote_import/etc.", - DeprecationWarning, - ) - return self._bridge - - def remote_import(self, module_name): - return self.client.remote_import(module_name) - - def remote_eval(self, eval_string, timeout_override=None, **kwargs): - """ - Takes an expression as an argument and evaluates it entirely on the server. - Example: b.bridge.remote_eval('[ f.name for f in currentProgram.functionManager.getFunctions(True)]') - If this expression would be evaluated on the client, it would take 2-3 minutes for a binary with ~8k functions due to ~8k roundtrips to call __next__ and ~8k roundtrips to access the name attribute - - Caveats: - - The expression `[ f for f in currentProgram.functionManager.getFunctions(True)]` still takes roughly a 1 minute to finish. Almost the entire time is spent sending the message to the client. This issue requires a deeper change in the RPC implementation to increase throughput or reduce message size - - To provide arguments into the eval context, supply them as keyword arguments with names matching the names used in the eval string (e.g., remote_eval("x+1", x=2)) - """ - return self.client.remote_eval( - eval_string, timeout_override=timeout_override, **kwargs - ) - - def remote_exec(self, exec_string, timeout_override=None, **kwargs): - """Takes python script as a string and executes it entirely on the server. - - To provide arguments into the exec context, supply them as keyword arguments with names matching the names used in the exec string (e.g., remote_exec("print(x)", x="helloworld")). - - Note: the python script must be able to be understood by the remote interpreter (e.g., if it's running python2, the script must be python2 compatible) - """ - return self.client.remote_exec( - exec_string, timeout_override=timeout_override, **kwargs - ) - - def remoteify(self, module_class_or_function, **kwargs): - """Push a module, class or function definition into the remote python interpreter, and return a handle to it. - - Notes: - * requires that the class or function code is able to be understood by the remote interpreter (e.g., if it's running python2, the source must be python2 compatible) - * If remoteify-ing a class, the class can't be defined in a REPL (a limitation of inspect.getsource). You need to define it in a file somewhere. - * If remoteify-ing a module, it can't do relative imports - they require a package structure which won't exist - * If remoteify-ing a module, you only get the handle back - it's not installed into the remote or local sys.modules, you need to do that yourself. - * You can't remoteify a decorated function/class - it'll only get the source for the decorator wrapper, not the original. - """ - return self.client.remoteify(module_class_or_function, **kwargs) - - def remote_shutdown(self): - return self.client.remote_shutdown() - - def get_stats(self): - """Get the statistics recorded across the run of this BridgeClient""" - return self.client.get_stats() - - -def _is_bridged_object(object): - """Utility function to detect if an object is bridged or not. - - Not recommended for use outside this class, because it breaks the goal that you shouldn't - need to know if something is bridged or not - """ - return hasattr(object, "_bridge_type") - - -def bridged_isinstance(test_object, class_or_tuple): - """Utility function to wrap isinstance to handle bridged objects. Behaves as isinstance, but if all the objects/classes - are bridged, will direct the call over the bridge. - - Currently, don't have a good way of handling a mix of bridge/non-bridge, so will just return false - """ - # make sure we have the real isinstance, just in case we've overridden it (e.g., with ghidra_bridge namespace) - builtin_isinstance = None - try: - from builtins import isinstance as builtin_isinstance # python3 - except: - # try falling back to python2 syntax - from __builtin__ import isinstance as builtin_isinstance - - result = False - - # force class_or_tuple to be a tuple - just easier that way - if _is_bridged_object(class_or_tuple): - # bridged object, so not a tuple - class_or_tuple = (class_or_tuple,) - if not builtin_isinstance(class_or_tuple, tuple): - # local clazz, not a tuple - class_or_tuple = (class_or_tuple,) - - # now is the test_object bridged or not? - if _is_bridged_object(test_object): - # yes - we need to handle. - # remove any non-bridged classes in the tuple - new_tuple = tuple( - clazz for clazz in class_or_tuple if _is_bridged_object(clazz) - ) - - if ( - new_tuple - ): # make sure there's still some things left to check - otherwise, just return false without shooting it over the bridge - result = test_object._bridge_isinstance(new_tuple) - else: - # test_object isn't bridged - remove any bridged classes in the tuple and palm it off to isinstance - new_tuple = tuple( - clazz for clazz in class_or_tuple if not _is_bridged_object(clazz) - ) - - result = builtin_isinstance(test_object, new_tuple) - - return result - - -class BridgedObject(object): - """An object you can only interact with on the opposite side of a bridge""" - - _bridge_conn = None - _bridge_handle = None - _bridge_type = None - _bridge_attrs = None - # overrides allow you to make changes just in the local bridge object, not against the remote object (e.g., to avoid conflicts with interactive fixups to the remote __main__) - _bridge_overrides = None - - # list of methods which we don't bridge, but need to have specific names (so we can't use the _bridge prefix for them) - # TODO decorator to mark a function as local, don't bridge it - then have it automatically fill this out (also needs to work for subclasses) - _LOCAL_METHODS = [ - "__del__", - "__str__", - "__repr__", - "__dir__", - "__bool__", - "__nonzero__", - "getdoc", - ] - - # list of attrs that we don't want to waste bridge calls on - _DONT_BRIDGE = [ - "__mro_entries__", # ignore mro entries - only being called if we're creating a class based off a bridged object - # associated with ipython - "_ipython_canary_method_should_not_exist_", - "__sizeof__", - ] - - # list of attrs that we don't want to waste bridge calls on, unless they really are defined in the bridged object - _DONT_BRIDGE_UNLESS_IN_ATTRS = [ - # associated with ipython - "_repr_mimebundle_", - "__init_subclass__", - # javapackage objects (like the ghidra module) don't have a __delattr__ - "__delattr__", - # for fmagin's ipyghidra - "__signature__", - "__annotations__", - "__objclass__", - "__wrapped__", - ] - - def __init__(self, bridge_conn, obj_dict): - self._bridge_conn = bridge_conn - self._bridge_handle = obj_dict[HANDLE] - self._bridge_type = obj_dict[TYPE] - self._bridge_attrs = obj_dict[ATTRS] - self._bridge_repr = obj_dict[REPR] - self._bridge_overrides = dict() - - def __getattribute__(self, attr): - if ( - attr.startswith(BRIDGE_PREFIX) - or attr == "__class__" - or attr in BridgedObject._DONT_BRIDGE - or attr in BridgedObject._LOCAL_METHODS - or ( - attr in BridgedObject._DONT_BRIDGE_UNLESS_IN_ATTRS - and attr not in self._bridge_attrs - ) - ): - # we don't want to bridge this for one reason or another (including it may not exist on the other end), - # so get the local version, or accept the AttributeError that we'll get if it's not present locally. - result = object.__getattribute__(self, attr) - else: - try: - result = self._bridged_get(attr) - except BridgeException as be: - # unwrap AttributeErrors if they occurred on the other side of the bridge - if be.args[1]._bridge_type.endswith("AttributeError"): - raise AttributeError(be.args[0]) - else: - # some other cause - just reraise the exception - raise - - return result - - def __setattr__(self, attr, value): - if attr.startswith(BRIDGE_PREFIX): - object.__setattr__(self, attr, value) - else: - self._bridged_set(attr, value) - - def _bridged_get(self, name): - if name in self._bridge_overrides: - return self._bridge_overrides[name] - - return self._bridge_conn.remote_get(self._bridge_handle, name) - - def _bridged_get_all(self): - """As an optimisation, get all of the attributes at once and store them as overrides. - - Should only use this for objects that are unlikely to have their attributes change values (e.g., imported modules), - otherwise you won't be able to get the updated values without clearing the override - """ - attrs_dict = self._bridge_conn.remote_get_all(self._bridge_handle) - - # the result is a dictionary of attributes and their bridged objects. set them as overrides in the bridged object - for name, value in attrs_dict.items(): - self._bridge_set_override(name, value) - - def _bridged_set(self, name, value): - if name in self._bridge_overrides: - self._bridge_overrides[name] = value - else: - self._bridge_conn.remote_set(self._bridge_handle, name, value) - - def _bridged_get_type(self): - """Get a bridged object representing the type of this object""" - return self._bridge_conn.remote_get_type(self._bridge_handle) - - def _bridge_set_override(self, name, value): - self._bridge_overrides[name] = value - - def _bridge_clear_override(self, name): - del self._bridge_overrides[name] - - def _bridge_isinstance(self, bridged_class_or_tuple): - """check whether this object is an instance of the bridged class (or tuple of bridged classes)""" - # enforce that the bridged_class_or_tuple elements are actually bridged - if not _is_bridged_object(bridged_class_or_tuple): - # might be a tuple - if isinstance(bridged_class_or_tuple, tuple): - # check all the elements of the tuple - for clazz in bridged_class_or_tuple: - if not _is_bridged_object(clazz): - raise Exception( - "Can't use _bridge_isinstance with non-bridged class {}".format( - clazz - ) - ) - else: - # nope :x - raise Exception( - "Can't use _bridge_isinstance with non-bridged class {}".format( - bridged_class_or_tuple - ) - ) - - # cool, arguments are valid - return self._bridge_conn.remote_isinstance(self, bridged_class_or_tuple) - - def __del__(self): - if ( - self._bridge_conn is not None - ): # only need to del if this was properly init'd - self._bridge_conn.remote_del(self._bridge_handle) - - def __repr__(self): - return "<{}('{}', type={}, handle={})>".format( - type(self).__name__, - self._bridge_repr, - self._bridge_type, - self._bridge_handle, - ) - - def __dir__(self): - return dir(super(type(self))) + ( - self._bridge_attrs if self._bridge_attrs else [] - ) - - def __bool__(self): - # py3 vs 2 - __bool__ vs __nonzero__ - return self._bridge_conn.remote_eval("bool(x)", x=self) - - __nonzero__ = __bool__ # handle being run in a py2 environment - - -class BridgedCallable(BridgedObject): - # TODO can we further make BridgedClass a subclass of BridgedCallable? How can we detect? Allow us to pull this class/type hack further away from normal calls - def __new__(cls, bridge_conn, obj_dict, class_init=None): - """BridgedCallables can also be classes, which means they might be used as base classes for other classes. If this happens, - you'll essentially get BridgedCallable.__new__ being called with 4 arguments to create the new class - (instead of 3, for an instance of BridgedCallable). - - We handle this by creating the class remotely, and returning the BridgedCallable to that remote class. Note that the class methods - (including __init__) will be bridged on the remote end, back to us. - - TODO: note sure what might happen if you define __new__ in a class that has a BridgedCallable as the base class - """ - if class_init is None: - # instance __new__ - return super(BridgedCallable, cls).__new__(cls) - else: - # want to create a class that's based off the remote class represented by a BridgedCallable (in the bases) - # [Assumption: BridgedCallable base always first? Not sure what would happen if you had multiple inheritance] - # ignore cls, it's just BridgedCallable - # name is the name we want to call the class - name = bridge_conn - # bases are what the class inherits from. Assuming the first one is the BridgedCallable - bases = obj_dict - # dct is the class dictionary - dct = class_init - assert isinstance(bases[0], BridgedCallable) - # create the class remotely, and return the BridgedCallable back to it - return bases[0]._bridge_conn.remote_create_type(name, bases, dct) - - def __init__(self, bridge_conn, obj_dict, class_init=None): - """As with __new__, __init__ may be called as part of a class creation, not just an instance of BridgedCallable. We just ignore that case""" - if class_init is None: - super(BridgedCallable, self).__init__(bridge_conn, obj_dict) - if "_bridge_nonreturn" in self._bridge_attrs: - # if the attribute is present (even if set to False/None), assume it's nonreturning. Shouldn't be present on anything else - self._bridge_nonreturn = True - - def __call__(self, *args, **kwargs): - # if we've marked this callable with _bridge_nonreturn, don't wait for a response - if getattr(self, "_bridge_nonreturn", False): - return self._bridge_call_nonreturn(*args, **kwargs) - - return self._bridge_conn.remote_call(self._bridge_handle, *args, **kwargs) - - def _bridge_call_nonreturn(self, *args, **kwargs): - """Explicitly invoke the call without expecting a response""" - return self._bridge_conn.remote_call_nonreturn( - self._bridge_handle, *args, **kwargs - ) - - def __get__(self, instance, owner): - """Implement descriptor get so that we can bind the BridgedCallable to an object if it's defined as part of a class - Use functools.partial to return a wrapper to the BridgedCallable with the instance object as the first arg - """ - return functools.partial(self, instance) - - -class BridgedIterator(BridgedObject): - def __next__(self): - # py2 vs 3 - next vs __next__ - try: - return self._bridged_get( - "__next__" if "__next__" in self._bridge_attrs else "next" - )() - except BridgeException as e: - # we expect the StopIteration exception - check to see if that's what we got, and if so, raise locally - if e.args[1]._bridge_type.endswith("StopIteration"): - raise StopIteration - # otherwise, something went bad - reraise - raise - - next = __next__ # handle being run in a py2 environment - - -class BridgedModule(BridgedObject): - """Represent a remote module (or javapackage) to allow for doing normal imports""" - - def __init__(self, bridge_conn, obj_dict): - BridgedObject.__init__(self, bridge_conn, obj_dict) - # python3 needs __path__ set (to anything) to treat a module as a package for doing "from foo.bar import flam" - # we mark the __path__ as the bridge_conn to allow easier detection of the package as a bridged one we might be responsible for - # strictly speaking, only packages need __path__, but we set it for modules as well so that we don't get heaps of errors on the server - # side when the import machinery tries to get __path__ for them - self._bridge_set_override("__path__", [repr(bridge_conn)]) - # allow a spec to be set. javapackages resist having attributes added, so we handle it here - self._bridge_set_override("__spec__", None) - - -class BridgedModuleFinderLoader: - """Add to sys.meta_path - returns itself if it can find a remote module to satisfy the import - - Note: position in sys.meta_path is important - you almost certainly want to add it to the end. Adding it at the start - could have it say it can load everything, and imports of local modules will instead be filled with remote modules - """ - - def __init__(self, bridge_client): - """Record the bridge client to use for remote importing""" - self.bridge_client = bridge_client - - def path_hook_fn(self, path): - """Called when the import machinery runs over path_hooks - returns itself as a finder if its this bridge connection""" - if path == repr(self.bridge_client.client): - return self - # not us, don't play along - raise ImportError() - - def find_module(self, fullname, path=None): - """called by import machinery - fullname is the dotted module name to load. If the module is part of a package, __path__ is from - the parent package - """ - if path is not None: - if repr(self.bridge_client.client) in path: - # this is coming from a bridged package in our bridge - return self - # parent isn't bridged, or is bridged but isn't from our bridge - we can't do anything with this - return None - - # package/module with no parent. See if it exists on the other side before we get excited - try: - self.bridge_client.remote_import(fullname) - # got something back, so yeah, we can fill it - return self - except BridgeException as be: - exception_type = be.args[1]._bridge_type - if exception_type.endswith( - "ModuleNotFoundError" - ) or exception_type.endswith("ImportError"): - # ModuleNotFoundError in py3, just ImportError in py2 - # module doesn't exist remotely, we can't help - return None - return None - else: - # something else went wrong with the bridge - reraise the exception so the user can deal with it - raise be - - def load_module(self, fullname): - """Called by import machinery - fullname is the dotted module name to load""" - # if the module is already loaded, just give that back - if fullname in sys.modules: - return sys.modules[fullname] - - # get the remote module - target = self.bridge_client.remote_import(fullname) - - # split out the name so we know - components = fullname.rsplit(".", 1) - parent = components[0] - if len(components) > 1: - child = components[1] - # set the child as an override on the parent, so the importlib machinery can set it as an attribute without stuffing up - needed for javapackage - if parent in sys.modules: - sys.modules[parent]._bridge_set_override(child, None) - - # set some import machinery fields - target._bridge_set_override("__loader__", self) - target._bridge_set_override("__package__", parent) - # ensure we have an override set on __spec__ for everything, including non-modules (e.g., BridgedCallables on java classes) - # otherwise, __spec__ gets set by import machinery later, leading to a client handle being pushed into the server, where other - # clients might get it if they import the same module - # TODO probably need to check there's nothing else being set against the modules? Or is there a way to reload modules for each new client? - target._bridge_set_override("__spec__", None) - - # add the module to sys.modules - sys.modules[fullname] = target - - # hand back the module - return target - - -def nonreturn(func): - """Decorator to simplying marking a function as nonreturning for the bridge""" - func._bridge_nonreturn = True - return func diff --git a/libbs/decompilers/ghidra/README.md b/libbs/decompilers/ghidra/README.md deleted file mode 100644 index 6e37a17b..00000000 --- a/libbs/decompilers/ghidra/README.md +++ /dev/null @@ -1,8 +0,0 @@ -## Installation -To install our Ghidra backend you need to do the following steps: -1. Install BinSync with the extra Ghidra dependencies: `pip install binsync[ghidra]` -2. Install the BinSync python stubs: `binsync --install` -3. Open a binary in Ghidra, then open `Window -> Script Manager` tab -4. Click on the left hand side `BinSync`, then click the check marks next to both scripts in the folder - -BinSync is now under your `Tools -> BinSync`. Use the `Start BinSync` to connect and go! \ No newline at end of file diff --git a/libbs/decompilers/ghidra/compat/bridge.py b/libbs/decompilers/ghidra/compat/bridge.py deleted file mode 100644 index ce3ae560..00000000 --- a/libbs/decompilers/ghidra/compat/bridge.py +++ /dev/null @@ -1,117 +0,0 @@ -import importlib -import inspect -import logging -import re -import time -from functools import wraps -import typing -from typing import Optional - -import ghidra_bridge - -if typing.TYPE_CHECKING: - from ..interface import GhidraDecompilerInterface - -_l = logging.getLogger(name=__name__) - - -def connect_to_bridge(connection_timeout=20) -> Optional[ghidra_bridge.GhidraBridge]: - start_time = time.time() - bridge = None - while time.time() - start_time < connection_timeout: - try: - bridge = ghidra_bridge.GhidraBridge( - namespace=globals(), interactive_mode=True, response_timeout=30 - ) - except ConnectionError as e: - _l.debug("Failed to connect to GhidraBridge: %s", e) - time.sleep(1) - - if bridge is not None: - break - - return bridge - - -def shutdown_bridge(bridge: ghidra_bridge.GhidraBridge): - if bridge is None: - return False - - return bool(bridge.remote_shutdown()) - - -def _ping_bridge(bridge: ghidra_bridge.GhidraBridge) -> bool: - connected = False - if bridge is not None: - try: - bridge.remote_eval("True") - connected = True - except Exception: - pass - - return connected - - -def is_bridge_alive(bridge: ghidra_bridge.GhidraBridge) -> bool: - return _ping_bridge(bridge) - - -class FlatAPIWrapper: - def __getattr__(self, name): - g = globals() - if name in g: - return g[name] - else: - raise AttributeError(f"No global import named {name}") - - -def ui_remote_eval(f): - @wraps(f) - def _ui_remote_eval(self: "GhidraDecompilerInterface", *args, **kwargs): - # exit early, no analysis needed - if self.headless: - return f(self, *args, **kwargs) - - # extract every argument name from the function signature - code_args = list(inspect.getfullargspec(f).args)[1:len(args)+1] - args_by_name = { - arg: val for arg, val in zip(code_args, args) - } - args_by_name["_self"] = self - args_by_name.update(kwargs) - - # update the code that uses self to use the _self variable - f_code = inspect.getsource(f) - f_code = f_code.replace("self.", "_self.") - - # extract all (from * imports) with a regex, and import them - import_pairs = re.findall("from (.*?) import (.*?)\n", f_code) - imported_objs = {} - for module, objs in import_pairs: - module_obj = importlib.import_module( - module, package="libbs.decompilers.ghidra" if module.startswith(".") else None - ) - for obj in objs.split(","): - obj_name = obj.strip() - imported_objs[obj_name] = getattr(module_obj, obj_name) - - namespace = args_by_name - namespace.update(imported_objs) - - # extract the remote code - remote_codes = re.findall(r"return (\[.*])", f_code.replace("\n", " ")) - if len(remote_codes) != 1: - raise ValueError(f"Failed to extract remote code from function {f}! This must be a bug in writing.") - - remote_code = remote_codes[0] - try: - val = self._bridge.remote_eval(remote_code, **namespace) - except Exception as e: - self.error(f"Failed to evaluate remote code: {remote_code}") - val = [] - - return val - - return _ui_remote_eval - - diff --git a/libbs/decompilers/ghidra/compat/headless.py b/libbs/decompilers/ghidra/compat/headless.py index a05bd1a3..c81aaf99 100644 --- a/libbs/decompilers/ghidra/compat/headless.py +++ b/libbs/decompilers/ghidra/compat/headless.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import Union, Optional, Tuple -from pyhidra.core import _analyze_program, _get_language, _get_compiler_spec +from pyghidra.core import _analyze_program, _get_language, _get_compiler_spec from jpype import JClass _l = logging.getLogger(__name__) @@ -22,12 +22,12 @@ def open_program( Taken from Pyhidra, but updated to also return the project associated with the program: https://github.com/dod-cyber-crime-center/pyhidra/blob/c878e91b53498f65f2eb0255e22189a6d172917c/pyhidra/core.py#L178 """ - from pyhidra.launcher import PyhidraLauncher, HeadlessPyhidraLauncher + from pyghidra.launcher import PyGhidraLauncher, HeadlessPyGhidraLauncher if binary_path is None and project_location is None: raise ValueError("You must provide either a binary path or a project location.") - if not PyhidraLauncher.has_launched(): - HeadlessPyhidraLauncher().start() + if not PyGhidraLauncher.has_launched(): + HeadlessPyGhidraLauncher().start() from ghidra.app.script import GhidraScriptUtil from ghidra.program.flatapi import FlatProgramAPI @@ -58,6 +58,7 @@ def _setup_project( loader: Union[str, JClass] = None ) -> Tuple["GhidraProject", "Program"]: from ghidra.base.project import GhidraProject + from ghidra.util.exception import NotFoundException from java.lang import ClassLoader from java.io import IOException @@ -70,6 +71,8 @@ def _setup_project( if not project_name: project_name = f"{binary_path.name}_ghidra" project_location /= project_name + + # Ensure the project location directory exists project_location.mkdir(exist_ok=True, parents=True) if isinstance(loader, str): @@ -95,7 +98,7 @@ def _setup_project( program_name = binary_path.name if project.getRootFolder().getFile(program_name): program = project.openProgram("/", program_name, False) - except IOException: + except (IOException, NotFoundException): project = GhidraProject.createProject(project_location, project_name, False) # NOTE: GhidraProject.importProgram behaves differently when a loader is provided diff --git a/libbs/decompilers/ghidra/compat/imports.py b/libbs/decompilers/ghidra/compat/imports.py index 705aec9c..68580dc6 100644 --- a/libbs/decompilers/ghidra/compat/imports.py +++ b/libbs/decompilers/ghidra/compat/imports.py @@ -1,12 +1,7 @@ import logging -from typing import Tuple, Iterable _l = logging.getLogger(__name__) -from ..interface import bridge -bridge = bridge or globals().get("binsync_ghidra_bridge", None) -HEADLESS = bridge is None - def get_private_class(path: str): from java.lang import ClassLoader @@ -15,66 +10,32 @@ def get_private_class(path: str): gcl = ClassLoader.getSystemClassLoader() return JClass(path, loader=gcl) +from ghidra.framework.model import DomainObjectListener +from ghidra.program.model.symbol import SourceType, SymbolType +from ghidra.program.model.pcode import HighFunctionDBUtil +from ghidra.program.model.data import ( + DataTypeConflictHandler, StructureDataType, ByteDataType, EnumDataType, CategoryPath, TypedefDataType +) +from ghidra.program.util import ChangeManager, ProgramChangeRecord, FunctionChangeRecord +from ghidra.program.database.function import VariableDB, FunctionDB +from ghidra.program.database.symbol import CodeSymbol, FunctionSymbol +from ghidra.program.model.listing import CodeUnit +from ghidra.app.cmd.comments import SetCommentCmd +from ghidra.app.cmd.label import RenameLabelCmd +from ghidra.app.context import ProgramLocationContextAction, ProgramLocationActionContext +from ghidra.app.decompiler import DecompInterface +from ghidra.app.plugin.core.analysis import AutoAnalysisManager +from ghidra.app.util.cparser.C import CParserUtils +from ghidra.app.decompiler import PrettyPrinter +from ghidra.util.task import ConsoleTaskMonitor +from ghidra.util.data import DataTypeParser +from ghidra.util.exception import CancelledException +from docking.action import MenuData +from docking.action.builder import ActionBuilder -def import_objs(path: str, objs: Iterable[str]): - module = bridge.remote_import(path) - new_objs = [getattr(module, obj) for obj in objs] - return new_objs if len(new_objs) > 1 else new_objs[0] - - -if HEADLESS: - from ghidra.framework.model import DomainObjectListener - from ghidra.program.model.symbol import SourceType, SymbolType - from ghidra.program.model.pcode import HighFunctionDBUtil - from ghidra.program.model.data import ( - DataTypeConflictHandler, StructureDataType, ByteDataType, EnumDataType, CategoryPath, TypedefDataType - ) - from ghidra.program.util import ChangeManager, ProgramChangeRecord, FunctionChangeRecord - from ghidra.program.database.function import VariableDB, FunctionDB - from ghidra.program.database.symbol import CodeSymbol, FunctionSymbol - from ghidra.program.model.listing import CodeUnit - from ghidra.app.cmd.comments import SetCommentCmd - from ghidra.app.cmd.label import RenameLabelCmd - from ghidra.app.context import ProgramLocationContextAction - from ghidra.app.decompiler import DecompInterface - from ghidra.app.plugin.core.analysis import AutoAnalysisManager - from ghidra.app.util.cparser.C import CParserUtils - from ghidra.app.decompiler import PrettyPrinter - from ghidra.util.task import ConsoleTaskMonitor - from ghidra.util.data import DataTypeParser - from ghidra.util.exception import CancelledException - from docking.action import MenuData - - EnumDB = get_private_class("ghidra.program.database.data.EnumDB") - StructureDB = get_private_class("ghidra.program.database.data.StructureDB") - TypedefDB = get_private_class("ghidra.program.database.data.TypedefDB") -else: - DomainObjectListener = import_objs("ghidra.framework.model", ("DomainObjectListener",)) - SourceType, SymbolType = import_objs("ghidra.program.model.symbol", ("SourceType", "SymbolType")) - HighFunctionDBUtil = import_objs("ghidra.program.model.pcode", ("HighFunctionDBUtil",)) - DataTypeConflictHandler, StructureDataType, ByteDataType, EnumDataType, CategoryPath, TypedefDataType = import_objs( - "ghidra.program.model.data", - ("DataTypeConflictHandler", "StructureDataType", "ByteDataType", "EnumDataType", "CategoryPath", "TypedefDataType") - ) - ChangeManager, ProgramChangeRecord, FunctionChangeRecord = import_objs("ghidra.program.util", ("ChangeManager", "ProgramChangeRecord", "FunctionChangeRecord")) - VariableDB, FunctionDB = import_objs("ghidra.program.database.function", ("VariableDB", "FunctionDB")) - CodeSymbol, FunctionSymbol = import_objs("ghidra.program.database.symbol", ("CodeSymbol", "FunctionSymbol")) - CodeUnit = import_objs("ghidra.program.model.listing", ("CodeUnit",)) - SetCommentCmd = import_objs("ghidra.app.cmd.comments", ("SetCommentCmd",)) - RenameLabelCmd = import_objs("ghidra.app.cmd.label", ("RenameLabelCmd",)) - ProgramLocationContextAction = import_objs("ghidra.app.context", ("ProgramLocationContextAction",)) - DecompInterface = import_objs("ghidra.app.decompiler", ("DecompInterface",)) - AutoAnalysisManager = import_objs("ghidra.app.plugin.core.analysis", ("AutoAnalysisManager",)) - CParserUtils = import_objs("ghidra.app.util.cparser.C", ("CParserUtils",)) - PrettyPrinter = import_objs("ghidra.app.decompiler", ("PrettyPrinter",)) - ConsoleTaskMonitor = import_objs("ghidra.util.task", ("ConsoleTaskMonitor",)) - DataTypeParser = import_objs("ghidra.util.data", ("DataTypeParser",)) - CancelledException = import_objs("ghidra.util.exception", ("CancelledException",)) - MenuData = import_objs("docking.action", ("MenuData",)) - EnumDB = import_objs("ghidra.program.database.data", ("EnumDB",)) - StructureDB = import_objs("ghidra.program.database.data", ("StructureDB",)) - TypedefDB = import_objs("ghidra.program.database.data", ("TypedefDB",)) - +EnumDB = get_private_class("ghidra.program.database.data.EnumDB") +StructureDB = get_private_class("ghidra.program.database.data.StructureDB") +TypedefDB = get_private_class("ghidra.program.database.data.TypedefDB") __all__ = [ # forcefully imported objects @@ -88,7 +49,9 @@ def import_objs(path: str, objs: Iterable[str]): "CodeSymbol", "FunctionSymbol", "ProgramLocationContextAction", + "ProgramLocationActionContext", "MenuData", + "ActionBuilder", "HighFunctionDBUtil", "DataTypeConflictHandler", "StructureDataType", diff --git a/libbs/decompilers/ghidra/compat/state.py b/libbs/decompilers/ghidra/compat/state.py new file mode 100644 index 00000000..cab98bb3 --- /dev/null +++ b/libbs/decompilers/ghidra/compat/state.py @@ -0,0 +1,54 @@ +import logging + +_l = logging.getLogger(__name__) + + +def _get_python_plugin(flat_api=None): + if flat_api is not None: + state = flat_api.getState() + else: + _l.warning("Using internal ghidra functions without a distinct FlatAPI is likely dangerous!") + # assume it must be either in the globals or __this__ object, but this will likley crash if we are here + gvs = dict(globals()) + state = gvs.get("getState", None) or gvs.get("__this__", None).getState + + tool = state.getTool() + api = None + if tool is not None: + for plugin in state.getTool().getManagedPlugins(): + if plugin.name == "PyGhidraPlugin": + api = plugin + break + else: + raise RuntimeError("PyGhidraPlugin not found") + else: + # This is s special case: semi-headless + # we started ghidra with something like pyhidra.run_script, which causes us to run the current instance + # as if it were a script, not a single service inside ghidra + api = state + + return api + + +def _in_headless_mode(flat_api): + return flat_api is not None and not hasattr(flat_api, "getState") + +# +# Public API for interacting with the Ghidra state +# + + +def get_current_program(flat_api=None) -> "ProgramDB": + api = _get_python_plugin(flat_api=flat_api) if not _in_headless_mode(flat_api) else flat_api + return api.getCurrentProgram() + + +def get_current_address(flat_api=None) -> int: + if _in_headless_mode(flat_api): + raise RuntimeError("Cannot get current address in headless mode") + + addr = _get_python_plugin(flat_api=flat_api).getProgramLocation().getAddress().offset + if addr is not None: + addr = int(addr) + + return addr \ No newline at end of file diff --git a/libbs/decompilers/ghidra/hooks.py b/libbs/decompilers/ghidra/hooks.py index 9d040bc1..017504d7 100644 --- a/libbs/decompilers/ghidra/hooks.py +++ b/libbs/decompilers/ghidra/hooks.py @@ -9,168 +9,115 @@ from libbs.decompilers.ghidra.interface import GhidraDecompilerInterface _l = logging.getLogger(__name__) +from .compat.imports import ( + DomainObjectListener, ChangeManager, ProgramChangeRecord, VariableDB, FunctionDB, CodeSymbol, + FunctionSymbol, FunctionChangeRecord +) +from jpype import JImplements, JOverride + + +@JImplements(DomainObjectListener, deferred=False) +class DataMonitor: + @JOverride + def __init__(self, deci: "GhidraDecompilerInterface"): + self._deci = deci + # Init event lists + self.funcEvents = { + ChangeManager.DOCR_FUNCTION_CHANGED, + ChangeManager.DOCR_FUNCTION_BODY_CHANGED, + ChangeManager.DOCR_VARIABLE_REFERENCE_ADDED, + ChangeManager.DOCR_VARIABLE_REFERENCE_REMOVED + } + + self.symDelEvents = { + ChangeManager.DOCR_SYMBOL_REMOVED + } + + self.symChgEvents = { + ChangeManager.DOCR_SYMBOL_ADDED, + ChangeManager.DOCR_SYMBOL_RENAMED, + ChangeManager.DOCR_SYMBOL_DATA_CHANGED + } + + self.typeEvents = { + ChangeManager.DOCR_SYMBOL_ADDRESS_CHANGED, + ChangeManager.DOCR_DATA_TYPE_CHANGED, + ChangeManager.DOCR_DATA_TYPE_REPLACED, + ChangeManager.DOCR_DATA_TYPE_RENAMED, + ChangeManager.DOCR_DATA_TYPE_SETTING_CHANGED, + ChangeManager.DOCR_DATA_TYPE_MOVED, + ChangeManager.DOCR_DATA_TYPE_ADDED + } + + self.imageBaseEvents = { + ChangeManager.DOCR_IMAGE_BASE_CHANGED + } + + self.TrackedEvents = ( + self.funcEvents | self.symDelEvents | self.symChgEvents | self.typeEvents | self.imageBaseEvents + ) + + @JOverride + def domainObjectChanged(self, ev): + try: + self.do_change_handler(ev) + except Exception as e: + excep_str = str(e).replace('\n', ' ') + self._deci.error(f"Error in domainObjectChanged: {excep_str}") + + def do_change_handler(self, ev): + for record in ev: + if not isinstance(record, ProgramChangeRecord): + continue + + changeType = record.getEventType() + if changeType not in self.TrackedEvents: + # bail out early if we don't care about this event + continue + + new_value = record.getNewValue() + obj = record.getObject() + if changeType in self.funcEvents: + func_change_type = record.getSpecificChangeType() + if func_change_type == FunctionChangeRecord.FunctionChangeType.RETURN_TYPE_CHANGED: + # Function return type changed + header = FunctionHeader( + name=None, addr=obj.getEntryPoint().getOffset(), type_=str(obj.getReturnType()) + ) + self._deci.function_header_changed(header) + + elif changeType in self.typeEvents: + if changeType == ChangeManager.DOCR_SYMBOL_ADDRESS_CHANGED: + # stack variables change address when retyped! + if isinstance(obj, VariableDB): + parent_namespace = obj.getParentNamespace() + storage = obj.getVariableStorage() + if ( + (new_value is not None) and (storage is not None) and bool(storage.isStackStorage()) + and (parent_namespace is not None) + ): + sv = StackVariable( + int(storage.stackOffset), + None, + str(obj.getDataType()), + int(storage.size), + int(obj.parentNamespace.entryPoint.offset) + ) + self._deci.stack_variable_changed( + sv + ) - -def create_data_monitor(deci: "GhidraDecompilerInterface"): - from .compat.imports import ( - DomainObjectListener, ChangeManager, ProgramChangeRecord, VariableDB, FunctionDB, CodeSymbol, - FunctionSymbol, FunctionChangeRecord - ) - - class DataMonitor(DomainObjectListener): - def __init__(self, deci: "GhidraDecompilerInterface"): - self._deci = deci - # Init event lists - self.funcEvents = { - ChangeManager.DOCR_FUNCTION_CHANGED, - ChangeManager.DOCR_FUNCTION_BODY_CHANGED, - ChangeManager.DOCR_VARIABLE_REFERENCE_ADDED, - ChangeManager.DOCR_VARIABLE_REFERENCE_REMOVED - } - - self.symDelEvents = { - ChangeManager.DOCR_SYMBOL_REMOVED - } - - self.symChgEvents = { - ChangeManager.DOCR_SYMBOL_ADDED, - ChangeManager.DOCR_SYMBOL_RENAMED, - ChangeManager.DOCR_SYMBOL_DATA_CHANGED - } - - self.typeEvents = { - ChangeManager.DOCR_SYMBOL_ADDRESS_CHANGED, - ChangeManager.DOCR_DATA_TYPE_CHANGED, - ChangeManager.DOCR_DATA_TYPE_REPLACED, - ChangeManager.DOCR_DATA_TYPE_RENAMED, - ChangeManager.DOCR_DATA_TYPE_SETTING_CHANGED, - ChangeManager.DOCR_DATA_TYPE_MOVED, - ChangeManager.DOCR_DATA_TYPE_ADDED - } - - self.imageBaseEvents = { - ChangeManager.DOCR_IMAGE_BASE_CHANGED - } - - self.TrackedEvents = ( - self.funcEvents | self.symDelEvents | self.symChgEvents | self.typeEvents | self.imageBaseEvents - ) - - def domainObjectChanged(self, ev): - try: - self.do_change_handler(ev) - except Exception as e: - excep_str = str(e).replace('\n', ' ') - self._deci.error(f"Error in domainObjectChanged: {excep_str}") - - def do_change_handler(self, ev): - for record in ev: - if not self._deci.isinstance(record, ProgramChangeRecord): - continue - - changeType = record.getEventType() - if changeType not in self.TrackedEvents: - # bail out early if we don't care about this event - continue - - new_value = record.getNewValue() - obj = record.getObject() - if changeType in self.funcEvents: - func_change_type = record.getSpecificChangeType() - if func_change_type == FunctionChangeRecord.FunctionChangeType.RETURN_TYPE_CHANGED: - # Function return type changed - header = FunctionHeader( - name=None, addr=obj.getEntryPoint().getOffset(), type_=str(obj.getReturnType()) - ) - self._deci.function_header_changed(header) - - elif changeType in self.typeEvents: - if changeType == ChangeManager.DOCR_SYMBOL_ADDRESS_CHANGED: - # stack variables change address when retyped! - if self._deci.isinstance(obj, VariableDB): - parent_namespace = obj.getParentNamespace() - storage = obj.getVariableStorage() - if ( - (new_value is not None) and (storage is not None) and bool(storage.isStackStorage()) - and (parent_namespace is not None) - ): - sv = StackVariable( - int(storage.stackOffset), - None, - str(obj.getDataType()), - int(storage.size), - int(obj.parentNamespace.entryPoint.offset) - ) - self._deci.stack_variable_changed( - sv - ) - - else: - try: - struct = self._deci.structs[new_value.name] - # TODO: access old name indicate deletion - # self._deci.struct_changed(Struct(None, None, None), deleted=True) - self._deci.struct_changed(struct) - except KeyError: - pass - if changeType == ChangeManager.DOCR_SYMBOL_ADDRESS_CHANGED: - # stack variables change address when retyped! - if self._deci.isinstance(obj, VariableDB): - parent_namespace = obj.getParentNamespace() - storage = obj.getVariableStorage() - if ( - (new_value is not None) and (storage is not None) and bool(storage.isStackStorage()) - and (parent_namespace is not None) - ): - self._deci.stack_variable_changed( - StackVariable( - int(storage.stackOffset), - None, - str(obj.getDataType()), - int(storage.size), - int(obj.parentNamespace.entryPoint.offset) - ) - ) - - else: - try: - struct = self._deci.structs[new_value.name] - # TODO: access old name indicate deletion - # self._deci.struct_changed(Struct(None, None, None), deleted=True) - self._deci.struct_changed(struct) - except KeyError: - pass - - try: - enum = self._deci.enums[new_value.name] - # self._deci.enum_changed(Enum(None, None), deleted=True) - self._deci.enum_changed(enum) - except KeyError: - pass - - elif changeType in self.symDelEvents: - # Globals are deleted first then recreated - if self._deci.isinstance(obj, CodeSymbol): - removed = GlobalVariable(obj.getAddress().getOffset(), obj.getName()) - # deleted kwarg not yet handled by global_variable_changed - self._deci.global_variable_changed(removed, deleted=True) - elif changeType in self.symChgEvents: - # For creation events, obj is stored in newValue - if obj is None and new_value is not None: - obj = new_value - - if changeType == ChangeManager.DOCR_SYMBOL_ADDED: - if self._deci.isinstance(obj, CodeSymbol): - gvar = GlobalVariable(obj.getAddress().getOffset(), obj.getName()) - self._deci.global_variable_changed(gvar) - elif changeType == ChangeManager.DOCR_SYMBOL_RENAMED: - if self._deci.isinstance(obj, CodeSymbol): - gvar = GlobalVariable(obj.getAddress().getOffset(), new_value) - self._deci.global_variable_changed(gvar) - if self._deci.isinstance(obj, FunctionSymbol): - header = FunctionHeader(name=new_value, addr=int(obj.getAddress().offset)) - self._deci.function_header_changed(header) - elif self._deci.isinstance(obj, VariableDB): + else: + try: + struct = self._deci.structs[new_value.name] + # TODO: access old name indicate deletion + # self._deci.struct_changed(Struct(None, None, None), deleted=True) + self._deci.struct_changed(struct) + except KeyError: + pass + if changeType == ChangeManager.DOCR_SYMBOL_ADDRESS_CHANGED: + # stack variables change address when retyped! + if isinstance(obj, VariableDB): parent_namespace = obj.getParentNamespace() storage = obj.getVariableStorage() if ( @@ -179,41 +126,101 @@ def do_change_handler(self, ev): ): self._deci.stack_variable_changed( StackVariable( - int(obj.variableStorage.stackOffset), - new_value, - None, + int(storage.stackOffset), None, + str(obj.getDataType()), + int(storage.size), int(obj.parentNamespace.entryPoint.offset) ) ) - elif self._deci.isinstance(obj, FunctionDB): - # TODO: Fix argument name support - # changed_arg = FunctionArgument(None, newValue, None, None) - # header = FunctionHeader(None, None, args={None: changed_arg}) - # self._deci.function_header_changed(header) + + else: + try: + struct = self._deci.structs[new_value.name] + # TODO: access old name indicate deletion + # self._deci.struct_changed(Struct(None, None, None), deleted=True) + self._deci.struct_changed(struct) + except KeyError: + pass + + try: + enum = self._deci.enums[new_value.name] + # self._deci.enum_changed(Enum(None, None), deleted=True) + self._deci.enum_changed(enum) + except KeyError: pass - else: - continue - elif changeType in self.imageBaseEvents: - new_base_addr = int(new_value.getOffset()) if new_value is not None else None - if new_base_addr is not None: - self._deci._binary_base_addr = new_base_addr + elif changeType in self.symDelEvents: + # Globals are deleted first then recreated + if isinstance(obj, CodeSymbol): + removed = GlobalVariable(obj.getAddress().getOffset(), obj.getName()) + # deleted kwarg not yet handled by global_variable_changed + self._deci.global_variable_changed(removed, deleted=True) + elif changeType in self.symChgEvents: + # For creation events, obj is stored in newValue + if obj is None and new_value is not None: + obj = new_value + + if changeType == ChangeManager.DOCR_SYMBOL_ADDED: + if isinstance(obj, CodeSymbol): + gvar = GlobalVariable(obj.getAddress().getOffset(), obj.getName()) + self._deci.global_variable_changed(gvar) + elif changeType == ChangeManager.DOCR_SYMBOL_RENAMED: + if isinstance(obj, CodeSymbol): + gvar = GlobalVariable(obj.getAddress().getOffset(), new_value) + self._deci.global_variable_changed(gvar) + if isinstance(obj, FunctionSymbol): + header = FunctionHeader(name=new_value, addr=int(obj.getAddress().offset)) + self._deci.function_header_changed(header) + elif isinstance(obj, VariableDB): + parent_namespace = obj.getParentNamespace() + storage = obj.getVariableStorage() + if ( + (new_value is not None) and (storage is not None) and bool(storage.isStackStorage()) + and (parent_namespace is not None) + ): + self._deci.stack_variable_changed( + StackVariable( + int(obj.variableStorage.stackOffset), + new_value, + None, + None, + int(obj.parentNamespace.entryPoint.offset) + ) + ) + elif isinstance(obj, FunctionDB): + # TODO: Fix argument name support + # changed_arg = FunctionArgument(None, newValue, None, None) + # header = FunctionHeader(None, None, args={None: changed_arg}) + # self._deci.function_header_changed(header) + pass + else: + continue + elif changeType in self.imageBaseEvents: + new_base_addr = int(new_value.getOffset()) if new_value is not None else None + if new_base_addr is not None: + self._deci._binary_base_addr = new_base_addr + + +def create_data_monitor(deci: "GhidraDecompilerInterface"): data_monitor = DataMonitor(deci) return data_monitor -def create_context_action(name, action_string, callback_func, category=None): - from .compat.imports import ProgramLocationContextAction, MenuData +def create_context_action(name, action_string, callback_func, category=None, plugin_name="libbs_ghidra", tool=None): + from .compat.imports import ProgramLocationActionContext, ActionBuilder + def _invoke(ctx: ProgramLocationActionContext): + threading.Thread(target=callback_func, daemon=True).start() - # XXX: you can't ever use super().__init__() due to some remote import issues - class GenericDecompilerCtxAction(ProgramLocationContextAction): - def actionPerformed(self, ctx): - threading.Thread(target=callback_func, daemon=True).start() + menu_path = [] + if category is not None and "/" in category: + menu_path.extend(category.split("/")) + menu_path.append(action_string) - action = GenericDecompilerCtxAction(name, category) - category_list = category.split("/") if category else [] - category_start = category_list[0] if category_list else category - action.setPopupMenuData(MenuData(category_list + [action_string], None, category_start)) + b = (ActionBuilder(name, plugin_name) + .popupMenuPath(list(menu_path)) + .withContext(ProgramLocationActionContext) + .validContextWhen(lambda ctx: ctx is not None and ctx.getAddress() is not None) + .onAction(_invoke)) - return action + return b.buildAndInstall(tool) diff --git a/libbs/decompilers/ghidra/interface.py b/libbs/decompilers/ghidra/interface.py index 3b0660bb..fe692e42 100644 --- a/libbs/decompilers/ghidra/interface.py +++ b/libbs/decompilers/ghidra/interface.py @@ -9,9 +9,6 @@ import queue import threading -from jfx_bridge.bridge import BridgedObject -from ghidra_bridge import GhidraBridge - from libbs.api import DecompilerInterface, CType from libbs.api.decompiler_interface import requires_decompilation from libbs.artifacts import ( @@ -20,8 +17,9 @@ ) from .artifact_lifter import GhidraArtifactLifter -from .compat.bridge import FlatAPIWrapper, connect_to_bridge, shutdown_bridge, ui_remote_eval, is_bridge_alive from .compat.transaction import ghidra_transaction +from .compat.headless import close_program, open_program +from .compat.state import get_current_address if typing.TYPE_CHECKING: from ghidra.program.model.listing import Function as GhidraFunction, Program @@ -31,7 +29,6 @@ _l = logging.getLogger(__name__) -bridge: Optional[GhidraBridge] = None class GhidraDecompilerInterface(DecompilerInterface): @@ -67,7 +64,6 @@ def __init__( # ui-only attributes self._data_monitor = None - self._bridge = None # cachable attributes self._active_ctx = None @@ -90,35 +86,18 @@ def __init__( ) def _init_gui_components(self, *args, **kwargs): - global bridge - self._bridge = connect_to_bridge() - if self._bridge is None: - raise RuntimeError("Failed to connect to Ghidra UI bridge.") - - # used for importing elsewhere - bridge = self._bridge - globals()["binsync_ghidra_bridge"] = self._bridge - - self.flat_api = FlatAPIWrapper() # XXX: yeah, this is bad naming! if self._start_headless_watchers: self.start_artifact_watchers() + super()._init_gui_components(*args, **kwargs) def _deinit_headless_components(self): if self._program is not None and self._project is not None: - from .compat.headless import close_program close_program(self._program, self._project) self._project = None self._program = None - if self._bridge is not None: - try: - shutdown_bridge(self._bridge) - except Exception: - pass - self._bridge = None - def _init_headless_components(self, *args, **kwargs): if self._program is not None: # We were already provided a program object as part of the instantiation, so just use it @@ -131,7 +110,6 @@ def _init_headless_components(self, *args, **kwargs): if os.getenv("GHIDRA_INSTALL_DIR", None) is None: raise RuntimeError("GHIDRA_INSTALL_DIR must be set in the environment to use Ghidra headless.") - from .compat.headless import open_program flat_api, project, program = open_program( binary_path=self._binary_path, analyze=self._headless_analyze, @@ -140,12 +118,11 @@ def _init_headless_components(self, *args, **kwargs): program_name=self._program_name, language=self._language, ) - if flat_api is None: - raise RuntimeError("Failed to open program with Pyhidra") - - self.flat_api = flat_api self._program = program self._project = project + self.flat_api = flat_api + if flat_api is None: + raise RuntimeError("Failed to open program with Pyhidra") # # GUI @@ -159,7 +136,7 @@ def start_artifact_watchers(self): from .hooks import create_data_monitor if not self.artifact_watchers_started: if self.flat_api is None: - raise RuntimeError("Cannot start artifact watchers without Ghidra Bridge connection.") + raise RuntimeError("Cannot start artifact watchers without FlatProgramAPI.") self._data_monitor = create_data_monitor(self) self.currentProgram.addListener(self._data_monitor) @@ -171,41 +148,10 @@ def stop_artifact_watchers(self): # TODO: generalize superclass method? super().stop_artifact_watchers() - @property - def gui_plugin(self): - """ - A special property to never exit this function if the remote server is running. - This is used to standardize plugin access across all decompilers. - Additionally, in Ghidra, this will allow us to take requests from other threads to make things created - on the main thread! - - WARNING: If you initialized with init_plugin=True, simply autocompleting (tab) in IPython will - cause this to loop forever. - """ - if self.loop_on_plugin and self._init_plugin: - last_bridge_check = time.time() - bridge_check_delta = 30 - while True: - if not self._main_thread_queue.empty(): - func, args, kwargs = self._main_thread_queue.get() - self._results_queue.put(func(*args, **kwargs)) - - if time.time() - last_bridge_check > bridge_check_delta: - if not is_bridge_alive(self._bridge): - break - last_bridge_check = time.time() - - time.sleep(1) - return None - def gui_run_on_main_thread(self, func, *args, **kwargs): self._main_thread_queue.put((func, args, kwargs)) return self._results_queue.get() - @gui_plugin.setter - def gui_plugin(self, value): - pass - def gui_register_ctx_menu(self, name, action_string, callback_func, category=None) -> bool: from .hooks import create_context_action @@ -215,25 +161,22 @@ def callback_func_wrap(*args, **kwargs): except Exception as e: self.warning(f"Exception in ctx menu callback {name}: {e}") raise - ctx_menu_action = create_context_action(name, action_string, callback_func_wrap, category or "LibBS") - self.flat_api.getState().getTool().addAction(ctx_menu_action) + create_context_action( + name, action_string, callback_func_wrap, category=(category or "LibBS"), + tool=self.flat_api.getState().getTool() + ) return True def gui_ask_for_string(self, question, title="Plugin Question") -> str: - answer = self._bridge.remote_eval( - "askString(title, question)", title=title, question=question, timeout_override=-1 - ) + answer = self.flat_api.askString(title, question) return answer if answer else "" def gui_ask_for_choice(self, question: str, choices: list, title="Plugin Question") -> str: - answer = self._bridge.remote_eval( - "askChoice(title, question, choices, choices[0])", title=title, question=question, choices=choices, - timeout_override=-1 - ) + answer = self.flat_api.askChoice(title, question, choices, choices[0]) return answer if answer else "" def gui_active_context(self) -> Optional[Context]: - active_addr = self.flat_api.currentLocation.getAddress().getOffset() + active_addr = get_current_address(flat_api=self.flat_api) if (self._active_ctx is None) or (active_addr is not None and self._active_ctx.addr != active_addr): gfuncs = self.__fast_function(active_addr) gfunc = gfuncs[0] if gfuncs else None @@ -844,14 +787,11 @@ def _global_vars(self, match_single_offset=None, **kwargs) -> Dict[int, GlobalVa # # Specialized print handlers + # TODO: refactor the below for the new ghidra changes # def print(self, msg, print_local=True, **kwargs): - if print_local: - print(msg) - - if self._bridge: - self._bridge.remote_exec(f'print("{msg}")') + print(msg) def info(self, msg: str, **kwargs): _l.info(msg) @@ -911,7 +851,8 @@ def _to_gaddr(self, addr: int): @property def currentProgram(self): - return self.flat_api.currentProgram + from .compat.state import get_current_program + return get_current_program(self.flat_api) @ghidra_transaction def _update_local_variable_symbols(self, symbols: Dict["HighSymbol", Tuple[str, Optional["DataType"]]]) -> bool: @@ -919,6 +860,15 @@ def _update_local_variable_symbols(self, symbols: Dict["HighSymbol", Tuple[str, r is not None for r in self.__update_local_variable_symbols(symbols) ]) + def _get_struct_by_name(self, name: str) -> Optional["StructureDB"]: + """ + Returns None if the struct does not exist or is not a struct. + """ + from .compat.imports import StructureDB + + struct = self.currentProgram.getDataTypeManager().getDataType("/" + name) + return struct if isinstance(struct, StructureDB) else None + def _struct_members_from_gstruct(self, gstruct: "StructDB") -> Dict[int, StructMember]: gmemb_info = self.__gstruct_members(gstruct) members = {} @@ -1097,38 +1047,27 @@ def _get_gtype_by_bs_name(self, name: str, bs_type: type[Artifact]) -> Optional[ #self.warning(f"Failed to get type by name: {g_scoped_name}") return None - if not self.isinstance(gtype, g_type): + if not isinstance(gtype, g_type): #self.warning(f"Type {g_scoped_name} is not a {g_type.__name__}") return None return gtype - @staticmethod - def isinstance(obj, cls): - """ - A proxy self.isinstance function that can handle BridgedObjects. This is necessary because the `self.isinstance` function - in the remote namespace will not recognize BridgedObjects as instances of classes in the local namespace. - """ - return obj._bridge_isinstance(cls) if isinstance(obj, BridgedObject) else isinstance(obj, cls) - # # Internal functions that are very dangerous # - @ui_remote_eval def __fast_function(self, lowered_addr: int) -> List["GhidraFunction"]: return [ self.currentProgram.getFunctionManager().getFunctionContaining(self.flat_api.toAddr(hex(lowered_addr))) ] - @ui_remote_eval def __functions(self) -> List[Tuple[int, str, int]]: return [ (int(func.getEntryPoint().getOffset()), str(func.getName()), int(func.getBody().getNumAddresses())) for func in self.currentProgram.getFunctionManager().getFunctions(True) ] - @ui_remote_eval def __update_local_variable_symbols(self, symbols: Dict["HighSymbol", Tuple[str, Optional["DataType"]]]) -> List: from .compat.imports import HighFunctionDBUtil, SourceType @@ -1137,25 +1076,24 @@ def __update_local_variable_symbols(self, symbols: Dict["HighSymbol", Tuple[str, for sym, updates in symbols.items() ] - @ui_remote_eval def _get_local_variable_symbols(self, func: Function) -> List[Tuple[str, "HighSymbol"]]: return [ (sym.name, sym) for sym in func.dec_obj.getHighFunction().getLocalSymbolMap().getSymbols() if sym.name ] - @ui_remote_eval + def __get_decless_gstack_vars(self, func: "GhidraFunction") -> List["LocalVariableDB"]: return [var for var in func.getAllVariables() if var.isStackVariable()] - @ui_remote_eval + def __get_gstack_vars(self, high_func: "HighFunction") -> List["LocalVariableDB"]: return [ var for var in high_func.getLocalSymbolMap().getSymbols() if var.storage and var.storage.isStackStorage() ] - @ui_remote_eval + def __enum_names(self) -> List[Tuple[str, "EnumDB"]]: from .compat.imports import EnumDB @@ -1165,7 +1103,7 @@ def __enum_names(self) -> List[Tuple[str, "EnumDB"]]: if isinstance(dType, EnumDB) ] - @ui_remote_eval + def __stack_variables(self, decompilation) -> List[Tuple[int, str, str, int]]: return [ (int(sym.getStorage().getStackOffset()), str(sym.getName()), sym.getDataType().getPathName(), int(sym.getSize())) @@ -1173,32 +1111,32 @@ def __stack_variables(self, decompilation) -> List[Tuple[int, str, str, int]]: if sym.getStorage().isStackStorage() ] - @ui_remote_eval + def __set_sym_names(self, sym_pairs, source_type): return [ sym.setName(new_name, source_type) for sym, new_name in sym_pairs ] - @ui_remote_eval + def __set_sym_types(self, sym_pairs, source_type): return [ sym.setDataType(new_type, False, True, source_type) for sym, new_type in sym_pairs ] - @ui_remote_eval + def __gstruct_members(self, gstruct: "StructureDB") -> List[Tuple[int, str, str, int]]: return [ (int(m.getOffset()), str(m.getFieldName()), str(m.getDataType().getPathName()), int(m.getLength())) for m in gstruct.getComponents() ] - @ui_remote_eval + def __get_enum_members(self, g_enum: "EnumDB") -> List[Tuple[str, int]]: return [ (name, g_enum.getValue(name)) for name in g_enum.getNames() ] - @ui_remote_eval + def __g_global_variables(self): # TODO: this could be optimized more both in use and in implementation # TODO: this just does not work for bigger than 50k syms @@ -1212,14 +1150,14 @@ def __g_global_variables(self): not self.currentProgram.getListing().getDataAt(sym.getAddress()).isStructure() ] - @ui_remote_eval + def __gstructs(self): return [ (struct.getPathName(), struct) for struct in self.currentProgram.getDataTypeManager().getAllStructures() ] - @ui_remote_eval + def __gtypedefs(self): from .compat.imports import TypedefDB @@ -1229,7 +1167,7 @@ def __gtypedefs(self): if isinstance(typedef, TypedefDB) ] - @ui_remote_eval + def __function_code_units(self): """ Returns a list of code units for each function in the program. diff --git a/libbs/decompilers/ghidra/testing.py b/libbs/decompilers/ghidra/testing.py deleted file mode 100644 index 4a3b914c..00000000 --- a/libbs/decompilers/ghidra/testing.py +++ /dev/null @@ -1,59 +0,0 @@ -import os -import subprocess -import tempfile -import time -from pathlib import Path - -from libbs.plugin_installer import PluginInstaller - - -class HeadlessGhidraDecompiler: - def __init__( - self, - binary_path: Path, - headless_dec_path: Path = None, - headless_script_path: Path = None, - ): - self._binary_path = Path(binary_path) - if not self._binary_path.exists(): - raise FileNotFoundError(f"Failed to find binary at {self._binary_path}") - - if headless_dec_path is None: - env_val = os.getenv("GHIDRA_HEADLESS_PATH", None) - if env_val is None: - raise ValueError("Must provide headless_dec_path or set GHIDRA_HEADLESS_PATH") - - headless_dec_path = Path(env_val) - if not headless_dec_path.exists(): - raise FileNotFoundError(f"Failed to find ghidra headless at {headless_dec_path}") - self._headless_dec_path = headless_dec_path - - self._headless_script_path = headless_script_path or PluginInstaller.find_pkg_files("libbs") / "decompiler_stubs" / "ghidra_libbs" / "ghidra_libbs_mainthread_server.py" - if not self._headless_script_path.exists(): - raise FileNotFoundError(f"Failed to find headless script at {self._headless_script_path}") - - self._proc = None - - def __enter__(self): - self._headless_g_project = tempfile.TemporaryDirectory() - self._proc = subprocess.Popen([ - str(self._headless_dec_path), - self._headless_g_project.name, - "headless", - "-import", - str(self._binary_path), - "-scriptPath", - str(self._headless_script_path.parent), - "-postScript", - str(self._headless_script_path.name), - ]) - time.sleep(1) - - def __exit__(self, exc_type, exc_val, exc_tb): - time.sleep(2) - # Wait until headless binary gets shutdown - try: - self._proc.kill() - except Exception: - pass - self._headless_g_project.cleanup() diff --git a/libbs/ui/qt_objects.py b/libbs/ui/qt_objects.py index a90d3b55..e158153d 100644 --- a/libbs/ui/qt_objects.py +++ b/libbs/ui/qt_objects.py @@ -4,7 +4,7 @@ from PySide6.QtCore import ( QDir, Qt, Signal, QAbstractTableModel, QModelIndex, QSortFilterProxyModel, QPersistentModelIndex, QEvent, QThread, Slot, QObject, QPropertyAnimation, QAbstractAnimation, QParallelAnimationGroup, - QLineF, QRect + QLineF, QTimer, QRect, ) from PySide6.QtWidgets import ( QAbstractItemView, @@ -66,7 +66,7 @@ from PyQt5.QtCore import ( QDir, Qt, QAbstractTableModel, QModelIndex, QSortFilterProxyModel, QPersistentModelIndex, QEvent, QThread, QObject, QPropertyAnimation, QAbstractAnimation, QParallelAnimationGroup, - QLineF, QRect + QLineF, QTimer, QRect, ) from PyQt5.QtCore import pyqtSignal as Signal from PyQt5.QtCore import pyqtSlot as Slot diff --git a/pyproject.toml b/pyproject.toml index d1c38c39..20668081 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,10 +21,8 @@ dependencies = [ "setuptools", "prompt_toolkit", "tqdm", - "jfx_bridge", - "ghidra_bridge", "psutil", - "pyhidra", + "pyghidra", "platformdirs", "filelock", "networkx" diff --git a/tests/test_client_server.py b/tests/test_client_server.py new file mode 100644 index 00000000..97f4d13e --- /dev/null +++ b/tests/test_client_server.py @@ -0,0 +1,656 @@ +import os +import tempfile +import threading +import time +import unittest +from pathlib import Path + +from libbs.api.decompiler_server import DecompilerServer +from libbs.api.decompiler_client import DecompilerClient +from libbs.api.decompiler_interface import DecompilerInterface +from libbs.decompilers import GHIDRA_DECOMPILER + +# Test binary path - use the same path as other tests +TEST_BINARIES_DIR = Path(os.getenv("TEST_BINARIES_DIR", Path(__file__).parent.parent.parent / "bs-artifacts" / "binaries")) +if not TEST_BINARIES_DIR.exists(): + # fallback to relative path + TEST_BINARIES_DIR = Path(__file__).parent.parent.parent / "bs-artifacts" / "binaries" + +FAUXWARE_PATH = TEST_BINARIES_DIR / "fauxware" + +class TestClientServer(unittest.TestCase): + """Test the new AF_UNIX socket-based DecompilerClient and DecompilerServer""" + + def setUp(self): + """Set up test environment""" + self.server = None + self.client = None + self.temp_dir = None + + def tearDown(self): + """Clean up test environment""" + if self.client: + self.client.shutdown() + if self.server: + self.server.stop() + if self.temp_dir and os.path.exists(self.temp_dir): + try: + os.rmdir(self.temp_dir) + except: + pass + + def test_server_startup_and_client_connection(self): + """Test that server starts and client can connect""" + # Start server with Ghidra headless and fauxware binary + with tempfile.TemporaryDirectory() as proj_dir: + self.server = DecompilerServer( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=FAUXWARE_PATH, + project_location=proj_dir, + project_name="test_fauxware" + ) + self.server.start() + + # Give server time to start + time.sleep(1) + + # Connect client + self.client = DecompilerClient(socket_path=self.server.socket_path) + + # Verify connection works + self.assertTrue(self.client.is_connected()) + self.assertTrue(self.client.ping()) + + # Test basic properties + self.assertEqual(self.client.name, "ghidra") + self.assertIsNotNone(self.client.binary_path) + self.assertIsNotNone(self.client.binary_hash) + self.assertTrue(self.client.decompiler_available) + + def test_artifact_collections_match_local(self): + """Test that client artifact collections behave like local interface""" + with tempfile.TemporaryDirectory() as proj_dir: + # Create server + self.server = DecompilerServer( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=FAUXWARE_PATH, + project_location=proj_dir, + project_name="test_fauxware_remote" + ) + self.server.start() + time.sleep(1) + + # Connect client + self.client = DecompilerClient(socket_path=self.server.socket_path) + + # Test that we get functions + remote_func_keys = list(self.client.functions.keys()) + self.assertGreater(len(remote_func_keys), 0, "Should have found functions") + + # Test that we can get light functions + remote_light_funcs = list(self.client.functions.items()) + self.assertGreater(len(remote_light_funcs), 0, "Should have light functions") + + # Verify functions are actual Function objects + if remote_light_funcs: + addr, func = remote_light_funcs[0] + self.assertIsNotNone(func, "Function should not be None") + self.assertEqual(func.addr, addr, "Function address should match key") + self.assertIsInstance(func.name, str, "Function should have a name") + + def test_client_server_method_calls(self): + """Test that client method calls work correctly""" + with tempfile.TemporaryDirectory() as proj_dir: + self.server = DecompilerServer( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=FAUXWARE_PATH, + project_location=proj_dir, + project_name="test_fauxware_methods" + ) + self.server.start() + time.sleep(1) + + self.client = DecompilerClient(socket_path=self.server.socket_path) + + # Test function size method + func_keys = list(self.client.functions.keys()) + self.assertGreater(len(func_keys), 0, "Should have functions") + + func_addr = func_keys[0] + func_size = self.client.get_func_size(func_addr) + self.assertGreater(func_size, 0, "Function size should be positive") + + # Test fast_get_function + fast_func = self.client.fast_get_function(func_addr) + self.assertIsNotNone(fast_func, "Fast function should not be None") + self.assertEqual(fast_func.addr, func_addr, "Fast function address should match") + + def test_client_discover_auto_detection(self): + """Test client auto-discovery functionality""" + with tempfile.TemporaryDirectory() as proj_dir: + self.server = DecompilerServer( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=FAUXWARE_PATH, + project_location=proj_dir, + project_name="test_fauxware_autodiscovery" + ) + self.server.start() + time.sleep(1) + + # Test auto-discovery (should find the server we just started) + try: + self.client = DecompilerClient.discover() + self.assertTrue(self.client.is_connected()) + self.assertEqual(self.client.name, "ghidra") + except ConnectionError: + # Auto-discovery might fail if multiple temp directories exist + # This is acceptable, we can still test manual connection + self.client = DecompilerClient(socket_path=self.server.socket_path) + self.assertTrue(self.client.is_connected()) + + def test_error_handling(self): + """Test error handling in client-server communication""" + with tempfile.TemporaryDirectory() as proj_dir: + self.server = DecompilerServer( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=FAUXWARE_PATH, + project_location=proj_dir, + project_name="test_fauxware_errors" + ) + self.server.start() + time.sleep(1) + + self.client = DecompilerClient(socket_path=self.server.socket_path) + + # Test KeyError handling for non-existent function + with self.assertRaises(KeyError, msg="Should raise KeyError for non-existent function"): + self.client.functions[0xDEADBEEF] # Non-existent function + + def test_client_context_manager(self): + """Test client context manager functionality""" + with tempfile.TemporaryDirectory() as proj_dir: + self.server = DecompilerServer( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=FAUXWARE_PATH, + project_location=proj_dir, + project_name="test_fauxware_context" + ) + self.server.start() + time.sleep(1) + + # Test context manager + with DecompilerClient(socket_path=self.server.socket_path) as client: + self.assertTrue(client.is_connected()) + self.assertEqual(client.name, "ghidra") + + # Client should be disconnected after context manager + # (Note: we can't test this easily since the client object is out of scope) + + def test_server_restart_discovery(self): + """Test that client can discover server after restart""" + with tempfile.TemporaryDirectory() as proj_dir: + # Start first server + self.server = DecompilerServer( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=FAUXWARE_PATH, + project_location=proj_dir, + project_name="test_fauxware_restart" + ) + self.server.start() + time.sleep(1) + + # Get the binary hash from the server + self.client = DecompilerClient(socket_path=self.server.socket_path) + binary_hash = self.client.binary_hash + self.assertIsNotNone(binary_hash, "Binary hash should not be None") + socket_path_1 = self.server.socket_path + self.client.shutdown() + + # Stop the server + self.server.stop() + time.sleep(0.5) + + # Start a new server (will have different socket path) + self.server = DecompilerServer( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=FAUXWARE_PATH, + project_location=proj_dir, + project_name="test_fauxware_restart2" + ) + self.server.start() + time.sleep(1) + socket_path_2 = self.server.socket_path + + # Socket paths should be different (different temp directories) + self.assertNotEqual(socket_path_1, socket_path_2, + "New server should have different socket path") + + # Client should discover the new server using binary_hash + self.client = DecompilerClient.discover(binary_hash=binary_hash) + self.assertTrue(self.client.is_connected()) + self.assertEqual(self.client.binary_hash, binary_hash) + self.assertEqual(self.client.socket_path, socket_path_2, + "Client should connect to new server, not old socket") + + def test_multiple_servers_binary_hash_matching(self): + """Test client can select correct server when multiple are running""" + # We'll use different binaries to get different hashes + # For this test, we'll create two servers with the same binary + # but simulate different binary_hash by using different project names + + with tempfile.TemporaryDirectory() as proj_dir1: + with tempfile.TemporaryDirectory() as proj_dir2: + # Start first server + server1 = DecompilerServer( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=FAUXWARE_PATH, + project_location=proj_dir1, + project_name="test_server1" + ) + server1.start() + time.sleep(1) + + # Get hash from first server + client1 = DecompilerClient(socket_path=server1.socket_path) + hash1 = client1.binary_hash + socket1 = server1.socket_path + client1.shutdown() + + # Start second server with same binary (will have same hash) + server2 = DecompilerServer( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=FAUXWARE_PATH, + project_location=proj_dir2, + project_name="test_server2" + ) + server2.start() + time.sleep(1) + + socket2 = server2.socket_path + self.assertNotEqual(socket1, socket2, "Servers should have different sockets") + + try: + # Discover with binary hash - should connect to one of the servers + # (since they have the same binary, they'll have the same hash) + discovered_client = DecompilerClient.discover(binary_hash=hash1) + self.assertTrue(discovered_client.is_connected()) + self.assertEqual(discovered_client.binary_hash, hash1) + + # Should connect to one of the two servers + self.assertIn(discovered_client.socket_path, [socket1, socket2], + "Should connect to one of the running servers") + discovered_client.shutdown() + + # Discover without binary hash - should connect to most recent + discovered_client2 = DecompilerClient.discover() + self.assertTrue(discovered_client2.is_connected()) + discovered_client2.shutdown() + + finally: + server1.stop() + server2.stop() + + def test_defunct_socket_handling(self): + """Test that client skips defunct socket files from stopped servers""" + with tempfile.TemporaryDirectory() as proj_dir: + # Start and stop a server to create a defunct socket + server1 = DecompilerServer( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=FAUXWARE_PATH, + project_location=proj_dir, + project_name="test_defunct" + ) + server1.start() + time.sleep(1) + defunct_socket = server1.socket_path + server1.stop() + time.sleep(0.5) + + # Manually recreate the socket file to simulate a stale socket + # (normally stop() removes it, but crashes might leave it) + import tempfile as tf + temp_dir = tf.mkdtemp(prefix="libbs_server_") + defunct_socket = os.path.join(temp_dir, "decompiler.sock") + # Create an empty file to simulate stale socket + open(defunct_socket, 'w').close() + + # Start a new server + self.server = DecompilerServer( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=FAUXWARE_PATH, + project_location=proj_dir, + project_name="test_working" + ) + self.server.start() + time.sleep(1) + + # Discovery should skip the defunct socket and find the working server + self.client = DecompilerClient.discover() + self.assertTrue(self.client.is_connected()) + self.assertEqual(self.client.socket_path, self.server.socket_path, + "Should connect to working server, not defunct socket") + + # Clean up the fake defunct socket + try: + os.unlink(defunct_socket) + os.rmdir(temp_dir) + except: + pass + + def test_discover_with_binary_hash_no_match(self): + """Test that discovery fails when binary_hash doesn't match any server""" + with tempfile.TemporaryDirectory() as proj_dir: + self.server = DecompilerServer( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=FAUXWARE_PATH, + project_location=proj_dir, + project_name="test_no_match" + ) + self.server.start() + time.sleep(1) + + # Try to discover with a non-matching binary hash + fake_hash = "this_hash_does_not_exist_12345" + with self.assertRaises(ConnectionError) as context: + DecompilerClient.discover(binary_hash=fake_hash) + + # Error message should mention the hash + self.assertIn(fake_hash, str(context.exception)) + self.assertIn("none matched", str(context.exception).lower()) + + def test_server_info_includes_binary_hash(self): + """Test that server_info response includes binary_hash""" + with tempfile.TemporaryDirectory() as proj_dir: + self.server = DecompilerServer( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=FAUXWARE_PATH, + project_location=proj_dir, + project_name="test_server_info" + ) + self.server.start() + time.sleep(1) + + self.client = DecompilerClient(socket_path=self.server.socket_path) + + # Server info is fetched during connection and stored + server_info = self.client._server_info + self.assertIsNotNone(server_info, "Server info should be available") + self.assertIn("binary_hash", server_info, "Server info should include binary_hash") + + # Verify binary_hash matches what we get from the property + self.assertEqual(server_info["binary_hash"], self.client.binary_hash, + "Server info binary_hash should match client property") + + def test_callback_events(self): + """Test that client receives callback events when artifacts change on server""" + with tempfile.TemporaryDirectory() as proj_dir: + self.server = DecompilerServer( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=FAUXWARE_PATH, + project_location=proj_dir, + project_name="test_callbacks" + ) + self.server.start() + time.sleep(1) + + self.client = DecompilerClient(socket_path=self.server.socket_path) + + # Track callback invocations + callback_events = [] + + def test_callback(artifact, **kwargs): + callback_events.append({ + "artifact_type": type(artifact).__name__, + "artifact": artifact, + "kwargs": kwargs + }) + + # Register callback for Comment artifacts + from libbs.artifacts import Comment + self.client.artifact_change_callbacks[Comment].append(test_callback) + + # Start artifact watchers (which starts event listener) + self.client.start_artifact_watchers() + time.sleep(0.5) # Give listener time to start + + # Verify event listener is running + self.assertTrue(self.client._event_listener_running, + "Event listener should be running") + self.assertTrue(self.client._subscribed_to_events, + "Client should be subscribed to events") + + # Trigger a callback on the server by creating a comment + # TODO: update this to just sent a comment so we can see the callback trigger naturally + test_comment = Comment(0x1234, "Test comment from callback test") + # Note: comment_changed will lift the artifact, which changes the address + lifted_comment = self.server.deci.comment_changed(test_comment) + + # Wait for event to be received and processed + time.sleep(0.5) + + # Verify callback was triggered + self.assertGreater(len(callback_events), 0, + "Callback should have been triggered") + + # Verify event contents + event = callback_events[0] + self.assertEqual(event["artifact_type"], "Comment", + "Event should be for Comment artifact") + # The address should match the lifted address, not the original + self.assertEqual(event["artifact"].addr, lifted_comment.addr, + "Comment address should match the lifted address") + self.assertIn("Test comment", event["artifact"].comment, + "Comment text should match") + + # Clean up + self.client.stop_artifact_watchers() + self.assertFalse(self.client._event_listener_running, + "Event listener should be stopped") + + def test_multiple_callbacks(self): + """Test that multiple callbacks can be registered and all are triggered""" + with tempfile.TemporaryDirectory() as proj_dir: + self.server = DecompilerServer( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=FAUXWARE_PATH, + project_location=proj_dir, + project_name="test_multiple_callbacks" + ) + self.server.start() + time.sleep(1) + + self.client = DecompilerClient(socket_path=self.server.socket_path) + + # Track callbacks + callback1_called = [] + callback2_called = [] + + def callback1(artifact, **kwargs): + callback1_called.append(artifact) + + def callback2(artifact, **kwargs): + callback2_called.append(artifact) + + # Register multiple callbacks + from libbs.artifacts import Struct + self.client.artifact_change_callbacks[Struct].append(callback1) + self.client.artifact_change_callbacks[Struct].append(callback2) + + # Start watchers + self.client.start_artifact_watchers() + time.sleep(0.5) + + # Trigger event + test_struct = Struct("TestStruct", 0x10, members={}) + self.server.deci.struct_changed(test_struct) + + # Wait for processing + time.sleep(0.5) + + # Both callbacks should have been called + self.assertEqual(len(callback1_called), 1, "Callback 1 should be called once") + self.assertEqual(len(callback2_called), 1, "Callback 2 should be called once") + self.assertEqual(callback1_called[0].name, "TestStruct") + self.assertEqual(callback2_called[0].name, "TestStruct") + + def test_callback_with_metadata(self): + """Test that callback metadata (like deleted flag) is passed correctly""" + with tempfile.TemporaryDirectory() as proj_dir: + self.server = DecompilerServer( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=FAUXWARE_PATH, + project_location=proj_dir, + project_name="test_callback_metadata" + ) + self.server.start() + time.sleep(1) + + self.client = DecompilerClient(socket_path=self.server.socket_path) + + # Track metadata + received_metadata = [] + + def metadata_callback(artifact, **kwargs): + received_metadata.append(kwargs) + + # Register callback + from libbs.artifacts import Enum + self.client.artifact_change_callbacks[Enum].append(metadata_callback) + + # Start watchers + self.client.start_artifact_watchers() + time.sleep(0.5) + + # Trigger event with metadata + test_enum = Enum("TestEnum", members={}) + self.server.deci.enum_changed(test_enum, deleted=True) + + # Wait for processing + time.sleep(0.5) + + # Verify metadata was passed + self.assertEqual(len(received_metadata), 1, "Callback should be called once") + self.assertIn("deleted", received_metadata[0], "Metadata should include deleted flag") + self.assertTrue(received_metadata[0]["deleted"], "deleted flag should be True") + + def test_artifact_watchers_integration(self): + """ + Test artifact callbacks with client-server architecture (adapted from test_remote_ghidra). + + Note: This test manually triggers callbacks on the server to test the event broadcast system, + since Ghidra's artifact watchers don't function in headless mode. + """ + from libbs.artifacts import FunctionHeader, StackVariable, Struct, GlobalVariable, Enum, Comment + from collections import defaultdict + + with tempfile.TemporaryDirectory() as proj_dir: + # Start server + self.server = DecompilerServer( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=FAUXWARE_PATH, + project_location=proj_dir, + project_name="test_artifact_watchers" + ) + self.server.start() + time.sleep(1) + + # Connect client + self.client = DecompilerClient(socket_path=self.server.socket_path) + + # Track callback hits + hits = defaultdict(list) + def func_hit(artifact, **kwargs): + hits[artifact.__class__].append(artifact) + + # Register callbacks for different artifact types + for typ in (FunctionHeader, StackVariable, Enum, Struct, GlobalVariable, Comment): + self.client.artifact_change_callbacks[typ].append(func_hit) + + # Start event listener + self.client.start_artifact_watchers() + time.sleep(0.5) + + # Test FunctionHeader callback by manually triggering on server + # (Ghidra headless watchers don't work, so we manually trigger) + func_addr = self.client.art_lifter.lift_addr(0x400664) + main = self.client.functions[func_addr] + + # Trigger callback on server side directly + test_header = FunctionHeader("test_func", func_addr, type_="int") + self.server.deci.function_header_changed(test_header) + time.sleep(0.5) + + # Verify callback was received on client + self.assertGreaterEqual(len(hits[FunctionHeader]), 1, + "FunctionHeader callback should be triggered") + + # Test Comment callback + test_comment = Comment(func_addr, "Test comment for integration test") + self.server.deci.comment_changed(test_comment) + time.sleep(0.5) + + self.assertGreaterEqual(len(hits[Comment]), 1, + "Comment callback should be triggered") + + # Test Struct callback + test_struct = Struct("TestStruct", 0x10, members={}) + self.server.deci.struct_changed(test_struct) + time.sleep(0.5) + + self.assertGreaterEqual(len(hits[Struct]), 1, + "Struct callback should be triggered") + + # Test Enum callback + test_enum = Enum("TestEnum", members={"VALUE1": 1, "VALUE2": 2}) + self.server.deci.enum_changed(test_enum) + time.sleep(0.5) + + self.assertGreaterEqual(len(hits[Enum]), 1, + "Enum callback should be triggered") + + # Test GlobalVariable callback + g_addr = self.client.art_lifter.lift_addr(0x4008e0) + test_gvar = GlobalVariable(g_addr, "test_global", "int", 4) + self.server.deci.global_variable_changed(test_gvar) + time.sleep(0.5) + + self.assertGreaterEqual(len(hits[GlobalVariable]), 1, + "GlobalVariable callback should be triggered") + + # Test that client can also modify artifacts through the server + # and they persist correctly + main.name = "modified_main" + self.client.functions[func_addr] = main + time.sleep(0.5) + + # Retrieve and verify the change persisted + modified_main = self.client.functions[func_addr] + self.assertEqual(modified_main.name, "modified_main", + "Function name modification should persist") + + # Clean up + self.client.stop_artifact_watchers() + self.assertFalse(self.client._event_listener_running, + "Event listener should be stopped") + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/test_decompilers.py b/tests/test_decompilers.py index 3ef9cb63..42c4e611 100644 --- a/tests/test_decompilers.py +++ b/tests/test_decompilers.py @@ -12,7 +12,6 @@ from libbs.artifacts import FunctionHeader, StackVariable, Struct, GlobalVariable, Enum, Comment, ArtifactFormat, \ Decompilation, Function, StructMember, Typedef, Segment from libbs.decompilers import IDA_DECOMPILER, ANGR_DECOMPILER, BINJA_DECOMPILER, GHIDRA_DECOMPILER -from libbs.decompilers.ghidra.testing import HeadlessGhidraDecompiler GHIDRA_HEADLESS_PATH = Path(os.environ.get('GHIDRA_INSTALL_DIR', "")) / "support" / "analyzeHeadless" IDA_HEADLESS_PATH = Path(os.environ.get('IDA_HEADLESS_PATH', "")) diff --git a/tests/test_remote_ghidra.py b/tests/test_remote_ghidra.py deleted file mode 100644 index 2338527c..00000000 --- a/tests/test_remote_ghidra.py +++ /dev/null @@ -1,137 +0,0 @@ -import logging -import time -import unittest -from pathlib import Path -from collections import defaultdict -import os - -from libbs.api import DecompilerInterface -from libbs.artifacts import FunctionHeader, StackVariable, Struct, GlobalVariable, Enum, Comment -from libbs.decompilers import GHIDRA_DECOMPILER -from libbs.decompilers.ghidra.testing import HeadlessGhidraDecompiler -from libbs.decompilers.ghidra.compat.transaction import Transaction -from libbs.decompilers.ghidra.interface import GhidraDecompilerInterface - -GHIDRA_HEADLESS_PATH = Path(os.environ.get('GHIDRA_INSTALL_DIR', "")) / "support" / "analyzeHeadless" -if os.getenv("TEST_BINARIES_DIR"): - TEST_BINARIES_DIR = Path(os.getenv("TEST_BINARIES_DIR")) -else: - # default assumes its a git repo that is above this one - TEST_BINARIES_DIR = Path(__file__).parent.parent.parent / "bs-artifacts" / "binaries" - -assert TEST_BINARIES_DIR.exists(), f"Test binaries dir {TEST_BINARIES_DIR} does not exist" -_l = logging.getLogger(__name__) - - -class TestRemoteGhidra(unittest.TestCase): - FAUXWARE_PATH = TEST_BINARIES_DIR / "fauxware" - - def setUp(self): - self.deci = None - - def tearDown(self): - if self.deci is not None: - self.deci.shutdown() - - def test_ghidra_artifact_watchers(self): - with HeadlessGhidraDecompiler(self.FAUXWARE_PATH, headless_dec_path=GHIDRA_HEADLESS_PATH): - deci: GhidraDecompilerInterface = DecompilerInterface.discover( - force_decompiler=GHIDRA_DECOMPILER, - binary_path=self.FAUXWARE_PATH, - start_headless_watchers=True - ) - self.deci = deci - - # - # Test Artifact Watchers - # - - hits = defaultdict(list) - def func_hit(*args, **kwargs): hits[args[0].__class__].append(args[0]) - - deci.artifact_change_callbacks = { - typ: [func_hit] for typ in (FunctionHeader, StackVariable, Enum, Struct, GlobalVariable, Comment) - } - - # Exact number of hits is not consistent, so we instead check for the minimum increment expected - old_header_hits = len(hits[FunctionHeader]) - - # function names - func_addr = deci.art_lifter.lift_addr(0x400664) - main = deci.functions[func_addr] - main.name = "changed" - deci.functions[func_addr] = main - - main.name = "main" - deci.functions[func_addr] = main - - assert len(hits[FunctionHeader]) >= old_header_hits + 2 - old_header_hits = len(hits[FunctionHeader]) - - # function return type - main.header.type = 'long' - deci.functions[func_addr] = main - time.sleep(5) - - main.header.type = 'double' - deci.functions[func_addr] = main - time.sleep(5) - - # confirm the final type is correct - new_main = deci.functions[func_addr] - assert new_main.header.type == main.header.type - - assert len(hits[FunctionHeader]) >= old_header_hits + 2 - - # global var names - old_global_hits = len(hits[GlobalVariable]) - g1_addr = deci.art_lifter.lift_addr(0x4008e0) - g2_addr = deci.art_lifter.lift_addr(0x601048) - g1 = deci.global_vars[g1_addr] - g2 = deci.global_vars[g2_addr] - g1.name = "gvar1" - g2.name = "gvar2" - deci.global_vars[g1_addr] = g1 - deci.global_vars[g2_addr] = g2 - # TODO: re-enable this once we have a better way to track global variable changes - #assert len(hits[GlobalVariable]) == old_global_hits + 2 - - main.stack_vars[-24].name = "named_char_array" - main.stack_vars[-12].name = "named_int" - deci.functions[func_addr] = main - # TODO: fixme: stack variable changes are not being tracked - # first_changed_sv = hits[StackVariable][0] - # assert first_changed_sv.name == main.stack_vars[-24].name - # assert len(hits[StackVariable]) == 2 - - # struct = deci.structs['eh_frame_hdr'] - # struct.name = "my_struct_name" - # deci.structs['eh_frame_hdr'] = struct - - # TODO: add argument naming - # func_args = main.header.args - # func_args[0].name = "changed_name" - # func_args[1].name = "changed_name2" - # deci.functions[func_addr] = main - - # assert hits[Struct] == 2 # One change results in 2 hits because the struct is first removed and then added again. - - # - # Test Image Base Watcher - # - - original_base_addr = deci.binary_base_addr - new_base_addr = 0x1000000 - # NOTE: if this code is continuously flaky, we can remove it - with Transaction(deci.flat_api, msg="BS::test_ghidra_artifact_watchers"): - deci.flat_api.currentProgram.setImageBase(deci.flat_api.toAddr(new_base_addr), True) - - time.sleep(0.5) - assert deci.binary_base_addr != original_base_addr - assert deci.binary_base_addr == new_base_addr - - deci.shutdown() - - -if __name__ == "__main__": - unittest.main()