protect bpool from buffer pollution by invalid buffers (#20342)

This commit is contained in:
Harshavardhana
2024-08-28 18:40:52 -07:00
committed by GitHub
parent 38c0840834
commit 504e52b45e
4 changed files with 77 additions and 47 deletions

View File

@@ -17,7 +17,9 @@
package bpool
import "testing"
import (
"testing"
)
// Tests - bytePool functionality.
func TestBytePool(t *testing.T) {
@@ -25,20 +27,20 @@ func TestBytePool(t *testing.T) {
width := 1024
capWidth := 2048
bufPool := NewBytePoolCap(size, width, capWidth)
bp := NewBytePoolCap(size, width, capWidth)
// Check the width
if bufPool.Width() != width {
t.Fatalf("bytepool width invalid: got %v want %v", bufPool.Width(), width)
if bp.Width() != width {
t.Fatalf("bytepool width invalid: got %v want %v", bp.Width(), width)
}
// Check with width cap
if bufPool.WidthCap() != capWidth {
t.Fatalf("bytepool capWidth invalid: got %v want %v", bufPool.WidthCap(), capWidth)
if bp.WidthCap() != capWidth {
t.Fatalf("bytepool capWidth invalid: got %v want %v", bp.WidthCap(), capWidth)
}
// Check that retrieved buffer are of the expected width
b := bufPool.Get()
b := bp.Get()
if len(b) != width {
t.Fatalf("bytepool length invalid: got %v want %v", len(b), width)
}
@@ -46,14 +48,14 @@ func TestBytePool(t *testing.T) {
t.Fatalf("bytepool cap invalid: got %v want %v", cap(b), capWidth)
}
bufPool.Put(b)
bp.Put(b)
// Fill the pool beyond the capped pool size.
for i := uint64(0); i < size*2; i++ {
bufPool.Put(make([]byte, bufPool.w))
bp.Put(make([]byte, bp.w, bp.wcap))
}
b = bufPool.Get()
b = bp.Get()
if len(b) != width {
t.Fatalf("bytepool length invalid: got %v want %v", len(b), width)
}
@@ -61,31 +63,37 @@ func TestBytePool(t *testing.T) {
t.Fatalf("bytepool length invalid: got %v want %v", cap(b), capWidth)
}
bufPool.Put(b)
// Close the channel so we can iterate over it.
close(bufPool.c)
bp.Put(b)
// Check the size of the pool.
if uint64(len(bufPool.c)) != size {
t.Fatalf("bytepool size invalid: got %v want %v", len(bufPool.c), size)
if uint64(len(bp.c)) != size {
t.Fatalf("bytepool size invalid: got %v want %v", len(bp.c), size)
}
bufPoolNoCap := NewBytePoolCap(size, width, 0)
// Check the width
if bufPoolNoCap.Width() != width {
t.Fatalf("bytepool width invalid: got %v want %v", bufPool.Width(), width)
// lets drain the buf channel first before we validate invalid buffers.
for i := uint64(0); i < size; i++ {
bp.Get() // discard
}
// Check with width cap
if bufPoolNoCap.WidthCap() != 0 {
t.Fatalf("bytepool capWidth invalid: got %v want %v", bufPool.WidthCap(), 0)
// Try putting some invalid buffers into pool
bp.Put(make([]byte, bp.w, bp.wcap-1)) // wrong capacity is rejected (less)
bp.Put(make([]byte, bp.w, bp.wcap+1)) // wrong capacity is rejected (more)
bp.Put(make([]byte, width)) // wrong capacity is rejected (very less)
if len(bp.c) > 0 {
t.Fatal("bytepool should have rejected invalid packets")
}
b = bufPoolNoCap.Get()
// Try putting a short slice into pool
bp.Put(make([]byte, bp.w, bp.wcap)[:2])
if len(bp.c) != 1 {
t.Fatal("bytepool should have accepted short slice with sufficient capacity")
}
b = bp.Get()
if len(b) != width {
t.Fatalf("bytepool length invalid: got %v want %v", len(b), width)
}
if cap(b) != width {
t.Fatalf("bytepool length invalid: got %v want %v", cap(b), width)
}
// Close the channel.
close(bp.c)
}