diff --git a/internal/ioutil/ioutil.go b/internal/ioutil/ioutil.go index 2cb07b894..09712f4b1 100644 --- a/internal/ioutil/ioutil.go +++ b/internal/ioutil/ioutil.go @@ -183,31 +183,21 @@ func (w *DeadlineWriter) Write(buf []byte) (int, error) { return 0, w.err } - c := make(chan ioret[int], 1) - t := time.NewTimer(w.timeout) - go func() { - n, err := w.WriteCloser.Write(buf) - c <- ioret[int]{val: n, err: err} - close(c) - }() - - select { - case r := <-c: - if !t.Stop() { - <-t.C - } - w.err = r.err - return r.val, r.err - case <-t.C: - w.WriteCloser.Close() - w.err = context.DeadlineExceeded - return 0, context.DeadlineExceeded - } + n, err := WithDeadline[int](context.Background(), w.timeout, func(ctx context.Context) (int, error) { + return w.WriteCloser.Write(buf) + }) + w.err = err + return n, err } // Close closer interface to close the underlying closer func (w *DeadlineWriter) Close() error { - return w.WriteCloser.Close() + err := w.WriteCloser.Close() + w.err = err + if err == nil { + w.err = errors.New("we are closed") // Avoids any reuse on the Write() side. + } + return err } // LimitWriter implements io.WriteCloser. diff --git a/internal/ioutil/ioutil_test.go b/internal/ioutil/ioutil_test.go index c393667a5..6e332b3f5 100644 --- a/internal/ioutil/ioutil_test.go +++ b/internal/ioutil/ioutil_test.go @@ -44,7 +44,6 @@ func (w *sleepWriter) Close() error { func TestDeadlineWriter(t *testing.T) { w := NewDeadlineWriter(&sleepWriter{timeout: 500 * time.Millisecond}, 450*time.Millisecond) _, err := w.Write([]byte("1")) - w.Close() if err != context.DeadlineExceeded { t.Error("DeadlineWriter shouldn't be successful - should return context.DeadlineExceeded") } @@ -52,6 +51,7 @@ func TestDeadlineWriter(t *testing.T) { if err != context.DeadlineExceeded { t.Error("DeadlineWriter shouldn't be successful - should return context.DeadlineExceeded") } + w.Close() w = NewDeadlineWriter(&sleepWriter{timeout: 100 * time.Millisecond}, 600*time.Millisecond) n, err := w.Write([]byte("abcd")) w.Close()