grid: Return rejection reason (#18834)

When rejecting incoming grid requests fill out the rejection reason and log it once.

This will give more context when startup is failing. Already logged after a retry on caller.
This commit is contained in:
Klaus Post 2024-01-19 10:35:24 -08:00 committed by GitHub
parent cc960adbee
commit 83bf15a703
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 28 additions and 23 deletions

View File

@ -786,7 +786,7 @@ func (c *Connection) handleIncoming(ctx context.Context, conn net.Conn, req conn
if debugPrint {
fmt.Println("expected to be client side, not server side")
}
return errors.New("expected to be client side, not server side")
return errors.New("grid: expected to be client side, not server side")
}
msg := message{
Op: OpConnectResponse,

View File

@ -161,6 +161,27 @@ func (m *Manager) Handler() http.HandlerFunc {
w.WriteHeader(http.StatusUpgradeRequired)
return
}
// will write an OpConnectResponse message to the remote and log it once locally.
writeErr := func(err error) {
if err == nil {
return
}
logger.LogOnceIf(ctx, err, err.Error())
resp := connectResp{
ID: m.ID,
Accepted: false,
RejectedReason: err.Error(),
}
if b, err := resp.MarshalMsg(nil); err == nil {
msg := message{
Op: OpConnectResponse,
Payload: b,
}
if b, err := msg.MarshalMsg(nil); err == nil {
wsutil.WriteMessage(conn, ws.StateServerSide, ws.OpBinary, b)
}
}
}
defer conn.Close()
if debugPrint {
fmt.Printf("grid: Upgraded request: %v\n", req.URL)
@ -168,7 +189,7 @@ func (m *Manager) Handler() http.HandlerFunc {
msg, _, err := wsutil.ReadClientData(conn)
if err != nil {
logger.LogIf(ctx, fmt.Errorf("grid: reading connect: %w", err))
writeErr(fmt.Errorf("reading connect: %w", err))
w.WriteHeader(http.StatusForbidden)
return
}
@ -179,44 +200,28 @@ func (m *Manager) Handler() http.HandlerFunc {
var message message
_, _, err = message.parse(msg)
if err != nil {
if debugPrint {
fmt.Println("parse err:", err)
}
logger.LogIf(ctx, fmt.Errorf("handleMessages: parsing connect: %w", err))
w.WriteHeader(http.StatusForbidden)
writeErr(fmt.Errorf("error parsing grid connect: %w", err))
return
}
if message.Op != OpConnect {
if debugPrint {
fmt.Println("op err:", message.Op)
}
logger.LogIf(ctx, fmt.Errorf("handler: unexpected op: %v", message.Op))
w.WriteHeader(http.StatusForbidden)
writeErr(fmt.Errorf("unexpected connect op: %v", message.Op))
return
}
var cReq connectReq
_, err = cReq.UnmarshalMsg(message.Payload)
if err != nil {
if debugPrint {
fmt.Println("handler: creq err:", err)
}
logger.LogIf(ctx, fmt.Errorf("handleMessages: parsing ConnectReq: %w", err))
w.WriteHeader(http.StatusForbidden)
writeErr(fmt.Errorf("error parsing connectReq: %w", err))
return
}
remote := m.targets[cReq.Host]
if remote == nil {
if debugPrint {
fmt.Printf("%s: handler: unknown host: %v. Have %v\n", m.local, cReq.Host, m.targets)
}
w.WriteHeader(http.StatusForbidden)
writeErr(fmt.Errorf("unknown incoming host: %v", cReq.Host))
return
}
if debugPrint {
fmt.Printf("handler: Got Connect Req %+v\n", cReq)
}
logger.LogIf(ctx, remote.handleIncoming(ctx, conn, cReq))
writeErr(remote.handleIncoming(ctx, conn, cReq))
}
}