diff --git a/tests/distributed/test_shm_buffer.py b/tests/distributed/test_shm_buffer.py index c6ceab181ff5..9fe409edc3ca 100644 --- a/tests/distributed/test_shm_buffer.py +++ b/tests/distributed/test_shm_buffer.py @@ -4,6 +4,8 @@ import traceback import unittest +import numpy as np + from vllm.distributed.device_communicators.shm_object_storage import ( SingleWriterShmRingBuffer, ) @@ -113,6 +115,69 @@ def test_clear_buffer(self): self.assertEqual(self.ring_buffer.data_buffer_start, 0) self.assertEqual(self.ring_buffer.data_buffer_end, 0) + def test_allocation_cycles(self): + buffer_size = 100 + ring = SingleWriterShmRingBuffer(data_buffer_size=buffer_size, create=True) + + # tracking allocations for assertions + allocated_bitmap = np.zeros( + (buffer_size,), dtype=np.bool_ + ) # addr -> is_allocated + allocation_map = dict() # monotonic_id -> (addr, size) + + def count_allocated(bitmap) -> int: + return np.sum(bitmap).item() + + def is_free_fn(a, b) -> bool: + return True + + def mark_allocated_with_assertion(id, addr, size): + addr = addr % buffer_size + self.assertEqual(count_allocated(allocated_bitmap[addr : addr + size]), 0) + + allocated_bitmap[addr : addr + size] = True + allocation_map[id] = (addr, size) + + def mark_freed_with_assertion(id): + self.assertTrue(id in allocation_map) + + addr, size = allocation_map.pop(id) + addr = addr % buffer_size + self.assertEqual( + count_allocated(allocated_bitmap[addr : addr + size]), size + ) + + allocated_bitmap[addr : addr + size] = False + + def ring_free(free_size=None): + freed_ids = ring.free_buf(is_free_fn, free_size) + for freed_id in freed_ids: + mark_freed_with_assertion(freed_id) + + def ring_allocate(allocate_size): + allocate_size_with_md = allocate_size + ring.MD_SIZE + try: + addr, monotonic_id = ring.allocate_buf(allocate_size) + mark_allocated_with_assertion(monotonic_id, addr, allocate_size_with_md) + except MemoryError: + # free 2x size for enough space if wrapping happened + ring_free(allocate_size_with_md * 2) + + # retry allocating + addr, monotonic_id = ring.allocate_buf(allocate_size) + mark_allocated_with_assertion(monotonic_id, addr, allocate_size_with_md) + + # 1. allocation & free cycles + for _ in range(33): + # will consume 2 + 8 = 10 bytes per allocation + ring_allocate(2) + + # 2. free all allocations + ring_free() + + # 3. try allocate the largest possible buffer + ring_allocate(buffer_size - ring.MD_SIZE) + def main(): """Main function demonstrating usage and running tests""" diff --git a/vllm/distributed/device_communicators/shm_object_storage.py b/vllm/distributed/device_communicators/shm_object_storage.py index 080bc03e3913..2ec33afb8783 100644 --- a/vllm/distributed/device_communicators/shm_object_storage.py +++ b/vllm/distributed/device_communicators/shm_object_storage.py @@ -127,9 +127,7 @@ def __init__( if create: # we are creating a buffer - self.metadata = { - self.monotonic_id_end: self.data_buffer_end - } # monotonic_id -> start address + self.metadata: dict[int, int] = {} # monotonic_id -> start address self.shared_memory = shared_memory.SharedMemory( create=True, size=self.data_buffer_size, name=name ) @@ -288,7 +286,15 @@ def free_buf( self.monotonic_id_start = ( self.monotonic_id_start + 1 ) % self.ID_MAX - self.data_buffer_start = address + if self.monotonic_id_start in self.metadata: + # pointing to the start addr of next allocation + self.data_buffer_start += ( + self.metadata[self.monotonic_id_start] + - self.data_buffer_start + ) % self.data_buffer_size + else: + # no remaining allocation, reset to zero + self.data_buffer_start = self.data_buffer_end = 0 freed_bytes += metadata[1] else: # there are still readers, we cannot free the buffer