From 38de8e69361a6a95512346ffc6ca72722b2a723e Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Sun, 28 Jan 2024 08:46:15 -0800 Subject: [PATCH] grid: Simpler reconnect logic (#18889) Do not rely on `connChange` to do reconnects. Instead, you can block while the connection is running and reconnect when handleMessages returns. Add fully async monitoring instead of monitoring on the main goroutine and keep this to avoid full network lockup. --- internal/grid/connection.go | 65 ++++++++++++++++++++----------------- internal/grid/handlers.go | 3 ++ internal/grid/muxclient.go | 6 ++-- internal/grid/muxserver.go | 10 ++++-- 4 files changed, 50 insertions(+), 34 deletions(-) diff --git a/internal/grid/connection.go b/internal/grid/connection.go index 73875c6b8..352eba4e2 100644 --- a/internal/grid/connection.go +++ b/internal/grid/connection.go @@ -707,26 +707,15 @@ func (c *Connection) connect() { if debugPrint { fmt.Println(c.Local, "Connected Waiting for Messages") } - c.updateState(StateConnected) - go c.handleMessages(c.ctx, conn) - // Monitor state changes and reconnect if needed. - c.connChange.L.Lock() - for { - newState := c.State() - if newState != StateConnected { - c.connChange.L.Unlock() - if newState == StateShutdown { - conn.Close() - return - } - if debugPrint { - fmt.Println(c.Local, "Disconnected") - } - // Reconnect - break - } - // Unlock and wait for state change. - c.connChange.Wait() + // Handle messages... + c.handleMessages(c.ctx, conn) + // Reconnect unless we are shutting down (debug only). + if c.State() == StateShutdown { + conn.Close() + return + } + if debugPrint { + fmt.Println(c.Local, "Disconnected. Attempting to reconnect.") } } } @@ -818,7 +807,7 @@ func (c *Connection) handleIncoming(ctx context.Context, conn net.Conn, req conn rid := uuid.UUID(req.ID) c.remoteID = &rid - c.updateState(StateConnected) + // Handle incoming messages until disconnect. c.handleMessages(ctx, conn) return nil } @@ -867,12 +856,36 @@ func (c *Connection) updateState(s State) { c.connChange.Broadcast() } +// monitorState will monitor the state of the connection and close the net.Conn if it changes. +func (c *Connection) monitorState(conn net.Conn, cancel context.CancelCauseFunc) { + c.connChange.L.Lock() + defer c.connChange.L.Unlock() + for { + newState := c.State() + if newState != StateConnected { + conn.Close() + cancel(ErrDisconnected) + return + } + // Unlock and wait for state change. + c.connChange.Wait() + } +} + // handleMessages will handle incoming messages on conn. // caller *must* hold reconnectMu. func (c *Connection) handleMessages(ctx context.Context, conn net.Conn) { + c.updateState(StateConnected) + ctx, cancel := context.WithCancelCause(ctx) + defer cancel(ErrDisconnected) + + // This will ensure that is something asks to disconnect and we are blocked on reads/writes + // the connection will be closed and readers/writers will unblock. + go c.monitorState(conn, cancel) + c.handleMsgWg.Add(2) c.reconnectMu.Unlock() - ctx, cancel := context.WithCancelCause(ctx) + // Read goroutine go func() { defer func() { @@ -1034,7 +1047,6 @@ func (c *Connection) handleMessages(ctx context.Context, conn net.Conn) { lastPongTime := time.Unix(lastPong, 0) if d := time.Since(lastPongTime); d > connPingInterval*2 { logger.LogIf(ctx, fmt.Errorf("host %s last pong too old (%v); disconnecting", c.Remote, d.Round(time.Millisecond))) - cancel(ErrDisconnected) return } } @@ -1084,14 +1096,12 @@ func (c *Connection) handleMessages(ctx context.Context, conn net.Conn) { err := wsw.writeMessage(&buf, c.side, ws.OpBinary, toSend) if err != nil { logger.LogIf(ctx, fmt.Errorf("ws writeMessage: %w", err)) - cancel(ErrDisconnected) return } PutByteBuffer(toSend) _, err = buf.WriteTo(conn) if err != nil { logger.LogIf(ctx, fmt.Errorf("ws write: %w", err)) - cancel(ErrDisconnected) return } continue @@ -1109,7 +1119,6 @@ func (c *Connection) handleMessages(ctx context.Context, conn net.Conn) { toSend, err = m.MarshalMsg(toSend) if err != nil { logger.LogIf(ctx, fmt.Errorf("msg.MarshalMsg: %w", err)) - cancel(ErrDisconnected) return } // Append as byte slices. @@ -1126,14 +1135,12 @@ func (c *Connection) handleMessages(ctx context.Context, conn net.Conn) { err = wsw.writeMessage(&buf, c.side, ws.OpBinary, toSend) if err != nil { logger.LogIf(ctx, fmt.Errorf("ws writeMessage: %w", err)) - cancel(ErrDisconnected) return } - // Tosend is our local buffer, so we can reuse it. + // buf is our local buffer, so we can reuse it. _, err = buf.WriteTo(conn) if err != nil { logger.LogIf(ctx, fmt.Errorf("ws write: %w", err)) - cancel(ErrDisconnected) return } diff --git a/internal/grid/handlers.go b/internal/grid/handlers.go index 82f7efdb2..902f55312 100644 --- a/internal/grid/handlers.go +++ b/internal/grid/handlers.go @@ -691,6 +691,9 @@ func (h *StreamTypeHandler[Payload, Req, Resp]) Call(ctx context.Context, c Stre if h.InCapacity > 0 { reqT = make(chan Req) // Request handler + if stream.Requests == nil { + return nil, fmt.Errorf("internal error: stream request channel nil") + } go func() { defer close(stream.Requests) for req := range reqT { diff --git a/internal/grid/muxclient.go b/internal/grid/muxclient.go index a04122c30..3ea50ceae 100644 --- a/internal/grid/muxclient.go +++ b/internal/grid/muxclient.go @@ -533,7 +533,9 @@ func (m *muxClient) closeLocked() { if m.closed { return } - close(m.respWait) - m.respWait = nil + if m.respWait != nil { + close(m.respWait) + m.respWait = nil + } m.closed = true } diff --git a/internal/grid/muxserver.go b/internal/grid/muxserver.go index fd7096f41..907722462 100644 --- a/internal/grid/muxserver.go +++ b/internal/grid/muxserver.go @@ -21,6 +21,7 @@ import ( "context" "errors" "fmt" + "runtime/debug" "sync" "sync/atomic" "time" @@ -138,7 +139,8 @@ func newMuxStream(ctx context.Context, msg message, c *Connection, handler Strea } if r := recover(); r != nil { logger.LogIf(ctx, fmt.Errorf("grid handler (%v) panic: %v", msg.Handler, r)) - err := RemoteErr(fmt.Sprintf("panic: %v", r)) + debug.PrintStack() + err := RemoteErr(fmt.Sprintf("remote call panic: %v", r)) handlerErr = &err } if debugPrint { @@ -244,8 +246,10 @@ func (m *muxServer) message(msg message) { if len(msg.Payload) > 0 { logger.LogIf(m.ctx, fmt.Errorf("muxServer: EOF message with payload")) } - close(m.inbound) - m.inbound = nil + if m.inbound != nil { + close(m.inbound) + m.inbound = nil + } return }