Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions image_processing/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
FROM anyscale/ray:2.51.1-slim-py312-cu128

# C compiler for Triton’s runtime build step (vLLM V1 engine)
# https://github.com/vllm-project/vllm/issues/2997
RUN sudo apt-get update && \
sudo apt-get install -y --no-install-recommends build-essential

RUN curl -LsSf https://astral.sh/uv/install.sh | sh

RUN uv pip install --system huggingface_hub boto3

RUN uv pip install --system vllm==0.11.0

RUN uv pip install --system transformers==4.57.1
55 changes: 55 additions & 0 deletions image_processing/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Large-Scale Image Processing with Vision Language Models

This example demonstrates how to build a production-ready image processing pipeline that scales to billions of images using Ray Data and vLLM on Anyscale. We process the [ReLAION-2B dataset](https://huggingface.co/datasets/laion/relaion2B-en-research-safe), which contains over 2 billion image URLs with associated metadata.

## What This Pipeline Does

The pipeline performs three main stages on each image:

1. **Parallel Image Download**: Asynchronously downloads images from URLs using aiohttp with 1,000 concurrent connections, handling timeouts and validation gracefully.

2. **Image Preprocessing**: Validates, resizes, and standardizes images to 128×128 JPEG format in RGB color space using PIL, filtering out corrupted or invalid images.

3. **Vision Model Inference**: Runs the Qwen2.5-VL-3B-Instruct vision-language model using vLLM to generate captions or analyze image content, scaling across up to 64 L4 GPU replicas based on workload.

The entire pipeline is orchestrated by Ray Data, which handles distributed execution, fault tolerance, and resource management across your cluster.

## Key Features

- **Massive Scale**: Processes 2B+ images efficiently with automatic resource scaling
- **High Throughput**: Concurrent downloads (1,000 connections) and batched inference (8 images per batch, 16 concurrent batches per GPU)
- **Fault Tolerant**: Gracefully handles network failures, invalid images, and transient errors
- **Cost Optimized**: Automatic GPU autoscaling (up to 64 L4 replicas) based on workload demand
- **Production Ready**: Timestamped outputs, configurable memory limits, and structured error handling

## How to Run

First, make sure you have the [Anyscale CLI](https://docs.anyscale.com/get-started/install-anyscale-cli) installed.

You'll need a HuggingFace token to access the ReLAION-2B dataset. Get one at [huggingface.co/settings/tokens](https://huggingface.co/settings/tokens).

Submit the job:

```bash
anyscale job submit -f job.yaml --env HF_TOKEN=$HF_TOKEN
```

Or use the convenience script:

```bash
./run.sh
```

Results will be written to `/mnt/shared_storage/process_images_output/{timestamp}/` in Parquet format.

## Configuration

The pipeline is configured for high-throughput processing:

- **Compute**: Up to 530 CPUs and 64 L4 GPUs (g6.xlarge workers) with auto-scaling
- **Vision Model**: Qwen2.5-VL-3B-Instruct on NVIDIA L4 GPUs with vLLM
- **Download**: 1,000 concurrent connections, 5-second timeout per image
- **Batch Processing**: 50 images per download batch, 8 images per inference batch
- **Output**: 100,000 rows per Parquet file for efficient storage

You can adjust these settings in `process_images.py` and `job.yaml` to match your requirements.
42 changes: 42 additions & 0 deletions image_processing/job.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# View the docs https://docs.anyscale.com/reference/job-api#jobconfig.
name: process-images

# When empty, use the default image. This can be an Anyscale-provided base image
# like anyscale/ray:2.43.0-slim-py312-cu125, a user-provided base image (provided
# that it meets certain specs), or you can build new images using the Anyscale
# image builder at https://console.anyscale-staging.com/v2/container-images.
# image_uri: # anyscale/ray:2.43.0-slim-py312-cu125
containerfile: ./Dockerfile

# When empty, Anyscale will auto-select the instance types. You can also specify
# minimum and maximum resources.
compute_config:
# Pin worker nodes to g6e.12xlarge so the vision workload lands on L40S GPUs.
worker_nodes:
- instance_type: g6e.12xlarge
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This instance type has 4 GPUs in a node, use 1 GPU node instance type would produce error after sometime of running

min_nodes: 0
max_nodes: 8
min_resources:
CPU: 0
GPU: 0
max_resources:
CPU: 384
GPU: 32

# Path to a local directory or a remote URI to a .zip file (S3, GS, HTTP) that
# will be the working directory for the job. The files in the directory will be
# automatically uploaded to the job environment in Anyscale.
working_dir: .

# When empty, this uses the default Anyscale Cloud in your organization.
cloud:

env_vars:
RAY_DEFAULT_OBJECT_STORE_MEMORY_PROPORTION: "0.5"

# The script to run in your job. You can also do "uv run main.py" if you have a
# pyproject.toml file in your working_dir.
entrypoint: python process_images.py

# If there is an error, do not retry.
max_retries: 0
240 changes: 240 additions & 0 deletions image_processing/process_images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
import os
import asyncio
import ray
from huggingface_hub import HfFileSystem
from ray.data.llm import vLLMEngineProcessorConfig, build_llm_processor
from PIL import Image
from io import BytesIO
import aiohttp
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timezone
import time


# ============================================================================
# SCALABILITY CONFIGURATION FOR 2B+ IMAGES
# ============================================================================
# num_images = 100
# Target 64 concurrent L4 replicas on g6.xlarge workers.

num_gpu = 32
num_cpu = 384
tensor_parallelism = 1
download_concurrency = 256
download_timeout = 5

timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
output_path = f"/mnt/shared_storage/process_images_output/{timestamp}"


def is_valid_url(url):
if not url or not isinstance(url, str):
return False
url_lower = url.lower().strip()
return url_lower.startswith("http://") or url_lower.startswith("https://")


async def download_single_image(session, url, semaphore):
async with semaphore:
if not is_valid_url(url):
return None

try:
async with session.get(
url, timeout=aiohttp.ClientTimeout(total=download_timeout)
) as response:
if response.status == 200:
content = await response.read()
return content
return None
except Exception:
return None


async def download_images_async(urls):
semaphore = asyncio.Semaphore(download_concurrency)

connector = aiohttp.TCPConnector(
limit=download_concurrency,
limit_per_host=100,
ttl_dns_cache=300,
enable_cleanup_closed=True,
)

timeout_config = aiohttp.ClientTimeout(total=download_timeout, connect=3)

async with aiohttp.ClientSession(
connector=connector, timeout=timeout_config
) as session:
tasks = [download_single_image(session, url, semaphore) for url in urls]
results = await asyncio.gather(*tasks, return_exceptions=True)

processed_results = []
for result in results:
if isinstance(result, Exception):
processed_results.append(None)
else:
processed_results.append(result)

return processed_results


def image_download(batch):
urls = batch["url"]

# Use a dedicated event loop per batch to avoid interfering with any
# existing asyncio loops that Ray or vLLM may be running in this process.
loop = asyncio.new_event_loop()
try:
results = loop.run_until_complete(download_images_async(urls))
finally:
Copy link
Author

@xyuzh xyuzh Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am using a verbose approach for the async download logic but I found the aysnc support for the map_batches at https://github.com/ray-project/ray/pull/46129/files
@robertnishihara do you recommend the async approach mentioned over the PR?

# Comprehensive cleanup to prevent resource leaks
try:
# Cancel all pending tasks in this specific loop
pending = asyncio.all_tasks(loop)
for task in pending:
task.cancel()

# Wait for all task cancellations to complete
if pending:
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))

