mirror of
https://github.com/juanfont/headscale.git
synced 2025-11-09 13:39:39 -05:00
fix: return valid AuthUrl in followup request on expired reg id
- tailscale client gets a new AuthUrl and sets entry in the regcache - regcache entry expires - client doesn't know about that - client always polls followup request а gets error When user clicks "Login" in the app (after cache expiry), they visit invalid URL and get "node not found in registration cache". Some clients on Windows for e.g. can't get a new AuthUrl without restart the app. To fix that we can issue a new reg id and return user a new valid AuthUrl. RegisterNode is refactored to be created with NewRegisterNode() to autocreate channel and other stuff.
This commit is contained in:
@@ -40,7 +40,7 @@ func (h *Headscale) handleRegister(
|
||||
}
|
||||
|
||||
if regReq.Followup != "" {
|
||||
return h.waitForFollowup(ctx, regReq)
|
||||
return h.waitForFollowup(ctx, regReq, machineKey)
|
||||
}
|
||||
|
||||
if regReq.Auth != nil && regReq.Auth.AuthKey != "" {
|
||||
@@ -142,6 +142,7 @@ func nodeToRegisterResponse(node *types.Node) *tailcfg.RegisterResponse {
|
||||
func (h *Headscale) waitForFollowup(
|
||||
ctx context.Context,
|
||||
regReq tailcfg.RegisterRequest,
|
||||
machineKey key.MachinePublic,
|
||||
) (*tailcfg.RegisterResponse, error) {
|
||||
fu, err := url.Parse(regReq.Followup)
|
||||
if err != nil {
|
||||
@@ -159,13 +160,49 @@ func (h *Headscale) waitForFollowup(
|
||||
return nil, NewHTTPError(http.StatusUnauthorized, "registration timed out", err)
|
||||
case node := <-reg.Registered:
|
||||
if node == nil {
|
||||
return nil, NewHTTPError(http.StatusUnauthorized, "node not found", nil)
|
||||
// registration is expired in the cache, instruct the client to try a new registration
|
||||
return h.reqToNewRegisterResponse(regReq, machineKey)
|
||||
}
|
||||
return nodeToRegisterResponse(node), nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, NewHTTPError(http.StatusNotFound, "followup registration not found", nil)
|
||||
// if the follow-up registration isn't found anymore, instruct the client to try a new registration
|
||||
return h.reqToNewRegisterResponse(regReq, machineKey)
|
||||
}
|
||||
|
||||
// reqToNewRegisterResponse refreshes the registration flow by creating a new
|
||||
// registration ID and returning the corresponding AuthURL so the client can
|
||||
// restart the authentication process.
|
||||
func (h *Headscale) reqToNewRegisterResponse(
|
||||
regReq tailcfg.RegisterRequest,
|
||||
machineKey key.MachinePublic,
|
||||
) (*tailcfg.RegisterResponse, error) {
|
||||
newRegID, err := types.NewRegistrationID()
|
||||
if err != nil {
|
||||
return nil, NewHTTPError(http.StatusInternalServerError, "failed to generate registration ID", err)
|
||||
}
|
||||
|
||||
nodeToRegister := types.NewRegisterNode(
|
||||
types.Node{
|
||||
Hostname: regReq.Hostinfo.Hostname,
|
||||
MachineKey: machineKey,
|
||||
NodeKey: regReq.NodeKey,
|
||||
Hostinfo: regReq.Hostinfo,
|
||||
LastSeen: ptr.To(time.Now()),
|
||||
},
|
||||
)
|
||||
|
||||
if !regReq.Expiry.IsZero() {
|
||||
nodeToRegister.Node.Expiry = ®Req.Expiry
|
||||
}
|
||||
|
||||
log.Info().Msgf("New followup node registration using key: %s", newRegID)
|
||||
h.state.SetRegistrationCacheEntry(newRegID, nodeToRegister)
|
||||
|
||||
return &tailcfg.RegisterResponse{
|
||||
AuthURL: h.authProvider.AuthURL(newRegID),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *Headscale) handleRegisterWithAuthKey(
|
||||
@@ -244,16 +281,15 @@ func (h *Headscale) handleRegisterInteractive(
|
||||
return nil, fmt.Errorf("generating registration ID: %w", err)
|
||||
}
|
||||
|
||||
nodeToRegister := types.RegisterNode{
|
||||
Node: types.Node{
|
||||
nodeToRegister := types.NewRegisterNode(
|
||||
types.Node{
|
||||
Hostname: regReq.Hostinfo.Hostname,
|
||||
MachineKey: machineKey,
|
||||
NodeKey: regReq.NodeKey,
|
||||
Hostinfo: regReq.Hostinfo,
|
||||
LastSeen: ptr.To(time.Now()),
|
||||
},
|
||||
Registered: make(chan *types.Node),
|
||||
}
|
||||
)
|
||||
|
||||
if !regReq.Expiry.IsZero() {
|
||||
nodeToRegister.Node.Expiry = ®Req.Expiry
|
||||
|
||||
@@ -749,8 +749,8 @@ func (api headscaleV1APIServer) DebugCreateNode(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newNode := types.RegisterNode{
|
||||
Node: types.Node{
|
||||
newNode := types.NewRegisterNode(
|
||||
types.Node{
|
||||
NodeKey: key.NewNode().Public(),
|
||||
MachineKey: key.NewMachine().Public(),
|
||||
Hostname: request.GetName(),
|
||||
@@ -761,8 +761,7 @@ func (api headscaleV1APIServer) DebugCreateNode(
|
||||
|
||||
Hostinfo: &hostinfo,
|
||||
},
|
||||
Registered: make(chan *types.Node),
|
||||
}
|
||||
)
|
||||
|
||||
log.Debug().
|
||||
Caller().
|
||||
|
||||
@@ -331,6 +331,12 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
verb := "Reauthenticated"
|
||||
newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) {
|
||||
log.Debug().Caller().Str("registration_id", registrationId.String()).Msg("registration session expired before authorization completed")
|
||||
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", err))
|
||||
|
||||
return
|
||||
}
|
||||
httpError(writer, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -89,6 +89,12 @@ func NewState(cfg *types.Config) (*State, error) {
|
||||
cacheCleanup,
|
||||
)
|
||||
|
||||
registrationCache.OnEvicted(
|
||||
func(id types.RegistrationID, rn types.RegisterNode) {
|
||||
rn.SendAndClose(nil)
|
||||
},
|
||||
)
|
||||
|
||||
db, err := hsdb.NewHeadscaleDatabase(
|
||||
cfg.Database,
|
||||
cfg.BaseDomain,
|
||||
@@ -1248,16 +1254,12 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
s.nodeStore.PutNode(*savedNode)
|
||||
}
|
||||
|
||||
// Signal to waiting clients
|
||||
regEntry.SendAndClose(savedNode)
|
||||
|
||||
// Delete from registration cache
|
||||
s.registrationCache.Delete(registrationID)
|
||||
|
||||
// Signal to waiting clients
|
||||
select {
|
||||
case regEntry.Registered <- savedNode:
|
||||
default:
|
||||
}
|
||||
close(regEntry.Registered)
|
||||
|
||||
// Update policy manager
|
||||
nodesChange, err := s.updatePolicyManagerNodes()
|
||||
if err != nil {
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
@@ -186,6 +187,28 @@ func (r RegistrationID) String() string {
|
||||
type RegisterNode struct {
|
||||
Node Node
|
||||
Registered chan *Node
|
||||
closed *atomic.Bool
|
||||
}
|
||||
|
||||
func NewRegisterNode(node Node) RegisterNode {
|
||||
return RegisterNode{
|
||||
Node: node,
|
||||
Registered: make(chan *Node),
|
||||
closed: &atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
func (rn *RegisterNode) SendAndClose(node *Node) {
|
||||
if rn.closed.Swap(true) {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case rn.Registered <- node:
|
||||
default:
|
||||
}
|
||||
|
||||
close(rn.Registered)
|
||||
}
|
||||
|
||||
// DefaultBatcherWorkers returns the default number of batcher workers.
|
||||
|
||||
Reference in New Issue
Block a user