diff --git a/connection.go b/connection.go index ba65e67b..39b8e4fa 100644 --- a/connection.go +++ b/connection.go @@ -16,6 +16,7 @@ package fuse import ( "context" + "errors" "fmt" "io" "log" @@ -80,6 +81,8 @@ type Connection struct { // Freelists, serviced by freelists.go. inMessages freelist.Freelist // GUARDED_BY(mu) outMessages freelist.Freelist // GUARDED_BY(mu) + + inMessageSize int } // State that is maintained for each in-flight op. This is stuffed into the @@ -121,6 +124,12 @@ func newConnection( cancelFuncs: make(map[uint64]func()), } + maxPayload := max(buffer.MaxReadSize, buffer.MaxWriteSize) + if cfg.MaxMessageSize > 0 { + maxPayload = max(maxPayload, int(cfg.MaxMessageSize)) + } + c.inMessageSize = maxPayload + buffer.GetPageSize() + // Initialize. if err := c.Init(); err != nil { c.close() @@ -172,7 +181,9 @@ func (c *Connection) Init() error { // Respond to the init op. initOp.Library = c.protocol initOp.MaxReadahead = maxReadahead - initOp.MaxWrite = buffer.MaxWriteSize + + maxPayload := c.inMessageSize - buffer.GetPageSize() + initOp.MaxWrite = uint32(maxPayload) initOp.Flags = 0 @@ -190,7 +201,6 @@ func (c *Connection) Init() error { // payload. It applies to both requests and replies, and does not include // the extra 1 page for the FUSE header and the "args" struct. We set it to // the max of our message in/out payload sizes. - maxPayload := max(buffer.MaxReadSize, buffer.MaxWriteSize) initOp.MaxPages = uint16(maxPayload / buffer.GetPageSize()) // Enable writeback caching if the user hasn't asked us not to. @@ -376,6 +386,7 @@ func (c *Connection) handleInterrupt(fuseID uint64) { func (c *Connection) readMessage() (*buffer.InMessage, error) { // Allocate a message. m := c.getInMessage() + m.AllocBlocks(c.inMessageSize) // Loop past transient errors. for { @@ -389,15 +400,11 @@ func (c *Connection) readMessage() (*buffer.InMessage, error) { // * EINTR means we should try again. (This seems to happen often on // OS X, cf. http://golang.org/issue/11180) // - if pe, ok := err.(*os.PathError); ok { - switch pe.Err { - case syscall.ENODEV: - err = io.EOF - - case syscall.EINTR: - err = nil - continue - } + if errors.Is(err, syscall.ENODEV) { + err = io.EOF + } else if errors.Is(err, syscall.EINTR) { + err = nil + continue } if err != nil { diff --git a/conversions.go b/conversions.go index a7c5a4fd..dd2f990e 100644 --- a/conversions.go +++ b/conversions.go @@ -145,13 +145,14 @@ func convertInMessage( } entries := make([]fuseops.BatchForgetEntry, 0, in.Count) - for i := uint32(0); i < in.Count; i++ { - type entry fusekernel.BatchForgetEntryIn - ein := (*entry)(inMsg.Consume(unsafe.Sizeof(entry{}))) - if ein == nil { - return nil, errors.New("Corrupt OpBatchForget") - } + entrySize := unsafe.Sizeof(fusekernel.BatchForgetEntryIn{}) + buf := inMsg.ConsumeBytes(uintptr(in.Count) * entrySize) + if len(buf) < int(in.Count*uint32(entrySize)) { + return nil, errors.New("Corrupt OpBatchForget") + } + for i := uint32(0); i < in.Count; i++ { + ein := (*fusekernel.BatchForgetEntryIn)(unsafe.Pointer(&buf[uintptr(i)*entrySize])) entries = append(entries, fuseops.BatchForgetEntry{ Inode: fuseops.InodeID(ein.Inode), N: ein.Nlookup, @@ -396,7 +397,11 @@ func convertInMessage( }, } // Use part of the incoming message storage as the read buffer. - to.Dst = inMsg.GetFree(int(in.Size)) + if config.EnableVectoredReads && int(in.Size) > buffer.MiBPlusPageSize { + to.DstBufs = inMsg.GetFreeVector(int(in.Size)) + } else { + to.Dst = inMsg.GetFree(int(in.Size)) + } o = to case fusekernel.OpReaddir: @@ -498,16 +503,31 @@ func convertInMessage( return nil, errors.New("Corrupt OpWrite") } - buf := inMsg.ConsumeBytes(inMsg.Len()) - if len(buf) < int(in.Size) { - return nil, errors.New("Corrupt OpWrite") + var buf []byte + var dataBlocks [][]byte + + if config.EnableVectoredWrites && inMsg.Len() > uintptr(buffer.MiBPlusPageSize) { + dataBlocks = inMsg.ConsumeVector(inMsg.Len()) + var totalLen int + for _, b := range dataBlocks { + totalLen += len(b) + } + if totalLen < int(in.Size) { + return nil, errors.New("Corrupt OpWrite") + } + } else { + buf = inMsg.ConsumeBytes(inMsg.Len()) + if len(buf) < int(in.Size) { + return nil, errors.New("Corrupt OpWrite") + } } o = &fuseops.WriteFileOp{ - Inode: fuseops.InodeID(inMsg.Header().Nodeid), - Handle: fuseops.HandleID(in.Fh), - Data: buf, - Offset: int64(in.Offset), + Inode: fuseops.InodeID(inMsg.Header().Nodeid), + Handle: fuseops.HandleID(in.Fh), + Data: buf, + DataBlocks: dataBlocks, + Offset: int64(in.Offset), OpContext: fuseops.OpContext{ FuseID: inMsg.Header().Unique, Pid: inMsg.Header().Pid, @@ -933,6 +953,8 @@ func (c *Connection) kernelResponseForOp( case *fuseops.ReadFileOp: if o.Data != nil { m.Append(o.Data...) + } else if o.DstBufs != nil { + m.Append(o.DstBufs...) } else { m.Append(o.Dst) } @@ -940,7 +962,7 @@ func (c *Connection) kernelResponseForOp( case *fuseops.WriteFileOp: out := (*fusekernel.WriteOut)(m.Grow(int(unsafe.Sizeof(fusekernel.WriteOut{})))) - out.Size = uint32(len(o.Data)) + out.Size = uint32(o.TotalSize()) case *fuseops.SyncFileOp: // Empty response diff --git a/debug.go b/debug.go index f486a84d..5f6ec30e 100644 --- a/debug.go +++ b/debug.go @@ -100,7 +100,7 @@ func describeRequest(op interface{}) (s string) { case *fuseops.WriteFileOp: addComponent("handle %d", typed.Handle) addComponent("offset %d", typed.Offset) - addComponent("%d bytes", len(typed.Data)) + addComponent("%d bytes", typed.TotalSize()) case *fuseops.RemoveXattrOp: addComponent("name %s", typed.Name) diff --git a/freelists.go b/freelists.go index 8489e1f5..e1ead2e4 100644 --- a/freelists.go +++ b/freelists.go @@ -31,7 +31,7 @@ func (c *Connection) getInMessage() *buffer.InMessage { c.mu.Unlock() if x == nil { - x = buffer.NewInMessage() + x = buffer.NewInMessage(c.inMessageSize) } return x @@ -39,6 +39,7 @@ func (c *Connection) getInMessage() *buffer.InMessage { // LOCKS_EXCLUDED(c.mu) func (c *Connection) putInMessage(x *buffer.InMessage) { + x.FreeBlocks() c.mu.Lock() c.inMessages.Put(unsafe.Pointer(x)) c.mu.Unlock() diff --git a/fuseops/ops.go b/fuseops/ops.go index 83e1e775..cebb9ccd 100644 --- a/fuseops/ops.go +++ b/fuseops/ops.go @@ -711,11 +711,24 @@ type ReadFileOp struct { // The destination buffer, whose length gives the size of the read. // The file system can write to this buffer for non-vectored reads. + // Note: Dst will be nil if EnableVectoredReads is enabled and the read size + // is larger than the pre-allocated buffer size. In this case, the file system + // must use DstBufs (if populated) or allocate its own buffers and return them + // via the Data field. Dst []byte - // Set by the file system: + // The destination buffers for vectored reads, whose total length gives the + // size of the read. + // The file system can write directly to these buffers for vectored reads. + // Note: DstBufs will be nil if EnableVectoredReads is disabled, or if the + // read size is smaller than the pre-allocated incoming message block (in + // which case Dst will be populated instead). + DstBufs [][]byte + + // Set by the file system (optional): // A list of slices of data to send back to the client. - // If this field is populated, the contents of `Dst` will be ignored. + // If this field is populated, the contents of `Dst` and `DstBufs` will be ignored. + // If both `Dst` and `DstBufs` are nil, this field MUST be populated to return any data. Data [][]byte // Set by the file system: the number of bytes read. @@ -796,6 +809,7 @@ type WriteFileOp struct { // to be because it uses file mmapping machinery // (https://tinyurl.com/avxy3dvm) to write a page at a time. Data []byte + DataBlocks [][]byte OpContext OpContext // If set, this function will be invoked after the operation response has been @@ -804,6 +818,18 @@ type WriteFileOp struct { Callback func() } +// TotalSize returns the total size of the write payload. +func (o *WriteFileOp) TotalSize() int { + dataLen := len(o.Data) + if dataLen == 0 && len(o.DataBlocks) > 0 { + for _, b := range o.DataBlocks { + dataLen += len(b) + } + } + return dataLen +} + + // Synchronize the current contents of an open file to storage. // // vfs.txt documents this as being called for by the fsync(2) system call diff --git a/internal/buffer/in_message.go b/internal/buffer/in_message.go index a9728833..aac2f453 100644 --- a/internal/buffer/in_message.go +++ b/internal/buffer/in_message.go @@ -22,19 +22,26 @@ import ( "unsafe" "github.com/jacobsa/fuse/internal/fusekernel" + "golang.org/x/sys/unix" ) // All requests read from the kernel, without data, are shorter than // this. var pageSize int -// We size the buffer to have enough room for a fuse request plus data -// associated with a write request. -var bufSize int +// Constants for buffer/message sizes. +const ( + // MiB is 1 MiB in bytes. + MiB = 1024 * 1024 +) + +// MiBPlusPageSize is 1 MiB + hardware page size. Since pageSize is determined +// at runtime, this is a variable rather than a compile-time constant. +var MiBPlusPageSize int func init() { - pageSize = syscall.Getpagesize() - bufSize = pageSize + MaxWriteSize + pageSize = unix.Getpagesize() + MiBPlusPageSize = MiB + pageSize } // Return the hardware page size. Note that this is not always 4KiB! Notably @@ -43,52 +50,213 @@ func GetPageSize() int { return pageSize } +type blockPool struct { + mu sync.Mutex + list [][]byte + limit int + alloc func() []byte + overflow sync.Pool +} + +func newBlockPool(limit int, alloc func() []byte) *blockPool { + p := &blockPool{ + limit: limit, + alloc: alloc, + } + p.overflow.New = func() interface{} { + return p.alloc() + } + return p +} + +func (p *blockPool) Get() []byte { + p.mu.Lock() + l := len(p.list) + if l > 0 { + buf := p.list[l-1] + p.list = p.list[:l-1] + p.mu.Unlock() + return buf + } + p.mu.Unlock() + return p.overflow.Get().([]byte) +} + +func (p *blockPool) Put(buf []byte) { + buf = buf[:cap(buf)] + p.mu.Lock() + if len(p.list) < p.limit { + p.list = append(p.list, buf) + p.mu.Unlock() + return + } + p.mu.Unlock() + p.overflow.Put(buf) +} + +var BlockPool1M = newBlockPool(48, func() []byte { + return make([]byte, MiB) +}) + +var BlockPool1MPlusPage = newBlockPool(8, func() []byte { + return make([]byte, MiBPlusPageSize) +}) + // An incoming message from the kernel, including leading fusekernel.InHeader // struct. Provides storage for messages and convenient access to their // contents. type InMessage struct { - remaining []byte - storage []byte - size int + blocks [][]byte + size int + consumed int + iovecs []unix.Iovec + borrowedBlocks [][]byte +} + +// NewInMessage creates a new InMessage. +func NewInMessage(size int) *InMessage { + return &InMessage{} +} + +func (m *InMessage) AllocBlocks(totalSize int) { + m.FreeBlocks() + + // Always allocate a 1MB+pageSize block first for header & metadata & payload + m.blocks = append(m.blocks, BlockPool1MPlusPage.Get()) + + if totalSize > len(m.blocks[0]) { + remaining := totalSize - len(m.blocks[0]) + num1MBlocks := (remaining + MiB - 1) / MiB + for i := 0; i < num1MBlocks; i++ { + m.blocks = append(m.blocks, BlockPool1M.Get()) + } + } + m.consumed = 0 + m.size = 0 +} + +func (m *InMessage) FreeBlocks() { + if len(m.blocks) > 0 { + BlockPool1MPlusPage.Put(m.blocks[0]) + for i := 1; i < len(m.blocks); i++ { + BlockPool1M.Put(m.blocks[i]) + } + } + m.blocks = nil + for _, b := range m.borrowedBlocks { + BlockPool1M.Put(b) + } + m.borrowedBlocks = nil + m.size = 0 + m.consumed = 0 } -// NewInMessage creates a new InMessage with its storage initialized. -func NewInMessage() *InMessage { - return &InMessage{ - storage: make([]byte, bufSize), +func (m *InMessage) ShrinkToFit(n int) { + m.size = n + + var bytesNeeded = n + var usedBlocks = 0 + for _, block := range m.blocks { + usedBlocks++ + if bytesNeeded <= len(block) { + break + } + bytesNeeded -= len(block) + } + + for i := usedBlocks; i < len(m.blocks); i++ { + if i == 0 { + BlockPool1MPlusPage.Put(m.blocks[i]) + } else { + BlockPool1M.Put(m.blocks[i]) + } } + m.blocks = m.blocks[:usedBlocks] } var readLock sync.Mutex +var fuseTContiguousPool sync.Pool -func (m *InMessage) ReadSingle(r io.Reader) (int, error) { + +func (m *InMessage) ReadSingleContiguous(r io.Reader, storage []byte) (int, error) { readLock.Lock() defer readLock.Unlock() // read request length - if _, err := io.ReadFull(r, m.storage[0:4]); err != nil { + if _, err := io.ReadFull(r, storage[0:4]); err != nil { return 0, err } - l := m.Header().Len + header := (*fusekernel.InHeader)(unsafe.Pointer(&storage[0])) + l := header.Len // read remaining request - if n, err := io.ReadFull(r, m.storage[4:l]); err != nil { + if n, err := io.ReadFull(r, storage[4:l]); err != nil { return n, err } return int(l), nil } -// Initialize with the data read by a single call to r.Read. The first call to + +// Initialize with the data read by a single call to r.Read or readv. The first call to // Consume will consume the bytes directly after the fusekernel.InHeader // struct. func (m *InMessage) Init(r io.Reader) error { - var n int var err error if fusekernel.IsPlatformFuseT { - n, err = m.ReadSingle(r) + if len(m.blocks) == 1 { + n, err = m.ReadSingleContiguous(r, m.blocks[0]) + } else { + var cap int + for _, b := range m.blocks { + cap += len(b) + } + var storage []byte + if v := fuseTContiguousPool.Get(); v != nil { + buf := v.([]byte) + if len(buf) >= cap { + storage = buf[:cap] + } + } + if storage == nil { + storage = make([]byte, cap) + } + defer func() { + fuseTContiguousPool.Put(storage) + }() + + n, err = m.ReadSingleContiguous(r, storage) + if err == nil { + var copied int + for _, b := range m.blocks { + if copied >= n { + break + } + toCopy := len(b) + if copied+toCopy > n { + toCopy = n - copied + } + copy(b, storage[copied:copied+toCopy]) + copied += toCopy + } + } + } } else { - n, err = r.Read(m.storage[:]) + if sc, ok := r.(syscall.Conn); ok { + var rawConn syscall.RawConn + rawConn, err = sc.SyscallConn() + if err == nil { + var readvErr error + err = rawConn.Control(func(fd uintptr) { + n, readvErr = unix.Readv(int(fd), m.blocks) + }) + if err == nil { + err = readvErr + } + } + } else { + return fmt.Errorf("Reader does not support SyscallConn") + } } if err != nil { @@ -101,8 +269,8 @@ func (m *InMessage) Init(r io.Reader) error { return fmt.Errorf("Unexpectedly read only %d bytes.", n) } - m.size = n - m.remaining = m.storage[headerSize:n] + m.ShrinkToFit(n) + m.consumed = int(headerSize) // Check the header's length. if int(m.Header().Len) != n { @@ -117,12 +285,27 @@ func (m *InMessage) Init(r io.Reader) error { // Return a reference to the header read in the most recent call to Init. func (m *InMessage) Header() *fusekernel.InHeader { - return (*fusekernel.InHeader)(unsafe.Pointer(&m.storage[0])) + return (*fusekernel.InHeader)(unsafe.Pointer(&m.blocks[0][0])) } // Return the number of bytes left to consume. func (m *InMessage) Len() uintptr { - return uintptr(len(m.remaining)) + return uintptr(m.size - m.consumed) +} + +// getBlockAndOffset returns the block index and the local offset within that block +// corresponding to the currently consumed logical bytes. +func (m *InMessage) getBlockAndOffset() (blockIdx int, localOffset int) { + localOffset = m.consumed + for blockIdx < len(m.blocks) { + bLen := len(m.blocks[blockIdx]) + if localOffset < bLen { + break + } + localOffset -= bLen + blockIdx++ + } + return blockIdx, localOffset } // Consume the next n bytes from the message, returning a nil pointer if there @@ -132,8 +315,15 @@ func (m *InMessage) Consume(n uintptr) unsafe.Pointer { return nil } - p := unsafe.Pointer(&m.remaining[0]) - m.remaining = m.remaining[n:] + blockIdx, offset := m.getBlockAndOffset() + + if offset+int(n) > len(m.blocks[blockIdx]) { + m.consumed += int(n) + return nil + } + + p := unsafe.Pointer(&m.blocks[blockIdx][offset]) + m.consumed += int(n) return p } @@ -145,16 +335,110 @@ func (m *InMessage) ConsumeBytes(n uintptr) []byte { return nil } - b := m.remaining[:n] - m.remaining = m.remaining[n:] + blockIdx, offset := m.getBlockAndOffset() + + if offset+int(n) <= len(m.blocks[blockIdx]) { + b := m.blocks[blockIdx][offset : offset+int(n)] + m.consumed += int(n) + return b + } + + // In production, any spanning allocation is larger than 1MB (since block 0 + // is 1MB + pageSize and fits all normal headers/payloads). Thus we always + // allocate directly from the heap. + res := make([]byte, n) + var bytesCopied = 0 + var remainingToCopy = int(n) + + for remainingToCopy > 0 && blockIdx < len(m.blocks) { + bLen := len(m.blocks[blockIdx]) + availableInBlock := bLen - offset + copyLen := availableInBlock + if copyLen > remainingToCopy { + copyLen = remainingToCopy + } + + copy(res[bytesCopied:bytesCopied+copyLen], m.blocks[blockIdx][offset:offset+copyLen]) + + bytesCopied += copyLen + remainingToCopy -= copyLen + + offset = 0 + blockIdx++ + } + + m.consumed += int(n) + return res +} + +// Equivalent to ConsumeBytes, except returns a slice of slices referencing the +// underlying blocks, without allocations/copies. +func (m *InMessage) ConsumeVector(n uintptr) [][]byte { + if n > m.Len() { + return nil + } + + blockIdx, offset := m.getBlockAndOffset() + + var res [][]byte + var remainingToCopy = int(n) + + for remainingToCopy > 0 && blockIdx < len(m.blocks) { + bLen := len(m.blocks[blockIdx]) + availableInBlock := bLen - offset + copyLen := availableInBlock + if copyLen > remainingToCopy { + copyLen = remainingToCopy + } + + res = append(res, m.blocks[blockIdx][offset:offset+copyLen]) + + remainingToCopy -= copyLen + offset = 0 + blockIdx++ + } - return b + m.consumed += int(n) + return res } -// Get the next n bytes after the message to use them as a temporary buffer +// Get a temporary buffer of n bytes. If it fits in the first block, we slice it +// directly. If it does not fit, we allocate a separate buffer. func (m *InMessage) GetFree(n int) []byte { - if n <= 0 || n > len(m.storage)-m.size { + if n <= 0 { return nil } - return m.storage[m.size : m.size+n] + if len(m.blocks) > 0 && m.size+n <= len(m.blocks[0]) { + return m.blocks[0][m.size : m.size+n] + } + // Since n doesn't fit in block 0, and block 0 has size 1MB + pageSize, + // n is necessarily larger than 1MB (assuming typical small offset like + // sizeof(ReadIn)). Thus we always allocate directly on the heap. + return make([]byte, n) +} + +// GetFreeVector returns a temporary set of buffers summing to n bytes. If it fits +// in the first block, we return a slice of the first block in a single-element slice. +// If it does not fit, we allocate 1MB buffers from BlockPool1M. +func (m *InMessage) GetFreeVector(n int) [][]byte { + if n <= 0 { + return nil + } + if len(m.blocks) > 0 && m.size+n <= len(m.blocks[0]) { + return [][]byte{m.blocks[0][m.size : m.size+n]} + } + + var res [][]byte + remaining := n + for remaining > 0 { + block := BlockPool1M.Get() + m.borrowedBlocks = append(m.borrowedBlocks, block) + allocSize := MiB + if remaining < allocSize { + allocSize = remaining + } + res = append(res, block[:allocSize]) + remaining -= allocSize + } + return res } diff --git a/internal/buffer/in_message_test.go b/internal/buffer/in_message_test.go new file mode 100644 index 00000000..5db82e32 --- /dev/null +++ b/internal/buffer/in_message_test.go @@ -0,0 +1,766 @@ +// Copyright 2026 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package buffer + +import ( + "os" + "syscall" + "testing" + "unsafe" + + "github.com/jacobsa/fuse/internal/fusekernel" +) + + +func TestInMessageAllocAndFree(t *testing.T) { + m := NewInMessage(0) + m.AllocBlocks(17 * MiB) // 17 MiB total size + + // 1 (1MB + pageSize) block + 16 1 MiB blocks = 17 blocks + if len(m.blocks) != 17 { + t.Errorf("expected 17 blocks, got %d", len(m.blocks)) + } + + // Block 0: 1 MiB + pageSize + expectedBlock0Size := MiBPlusPageSize + if len(m.blocks[0]) != expectedBlock0Size { + t.Errorf("expected block 0 to be %d, got %d", expectedBlock0Size, len(m.blocks[0])) + } + + // Blocks 1-16: 1 MiB + for i := 1; i < 17; i++ { + if len(m.blocks[i]) != MiB { + t.Errorf("expected block %d to be 1 MiB, got %d", i, len(m.blocks[i])) + } + } + + // Shrink to fit for small message (fits within the first block) + m.ShrinkToFit(100) + if len(m.blocks) != 1 { + t.Errorf("expected 1 block after shrinking to 100 bytes, got %d", len(m.blocks)) + } + + m.FreeBlocks() + if len(m.blocks) != 0 { + t.Errorf("expected 0 blocks after FreeBlocks, got %d", len(m.blocks)) + } +} + +func TestInMessageConsumeAndBytes(t *testing.T) { + m := NewInMessage(0) + // Manually allocate small blocks to verify spanning without requiring large pipe transfers. + m.blocks = [][]byte{ + make([]byte, 50), + make([]byte, 50), + } + + msgLen := 100 + + // Build a dummy input stream + data := make([]byte, msgLen) + // Write InHeader + header := (*fusekernel.InHeader)(unsafe.Pointer(&data[0])) + header.Len = uint32(msgLen) + header.Opcode = 123 + header.Unique = 456 + + // Write some bytes spanning across the boundary (offset 50) + data[49] = 'Y' + data[50] = 'Z' + + r, w, err := os.Pipe() + if err != nil { + t.Fatalf("Pipe failed: %v", err) + } + defer r.Close() + + _, err = w.Write(data) + if err != nil { + t.Fatalf("Write failed: %v", err) + } + w.Close() + + err = m.Init(r) + if err != nil { + t.Fatalf("Init failed: %v", err) + } + + if m.Header().Unique != 456 { + t.Errorf("expected Unique = 456, got %d", m.Header().Unique) + } + + // Consume a dummy struct from Block 0 (offset 40 to 48, size 8) + p := m.Consume(8) + if p == nil { + t.Fatalf("Consume returned nil") + } + + // Consume remaining bytes of Block 0 up to block boundary (so consumed is 49) + skip := 49 - 40 - 8 + m.Consume(uintptr(skip)) + + // Now we are at the end of block 0 (offset 49). The next bytes are 'Y' and 'Z'. + // This spans across the boundary. + yz := m.ConsumeBytes(2) + if string(yz) != "YZ" { + t.Errorf("expected 'YZ', got %q", string(yz)) + } + + // Clear m.blocks to nil so FreeBlocks doesn't put our small manual slices into BlockPool1MPlusPage/BlockPool1M. + m.blocks = nil + m.FreeBlocks() +} + +func TestInMessageInitFuseT(t *testing.T) { + fusekernel.IsPlatformFuseT = true + defer func() { + fusekernel.IsPlatformFuseT = false + }() + + runTest := func(t *testing.T, blockSizes []int) { + m := NewInMessage(0) + for _, sz := range blockSizes { + m.blocks = append(m.blocks, make([]byte, sz)) + } + + var totalBlockCap int + for _, b := range m.blocks { + totalBlockCap += len(b) + } + + // Prepare dummy FUSE header + message + data := make([]byte, totalBlockCap) + header := (*fusekernel.InHeader)(unsafe.Pointer(&data[0])) + header.Len = uint32(totalBlockCap) + header.Opcode = 999 + header.Unique = 888 + + // Write some test markers + data[40] = 0xAA + data[totalBlockCap-1] = 0xBB + + r, w, err := os.Pipe() + if err != nil { + t.Fatalf("Pipe failed: %v", err) + } + defer r.Close() + + go func() { + _, _ = w.Write(data) + w.Close() + }() + + err = m.Init(r) + if err != nil { + t.Fatalf("Init failed: %v", err) + } + + if m.Header().Opcode != 999 { + t.Errorf("expected Opcode = 999, got %d", m.Header().Opcode) + } + if m.Header().Unique != 888 { + t.Errorf("expected Unique = 888, got %d", m.Header().Unique) + } + + // Verify first block byte and last block byte + if m.blocks[0][40] != 0xAA { + t.Errorf("expected blocks[0][40] = 0xAA, got %x", m.blocks[0][40]) + } + lastBlock := m.blocks[len(m.blocks)-1] + if lastBlock[len(lastBlock)-1] != 0xBB { + t.Errorf("expected last byte of last block = 0xBB, got %x", lastBlock[len(lastBlock)-1]) + } + } + + t.Run("single_block", func(t *testing.T) { + runTest(t, []int{110}) + }) + + t.Run("multiple_blocks", func(t *testing.T) { + runTest(t, []int{50, 30, 30}) + }) +} + +func TestInMessageGetFree(t *testing.T) { + m := NewInMessage(0) + + // Case 1: len(m.blocks) == 0 + // should return valid buffer via fallback allocation + if buf := m.GetFree(10); len(buf) != 10 { + t.Errorf("expected buffer of size 10 when no blocks allocated, got %v", buf) + } + m.FreeBlocks() + + firstBlockSize := MiBPlusPageSize + m.AllocBlocks(firstBlockSize) + m.size = 100 // Set message size to 100 bytes + + // Case 2: n <= 0 -> should return nil + if buf := m.GetFree(0); buf != nil { + t.Errorf("expected nil for n=0, got %v", buf) + } + if buf := m.GetFree(-5); buf != nil { + t.Errorf("expected nil for n=-5, got %v", buf) + } + + // Case 3: n is larger than remaining space in blocks[0] + // remaining is: firstBlockSize - 100 + tooLarge := firstBlockSize - 100 + 1 + + // should return fallback buffer + bufTooLarge := m.GetFree(tooLarge) + if len(bufTooLarge) != tooLarge { + t.Errorf("expected buffer of size %d for too large request, got %d", tooLarge, len(bufTooLarge)) + } + if &bufTooLarge[0] == &m.blocks[0][100] { + t.Errorf("expected fallback buffer to not be part of block 0") + } + m.FreeBlocks() + + // Re-allocate blocks + m.AllocBlocks(firstBlockSize) + m.size = 100 + + // Case 4: normal allocation within remaining space + buf1 := m.GetFree(500) + if len(buf1) != 500 { + t.Errorf("expected buffer of size 500, got %d", len(buf1)) + } + expectedStart := &m.blocks[0][100] + if &buf1[0] != expectedStart { + t.Errorf("expected buffer to start at index 100 of block 0") + } + + // Slicing again (m.size is still 100, we don't advance m.size on GetFree) + buf2 := m.GetFree(500) + if len(buf2) != 500 { + t.Errorf("expected buffer of size 500, got %d", len(buf2)) + } + if &buf2[0] != expectedStart { + t.Errorf("expected buffer to start at index 100 of block 0") + } + + m.FreeBlocks() +} + +func TestInMessageGetFreeVector(t *testing.T) { + m := NewInMessage(0) + + // Case 1: len(m.blocks) == 0 -> should return allocated blocks from pool + bufs := m.GetFreeVector(2*MiB + 100) // 2MB + 100 bytes + if len(bufs) != 3 { + t.Errorf("expected 3 buffers, got %d", len(bufs)) + } else { + if len(bufs[0]) != MiB || len(bufs[1]) != MiB || len(bufs[2]) != 100 { + t.Errorf("unexpected buffer sizes: %d, %d, %d", len(bufs[0]), len(bufs[1]), len(bufs[2])) + } + } + // Verify that freeing returning buffers to the pool works + m.FreeBlocks() + + firstBlockSize := MiBPlusPageSize + m.AllocBlocks(firstBlockSize) + m.size = 100 // Set message size to 100 bytes + + // Case 2: n <= 0 -> should return nil + if bufs := m.GetFreeVector(0); bufs != nil { + t.Errorf("expected nil for n=0, got %v", bufs) + } + if bufs := m.GetFreeVector(-5); bufs != nil { + t.Errorf("expected nil for n=-5, got %v", bufs) + } + + // Case 3: n fits in blocks[0] + fitSize := 500 + bufsFit := m.GetFreeVector(fitSize) + if len(bufsFit) != 1 { + t.Errorf("expected 1 buffer, got %d", len(bufsFit)) + } else if len(bufsFit[0]) != fitSize { + t.Errorf("expected buffer of size %d, got %d", fitSize, len(bufsFit[0])) + } else if &bufsFit[0][0] != &m.blocks[0][100] { + t.Errorf("expected buffer to start at index 100 of block 0") + } + + // Case 4: n is larger than remaining space in blocks[0] + // remaining is: firstBlockSize - 100 + tooLarge := firstBlockSize - 100 + 1 + bufsTooLarge := m.GetFreeVector(tooLarge) + if len(bufsTooLarge) != 2 { + t.Errorf("expected 2 buffers, got %d", len(bufsTooLarge)) + } else { + if len(bufsTooLarge[0]) != MiB || len(bufsTooLarge[1]) != tooLarge-MiB { + t.Errorf("unexpected sizes: %d, %d", len(bufsTooLarge[0]), len(bufsTooLarge[1])) + } + } + + m.FreeBlocks() +} + +var benchmarkSink []byte + +func BenchmarkConsumeBytesSpanning(b *testing.B) { + m := NewInMessage(0) + firstBlockSize := MiBPlusPageSize + totalSize := firstBlockSize + 2000 + m.AllocBlocks(totalSize) + + m.blocks[0][firstBlockSize-1] = 'Y' + m.blocks[1][0] = 'Z' + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + m.consumed = firstBlockSize - 1 + m.size = totalSize + + res := m.ConsumeBytes(2) + if len(res) != 2 || res[0] != 'Y' || res[1] != 'Z' { + b.Fatalf("unexpected result: %v", res) + } + benchmarkSink = res + m.FreeBlocks() + m.AllocBlocks(totalSize) + m.blocks[0][firstBlockSize-1] = 'Y' + m.blocks[1][0] = 'Z' + } + m.FreeBlocks() +} + +func BenchmarkGetFree(b *testing.B) { + m := NewInMessage(0) + m.AllocBlocks(20000) + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + res := m.GetFree(10000) + if len(res) != 10000 { + b.Fatalf("expected 10000, got %d", len(res)) + } + benchmarkSink = res + } + m.FreeBlocks() +} + +func BenchmarkAllocBlocks(b *testing.B) { + m := NewInMessage(0) + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + m.AllocBlocks(4096 + 2000) + m.FreeBlocks() + } +} + +type fakeRawConn struct { + fd uintptr +} + +func (c fakeRawConn) Control(f func(fd uintptr)) error { + f(c.fd) + return nil +} + +func (c fakeRawConn) Read(f func(fd uintptr) bool) error { + return syscall.ENOTSUP +} + +func (c fakeRawConn) Write(f func(fd uintptr) bool) error { + return syscall.ENOTSUP +} + +type fakeFdReader struct { + fd uintptr +} + +func (r fakeFdReader) SyscallConn() (syscall.RawConn, error) { + return fakeRawConn{fd: r.fd}, nil +} + +func (r fakeFdReader) Read(p []byte) (int, error) { + return 0, nil +} + +func BenchmarkInMessageInitWithReadv(b *testing.B) { + m := NewInMessage(0) + m.AllocBlocks(5 * MiB) + r := fakeFdReader{fd: ^uintptr(0)} // -1 + + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _ = m.Init(r) + } + m.FreeBlocks() +} + +type fakeFuseTReader struct { + data []byte +} + +func (r *fakeFuseTReader) Read(p []byte) (int, error) { + return copy(p, r.data), nil +} + +func BenchmarkInMessageInitFuseT(b *testing.B) { + fusekernel.IsPlatformFuseT = true + defer func() { + fusekernel.IsPlatformFuseT = false + }() + + m := NewInMessage(0) + totalSize := MiBPlusPageSize + m.AllocBlocks(totalSize) + defer m.FreeBlocks() + + data := make([]byte, totalSize) + header := (*fusekernel.InHeader)(unsafe.Pointer(&data[0])) + header.Len = uint32(totalSize) + + r := &fakeFuseTReader{data: data} + + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + err := m.Init(r) + if err != nil { + b.Fatalf("Init failed: %v", err) + } + } +} + +func BenchmarkInMessageInitFuseTMultiBlock(b *testing.B) { + fusekernel.IsPlatformFuseT = true + defer func() { + fusekernel.IsPlatformFuseT = false + }() + + m := NewInMessage(0) + firstBlockSize := MiBPlusPageSize + totalSize := firstBlockSize + 2*MiB + m.AllocBlocks(totalSize) + defer m.FreeBlocks() + + data := make([]byte, totalSize) + header := (*fusekernel.InHeader)(unsafe.Pointer(&data[0])) + header.Len = uint32(totalSize) + + r := &fakeFuseTReader{data: data} + + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + err := m.Init(r) + if err != nil { + b.Fatalf("Init failed: %v", err) + } + } +} + +var ( + benchmarkConsumeSink unsafe.Pointer + benchmarkConsumeVectorSink [][]byte + benchmarkGetFreeVectorSink [][]byte +) + +func BenchmarkInMessageConsume(b *testing.B) { + m := NewInMessage(0) + m.AllocBlocks(4096) + defer m.FreeBlocks() + m.size = 4096 + + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + m.consumed = 0 + benchmarkConsumeSink = m.Consume(16) + } +} + +func BenchmarkInMessageConsumeVector_SingleBlock(b *testing.B) { + m := NewInMessage(0) + m.AllocBlocks(4096) + defer m.FreeBlocks() + m.size = 4096 + + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + m.consumed = 0 + benchmarkConsumeVectorSink = m.ConsumeVector(128) + } +} + +func BenchmarkInMessageConsumeVector_Spanning(b *testing.B) { + m := NewInMessage(0) + firstBlockSize := MiBPlusPageSize + totalSize := firstBlockSize + 2*MiB + m.AllocBlocks(totalSize) + defer m.FreeBlocks() + m.size = totalSize + + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + m.consumed = firstBlockSize - 128 + benchmarkConsumeVectorSink = m.ConsumeVector(256) + } +} + +func BenchmarkInMessageGetFreeVector_SingleBlock(b *testing.B) { + m := NewInMessage(0) + m.AllocBlocks(4096) + defer m.FreeBlocks() + + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + m.size = 0 + benchmarkGetFreeVectorSink = m.GetFreeVector(128) + } +} + +func BenchmarkInMessageGetFreeVector_MultiBlock(b *testing.B) { + m := NewInMessage(0) + m.AllocBlocks(4096) + defer m.FreeBlocks() + + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + m.size = 0 + benchmarkGetFreeVectorSink = m.GetFreeVector(2 * MiB) + + // Return borrowed blocks to the pool and clear m.borrowedBlocks + for _, block := range m.borrowedBlocks { + BlockPool1M.Put(block) + } + m.borrowedBlocks = m.borrowedBlocks[:0] + } +} + +func TestInMessageConsumeVector(t *testing.T) { + m := NewInMessage(0) + // Manually set up blocks for precise spanning test + m.blocks = [][]byte{ + make([]byte, 50), + make([]byte, 50), + make([]byte, 50), + } + m.size = 150 + m.consumed = 0 + + // Populate data + for i := 0; i < 150; i++ { + m.blocks[i/50][i%50] = byte(i) + } + + // 1. Request more than available -> should return nil + if res := m.ConsumeVector(151); res != nil { + t.Errorf("expected nil when consuming more than size, got %v", res) + } + + // 2. Consume within the first block (offset 0 to 30) + v1 := m.ConsumeVector(30) + if len(v1) != 1 || len(v1[0]) != 30 { + t.Fatalf("expected 1 slice of length 30, got %v", v1) + } + if v1[0][0] != 0 || v1[0][29] != 29 { + t.Errorf("unexpected content in v1: %v", v1[0]) + } + if m.consumed != 30 { + t.Errorf("expected consumed to be 30, got %d", m.consumed) + } + + // 3. Consume spanning block 0 and block 1 (offset 30 to 70, spanning boundary at 50) + // Remaining in block 0: 20 bytes (30 to 49) + // Needed from block 1: 20 bytes (50 to 69) + v2 := m.ConsumeVector(40) + if len(v2) != 2 { + t.Fatalf("expected 2 slices, got %d", len(v2)) + } + if len(v2[0]) != 20 || len(v2[1]) != 20 { + t.Errorf("expected slices of size 20 and 20, got %d and %d", len(v2[0]), len(v2[1])) + } + if v2[0][0] != 30 || v2[0][19] != 49 || v2[1][0] != 50 || v2[1][19] != 69 { + t.Errorf("unexpected content in v2: %v, %v", v2[0], v2[1]) + } + if m.consumed != 70 { + t.Errorf("expected consumed to be 70, got %d", m.consumed) + } + + // 4. Consume spanning block 1 and block 2 (offset 70 to 120) + // Remaining in block 1: 30 bytes (70 to 99) + // Needed from block 2: 20 bytes (100 to 119) + v3 := m.ConsumeVector(50) + if len(v3) != 2 { + t.Fatalf("expected 2 slices, got %d", len(v3)) + } + if len(v3[0]) != 30 || len(v3[1]) != 20 { + t.Errorf("expected slices of size 30 and 20, got %d and %d", len(v3[0]), len(v3[1])) + } + if v3[0][0] != 70 || v3[0][29] != 99 || v3[1][0] != 100 || v3[1][19] != 119 { + t.Errorf("unexpected content in v3: %v, %v", v3[0], v3[1]) + } + + // Clear m.blocks to prevent FreeBlocks from putting them in the global pool + m.blocks = nil + m.FreeBlocks() +} + +type nonSyscallConnReader struct{} + +func (r nonSyscallConnReader) Read(p []byte) (int, error) { + return 0, nil +} + +func TestInMessageInitNonSyscallConnReader(t *testing.T) { + // Temporarily disable IsPlatformFuseT to trigger the non-FuseT branch + fusekernel.IsPlatformFuseT = false + + m := NewInMessage(0) + m.AllocBlocks(4096) + defer m.FreeBlocks() + + err := m.Init(nonSyscallConnReader{}) + if err == nil { + t.Fatal("expected error when using a reader that does not implement SyscallConn, got nil") + } + expectedErr := "Reader does not support SyscallConn" + if err.Error() != expectedErr { + t.Errorf("expected error %q, got %q", expectedErr, err.Error()) + } +} + +func TestInMessageInitInvalidHeader(t *testing.T) { + m := NewInMessage(0) + m.AllocBlocks(4096) + defer m.FreeBlocks() + + // 1. Short read (fewer than headerSize = 40 bytes) + r, w, err := os.Pipe() + if err != nil { + t.Fatalf("Pipe failed: %v", err) + } + _, _ = w.Write([]byte("too short")) + w.Close() + + err = m.Init(r) + r.Close() + if err == nil { + t.Fatal("expected error for short read, got nil") + } + + // 2. Header length mismatch + r, w, err = os.Pipe() + if err != nil { + t.Fatalf("Pipe failed: %v", err) + } + header := fusekernel.InHeader{ + Len: 100, // header says 100 bytes, but we only write 40 bytes + } + _, _ = w.Write((*[unsafe.Sizeof(header)]byte)(unsafe.Pointer(&header))[:]) + w.Close() + + err = m.Init(r) + r.Close() + if err == nil { + t.Fatal("expected error for header length mismatch, got nil") + } +} + +func TestInMessageShrinkToFitMultiBlock(t *testing.T) { + m := NewInMessage(0) + + m.AllocBlocks(3 * MiB) // Allocates 1 block 0 (1MB+pageSize) + 2 blocks of 1MB = 3 blocks + if len(m.blocks) != 3 { + t.Fatalf("expected 3 blocks, got %d", len(m.blocks)) + } + + // Shrink to fit 1.5 MB (should keep block 0 and block 1, release block 2) + m.ShrinkToFit(1500000) + if len(m.blocks) != 2 { + t.Errorf("expected 2 blocks after shrinking to 1.5MB, got %d", len(m.blocks)) + } + + m.FreeBlocks() +} + +func TestInMessageBlockPoolRecycling(t *testing.T) { + m := NewInMessage(0) + + // Measure initial pool size + BlockPool1M.mu.Lock() + initialPoolSize := len(BlockPool1M.list) + BlockPool1M.mu.Unlock() + + // Request a large vector that borrows blocks from BlockPool1M + bufs := m.GetFreeVector(2*MiB + 100) + if len(bufs) != 3 { + t.Fatalf("expected 3 buffers, got %d", len(bufs)) + } + + m.FreeBlocks() + + // Measure pool size after freeing + BlockPool1M.mu.Lock() + finalPoolSize := len(BlockPool1M.list) + BlockPool1M.mu.Unlock() + + if finalPoolSize < initialPoolSize { + t.Errorf("expected pool size to be at least %d, got %d; blocks were not recycled!", initialPoolSize, finalPoolSize) + } +} + +func TestInMessageConsumeEdgeCases(t *testing.T) { + m := NewInMessage(0) + m.blocks = [][]byte{ + make([]byte, 10), + } + m.size = 10 + m.consumed = 0 + + // 1. Consume 0 bytes -> should return non-nil (start of block) + p := m.Consume(0) + if p == nil { + t.Errorf("expected non-nil pointer for 0-byte Consume") + } + + // 2. ConsumeBytes 0 bytes -> should return empty slice + b := m.ConsumeBytes(0) + if len(b) != 0 { + t.Errorf("expected empty slice for 0-byte ConsumeBytes, got %v", b) + } + + // 3. ConsumeVector 0 bytes -> should return empty slice + v := m.ConsumeVector(0) + if len(v) != 0 { + t.Errorf("expected empty slice/vector for 0-byte ConsumeVector, got %v", v) + } + + // 4. Request more than available + if p2 := m.Consume(11); p2 != nil { + t.Errorf("expected nil when consuming more than available, got %v", p2) + } + if b2 := m.ConsumeBytes(11); b2 != nil { + t.Errorf("expected nil when consuming more than available, got %v", b2) + } + if v2 := m.ConsumeVector(11); v2 != nil { + t.Errorf("expected nil when consuming more than available, got %v", v2) + } + + m.blocks = nil + m.FreeBlocks() +} diff --git a/internal/buffer/out_message_test.go b/internal/buffer/out_message_test.go index c2cd9b17..dcc577a5 100644 --- a/internal/buffer/out_message_test.go +++ b/internal/buffer/out_message_test.go @@ -276,13 +276,14 @@ func BenchmarkOutMessageReset(b *testing.B) { // Many megabytes worth of buffers, which should defeat the CPU cache. b.Run("Many buffers", func(b *testing.B) { // The number of messages; intentionally a power of two. - const numMessages = 128 + const numMessages = 2097152 // 2^21 * 40 bytes ≈ 80 MiB - var oms [numMessages]OutMessage - if s := unsafe.Sizeof(oms); s < 128<<20 { + oms := make([]OutMessage, numMessages) + if s := uintptr(len(oms)) * unsafe.Sizeof(OutMessage{}); s < 80<<20 { panic(fmt.Sprintf("Array is too small; total size: %d", s)) } + b.ResetTimer() for i := 0; i < b.N; i++ { oms[i%numMessages].Reset() } @@ -306,13 +307,14 @@ func BenchmarkOutMessageGrowShrink(b *testing.B) { // Many megabytes worth of buffers, which should defeat the CPU cache. b.Run("Many buffers", func(b *testing.B) { // The number of messages; intentionally a power of two. - const numMessages = 128 + const numMessages = 2097152 // 2^21 * 40 bytes ≈ 80 MiB - var oms [numMessages]OutMessage - if s := unsafe.Sizeof(oms); s < 128<<20 { + oms := make([]OutMessage, numMessages) + if s := uintptr(len(oms)) * unsafe.Sizeof(OutMessage{}); s < 80<<20 { panic(fmt.Sprintf("Array is too small; total size: %d", s)) } + b.ResetTimer() for i := 0; i < b.N; i++ { oms[i%numMessages].Grow(MaxReadSize) oms[i%numMessages].ShrinkTo(OutMessageHeaderSize) diff --git a/mount_config.go b/mount_config.go index f95895ad..a235ebae 100644 --- a/mount_config.go +++ b/mount_config.go @@ -238,9 +238,34 @@ type MountConfig struct { // in ReadFileOp.Data. // // Currently, both the read mechanisms can coexist. The library's behavior is - // to always provide ReadFileOp.Dst. If the file system populates ReadFileOp.Data, - // that data will be used for a vectored read, irrespective of this flag's value. + // to always provide ReadFileOp.Dst (except when EnableVectoredReads is true + // and the read size is larger than the pre-allocated incoming message block). + // If the file system populates ReadFileOp.Data, that data will be used for a + // vectored read, irrespective of this flag's value. UseVectoredRead bool + + // EnableVectoredReads bypasses allocating a large contiguous buffer in + // ReadFileOp.Dst when the read size is larger than the pre-allocated incoming + // message block. Instead, ReadFileOp.Dst will be nil, forcing the filesystem + // to use ReadFileOp.Data to return the read payload. This avoids allocations + // and copy overhead for large reads. + EnableVectoredReads bool + + // EnableVectoredWrites bypasses copying write payload bytes into a single + // contiguous slice in WriteFileOp.Data, instead providing the raw non-contiguous + // blocks in WriteFileOp.DataBlocks. This improves performance by avoiding copies + // and allocations for large writes. + // Note: DataBlocks will be nil if EnableVectoredWrites is disabled, or if the + // write size is smaller than the pre-allocated incoming message block (in + // which case Data will be populated instead). + EnableVectoredWrites bool + + // The maximum size of a FUSE message (in bytes) that the daemon is + // prepared to read or write. If not set, defaults to 1 MiB. + // NOTE: For MaxMessageSize greater than 1MiB, enabling EnableVectoredReads + // and EnableVectoredWrites is highly recommended to avoid significant + // performance regressions due to large heap allocations and copies. + MaxMessageSize uint32 } type FUSEImpl uint8 diff --git a/samples/memfs/inode.go b/samples/memfs/inode.go index cb044844..7d5b20b7 100644 --- a/samples/memfs/inode.go +++ b/samples/memfs/inode.go @@ -375,6 +375,44 @@ func (in *inode) WriteAt(p []byte, off int64) (int, error) { return n, nil } +// WriteBlocksAt writes the given blocks at the specified offset. +// +// REQUIRES: in.isFile() +func (in *inode) WriteBlocksAt(blocks [][]byte, off int64) (int, error) { + if !in.isFile() { + panic("WriteBlocksAt called on non-file.") + } + + // Update the modification time. + in.attrs.Mtime = time.Now() + + // Compute total length. + var totalLen int + for _, b := range blocks { + totalLen += len(b) + } + + // Ensure that the contents slice is long enough. + newLen := int(off) + totalLen + if len(in.contents) < newLen { + padding := make([]byte, newLen-len(in.contents)) + in.contents = append(in.contents, padding...) + in.attrs.Size = uint64(newLen) + } + + // Copy in the data from each block. + var bytesWritten int + for _, b := range blocks { + n := copy(in.contents[off+int64(bytesWritten):], b) + if n != len(b) { + panic(fmt.Sprintf("Unexpected short copy: %v", n)) + } + bytesWritten += n + } + + return bytesWritten, nil +} + // Update attributes from non-nil parameters. func (in *inode) SetAttributes( size *uint64, diff --git a/samples/memfs/memfs.go b/samples/memfs/memfs.go index 8e3b877a..6c76c5ca 100644 --- a/samples/memfs/memfs.go +++ b/samples/memfs/memfs.go @@ -694,7 +694,12 @@ func (fs *memFS) WriteFile( inode := fs.getInodeOrDie(op.Inode) // Serve the request. - _, err := inode.WriteAt(op.Data, op.Offset) + var err error + if len(op.Data) > 0 { + _, err = inode.WriteAt(op.Data, op.Offset) + } else { + _, err = inode.WriteBlocksAt(op.DataBlocks, op.Offset) + } op.Callback = fs.writeFileCallback diff --git a/samples/memfs/memfs_test.go b/samples/memfs/memfs_test.go index 0b021130..0b9320d1 100644 --- a/samples/memfs/memfs_test.go +++ b/samples/memfs/memfs_test.go @@ -2132,3 +2132,27 @@ func (t *AtmoicOTruncDisabledTest) SetUp(ti *TestInfo) { } func init() { RegisterTestSuite(&AtmoicOTruncDisabledTest{}) } + +type VectoredWritesTest struct { + memFSTest +} + +func (t *VectoredWritesTest) SetUp(ti *TestInfo) { + t.MountConfig.EnableVectoredWrites = true + t.memFSTest.SetUp(ti) +} + +func (t *VectoredWritesTest) TestVectoredWriteBasic() { + fileName := path.Join(t.Dir, "vectored_file") + const contents = "This is a test of vectored write payload. It should be parsed as DataBlocks." + + err := os.WriteFile(fileName, []byte(contents), 0600) + AssertEq(nil, err) + + data, err := os.ReadFile(fileName) + AssertEq(nil, err) + AssertEq(contents, string(data)) +} + +func init() { RegisterTestSuite(&VectoredWritesTest{}) } + diff --git a/wirelog.go b/wirelog.go index 2ce7ab12..3fbf09d9 100644 --- a/wirelog.go +++ b/wirelog.go @@ -94,7 +94,7 @@ func formatWireLogEntry(op any, opErr error, wlog *WireLogRecord) ([]byte, error args["BytesRead"] = typed.BytesRead case *fuseops.WriteFileOp: - args["Size"] = len(typed.Data) + args["Size"] = typed.TotalSize() } wlog.Args = args