return an error in CopyAligned upon premature EOF (#18110)

add a unit-test to capture this corner case
This commit is contained in:
Harshavardhana 2023-09-26 11:20:06 -07:00 committed by GitHub
parent cdeab19673
commit d9f1df01eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 57 additions and 7 deletions

View File

@ -22,6 +22,7 @@ package ioutil
import ( import (
"bytes" "bytes"
"context" "context"
"errors"
"io" "io"
"os" "os"
"sync" "sync"
@ -348,6 +349,10 @@ const DirectioAlignSize = 4096
// input writer *os.File not a generic io.Writer. Make sure to have // input writer *os.File not a generic io.Writer. Make sure to have
// the file opened for writes with syscall.O_DIRECT flag. // the file opened for writes with syscall.O_DIRECT flag.
func CopyAligned(w io.Writer, r io.Reader, alignedBuf []byte, totalSize int64, file *os.File) (int64, error) { func CopyAligned(w io.Writer, r io.Reader, alignedBuf []byte, totalSize int64, file *os.File) (int64, error) {
if totalSize == 0 {
return 0, nil
}
// Writes remaining bytes in the buffer. // Writes remaining bytes in the buffer.
writeUnaligned := func(w io.Writer, buf []byte) (remainingWritten int64, err error) { writeUnaligned := func(w io.Writer, buf []byte) (remainingWritten int64, err error) {
// Disable O_DIRECT on fd's on unaligned buffer // Disable O_DIRECT on fd's on unaligned buffer
@ -364,17 +369,19 @@ func CopyAligned(w io.Writer, r io.Reader, alignedBuf []byte, totalSize int64, f
var written int64 var written int64
for { for {
buf := alignedBuf buf := alignedBuf
if totalSize != -1 { if totalSize > 0 {
remaining := totalSize - written remaining := totalSize - written
if remaining < int64(len(buf)) { if remaining < int64(len(buf)) {
buf = buf[:remaining] buf = buf[:remaining]
} }
} }
nr, err := io.ReadFull(r, buf) nr, err := io.ReadFull(r, buf)
eof := err == io.EOF || err == io.ErrUnexpectedEOF eof := errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF)
if err != nil && !eof { if err != nil && !eof {
return written, err return written, err
} }
buf = buf[:nr] buf = buf[:nr]
var nw int64 var nw int64
if len(buf)%DirectioAlignSize == 0 { if len(buf)%DirectioAlignSize == 0 {
@ -386,22 +393,30 @@ func CopyAligned(w io.Writer, r io.Reader, alignedBuf []byte, totalSize int64, f
// buf is not aligned, hence use writeUnaligned() // buf is not aligned, hence use writeUnaligned()
nw, err = writeUnaligned(w, buf) nw, err = writeUnaligned(w, buf)
} }
if nw > 0 { if nw > 0 {
written += nw written += nw
} }
if err != nil { if err != nil {
return written, err return written, err
} }
if nw != int64(len(buf)) { if nw != int64(len(buf)) {
return written, io.ErrShortWrite return written, io.ErrShortWrite
} }
if totalSize != -1 { if totalSize > 0 && written == totalSize {
if written == totalSize { // we have written the entire stream, return right here.
return written, nil return written, nil
} }
}
if eof { if eof {
// We reached EOF prematurely but we did not write everything
// that we promised that we would write.
if totalSize > 0 && written != totalSize {
return written, io.ErrUnexpectedEOF
}
return written, nil return written, nil
} }
} }

View File

@ -1,4 +1,4 @@
// Copyright (c) 2015-2021 MinIO, Inc. // Copyright (c) 2015-2023 MinIO, Inc.
// //
// This file is part of MinIO Object Storage stack // This file is part of MinIO Object Storage stack
// //
@ -20,8 +20,10 @@ package ioutil
import ( import (
"bytes" "bytes"
"context" "context"
"errors"
"io" "io"
"os" "os"
"strings"
"testing" "testing"
"time" "time"
) )
@ -205,3 +207,36 @@ func TestSameFile(t *testing.T) {
t.Fatal("Expected the files not to be same") t.Fatal("Expected the files not to be same")
} }
} }
func TestCopyAligned(t *testing.T) {
f, err := os.CreateTemp("", "")
if err != nil {
t.Errorf("Error creating tmp file: %v", err)
}
defer f.Close()
defer os.Remove(f.Name())
r := strings.NewReader("hello world")
bufp := ODirectPoolSmall.Get().(*[]byte)
defer ODirectPoolSmall.Put(bufp)
written, err := CopyAligned(f, io.LimitReader(r, 5), *bufp, r.Size(), f)
if !errors.Is(err, io.ErrUnexpectedEOF) {
t.Errorf("Expected io.ErrUnexpectedEOF, but got %v", err)
}
if written != 5 {
t.Errorf("Expected written to be '5', but got %v", written)
}
f.Seek(0, io.SeekStart)
r.Seek(0, io.SeekStart)
written, err = CopyAligned(f, r, *bufp, r.Size(), f)
if !errors.Is(err, nil) {
t.Errorf("Expected nil, but got %v", err)
}
if written != r.Size() {
t.Errorf("Expected written to be '%v', but got %v", r.Size(), written)
}
}