# Shut down async generators
loop.run_until_complete(loop.shutdown_asyncgens())

# Shut down the default executor (thread pool)
loop.run_until_complete(loop.shutdown_default_executor())
except Exception:
# Best-effort cleanup; failures here should not impact the worker
pass
finally:
# Close the loop without affecting the global event loop
loop.close()

batch["bytes"] = results
return batch


def process_single_image(image_bytes):
if image_bytes is None:
return None

try:
img = Image.open(BytesIO(image_bytes))
img.load()

if img.mode != "RGB":
img = img.convert("RGB")

img = img.resize((128, 128), Image.Resampling.LANCZOS)

output_buffer = BytesIO()
img.save(output_buffer, format="JPEG", quality=95)
return output_buffer.getvalue()
except Exception:
return None


def process_image_bytes(batch):
image_bytes_list = batch["bytes"]

with ThreadPoolExecutor(max_workers=50) as executor:
results = list(executor.map(process_single_image, image_bytes_list))

batch["bytes"] = results
return batch


vision_processor_config = vLLMEngineProcessorConfig(
model_source="Qwen/Qwen2.5-VL-3B-Instruct",
engine_kwargs=dict(
tensor_parallel_size=tensor_parallelism,
pipeline_parallel_size=1,
max_model_len=32768,
enable_chunked_prefill=True,
max_num_batched_tokens=2048,
distributed_executor_backend="mp",
),
runtime_env=dict(
env_vars=dict(
VLLM_USE_V1="1",
VLLM_DISABLE_COMPILE_CACHE="1",
),
),
batch_size=128,
max_concurrent_batches=128,
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changing the batch_size and max_concurrent_batches didn't change the GPU memory usage and VLM throughput

accelerator_type="L40S",
concurrency=num_gpu,
has_image=True,
)


def vision_preprocess(row):
image_bytes = row["bytes"]
return dict(
messages=[
{
"role": "user",
"content": [
{
"type": "image",
"image": Image.open(BytesIO(image_bytes)),
},
],
},
],
sampling_params=dict(
temperature=0.3,
max_tokens=150,
detokenize=False,
),
)


def vision_postprocess(row):
row.pop("bytes")
return row


vision_processor = build_llm_processor(
vision_processor_config,
preprocess=vision_preprocess,
postprocess=vision_postprocess,
)

tasks_per_cpu = 1
concurrency = num_cpu * tasks_per_cpu

ctx = ray.data.DataContext.get_current()
target_block_size_mb = 128
ctx.target_max_block_size = target_block_size_mb * 1024 * 1024
ctx.use_push_based_shuffle = False

# The data data processing include the following steps:
# 1. Download the 2B images dataset with url column
# 2. Download the images from the url in parallel async jobs
# 3. Check the image is valid and resize the images to 128x128
# 4. Process the images with the vision model
# 5. Write the images to the output path
dataset = (
ray.data.read_parquet(
"hf://datasets/laion/relaion2B-en-research-safe/",
file_extensions=["parquet"],
columns=["url"],
filesystem=HfFileSystem(token=os.environ["HF_TOKEN"]),
concurrency=concurrency,
num_cpus=2,
memory=int(4 * 1024**3),
) # Download the dataset with memory allocation to avoid OOM errors
.map_batches(image_download, batch_size=50, num_cpus=1, concurrency=num_cpu)
.drop_columns(["url"])
.map_batches(process_image_bytes, batch_size=50, num_cpus=1, concurrency=num_cpu)
.filter(lambda row: row["bytes"] is not None)
)

dataset = vision_processor(dataset)

dataset.write_parquet(
output_path,
max_rows_per_file=100000,
)
2 changes: 2 additions & 0 deletions image_processing/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
anyscale job submit -f job.yaml \
--env HF_TOKEN=$HF_TOKEN