protect bpool from buffer pollution by invalid buffers (#20342)

This commit is contained in:
Harshavardhana 2024-08-28 18:40:52 -07:00 committed by GitHub
parent 38c0840834
commit 504e52b45e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 77 additions and 47 deletions

View File

@ -48,14 +48,14 @@ func newParallelReader(readers []io.ReaderAt, e Erasure, offset, totalLength int
r2b[i] = i r2b[i] = i
} }
bufs := make([][]byte, len(readers)) bufs := make([][]byte, len(readers))
// Fill buffers
b := globalBytePoolCap.Load().Get()
shardSize := int(e.ShardSize()) shardSize := int(e.ShardSize())
if cap(b) < len(readers)*shardSize { var b []byte
// We should always have enough capacity, but older objects may be bigger.
globalBytePoolCap.Load().Put(b) // We should always have enough capacity, but older objects may be bigger
b = nil // we do not need stashbuffer for them.
} else { if globalBytePoolCap.Load().WidthCap() >= len(readers)*shardSize {
// Fill buffers
b = globalBytePoolCap.Load().Get()
// Seed the buffers. // Seed the buffers.
for i := range bufs { for i := range bufs {
bufs[i] = b[i*shardSize : (i+1)*shardSize] bufs[i] = b[i*shardSize : (i+1)*shardSize]

View File

@ -93,14 +93,18 @@ func testDeleteObject(obj ObjectLayer, instanceType string, t TestErrHandler) {
md5Bytes := md5.Sum([]byte(object.content)) md5Bytes := md5.Sum([]byte(object.content))
oi, err := obj.PutObject(context.Background(), testCase.bucketName, object.name, mustGetPutObjReader(t, strings.NewReader(object.content), oi, err := obj.PutObject(context.Background(), testCase.bucketName, object.name, mustGetPutObjReader(t, strings.NewReader(object.content),
int64(len(object.content)), hex.EncodeToString(md5Bytes[:]), ""), ObjectOptions{}) int64(len(object.content)), hex.EncodeToString(md5Bytes[:]), ""), ObjectOptions{})
t.Log(oi)
if err != nil { if err != nil {
t.Log(oi)
t.Fatalf("%s : %s", instanceType, err.Error()) t.Fatalf("%s : %s", instanceType, err.Error())
} }
} }
oi, err := obj.DeleteObject(context.Background(), testCase.bucketName, testCase.pathToDelete, ObjectOptions{}) oi, err := obj.DeleteObject(context.Background(), testCase.bucketName, testCase.pathToDelete, ObjectOptions{})
t.Log(oi, err) if err != nil && !isErrObjectNotFound(err) {
t.Log(oi)
t.Errorf("Test %d: %s: Expected to pass, but failed with: <ERROR> %s", i+1, instanceType, err)
continue
}
result, err := obj.ListObjects(context.Background(), testCase.bucketName, "", "", "", 1000) result, err := obj.ListObjects(context.Background(), testCase.bucketName, "", "", "", 1000)
if err != nil { if err != nil {

View File

@ -1,4 +1,4 @@
// Copyright (c) 2015-2023 MinIO, Inc. // Copyright (c) 2015-2024 MinIO, Inc.
// //
// This file is part of MinIO Object Storage stack // This file is part of MinIO Object Storage stack
// //
@ -17,7 +17,9 @@
package bpool package bpool
import "github.com/klauspost/reedsolomon" import (
"github.com/klauspost/reedsolomon"
)
// BytePoolCap implements a leaky pool of []byte in the form of a bounded channel. // BytePoolCap implements a leaky pool of []byte in the form of a bounded channel.
type BytePoolCap struct { type BytePoolCap struct {
@ -29,11 +31,14 @@ type BytePoolCap struct {
// NewBytePoolCap creates a new BytePool bounded to the given maxSize, with new // NewBytePoolCap creates a new BytePool bounded to the given maxSize, with new
// byte arrays sized based on width. // byte arrays sized based on width.
func NewBytePoolCap(maxSize uint64, width int, capwidth int) (bp *BytePoolCap) { func NewBytePoolCap(maxSize uint64, width int, capwidth int) (bp *BytePoolCap) {
if capwidth > 0 && capwidth < 64 { if capwidth <= 0 {
panic("total buffer capacity must be provided")
}
if capwidth < 64 {
panic("buffer capped with smaller than 64 bytes is not supported") panic("buffer capped with smaller than 64 bytes is not supported")
} }
if capwidth > 0 && width > capwidth { if width > capwidth {
panic("buffer length cannot be > capacity of the buffer") panic("minimum buffer length cannot be > capacity of the buffer")
} }
return &BytePoolCap{ return &BytePoolCap{
c: make(chan []byte, maxSize), c: make(chan []byte, maxSize),
@ -60,11 +65,7 @@ func (bp *BytePoolCap) Get() (b []byte) {
// reuse existing buffer // reuse existing buffer
default: default:
// create new aligned buffer // create new aligned buffer
if bp.wcap > 0 {
b = reedsolomon.AllocAligned(1, bp.wcap)[0][:bp.w] b = reedsolomon.AllocAligned(1, bp.wcap)[0][:bp.w]
} else {
b = reedsolomon.AllocAligned(1, bp.w)[0]
}
} }
return return
} }
@ -74,8 +75,17 @@ func (bp *BytePoolCap) Put(b []byte) {
if bp == nil { if bp == nil {
return return
} }
if cap(b) != bp.wcap {
// someone tried to put back buffer which is not part of this buffer pool
// we simply don't put this back into pool, a modified buffer provided
// by this package is no more usable, callers make sure to not modify
// the capacity of the buffer.
return
}
select { select {
case bp.c <- b: case bp.c <- b[:bp.w]:
// buffer went back into pool // buffer went back into pool
default: default:
// buffer didn't go back into pool, just discard // buffer didn't go back into pool, just discard
@ -97,3 +107,11 @@ func (bp *BytePoolCap) WidthCap() (n int) {
} }
return bp.wcap return bp.wcap
} }
// CurrentSize returns current size of buffer pool
func (bp *BytePoolCap) CurrentSize() int {
if bp == nil {
return 0
}
return len(bp.c) * bp.w
}

View File

@ -17,7 +17,9 @@
package bpool package bpool
import "testing" import (
"testing"
)
// Tests - bytePool functionality. // Tests - bytePool functionality.
func TestBytePool(t *testing.T) { func TestBytePool(t *testing.T) {
@ -25,20 +27,20 @@ func TestBytePool(t *testing.T) {
width := 1024 width := 1024
capWidth := 2048 capWidth := 2048
bufPool := NewBytePoolCap(size, width, capWidth) bp := NewBytePoolCap(size, width, capWidth)
// Check the width // Check the width
if bufPool.Width() != width { if bp.Width() != width {
t.Fatalf("bytepool width invalid: got %v want %v", bufPool.Width(), width) t.Fatalf("bytepool width invalid: got %v want %v", bp.Width(), width)
} }
// Check with width cap // Check with width cap
if bufPool.WidthCap() != capWidth { if bp.WidthCap() != capWidth {
t.Fatalf("bytepool capWidth invalid: got %v want %v", bufPool.WidthCap(), capWidth) t.Fatalf("bytepool capWidth invalid: got %v want %v", bp.WidthCap(), capWidth)
} }
// Check that retrieved buffer are of the expected width // Check that retrieved buffer are of the expected width
b := bufPool.Get() b := bp.Get()
if len(b) != width { if len(b) != width {
t.Fatalf("bytepool length invalid: got %v want %v", len(b), width) t.Fatalf("bytepool length invalid: got %v want %v", len(b), width)
} }
@ -46,14 +48,14 @@ func TestBytePool(t *testing.T) {
t.Fatalf("bytepool cap invalid: got %v want %v", cap(b), capWidth) t.Fatalf("bytepool cap invalid: got %v want %v", cap(b), capWidth)
} }
bufPool.Put(b) bp.Put(b)
// Fill the pool beyond the capped pool size. // Fill the pool beyond the capped pool size.
for i := uint64(0); i < size*2; i++ { for i := uint64(0); i < size*2; i++ {
bufPool.Put(make([]byte, bufPool.w)) bp.Put(make([]byte, bp.w, bp.wcap))
} }
b = bufPool.Get() b = bp.Get()
if len(b) != width { if len(b) != width {
t.Fatalf("bytepool length invalid: got %v want %v", len(b), width) t.Fatalf("bytepool length invalid: got %v want %v", len(b), width)
} }
@ -61,31 +63,37 @@ func TestBytePool(t *testing.T) {
t.Fatalf("bytepool length invalid: got %v want %v", cap(b), capWidth) t.Fatalf("bytepool length invalid: got %v want %v", cap(b), capWidth)
} }
bufPool.Put(b) bp.Put(b)
// Close the channel so we can iterate over it.
close(bufPool.c)
// Check the size of the pool. // Check the size of the pool.
if uint64(len(bufPool.c)) != size { if uint64(len(bp.c)) != size {
t.Fatalf("bytepool size invalid: got %v want %v", len(bufPool.c), size) t.Fatalf("bytepool size invalid: got %v want %v", len(bp.c), size)
} }
bufPoolNoCap := NewBytePoolCap(size, width, 0) // lets drain the buf channel first before we validate invalid buffers.
// Check the width for i := uint64(0); i < size; i++ {
if bufPoolNoCap.Width() != width { bp.Get() // discard
t.Fatalf("bytepool width invalid: got %v want %v", bufPool.Width(), width)
} }
// Check with width cap // Try putting some invalid buffers into pool
if bufPoolNoCap.WidthCap() != 0 { bp.Put(make([]byte, bp.w, bp.wcap-1)) // wrong capacity is rejected (less)
t.Fatalf("bytepool capWidth invalid: got %v want %v", bufPool.WidthCap(), 0) bp.Put(make([]byte, bp.w, bp.wcap+1)) // wrong capacity is rejected (more)
bp.Put(make([]byte, width)) // wrong capacity is rejected (very less)
if len(bp.c) > 0 {
t.Fatal("bytepool should have rejected invalid packets")
} }
b = bufPoolNoCap.Get()
// Try putting a short slice into pool
bp.Put(make([]byte, bp.w, bp.wcap)[:2])
if len(bp.c) != 1 {
t.Fatal("bytepool should have accepted short slice with sufficient capacity")
}
b = bp.Get()
if len(b) != width { if len(b) != width {
t.Fatalf("bytepool length invalid: got %v want %v", len(b), width) t.Fatalf("bytepool length invalid: got %v want %v", len(b), width)
} }
if cap(b) != width {
t.Fatalf("bytepool length invalid: got %v want %v", cap(b), width) // Close the channel.
} close(bp.c)
} }