diff --git a/image_processing/Dockerfile b/image_processing/Dockerfile new file mode 100644 index 0000000..d48065a --- /dev/null +++ b/image_processing/Dockerfile @@ -0,0 +1,14 @@ +FROM anyscale/ray:2.51.1-slim-py312-cu128 + +# C compiler for Triton’s runtime build step (vLLM V1 engine) +# https://github.com/vllm-project/vllm/issues/2997 +RUN sudo apt-get update && \ + sudo apt-get install -y --no-install-recommends build-essential + +RUN curl -LsSf https://astral.sh/uv/install.sh | sh + +RUN uv pip install --system huggingface_hub boto3 + +RUN uv pip install --system vllm==0.11.0 + +RUN uv pip install --system transformers==4.57.1 \ No newline at end of file diff --git a/image_processing/README.md b/image_processing/README.md new file mode 100644 index 0000000..0de2e7a --- /dev/null +++ b/image_processing/README.md @@ -0,0 +1,55 @@ +# Large-Scale Image Processing with Vision Language Models + +This example demonstrates how to build a production-ready image processing pipeline that scales to billions of images using Ray Data and vLLM on Anyscale. We process the [ReLAION-2B dataset](https://huggingface.co/datasets/laion/relaion2B-en-research-safe), which contains over 2 billion image URLs with associated metadata. + +## What This Pipeline Does + +The pipeline performs three main stages on each image: + +1. **Parallel Image Download**: Asynchronously downloads images from URLs using aiohttp with 1,000 concurrent connections, handling timeouts and validation gracefully. + +2. **Image Preprocessing**: Validates, resizes, and standardizes images to 128×128 JPEG format in RGB color space using PIL, filtering out corrupted or invalid images. + +3. **Vision Model Inference**: Runs the Qwen2.5-VL-3B-Instruct vision-language model using vLLM to generate captions or analyze image content, scaling across up to 64 L4 GPU replicas based on workload. + +The entire pipeline is orchestrated by Ray Data, which handles distributed execution, fault tolerance, and resource management across your cluster. + +## Key Features + +- **Massive Scale**: Processes 2B+ images efficiently with automatic resource scaling +- **High Throughput**: Concurrent downloads (1,000 connections) and batched inference (8 images per batch, 16 concurrent batches per GPU) +- **Fault Tolerant**: Gracefully handles network failures, invalid images, and transient errors +- **Cost Optimized**: Automatic GPU autoscaling (up to 64 L4 replicas) based on workload demand +- **Production Ready**: Timestamped outputs, configurable memory limits, and structured error handling + +## How to Run + +First, make sure you have the [Anyscale CLI](https://docs.anyscale.com/get-started/install-anyscale-cli) installed. + +You'll need a HuggingFace token to access the ReLAION-2B dataset. Get one at [huggingface.co/settings/tokens](https://huggingface.co/settings/tokens). + +Submit the job: + +```bash +anyscale job submit -f job.yaml --env HF_TOKEN=$HF_TOKEN +``` + +Or use the convenience script: + +```bash +./run.sh +``` + +Results will be written to `/mnt/shared_storage/process_images_output/{timestamp}/` in Parquet format. + +## Configuration + +The pipeline is configured for high-throughput processing: + +- **Compute**: Up to 530 CPUs and 64 L4 GPUs (g6.xlarge workers) with auto-scaling +- **Vision Model**: Qwen2.5-VL-3B-Instruct on NVIDIA L4 GPUs with vLLM +- **Download**: 1,000 concurrent connections, 5-second timeout per image +- **Batch Processing**: 50 images per download batch, 8 images per inference batch +- **Output**: 100,000 rows per Parquet file for efficient storage + +You can adjust these settings in `process_images.py` and `job.yaml` to match your requirements. \ No newline at end of file diff --git a/image_processing/job.yaml b/image_processing/job.yaml new file mode 100644 index 0000000..f9bac5a --- /dev/null +++ b/image_processing/job.yaml @@ -0,0 +1,42 @@ +# View the docs https://docs.anyscale.com/reference/job-api#jobconfig. +name: process-images + +# When empty, use the default image. This can be an Anyscale-provided base image +# like anyscale/ray:2.43.0-slim-py312-cu125, a user-provided base image (provided +# that it meets certain specs), or you can build new images using the Anyscale +# image builder at https://console.anyscale-staging.com/v2/container-images. +# image_uri: # anyscale/ray:2.43.0-slim-py312-cu125 +containerfile: ./Dockerfile + +# When empty, Anyscale will auto-select the instance types. You can also specify +# minimum and maximum resources. +compute_config: + # Pin worker nodes to g6e.12xlarge so the vision workload lands on L40S GPUs. + worker_nodes: + - instance_type: g6e.12xlarge + min_nodes: 0 + max_nodes: 8 + min_resources: + CPU: 0 + GPU: 0 + max_resources: + CPU: 384 + GPU: 32 + +# Path to a local directory or a remote URI to a .zip file (S3, GS, HTTP) that +# will be the working directory for the job. The files in the directory will be +# automatically uploaded to the job environment in Anyscale. +working_dir: . + +# When empty, this uses the default Anyscale Cloud in your organization. +cloud: + +env_vars: + RAY_DEFAULT_OBJECT_STORE_MEMORY_PROPORTION: "0.5" + +# The script to run in your job. You can also do "uv run main.py" if you have a +# pyproject.toml file in your working_dir. +entrypoint: python process_images.py + +# If there is an error, do not retry. +max_retries: 0 \ No newline at end of file diff --git a/image_processing/process_images.py b/image_processing/process_images.py new file mode 100644 index 0000000..b531a93 --- /dev/null +++ b/image_processing/process_images.py @@ -0,0 +1,240 @@ +import os +import asyncio +import ray +from huggingface_hub import HfFileSystem +from ray.data.llm import vLLMEngineProcessorConfig, build_llm_processor +from PIL import Image +from io import BytesIO +import aiohttp +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime, timezone +import time + + +# ============================================================================ +# SCALABILITY CONFIGURATION FOR 2B+ IMAGES +# ============================================================================ +# num_images = 100 +# Target 64 concurrent L4 replicas on g6.xlarge workers. + +num_gpu = 32 +num_cpu = 384 +tensor_parallelism = 1 +download_concurrency = 256 +download_timeout = 5 + +timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") +output_path = f"/mnt/shared_storage/process_images_output/{timestamp}" + + +def is_valid_url(url): + if not url or not isinstance(url, str): + return False + url_lower = url.lower().strip() + return url_lower.startswith("http://") or url_lower.startswith("https://") + + +async def download_single_image(session, url, semaphore): + async with semaphore: + if not is_valid_url(url): + return None + + try: + async with session.get( + url, timeout=aiohttp.ClientTimeout(total=download_timeout) + ) as response: + if response.status == 200: + content = await response.read() + return content + return None + except Exception: + return None + + +async def download_images_async(urls): + semaphore = asyncio.Semaphore(download_concurrency) + + connector = aiohttp.TCPConnector( + limit=download_concurrency, + limit_per_host=100, + ttl_dns_cache=300, + enable_cleanup_closed=True, + ) + + timeout_config = aiohttp.ClientTimeout(total=download_timeout, connect=3) + + async with aiohttp.ClientSession( + connector=connector, timeout=timeout_config + ) as session: + tasks = [download_single_image(session, url, semaphore) for url in urls] + results = await asyncio.gather(*tasks, return_exceptions=True) + + processed_results = [] + for result in results: + if isinstance(result, Exception): + processed_results.append(None) + else: + processed_results.append(result) + + return processed_results + + +def image_download(batch): + urls = batch["url"] + + # Use a dedicated event loop per batch to avoid interfering with any + # existing asyncio loops that Ray or vLLM may be running in this process. + loop = asyncio.new_event_loop() + try: + results = loop.run_until_complete(download_images_async(urls)) + finally: + # Comprehensive cleanup to prevent resource leaks + try: + # Cancel all pending tasks in this specific loop + pending = asyncio.all_tasks(loop) + for task in pending: + task.cancel() + + # Wait for all task cancellations to complete + if pending: + loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True)) + + # Shut down async generators + loop.run_until_complete(loop.shutdown_asyncgens()) + + # Shut down the default executor (thread pool) + loop.run_until_complete(loop.shutdown_default_executor()) + except Exception: + # Best-effort cleanup; failures here should not impact the worker + pass + finally: + # Close the loop without affecting the global event loop + loop.close() + + batch["bytes"] = results + return batch + + +def process_single_image(image_bytes): + if image_bytes is None: + return None + + try: + img = Image.open(BytesIO(image_bytes)) + img.load() + + if img.mode != "RGB": + img = img.convert("RGB") + + img = img.resize((128, 128), Image.Resampling.LANCZOS) + + output_buffer = BytesIO() + img.save(output_buffer, format="JPEG", quality=95) + return output_buffer.getvalue() + except Exception: + return None + + +def process_image_bytes(batch): + image_bytes_list = batch["bytes"] + + with ThreadPoolExecutor(max_workers=50) as executor: + results = list(executor.map(process_single_image, image_bytes_list)) + + batch["bytes"] = results + return batch + + +vision_processor_config = vLLMEngineProcessorConfig( + model_source="Qwen/Qwen2.5-VL-3B-Instruct", + engine_kwargs=dict( + tensor_parallel_size=tensor_parallelism, + pipeline_parallel_size=1, + max_model_len=32768, + enable_chunked_prefill=True, + max_num_batched_tokens=2048, + distributed_executor_backend="mp", + ), + runtime_env=dict( + env_vars=dict( + VLLM_USE_V1="1", + VLLM_DISABLE_COMPILE_CACHE="1", + ), + ), + batch_size=128, + max_concurrent_batches=128, + accelerator_type="L40S", + concurrency=num_gpu, + has_image=True, +) + + +def vision_preprocess(row): + image_bytes = row["bytes"] + return dict( + messages=[ + { + "role": "user", + "content": [ + { + "type": "image", + "image": Image.open(BytesIO(image_bytes)), + }, + ], + }, + ], + sampling_params=dict( + temperature=0.3, + max_tokens=150, + detokenize=False, + ), + ) + + +def vision_postprocess(row): + row.pop("bytes") + return row + + +vision_processor = build_llm_processor( + vision_processor_config, + preprocess=vision_preprocess, + postprocess=vision_postprocess, +) + +tasks_per_cpu = 1 +concurrency = num_cpu * tasks_per_cpu + +ctx = ray.data.DataContext.get_current() +target_block_size_mb = 128 +ctx.target_max_block_size = target_block_size_mb * 1024 * 1024 +ctx.use_push_based_shuffle = False + +# The data data processing include the following steps: +# 1. Download the 2B images dataset with url column +# 2. Download the images from the url in parallel async jobs +# 3. Check the image is valid and resize the images to 128x128 +# 4. Process the images with the vision model +# 5. Write the images to the output path +dataset = ( + ray.data.read_parquet( + "hf://datasets/laion/relaion2B-en-research-safe/", + file_extensions=["parquet"], + columns=["url"], + filesystem=HfFileSystem(token=os.environ["HF_TOKEN"]), + concurrency=concurrency, + num_cpus=2, + memory=int(4 * 1024**3), + ) # Download the dataset with memory allocation to avoid OOM errors + .map_batches(image_download, batch_size=50, num_cpus=1, concurrency=num_cpu) + .drop_columns(["url"]) + .map_batches(process_image_bytes, batch_size=50, num_cpus=1, concurrency=num_cpu) + .filter(lambda row: row["bytes"] is not None) +) + +dataset = vision_processor(dataset) + +dataset.write_parquet( + output_path, + max_rows_per_file=100000, +) \ No newline at end of file diff --git a/image_processing/run.sh b/image_processing/run.sh new file mode 100755 index 0000000..8758376 --- /dev/null +++ b/image_processing/run.sh @@ -0,0 +1,2 @@ +anyscale job submit -f job.yaml \ + --env HF_TOKEN=$HF_TOKEN \ No newline at end of file