Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions tests/distributed/test_shm_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import traceback
import unittest

import numpy as np

from vllm.distributed.device_communicators.shm_object_storage import (
SingleWriterShmRingBuffer,
)
Expand Down Expand Up @@ -113,6 +115,69 @@ def test_clear_buffer(self):
self.assertEqual(self.ring_buffer.data_buffer_start, 0)
self.assertEqual(self.ring_buffer.data_buffer_end, 0)

def test_allocation_cycles(self):
buffer_size = 100
ring = SingleWriterShmRingBuffer(data_buffer_size=buffer_size, create=True)

# tracking allocations for assertions
allocated_bitmap = np.zeros(
(buffer_size,), dtype=np.bool_
) # addr -> is_allocated
allocation_map = dict() # monotonic_id -> (addr, size)

def count_allocated(bitmap) -> int:
return np.sum(bitmap).item()

def is_free_fn(a, b) -> bool:
return True

def mark_allocated_with_assertion(id, addr, size):
addr = addr % buffer_size
self.assertEqual(count_allocated(allocated_bitmap[addr : addr + size]), 0)

allocated_bitmap[addr : addr + size] = True
allocation_map[id] = (addr, size)

def mark_freed_with_assertion(id):
self.assertTrue(id in allocation_map)

addr, size = allocation_map.pop(id)
addr = addr % buffer_size
self.assertEqual(
count_allocated(allocated_bitmap[addr : addr + size]), size
)

allocated_bitmap[addr : addr + size] = False

def ring_free(free_size=None):
freed_ids = ring.free_buf(is_free_fn, free_size)
for freed_id in freed_ids:
mark_freed_with_assertion(freed_id)

def ring_allocate(allocate_size):
allocate_size_with_md = allocate_size + ring.MD_SIZE
try:
addr, monotonic_id = ring.allocate_buf(allocate_size)
mark_allocated_with_assertion(monotonic_id, addr, allocate_size_with_md)
except MemoryError:
# free 2x size for enough space if wrapping happened
ring_free(allocate_size_with_md * 2)

# retry allocating
addr, monotonic_id = ring.allocate_buf(allocate_size)
mark_allocated_with_assertion(monotonic_id, addr, allocate_size_with_md)

# 1. allocation & free cycles
for _ in range(33):
# will consume 2 + 8 = 10 bytes per allocation
ring_allocate(2)

# 2. free all allocations
ring_free()

# 3. try allocate the largest possible buffer
ring_allocate(buffer_size - ring.MD_SIZE)


def main():
"""Main function demonstrating usage and running tests"""
Expand Down
14 changes: 10 additions & 4 deletions vllm/distributed/device_communicators/shm_object_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,7 @@ def __init__(

if create:
# we are creating a buffer
self.metadata = {
self.monotonic_id_end: self.data_buffer_end
} # monotonic_id -> start address
self.metadata: dict[int, int] = {} # monotonic_id -> start address
self.shared_memory = shared_memory.SharedMemory(
create=True, size=self.data_buffer_size, name=name
)
Expand Down Expand Up @@ -288,7 +286,15 @@ def free_buf(
self.monotonic_id_start = (
self.monotonic_id_start + 1
) % self.ID_MAX
self.data_buffer_start = address
if self.monotonic_id_start in self.metadata:
# pointing to the start addr of next allocation
self.data_buffer_start += (
self.metadata[self.monotonic_id_start]
- self.data_buffer_start
) % self.data_buffer_size
else:
# no remaining allocation, reset to zero
self.data_buffer_start = self.data_buffer_end = 0
freed_bytes += metadata[1]
else:
# there are still readers, we cannot free the buffer
Expand Down