diff --git a/internal/ioutil/ioutil.go b/internal/ioutil/ioutil.go index 9692a09bd..196f19572 100644 --- a/internal/ioutil/ioutil.go +++ b/internal/ioutil/ioutil.go @@ -22,6 +22,7 @@ package ioutil import ( "bytes" "context" + "errors" "io" "os" "sync" @@ -348,6 +349,10 @@ const DirectioAlignSize = 4096 // input writer *os.File not a generic io.Writer. Make sure to have // the file opened for writes with syscall.O_DIRECT flag. func CopyAligned(w io.Writer, r io.Reader, alignedBuf []byte, totalSize int64, file *os.File) (int64, error) { + if totalSize == 0 { + return 0, nil + } + // Writes remaining bytes in the buffer. writeUnaligned := func(w io.Writer, buf []byte) (remainingWritten int64, err error) { // Disable O_DIRECT on fd's on unaligned buffer @@ -364,17 +369,19 @@ func CopyAligned(w io.Writer, r io.Reader, alignedBuf []byte, totalSize int64, f var written int64 for { buf := alignedBuf - if totalSize != -1 { + if totalSize > 0 { remaining := totalSize - written if remaining < int64(len(buf)) { buf = buf[:remaining] } } + nr, err := io.ReadFull(r, buf) - eof := err == io.EOF || err == io.ErrUnexpectedEOF + eof := errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) if err != nil && !eof { return written, err } + buf = buf[:nr] var nw int64 if len(buf)%DirectioAlignSize == 0 { @@ -386,22 +393,30 @@ func CopyAligned(w io.Writer, r io.Reader, alignedBuf []byte, totalSize int64, f // buf is not aligned, hence use writeUnaligned() nw, err = writeUnaligned(w, buf) } + if nw > 0 { written += nw } + if err != nil { return written, err } + if nw != int64(len(buf)) { return written, io.ErrShortWrite } - if totalSize != -1 { - if written == totalSize { - return written, nil - } + if totalSize > 0 && written == totalSize { + // we have written the entire stream, return right here. + return written, nil } + if eof { + // We reached EOF prematurely but we did not write everything + // that we promised that we would write. + if totalSize > 0 && written != totalSize { + return written, io.ErrUnexpectedEOF + } return written, nil } } diff --git a/internal/ioutil/ioutil_test.go b/internal/ioutil/ioutil_test.go index 921e78ecb..37efea3cc 100644 --- a/internal/ioutil/ioutil_test.go +++ b/internal/ioutil/ioutil_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2015-2021 MinIO, Inc. +// Copyright (c) 2015-2023 MinIO, Inc. // // This file is part of MinIO Object Storage stack // @@ -20,8 +20,10 @@ package ioutil import ( "bytes" "context" + "errors" "io" "os" + "strings" "testing" "time" ) @@ -205,3 +207,36 @@ func TestSameFile(t *testing.T) { t.Fatal("Expected the files not to be same") } } + +func TestCopyAligned(t *testing.T) { + f, err := os.CreateTemp("", "") + if err != nil { + t.Errorf("Error creating tmp file: %v", err) + } + defer f.Close() + defer os.Remove(f.Name()) + + r := strings.NewReader("hello world") + + bufp := ODirectPoolSmall.Get().(*[]byte) + defer ODirectPoolSmall.Put(bufp) + + written, err := CopyAligned(f, io.LimitReader(r, 5), *bufp, r.Size(), f) + if !errors.Is(err, io.ErrUnexpectedEOF) { + t.Errorf("Expected io.ErrUnexpectedEOF, but got %v", err) + } + if written != 5 { + t.Errorf("Expected written to be '5', but got %v", written) + } + + f.Seek(0, io.SeekStart) + r.Seek(0, io.SeekStart) + + written, err = CopyAligned(f, r, *bufp, r.Size(), f) + if !errors.Is(err, nil) { + t.Errorf("Expected nil, but got %v", err) + } + if written != r.Size() { + t.Errorf("Expected written to be '%v', but got %v", r.Size(), written) + } +}