simplify deadlineWriter, re-use WithDeadline (#18948)

This commit is contained in:
Harshavardhana 2024-02-02 03:02:31 -08:00 committed by GitHub
parent 31743789dc
commit d99d16e8c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 12 additions and 22 deletions

View File

@ -183,31 +183,21 @@ func (w *DeadlineWriter) Write(buf []byte) (int, error) {
return 0, w.err return 0, w.err
} }
c := make(chan ioret[int], 1) n, err := WithDeadline[int](context.Background(), w.timeout, func(ctx context.Context) (int, error) {
t := time.NewTimer(w.timeout) return w.WriteCloser.Write(buf)
go func() { })
n, err := w.WriteCloser.Write(buf) w.err = err
c <- ioret[int]{val: n, err: err} return n, err
close(c)
}()
select {
case r := <-c:
if !t.Stop() {
<-t.C
}
w.err = r.err
return r.val, r.err
case <-t.C:
w.WriteCloser.Close()
w.err = context.DeadlineExceeded
return 0, context.DeadlineExceeded
}
} }
// Close closer interface to close the underlying closer // Close closer interface to close the underlying closer
func (w *DeadlineWriter) Close() error { func (w *DeadlineWriter) Close() error {
return w.WriteCloser.Close() err := w.WriteCloser.Close()
w.err = err
if err == nil {
w.err = errors.New("we are closed") // Avoids any reuse on the Write() side.
}
return err
} }
// LimitWriter implements io.WriteCloser. // LimitWriter implements io.WriteCloser.

View File

@ -44,7 +44,6 @@ func (w *sleepWriter) Close() error {
func TestDeadlineWriter(t *testing.T) { func TestDeadlineWriter(t *testing.T) {
w := NewDeadlineWriter(&sleepWriter{timeout: 500 * time.Millisecond}, 450*time.Millisecond) w := NewDeadlineWriter(&sleepWriter{timeout: 500 * time.Millisecond}, 450*time.Millisecond)
_, err := w.Write([]byte("1")) _, err := w.Write([]byte("1"))
w.Close()
if err != context.DeadlineExceeded { if err != context.DeadlineExceeded {
t.Error("DeadlineWriter shouldn't be successful - should return context.DeadlineExceeded") t.Error("DeadlineWriter shouldn't be successful - should return context.DeadlineExceeded")
} }
@ -52,6 +51,7 @@ func TestDeadlineWriter(t *testing.T) {
if err != context.DeadlineExceeded { if err != context.DeadlineExceeded {
t.Error("DeadlineWriter shouldn't be successful - should return context.DeadlineExceeded") t.Error("DeadlineWriter shouldn't be successful - should return context.DeadlineExceeded")
} }
w.Close()
w = NewDeadlineWriter(&sleepWriter{timeout: 100 * time.Millisecond}, 600*time.Millisecond) w = NewDeadlineWriter(&sleepWriter{timeout: 100 * time.Millisecond}, 600*time.Millisecond)
n, err := w.Write([]byte("abcd")) n, err := w.Write([]byte("abcd"))
w.Close() w.Close()