diff --git a/internal/grid/connection.go b/internal/grid/connection.go index 2203c9580..40dc000ca 100644 --- a/internal/grid/connection.go +++ b/internal/grid/connection.go @@ -786,7 +786,7 @@ func (c *Connection) handleIncoming(ctx context.Context, conn net.Conn, req conn if debugPrint { fmt.Println("expected to be client side, not server side") } - return errors.New("expected to be client side, not server side") + return errors.New("grid: expected to be client side, not server side") } msg := message{ Op: OpConnectResponse, diff --git a/internal/grid/manager.go b/internal/grid/manager.go index 06658d847..d291f29e5 100644 --- a/internal/grid/manager.go +++ b/internal/grid/manager.go @@ -161,6 +161,27 @@ func (m *Manager) Handler() http.HandlerFunc { w.WriteHeader(http.StatusUpgradeRequired) return } + // will write an OpConnectResponse message to the remote and log it once locally. + writeErr := func(err error) { + if err == nil { + return + } + logger.LogOnceIf(ctx, err, err.Error()) + resp := connectResp{ + ID: m.ID, + Accepted: false, + RejectedReason: err.Error(), + } + if b, err := resp.MarshalMsg(nil); err == nil { + msg := message{ + Op: OpConnectResponse, + Payload: b, + } + if b, err := msg.MarshalMsg(nil); err == nil { + wsutil.WriteMessage(conn, ws.StateServerSide, ws.OpBinary, b) + } + } + } defer conn.Close() if debugPrint { fmt.Printf("grid: Upgraded request: %v\n", req.URL) @@ -168,7 +189,7 @@ func (m *Manager) Handler() http.HandlerFunc { msg, _, err := wsutil.ReadClientData(conn) if err != nil { - logger.LogIf(ctx, fmt.Errorf("grid: reading connect: %w", err)) + writeErr(fmt.Errorf("reading connect: %w", err)) w.WriteHeader(http.StatusForbidden) return } @@ -179,44 +200,28 @@ func (m *Manager) Handler() http.HandlerFunc { var message message _, _, err = message.parse(msg) if err != nil { - if debugPrint { - fmt.Println("parse err:", err) - } - logger.LogIf(ctx, fmt.Errorf("handleMessages: parsing connect: %w", err)) - w.WriteHeader(http.StatusForbidden) + writeErr(fmt.Errorf("error parsing grid connect: %w", err)) return } if message.Op != OpConnect { - if debugPrint { - fmt.Println("op err:", message.Op) - } - logger.LogIf(ctx, fmt.Errorf("handler: unexpected op: %v", message.Op)) - w.WriteHeader(http.StatusForbidden) + writeErr(fmt.Errorf("unexpected connect op: %v", message.Op)) return } var cReq connectReq _, err = cReq.UnmarshalMsg(message.Payload) if err != nil { - if debugPrint { - fmt.Println("handler: creq err:", err) - } - logger.LogIf(ctx, fmt.Errorf("handleMessages: parsing ConnectReq: %w", err)) - w.WriteHeader(http.StatusForbidden) + writeErr(fmt.Errorf("error parsing connectReq: %w", err)) return } remote := m.targets[cReq.Host] if remote == nil { - if debugPrint { - fmt.Printf("%s: handler: unknown host: %v. Have %v\n", m.local, cReq.Host, m.targets) - } - w.WriteHeader(http.StatusForbidden) + writeErr(fmt.Errorf("unknown incoming host: %v", cReq.Host)) return } if debugPrint { fmt.Printf("handler: Got Connect Req %+v\n", cReq) } - - logger.LogIf(ctx, remote.handleIncoming(ctx, conn, cReq)) + writeErr(remote.handleIncoming(ctx, conn, cReq)) } }