diff --git a/image_processing/Dockerfile b/image_processing/Dockerfile new file mode 100644 index 0000000..9ea160a --- /dev/null +++ b/image_processing/Dockerfile @@ -0,0 +1,14 @@ +FROM anyscale/ray:2.51.1-slim-py312-cu128 + +# C compiler for Triton’s runtime build step (vLLM V1 engine) +# https://github.com/vllm-project/vllm/issues/2997 +RUN sudo apt-get update && \ + sudo apt-get install -y --no-install-recommends build-essential + +RUN curl -LsSf https://astral.sh/uv/install.sh | sh + +RUN uv pip install --system huggingface_hub boto3 + +RUN uv pip install --system vllm==0.11.0 + +RUN uv pip install --system transformers==4.57.1 diff --git a/image_processing/README.md b/image_processing/README.md new file mode 100644 index 0000000..0de2e7a --- /dev/null +++ b/image_processing/README.md @@ -0,0 +1,55 @@ +# Large-Scale Image Processing with Vision Language Models + +This example demonstrates how to build a production-ready image processing pipeline that scales to billions of images using Ray Data and vLLM on Anyscale. We process the [ReLAION-2B dataset](https://huggingface.co/datasets/laion/relaion2B-en-research-safe), which contains over 2 billion image URLs with associated metadata. + +## What This Pipeline Does + +The pipeline performs three main stages on each image: + +1. **Parallel Image Download**: Asynchronously downloads images from URLs using aiohttp with 1,000 concurrent connections, handling timeouts and validation gracefully. + +2. **Image Preprocessing**: Validates, resizes, and standardizes images to 128×128 JPEG format in RGB color space using PIL, filtering out corrupted or invalid images. + +3. **Vision Model Inference**: Runs the Qwen2.5-VL-3B-Instruct vision-language model using vLLM to generate captions or analyze image content, scaling across up to 64 L4 GPU replicas based on workload. + +The entire pipeline is orchestrated by Ray Data, which handles distributed execution, fault tolerance, and resource management across your cluster. + +## Key Features + +- **Massive Scale**: Processes 2B+ images efficiently with automatic resource scaling +- **High Throughput**: Concurrent downloads (1,000 connections) and batched inference (8 images per batch, 16 concurrent batches per GPU) +- **Fault Tolerant**: Gracefully handles network failures, invalid images, and transient errors +- **Cost Optimized**: Automatic GPU autoscaling (up to 64 L4 replicas) based on workload demand +- **Production Ready**: Timestamped outputs, configurable memory limits, and structured error handling + +## How to Run + +First, make sure you have the [Anyscale CLI](https://docs.anyscale.com/get-started/install-anyscale-cli) installed. + +You'll need a HuggingFace token to access the ReLAION-2B dataset. Get one at [huggingface.co/settings/tokens](https://huggingface.co/settings/tokens). + +Submit the job: + +```bash +anyscale job submit -f job.yaml --env HF_TOKEN=$HF_TOKEN +``` + +Or use the convenience script: + +```bash +./run.sh +``` + +Results will be written to `/mnt/shared_storage/process_images_output/{timestamp}/` in Parquet format. + +## Configuration + +The pipeline is configured for high-throughput processing: + +- **Compute**: Up to 530 CPUs and 64 L4 GPUs (g6.xlarge workers) with auto-scaling +- **Vision Model**: Qwen2.5-VL-3B-Instruct on NVIDIA L4 GPUs with vLLM +- **Download**: 1,000 concurrent connections, 5-second timeout per image +- **Batch Processing**: 50 images per download batch, 8 images per inference batch +- **Output**: 100,000 rows per Parquet file for efficient storage + +You can adjust these settings in `process_images.py` and `job.yaml` to match your requirements. \ No newline at end of file diff --git a/image_processing/job.yaml b/image_processing/job.yaml new file mode 100644 index 0000000..9524d17 --- /dev/null +++ b/image_processing/job.yaml @@ -0,0 +1,42 @@ +# View the docs https://docs.anyscale.com/reference/job-api#jobconfig. +name: process-images + +# When empty, use the default image. This can be an Anyscale-provided base image +# like anyscale/ray:2.43.0-slim-py312-cu125, a user-provided base image (provided +# that it meets certain specs), or you can build new images using the Anyscale +# image builder at https://console.anyscale-staging.com/v2/container-images. +# image_uri: # anyscale/ray:2.43.0-slim-py312-cu125 +containerfile: ./Dockerfile + +# When empty, Anyscale will auto-select the instance types. You can also specify +# minimum and maximum resources. +compute_config: + # Pin worker nodes to g6e.12xlarge so the vision workload lands on L40S GPUs. + worker_nodes: + - instance_type: g5.12xlarge + min_nodes: 0 + max_nodes: 16 + max_resources: + CPU: 768 + GPU: 64 + +# Path to a local directory or a remote URI to a .zip file (S3, GS, HTTP) that +# will be the working directory for the job. The files in the directory will be +# automatically uploaded to the job environment in Anyscale. +working_dir: . + +# When empty, this uses the default Anyscale Cloud in your organization. +cloud: + +env_vars: + RAY_DEFAULT_OBJECT_STORE_MEMORY_PROPORTION: "0.5" + +# The script to run in your job. You can also do "uv run main.py" if you have a +# pyproject.toml file in your working_dir. +entrypoint: python process_images.py + +# If there is an error, do not retry. +max_retries: 0 + +# Kill the job after 2 hours to control costs. +timeout_s: 7200 \ No newline at end of file diff --git a/image_processing/process_images.py b/image_processing/process_images.py new file mode 100644 index 0000000..4a87f37 --- /dev/null +++ b/image_processing/process_images.py @@ -0,0 +1,299 @@ +import os +import ray +from huggingface_hub import HfFileSystem +from ray.data.llm import vLLMEngineProcessorConfig, build_llm_processor +from PIL import Image +from io import BytesIO +import requests +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime, timezone +import time +import logging +from typing import Optional, Dict, Any, List + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +# ============================================================================ +# SCALABILITY CONFIGURATION FOR 2B+ IMAGES +# ============================================================================ +num_images_to_process = 10**7 +# Target 64 concurrent L4 replicas on g6.xlarge workers. + +num_gpu = 64 +num_cpu = 384 * 2 +tensor_parallelism = 1 + +# Download configuration - optimized for thread pool +download_threads_per_batch = 20 # Number of threads per batch (was semaphore) +download_timeout = 10 # Increased from 5s to be more tolerant +max_retries = 1 # Number of retries for transient failures +batch_size = 50 # Reduced from 100 for better memory management + +# Create a session for connection pooling +session = requests.Session() +adapter = requests.adapters.HTTPAdapter( + pool_connections=100, + pool_maxsize=100, + max_retries=0, # We handle retries manually +) +session.mount("http://", adapter) +session.mount("https://", adapter) + +timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") +output_path = f"/mnt/shared_storage/process_images_output/{timestamp}" + + +def is_valid_url(url: str) -> bool: + """Check if URL is valid and properly formatted.""" + if not url or not isinstance(url, str): + return False + url_lower = url.lower().strip() + return url_lower.startswith("http://") or url_lower.startswith("https://") + + +def download_single_image_with_retry(url: str) -> Dict[str, Any]: + """ + Download a single image with retry logic and detailed error tracking. + + Returns: + Dict with keys: + - 'content': Image bytes or None + - 'status': Success/failure status + - 'url': Original URL for tracking + """ + if not is_valid_url(url): + return {"content": None, "status": "invalid_url", "url": url} + + last_error = None + for attempt in range(max_retries): + try: + # Use the global session for connection pooling + response = session.get(url, timeout=download_timeout, stream=True) + + if response.status_code == 200: + # Read content in chunks to handle large images efficiently + content = response.content + return {"content": content, "status": "success", "url": url} + elif response.status_code == 404: + # Don't retry 404s - they're permanent failures + return { + "content": None, + "status": f"http_{response.status_code}", + "url": url, + } + else: + last_error = f"http_{response.status_code}" + + except requests.Timeout: + last_error = "timeout" + if attempt < max_retries - 1: + # Exponential backoff for retries + time.sleep(2**attempt) + continue + + except requests.ConnectionError as e: + last_error = "connection_error" + if attempt < max_retries - 1: + time.sleep(2**attempt) + continue + + except Exception as e: + last_error = f"error_{type(e).__name__}" + logger.debug(f"Unexpected error downloading {url}: {e}") + break + + # All retries exhausted + return {"content": None, "status": last_error, "url": url} + + +def image_download(batch: Dict[str, List]) -> Dict[str, List]: + """ + Download images in batch using thread pool for parallelism. + + This replaces the complex async implementation with a simpler, + more robust thread-based approach that's equally performant for I/O. + """ + urls = batch["url"] + + # Use ThreadPoolExecutor for parallel downloads + with ThreadPoolExecutor(max_workers=download_threads_per_batch) as executor: + # Map download function over all URLs + results = list(executor.map(download_single_image_with_retry, urls)) + + # Extract content and status for downstream processing + batch["bytes"] = [r["content"] for r in results] + batch["download_status"] = [r["status"] for r in results] + + # Log statistics for monitoring + success_count = sum(1 for r in results if r["status"] == "success") + total_count = len(results) + failure_types = {} + for r in results: + if r["status"] != "success": + failure_types[r["status"]] = failure_types.get(r["status"], 0) + 1 + + if total_count > 0: + success_rate = (success_count / total_count) * 100 + logger.info( + f"Batch download: {success_count}/{total_count} succeeded ({success_rate:.1f}%)" + ) + if failure_types: + logger.info(f"Failure breakdown: {failure_types}") + + return batch + + +def process_single_image(image_bytes: Optional[bytes]) -> Optional[bytes]: + """ + Process a single image: validate, convert to RGB, and resize. + + Args: + image_bytes: Raw image bytes or None + + Returns: + Processed image bytes as JPEG or None if processing fails + """ + if image_bytes is None: + return None + + try: + img = Image.open(BytesIO(image_bytes)) + img.load() # Force load to detect corrupt images early + + # Convert to RGB if needed + if img.mode != "RGB": + img = img.convert("RGB") + + # Resize to target dimensions + img = img.resize((128, 128), Image.Resampling.LANCZOS) + + # Save as JPEG with high quality + output_buffer = BytesIO() + img.save(output_buffer, format="JPEG", quality=95) + return output_buffer.getvalue() + + except Exception as e: + logger.debug(f"Failed to process image: {type(e).__name__}") + return None + + +def process_image_bytes(batch): + image_bytes_list = batch["bytes"] + + with ThreadPoolExecutor(max_workers=50) as executor: + results = list(executor.map(process_single_image, image_bytes_list)) + + batch["bytes"] = results + return batch + + +vision_processor_config = vLLMEngineProcessorConfig( + model_source="Qwen/Qwen2.5-VL-3B-Instruct", + engine_kwargs=dict( + tensor_parallel_size=tensor_parallelism, + pipeline_parallel_size=1, + max_model_len=32768, + enable_chunked_prefill=True, + max_num_batched_tokens=2048, + distributed_executor_backend="mp", + ), + runtime_env=dict( + env_vars=dict( + VLLM_USE_V1="1", + VLLM_DISABLE_COMPILE_CACHE="1", + ), + ), + batch_size=128, + max_concurrent_batches=128, + accelerator_type="A10G", + concurrency=num_gpu, + has_image=True, +) + + +def vision_preprocess(row): + image_bytes = row["bytes"] + return dict( + messages=[ + { + "role": "user", + "content": [ + { + "type": "image", + "image": Image.open(BytesIO(image_bytes)), + }, + ], + }, + ], + sampling_params=dict( + temperature=0.3, + max_tokens=150, + detokenize=False, + ), + ) + + +def vision_postprocess(row): + row.pop("bytes") + return row + + +vision_processor = build_llm_processor( + vision_processor_config, + preprocess=vision_preprocess, + postprocess=vision_postprocess, +) + +tasks_per_cpu = 1 +concurrency = num_cpu * tasks_per_cpu + +ctx = ray.data.DataContext.get_current() +target_block_size_mb = 128 +ctx.target_max_block_size = target_block_size_mb * 1024 * 1024 +ctx.use_push_based_shuffle = False +ctx.retried_io_errors.append("429 Client Error: Too Many Requests for url") + + +# The data processing pipeline includes the following steps: +# 1. Read the 2B images dataset with url column +# 2. Download images using thread pool with retry logic +# 3. Validate and resize images to 128x128 +# 4. Process images with the vision model +# 5. Write results to the output path +dataset = ( + ray.data.read_parquet( + "hf://datasets/laion/relaion2B-en-research-safe/", + file_extensions=["parquet"], + columns=["url"], + filesystem=HfFileSystem(token=os.environ["HF_TOKEN"]), + concurrency=concurrency, + memory=int(4 * 1024**3), + ) # Read dataset with memory allocation to avoid OOM errors + .limit(num_images_to_process) + .repartition(num_cpu) + .map_batches( + image_download, + batch_size=batch_size, # Use optimized batch size + concurrency=num_cpu, + ) + .drop_columns(["url"]) # Drop URL after download to save memory + .map_batches( + process_image_bytes, + batch_size=batch_size, # Consistent batch size + concurrency=num_cpu, + ) + .filter( + lambda row: row["bytes"] is not None + ) # Filter out failed downloads/processing + .drop_columns(["download_status"]) # Drop status column after filtering +) + +dataset = vision_processor(dataset) + +dataset.write_parquet( + output_path, + max_rows_per_file=100000, +)