improve testing of route failover logic
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
bf4fd078fc
commit
1704977e76
|
@ -503,7 +503,7 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
|
|||
return router
|
||||
}
|
||||
|
||||
// Serve launches a GIN server with the Headscale API.
|
||||
// Serve launches the HTTP and gRPC server service Headscale and the API.
|
||||
func (h *Headscale) Serve() error {
|
||||
if _, enableProfile := os.LookupEnv("HEADSCALE_PROFILING_ENABLED"); enableProfile {
|
||||
if profilePath, ok := os.LookupEnv("HEADSCALE_PROFILING_PATH"); ok {
|
||||
|
@ -532,7 +532,7 @@ func (h *Headscale) Serve() error {
|
|||
|
||||
region, err := h.DERPServer.GenerateRegion()
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("generating DERP region for embedded server: %w", err)
|
||||
}
|
||||
|
||||
if h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion {
|
||||
|
@ -607,14 +607,14 @@ func (h *Headscale) Serve() error {
|
|||
}...,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("setting up gRPC gateway via socket: %w", err)
|
||||
}
|
||||
|
||||
// Connect to the gRPC server over localhost to skip
|
||||
// the authentication.
|
||||
err = v1.RegisterHeadscaleServiceHandler(ctx, grpcGatewayMux, grpcGatewayConn)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("registering Headscale API service to gRPC: %w", err)
|
||||
}
|
||||
|
||||
// Start the local gRPC server without TLS and without authentication
|
||||
|
@ -635,9 +635,7 @@ func (h *Headscale) Serve() error {
|
|||
|
||||
tlsConfig, err := h.getTLSSettings()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to set up TLS configuration")
|
||||
|
||||
return err
|
||||
return fmt.Errorf("configuring TLS settings: %w", err)
|
||||
}
|
||||
|
||||
//
|
||||
|
@ -702,15 +700,11 @@ func (h *Headscale) Serve() error {
|
|||
httpServer := &http.Server{
|
||||
Addr: h.cfg.Addr,
|
||||
Handler: router,
|
||||
ReadTimeout: types.HTTPReadTimeout,
|
||||
// Go does not handle timeouts in HTTP very well, and there is
|
||||
// no good way to handle streaming timeouts, therefore we need to
|
||||
// keep this at unlimited and be careful to clean up connections
|
||||
// https://blog.cloudflare.com/the-complete-guide-to-golang-net-http-timeouts/#aboutstreaming
|
||||
// TODO(kradalby): this timeout can now be set per handler with http.ResponseController:
|
||||
// https://www.alexedwards.net/blog/how-to-use-the-http-responsecontroller-type
|
||||
// replace this so only the longpoller has no timeout.
|
||||
WriteTimeout: 0,
|
||||
ReadTimeout: types.HTTPTimeout,
|
||||
|
||||
// Long polling should not have any timeout, this is overriden
|
||||
// further down the chain
|
||||
WriteTimeout: types.HTTPTimeout,
|
||||
}
|
||||
|
||||
var httpListener net.Listener
|
||||
|
@ -729,27 +723,46 @@ func (h *Headscale) Serve() error {
|
|||
log.Info().
|
||||
Msgf("listening and serving HTTP on: %s", h.cfg.Addr)
|
||||
|
||||
promMux := http.NewServeMux()
|
||||
promMux.Handle("/metrics", promhttp.Handler())
|
||||
debugMux := http.NewServeMux()
|
||||
debugMux.HandleFunc("/debug/notifier", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(h.nodeNotifier.String()))
|
||||
|
||||
promHTTPServer := &http.Server{
|
||||
return
|
||||
})
|
||||
debugMux.HandleFunc("/debug/mapresp", func(w http.ResponseWriter, r *http.Request) {
|
||||
h.mapSessionMu.Lock()
|
||||
defer h.mapSessionMu.Unlock()
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString("mapresponders:\n")
|
||||
for k, v := range h.mapSessions {
|
||||
fmt.Fprintf(&b, "\t%d: %p\n", k, v)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(b.String()))
|
||||
|
||||
return
|
||||
})
|
||||
debugMux.Handle("/metrics", promhttp.Handler())
|
||||
|
||||
debugHTTPServer := &http.Server{
|
||||
Addr: h.cfg.MetricsAddr,
|
||||
Handler: promMux,
|
||||
ReadTimeout: types.HTTPReadTimeout,
|
||||
Handler: debugMux,
|
||||
ReadTimeout: types.HTTPTimeout,
|
||||
WriteTimeout: 0,
|
||||
}
|
||||
|
||||
var promHTTPListener net.Listener
|
||||
promHTTPListener, err = net.Listen("tcp", h.cfg.MetricsAddr)
|
||||
|
||||
debugHTTPListener, err := net.Listen("tcp", h.cfg.MetricsAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to bind to TCP address: %w", err)
|
||||
}
|
||||
|
||||
errorGroup.Go(func() error { return promHTTPServer.Serve(promHTTPListener) })
|
||||
errorGroup.Go(func() error { return debugHTTPServer.Serve(debugHTTPListener) })
|
||||
|
||||
log.Info().
|
||||
Msgf("listening and serving metrics on: %s", h.cfg.MetricsAddr)
|
||||
Msgf("listening and serving debug and metrics on: %s", h.cfg.MetricsAddr)
|
||||
|
||||
var tailsqlContext context.Context
|
||||
if tailsqlEnabled {
|
||||
|
@ -815,7 +828,7 @@ func (h *Headscale) Serve() error {
|
|||
context.Background(),
|
||||
types.HTTPShutdownTimeout,
|
||||
)
|
||||
if err := promHTTPServer.Shutdown(ctx); err != nil {
|
||||
if err := debugHTTPServer.Shutdown(ctx); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to shutdown prometheus http")
|
||||
}
|
||||
if err := httpServer.Shutdown(ctx); err != nil {
|
||||
|
@ -833,7 +846,7 @@ func (h *Headscale) Serve() error {
|
|||
}
|
||||
|
||||
// Close network listeners
|
||||
promHTTPListener.Close()
|
||||
debugHTTPListener.Close()
|
||||
httpListener.Close()
|
||||
grpcGatewayConn.Close()
|
||||
|
||||
|
@ -898,7 +911,7 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
|
|||
server := &http.Server{
|
||||
Addr: h.cfg.TLS.LetsEncrypt.Listen,
|
||||
Handler: certManager.HTTPHandler(http.HandlerFunc(h.redirect)),
|
||||
ReadTimeout: types.HTTPReadTimeout,
|
||||
ReadTimeout: types.HTTPTimeout,
|
||||
}
|
||||
|
||||
go func() {
|
||||
|
|
|
@ -4,11 +4,13 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sort"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/util/set"
|
||||
)
|
||||
|
||||
var ErrRouteIsNotAvailable = errors.New("route is not available")
|
||||
|
@ -402,11 +404,10 @@ func SaveNodeRoutes(tx *gorm.DB, node *types.Node) (bool, error) {
|
|||
return sendUpdate, nil
|
||||
}
|
||||
|
||||
// FailoverRouteIfAvailable takes a node and checks if the node's route
|
||||
// currently have a functioning host that exposes the network.
|
||||
// If it does not, it is failed over to another suitable route if there
|
||||
// is one.
|
||||
func FailoverRouteIfAvailable(
|
||||
// FailoverNodeRoutesIfNeccessary takes a node and checks if the node's route
|
||||
// need to be failed over to another host.
|
||||
// If needed, the failover will be attempted.
|
||||
func FailoverNodeRoutesIfNeccessary(
|
||||
tx *gorm.DB,
|
||||
isConnected types.NodeConnectedMap,
|
||||
node *types.Node,
|
||||
|
@ -416,8 +417,12 @@ func FailoverRouteIfAvailable(
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
var changedNodes []types.NodeID
|
||||
log.Trace().Msgf("NODE ROUTES: %d", len(nodeRoutes))
|
||||
changedNodes := make(set.Set[types.NodeID])
|
||||
|
||||
nodeRouteLoop:
|
||||
for _, nodeRoute := range nodeRoutes {
|
||||
log.Trace().Msgf("NODE ROUTE: %d", nodeRoute.ID)
|
||||
routes, err := getRoutesByPrefix(tx, netip.Prefix(nodeRoute.Prefix))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting routes by prefix: %w", err)
|
||||
|
@ -427,29 +432,37 @@ func FailoverRouteIfAvailable(
|
|||
if route.IsPrimary {
|
||||
// if we have a primary route, and the node is connected
|
||||
// nothing needs to be done.
|
||||
if isConnected[route.Node.ID] {
|
||||
return nil, nil
|
||||
if conn, ok := isConnected[route.Node.ID]; conn && ok {
|
||||
continue nodeRouteLoop
|
||||
}
|
||||
|
||||
// if not, we need to failover the route
|
||||
failover := failoverRoute(isConnected, &route, routes)
|
||||
if failover != nil {
|
||||
failover.save(tx)
|
||||
err := failover.save(tx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("saving failover routes: %w", err)
|
||||
}
|
||||
|
||||
changedNodes = append(changedNodes, failover.old.Node.ID, failover.new.Node.ID)
|
||||
changedNodes.Add(failover.old.Node.ID)
|
||||
changedNodes.Add(failover.new.Node.ID)
|
||||
|
||||
continue nodeRouteLoop
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
chng := changedNodes.Slice()
|
||||
sort.SliceStable(chng, func(i, j int) bool {
|
||||
return chng[i] < chng[j]
|
||||
})
|
||||
|
||||
if len(changedNodes) != 0 {
|
||||
return &types.StateUpdate{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: changedNodes,
|
||||
Message: "called from db.FailoverRouteIfAvailable",
|
||||
ChangeNodes: chng,
|
||||
Message: "called from db.FailoverNodeRoutesIfNeccessary",
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -7,9 +7,9 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gopkg.in/check.v1"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
|
@ -270,6 +270,370 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
|
|||
}
|
||||
|
||||
var ipp = func(s string) types.IPPrefix { return types.IPPrefix(netip.MustParsePrefix(s)) }
|
||||
var n = func(nid types.NodeID) types.Node {
|
||||
return types.Node{ID: nid}
|
||||
}
|
||||
var np = func(nid types.NodeID) *types.Node {
|
||||
no := n(nid)
|
||||
return &no
|
||||
}
|
||||
var r = func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary bool) types.Route {
|
||||
return types.Route{
|
||||
Model: gorm.Model{
|
||||
ID: id,
|
||||
},
|
||||
Node: n(nid),
|
||||
Prefix: prefix,
|
||||
Enabled: enabled,
|
||||
IsPrimary: primary,
|
||||
}
|
||||
}
|
||||
var rp = func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary bool) *types.Route {
|
||||
ro := r(id, nid, prefix, enabled, primary)
|
||||
return &ro
|
||||
}
|
||||
|
||||
func dbForTest(t *testing.T, testName string) *HSDatabase {
|
||||
t.Helper()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", testName)
|
||||
if err != nil {
|
||||
t.Fatalf("creating tempdir: %s", err)
|
||||
}
|
||||
|
||||
dbPath := tmpDir + "/headscale_test.db"
|
||||
|
||||
db, err = NewHeadscaleDatabase(
|
||||
types.DatabaseConfig{
|
||||
Type: "sqlite3",
|
||||
Sqlite: types.SqliteConfig{
|
||||
Path: dbPath,
|
||||
},
|
||||
},
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("setting up database: %s", err)
|
||||
}
|
||||
|
||||
t.Logf("database set up at: %s", dbPath)
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func TestFailoverNodeRoutesIfNeccessary(t *testing.T) {
|
||||
su := func(nids ...types.NodeID) *types.StateUpdate {
|
||||
return &types.StateUpdate{
|
||||
ChangeNodes: nids,
|
||||
}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
nodes types.Nodes
|
||||
routes types.Routes
|
||||
isConnected []types.NodeConnectedMap
|
||||
want []*types.StateUpdate
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "n1-down-n2-down-n1-up",
|
||||
nodes: types.Nodes{
|
||||
np(1),
|
||||
np(2),
|
||||
np(1),
|
||||
},
|
||||
routes: types.Routes{
|
||||
r(1, 1, ipp("10.0.0.0/24"), true, true),
|
||||
r(2, 2, ipp("10.0.0.0/24"), true, false),
|
||||
},
|
||||
isConnected: []types.NodeConnectedMap{
|
||||
// n1 goes down
|
||||
{
|
||||
1: false,
|
||||
2: true,
|
||||
},
|
||||
// n2 goes down
|
||||
{
|
||||
1: false,
|
||||
2: false,
|
||||
},
|
||||
// n1 comes up
|
||||
{
|
||||
1: true,
|
||||
2: false,
|
||||
},
|
||||
},
|
||||
want: []*types.StateUpdate{
|
||||
// route changes from 1 -> 2
|
||||
su(1, 2),
|
||||
// both down, no change
|
||||
nil,
|
||||
// route changes from 2 -> 1
|
||||
su(1, 2),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "n1-recon-n2-down-n1-recon-n2-up",
|
||||
nodes: types.Nodes{
|
||||
np(1),
|
||||
np(2),
|
||||
np(1),
|
||||
np(2),
|
||||
},
|
||||
routes: types.Routes{
|
||||
r(1, 1, ipp("10.0.0.0/24"), true, true),
|
||||
r(2, 2, ipp("10.0.0.0/24"), true, false),
|
||||
},
|
||||
isConnected: []types.NodeConnectedMap{
|
||||
// n1 up recon = noop
|
||||
{
|
||||
1: true,
|
||||
2: true,
|
||||
},
|
||||
// n2 goes down
|
||||
{
|
||||
1: true,
|
||||
2: false,
|
||||
},
|
||||
// n1 up recon = noop
|
||||
{
|
||||
1: true,
|
||||
2: false,
|
||||
},
|
||||
// n2 comes back up
|
||||
{
|
||||
1: true,
|
||||
2: false,
|
||||
},
|
||||
},
|
||||
want: []*types.StateUpdate{
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "n1-recon-n2-down-n1-recon-n2-up",
|
||||
nodes: types.Nodes{
|
||||
np(1),
|
||||
np(1),
|
||||
np(3),
|
||||
np(3),
|
||||
np(2),
|
||||
np(1),
|
||||
},
|
||||
routes: types.Routes{
|
||||
r(1, 1, ipp("10.0.0.0/24"), true, true),
|
||||
r(2, 2, ipp("10.0.0.0/24"), true, false),
|
||||
r(3, 3, ipp("10.0.0.0/24"), true, false),
|
||||
},
|
||||
isConnected: []types.NodeConnectedMap{
|
||||
// n1 goes down
|
||||
{
|
||||
1: false,
|
||||
2: false,
|
||||
3: true,
|
||||
},
|
||||
// n1 comes up
|
||||
{
|
||||
1: true,
|
||||
2: false,
|
||||
3: true,
|
||||
},
|
||||
// n3 goes down
|
||||
{
|
||||
1: true,
|
||||
2: false,
|
||||
3: false,
|
||||
},
|
||||
// n3 comes up
|
||||
{
|
||||
1: true,
|
||||
2: false,
|
||||
3: true,
|
||||
},
|
||||
// n2 comes up
|
||||
{
|
||||
1: true,
|
||||
2: true,
|
||||
3: true,
|
||||
},
|
||||
// n1 goes down
|
||||
{
|
||||
1: false,
|
||||
2: true,
|
||||
3: true,
|
||||
},
|
||||
},
|
||||
want: []*types.StateUpdate{
|
||||
su(1, 3), // n1 -> n3
|
||||
nil,
|
||||
su(1, 3), // n3 -> n1
|
||||
nil,
|
||||
nil,
|
||||
su(1, 2), // n1 -> n2
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "n1-recon-n2-dis-n3-take",
|
||||
nodes: types.Nodes{
|
||||
np(1),
|
||||
np(3),
|
||||
},
|
||||
routes: types.Routes{
|
||||
r(1, 1, ipp("10.0.0.0/24"), true, true),
|
||||
r(2, 2, ipp("10.0.0.0/24"), false, false),
|
||||
r(3, 3, ipp("10.0.0.0/24"), true, false),
|
||||
},
|
||||
isConnected: []types.NodeConnectedMap{
|
||||
// n1 goes down
|
||||
{
|
||||
1: false,
|
||||
2: true,
|
||||
3: true,
|
||||
},
|
||||
// n3 goes down
|
||||
{
|
||||
1: false,
|
||||
2: true,
|
||||
3: false,
|
||||
},
|
||||
},
|
||||
want: []*types.StateUpdate{
|
||||
su(1, 3), // n1 -> n3
|
||||
nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multi-n1-oneforeach-n2-n3",
|
||||
nodes: types.Nodes{
|
||||
np(1),
|
||||
},
|
||||
routes: types.Routes{
|
||||
r(1, 1, ipp("10.0.0.0/24"), true, true),
|
||||
r(4, 1, ipp("10.1.0.0/24"), true, true),
|
||||
r(2, 2, ipp("10.0.0.0/24"), true, false),
|
||||
r(3, 3, ipp("10.1.0.0/24"), true, false),
|
||||
},
|
||||
isConnected: []types.NodeConnectedMap{
|
||||
// n1 goes down
|
||||
{
|
||||
1: false,
|
||||
2: true,
|
||||
3: true,
|
||||
},
|
||||
},
|
||||
want: []*types.StateUpdate{
|
||||
su(1, 2, 3), // n1 -> n2,n3
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multi-n1-onefor-n2-disabled-n3",
|
||||
nodes: types.Nodes{
|
||||
np(1),
|
||||
},
|
||||
routes: types.Routes{
|
||||
r(1, 1, ipp("10.0.0.0/24"), true, true),
|
||||
r(4, 1, ipp("10.1.0.0/24"), true, true),
|
||||
r(2, 2, ipp("10.0.0.0/24"), true, false),
|
||||
r(3, 3, ipp("10.1.0.0/24"), false, false),
|
||||
},
|
||||
isConnected: []types.NodeConnectedMap{
|
||||
// n1 goes down
|
||||
{
|
||||
1: false,
|
||||
2: true,
|
||||
3: true,
|
||||
},
|
||||
},
|
||||
want: []*types.StateUpdate{
|
||||
su(1, 2), // n1 -> n2, n3 is not enabled
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multi-n1-onefor-n2-offline-n3",
|
||||
nodes: types.Nodes{
|
||||
np(1),
|
||||
},
|
||||
routes: types.Routes{
|
||||
r(1, 1, ipp("10.0.0.0/24"), true, true),
|
||||
r(4, 1, ipp("10.1.0.0/24"), true, true),
|
||||
r(2, 2, ipp("10.0.0.0/24"), true, false),
|
||||
r(3, 3, ipp("10.1.0.0/24"), true, false),
|
||||
},
|
||||
isConnected: []types.NodeConnectedMap{
|
||||
// n1 goes down
|
||||
{
|
||||
1: false,
|
||||
2: true,
|
||||
3: false,
|
||||
},
|
||||
},
|
||||
want: []*types.StateUpdate{
|
||||
su(1, 2), // n1 -> n2, n3 is offline
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multi-n2-back-to-multi-n1",
|
||||
nodes: types.Nodes{
|
||||
np(1),
|
||||
},
|
||||
routes: types.Routes{
|
||||
r(1, 1, ipp("10.0.0.0/24"), true, false),
|
||||
r(4, 1, ipp("10.1.0.0/24"), true, true),
|
||||
r(2, 2, ipp("10.0.0.0/24"), true, true),
|
||||
r(3, 3, ipp("10.1.0.0/24"), true, false),
|
||||
},
|
||||
isConnected: []types.NodeConnectedMap{
|
||||
// n1 goes down
|
||||
{
|
||||
1: true,
|
||||
2: false,
|
||||
3: true,
|
||||
},
|
||||
},
|
||||
want: []*types.StateUpdate{
|
||||
su(1, 2), // n2 -> n1
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if (len(tt.isConnected) != len(tt.want)) && len(tt.want) != len(tt.nodes) {
|
||||
t.Fatalf("nodes (%d), isConnected updates (%d), wants (%d) must be equal", len(tt.nodes), len(tt.isConnected), len(tt.want))
|
||||
}
|
||||
|
||||
db := dbForTest(t, tt.name)
|
||||
|
||||
for _, route := range tt.routes {
|
||||
if err := db.DB.Save(&route).Error; err != nil {
|
||||
t.Fatalf("failed to create route: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
for step := range len(tt.isConnected) {
|
||||
node := tt.nodes[step]
|
||||
isConnected := tt.isConnected[step]
|
||||
want := tt.want[step]
|
||||
|
||||
got, err := Write(db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
|
||||
return FailoverNodeRoutesIfNeccessary(tx, isConnected, node)
|
||||
})
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("failoverRoute() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(want, got, cmpopts.IgnoreFields(types.StateUpdate{}, "Type", "Message")); diff != "" {
|
||||
t.Errorf("failoverRoute() unexpected result (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailoverRouteTx(t *testing.T) {
|
||||
tests := []struct {
|
||||
|
@ -637,19 +1001,7 @@ func TestFailoverRouteTx(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "failover-db-test")
|
||||
assert.NoError(t, err)
|
||||
|
||||
db, err = NewHeadscaleDatabase(
|
||||
types.DatabaseConfig{
|
||||
Type: "sqlite3",
|
||||
Sqlite: types.SqliteConfig{
|
||||
Path: tmpDir + "/headscale_test.db",
|
||||
},
|
||||
},
|
||||
"",
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
db := dbForTest(t, tt.name)
|
||||
|
||||
for _, route := range tt.routes {
|
||||
if err := db.DB.Save(&route).Error; err != nil {
|
||||
|
|
|
@ -31,7 +31,7 @@ func loadDERPMapFromPath(path string) (*tailcfg.DERPMap, error) {
|
|||
}
|
||||
|
||||
func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), types.HTTPReadTimeout)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), types.HTTPTimeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, addr.String(), nil)
|
||||
|
@ -40,7 +40,7 @@ func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) {
|
|||
}
|
||||
|
||||
client := http.Client{
|
||||
Timeout: types.HTTPReadTimeout,
|
||||
Timeout: types.HTTPTimeout,
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
|
|
|
@ -3,7 +3,6 @@ package hscontrol
|
|||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
|
@ -12,7 +11,6 @@ import (
|
|||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/h2c"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/control/controlbase"
|
||||
"tailscale.com/control/controlhttp"
|
||||
"tailscale.com/tailcfg"
|
||||
|
@ -103,12 +101,12 @@ func (h *Headscale) NoiseUpgradeHandler(
|
|||
router.HandleFunc("/machine/map", noiseServer.NoisePollNetMapHandler)
|
||||
|
||||
server := http.Server{
|
||||
ReadTimeout: types.HTTPReadTimeout,
|
||||
ReadTimeout: types.HTTPTimeout,
|
||||
}
|
||||
|
||||
noiseServer.httpBaseConfig = &http.Server{
|
||||
Handler: router,
|
||||
ReadHeaderTimeout: types.HTTPReadTimeout,
|
||||
ReadHeaderTimeout: types.HTTPTimeout,
|
||||
}
|
||||
noiseServer.http2Server = &http2.Server{}
|
||||
|
||||
|
@ -225,15 +223,6 @@ func (ns *noiseServer) NoisePollNetMapHandler(
|
|||
key.NodePublic{},
|
||||
)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
log.Warn().
|
||||
Str("handler", "NoisePollNetMap").
|
||||
Uint64("node.id", node.ID.Uint64()).
|
||||
Msgf("Ignoring request, cannot find node with key %s", mapRequest.NodeKey.String())
|
||||
http.Error(writer, "Internal error", http.StatusNotFound)
|
||||
|
||||
return
|
||||
}
|
||||
log.Error().
|
||||
Str("handler", "NoisePollNetMap").
|
||||
Uint64("node.id", node.ID.Uint64()).
|
||||
|
@ -242,58 +231,59 @@ func (ns *noiseServer) NoisePollNetMapHandler(
|
|||
|
||||
return
|
||||
}
|
||||
log.Debug().
|
||||
Str("handler", "NoisePollNetMap").
|
||||
Str("node", node.Hostname).
|
||||
Int("cap_ver", int(mapRequest.Version)).
|
||||
Uint64("node.id", node.ID.Uint64()).
|
||||
Msg("A node sending a MapRequest with Noise protocol")
|
||||
sess := ns.headscale.newMapSession(req.Context(), mapRequest, writer, node)
|
||||
|
||||
session := ns.headscale.newMapSession(req.Context(), mapRequest, writer, node)
|
||||
sess.tracef("a node sending a MapRequest with Noise protocol")
|
||||
|
||||
// If a streaming mapSession exists for this node, close it
|
||||
// and start a new one.
|
||||
if session.isStreaming() {
|
||||
log.Debug().
|
||||
Caller().
|
||||
Uint64("node.id", node.ID.Uint64()).
|
||||
Int("cap_ver", int(mapRequest.Version)).
|
||||
Msg("Aquiring lock to check stream")
|
||||
if sess.isStreaming() {
|
||||
sess.tracef("aquiring lock to check stream")
|
||||
|
||||
ns.headscale.mapSessionMu.Lock()
|
||||
if oldSession, ok := ns.headscale.mapSessions[node.ID]; ok {
|
||||
log.Info().
|
||||
Caller().
|
||||
Uint64("node.id", node.ID.Uint64()).
|
||||
Msg("Node has an open streaming session, replacing")
|
||||
oldSession.close()
|
||||
if _, ok := ns.headscale.mapSessions[node.ID]; ok {
|
||||
// NOTE/TODO(kradalby): From how I understand the protocol, when
|
||||
// a client connects with stream=true, and already has a streaming
|
||||
// connection open, the correct way is to close the current channel
|
||||
// and replace it. However, I cannot manage to get that working with
|
||||
// some sort of lock/block happening on the cancelCh in the streaming
|
||||
// session.
|
||||
// Not closing the channel and replacing it puts us in a weird state
|
||||
// which keeps a ghost stream open, receiving keep alives, but no updates.
|
||||
//
|
||||
// Typically a new connection is opened when one exists as a client which
|
||||
// is already authenticated reconnects (e.g. down, then up). The client will
|
||||
// start auth and streaming at the same time, and then cancel the streaming
|
||||
// when the auth has finished successfully, opening a new connection.
|
||||
//
|
||||
// As a work-around to not replacing, abusing the clients "resilience"
|
||||
// by reject the new connection which will cause the client to immediately
|
||||
// reconnect and "fix" the issue, as the other connection typically has been
|
||||
// closed, meaning there is nothing to replace.
|
||||
//
|
||||
// sess.infof("node has an open stream(%p), replacing with %p", oldSession, sess)
|
||||
// oldSession.close()
|
||||
|
||||
defer ns.headscale.mapSessionMu.Unlock()
|
||||
|
||||
sess.infof("node has an open stream(%p), rejecting new stream", sess)
|
||||
return
|
||||
}
|
||||
|
||||
ns.headscale.mapSessions[node.ID] = session
|
||||
ns.headscale.mapSessions[node.ID] = sess
|
||||
ns.headscale.mapSessionMu.Unlock()
|
||||
log.Debug().
|
||||
Caller().
|
||||
Uint64("node.id", node.ID.Uint64()).
|
||||
Int("cap_ver", int(mapRequest.Version)).
|
||||
Msg("Releasing lock to check stream")
|
||||
sess.tracef("releasing lock to check stream")
|
||||
}
|
||||
|
||||
session.serve()
|
||||
sess.serve()
|
||||
|
||||
if session.isStreaming() {
|
||||
log.Debug().
|
||||
Caller().
|
||||
Uint64("node.id", node.ID.Uint64()).
|
||||
Int("cap_ver", int(mapRequest.Version)).
|
||||
Msg("Aquiring lock to remove stream")
|
||||
if sess.isStreaming() {
|
||||
sess.tracef("aquiring lock to remove stream")
|
||||
ns.headscale.mapSessionMu.Lock()
|
||||
defer ns.headscale.mapSessionMu.Unlock()
|
||||
|
||||
delete(ns.headscale.mapSessions, node.ID)
|
||||
|
||||
ns.headscale.mapSessionMu.Unlock()
|
||||
log.Debug().
|
||||
Caller().
|
||||
Uint64("node.id", node.ID.Uint64()).
|
||||
Int("cap_ver", int(mapRequest.Version)).
|
||||
Msg("Releasing lock to remove stream")
|
||||
sess.tracef("releasing lock to remove stream")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -172,11 +172,19 @@ func (n *Notifier) String() string {
|
|||
n.l.RLock()
|
||||
defer n.l.RUnlock()
|
||||
|
||||
str := []string{"Notifier, in map:\n"}
|
||||
var b strings.Builder
|
||||
b.WriteString("chans:\n")
|
||||
|
||||
for k, v := range n.nodes {
|
||||
str = append(str, fmt.Sprintf("\t%d: %v\n", k, v))
|
||||
fmt.Fprintf(&b, "\t%d: %p\n", k, v)
|
||||
}
|
||||
|
||||
return strings.Join(str, "")
|
||||
b.WriteString("\n")
|
||||
b.WriteString("connected:\n")
|
||||
|
||||
for k, v := range n.connected {
|
||||
fmt.Fprintf(&b, "\t%d: %t\n", k, v)
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
|
|
@ -48,6 +48,8 @@ type mapSession struct {
|
|||
ch chan types.StateUpdate
|
||||
cancelCh chan struct{}
|
||||
|
||||
keepAliveTicker *time.Ticker
|
||||
|
||||
node *types.Node
|
||||
w http.ResponseWriter
|
||||
|
||||
|
@ -85,6 +87,8 @@ func (h *Headscale) newMapSession(
|
|||
ch: updateChan,
|
||||
cancelCh: make(chan struct{}),
|
||||
|
||||
keepAliveTicker: time.NewTicker(keepAliveInterval + (time.Duration(rand.IntN(9000)) * time.Millisecond)),
|
||||
|
||||
// Loggers
|
||||
warnf: warnf,
|
||||
infof: infof,
|
||||
|
@ -100,10 +104,9 @@ func (m *mapSession) close() {
|
|||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case m.cancelCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
m.tracef("mapSession (%p) sending message on cancel chan")
|
||||
m.cancelCh <- struct{}{}
|
||||
m.tracef("mapSession (%p) sent message on cancel chan")
|
||||
}
|
||||
|
||||
func (m *mapSession) isStreaming() bool {
|
||||
|
@ -118,13 +121,6 @@ func (m *mapSession) isReadOnlyUpdate() bool {
|
|||
return !m.req.Stream && m.req.OmitPeers && m.req.ReadOnly
|
||||
}
|
||||
|
||||
func (m *mapSession) flush200() {
|
||||
m.w.WriteHeader(http.StatusOK)
|
||||
if f, ok := m.w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// handlePoll ensures the node gets the appropriate updates from either
|
||||
// polling or immediate responses.
|
||||
//
|
||||
|
@ -211,7 +207,12 @@ func (m *mapSession) serve() {
|
|||
|
||||
m.pollFailoverRoutes("node connected", m.node)
|
||||
|
||||
keepAliveTicker := time.NewTicker(keepAliveInterval + (time.Duration(rand.IntN(9000)) * time.Millisecond))
|
||||
// Upgrade the writer to a ResponseController
|
||||
rc := http.NewResponseController(m.w)
|
||||
|
||||
// Longpolling will break if there is a write timeout,
|
||||
// so it needs to be disabled.
|
||||
rc.SetWriteDeadline(time.Time{})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, m.node.Hostname))
|
||||
defer cancel()
|
||||
|
@ -324,18 +325,16 @@ func (m *mapSession) serve() {
|
|||
startWrite := time.Now()
|
||||
_, err = m.w.Write(data)
|
||||
if err != nil {
|
||||
m.errf(err, "Could not write the map response, for mapSession: %p, stream: %t", m, m.isStreaming())
|
||||
|
||||
m.errf(err, "Could not write the map response, for mapSession: %p", m)
|
||||
return
|
||||
}
|
||||
|
||||
if flusher, ok := m.w.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
} else {
|
||||
log.Error().Msg("Failed to create http flusher")
|
||||
|
||||
err = rc.Flush()
|
||||
if err != nil {
|
||||
m.errf(err, "flushing the map response to client, for mapSession: %p", m)
|
||||
return
|
||||
}
|
||||
|
||||
log.Trace().Str("node", m.node.Hostname).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey.String()).Msg("finished writing mapresp to node")
|
||||
|
||||
m.infof("update sent")
|
||||
|
@ -402,7 +401,7 @@ func (m *mapSession) serve() {
|
|||
derp = true
|
||||
}
|
||||
|
||||
case <-keepAliveTicker.C:
|
||||
case <-m.keepAliveTicker.C:
|
||||
data, err := m.mapper.KeepAliveResponse(m.req, m.node)
|
||||
if err != nil {
|
||||
m.errf(err, "Error generating the keep alive msg")
|
||||
|
@ -415,11 +414,9 @@ func (m *mapSession) serve() {
|
|||
|
||||
return
|
||||
}
|
||||
if flusher, ok := m.w.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
} else {
|
||||
log.Error().Msg("Failed to create http flusher")
|
||||
|
||||
err = rc.Flush()
|
||||
if err != nil {
|
||||
m.errf(err, "flushing keep alive to client, for mapSession: %p", m)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -428,7 +425,7 @@ func (m *mapSession) serve() {
|
|||
|
||||
func (m *mapSession) pollFailoverRoutes(where string, node *types.Node) {
|
||||
update, err := db.Write(m.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
|
||||
return db.FailoverRouteIfAvailable(tx, m.h.nodeNotifier.ConnectedMap(), node)
|
||||
return db.FailoverNodeRoutesIfNeccessary(tx, m.h.nodeNotifier.ConnectedMap(), node)
|
||||
})
|
||||
if err != nil {
|
||||
m.errf(err, fmt.Sprintf("failed to ensure failover routes, %s", where))
|
||||
|
@ -565,7 +562,7 @@ func (m *mapSession) handleEndpointUpdate() {
|
|||
},
|
||||
m.node.ID)
|
||||
|
||||
m.flush200()
|
||||
m.w.WriteHeader(http.StatusOK)
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -654,7 +651,9 @@ func (m *mapSession) handleReadOnlyRequest() {
|
|||
m.errf(err, "Failed to write response")
|
||||
}
|
||||
|
||||
m.flush200()
|
||||
m.w.WriteHeader(http.StatusOK)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func logTracePeerChange(hostname string, hostinfoChange bool, change *tailcfg.PeerChange) {
|
||||
|
|
|
@ -3,7 +3,7 @@ package types
|
|||
import "time"
|
||||
|
||||
const (
|
||||
HTTPReadTimeout = 30 * time.Second
|
||||
HTTPTimeout = 30 * time.Second
|
||||
HTTPShutdownTimeout = 3 * time.Second
|
||||
TLSALPN01ChallengeType = "TLS-ALPN-01"
|
||||
HTTP01ChallengeType = "HTTP-01"
|
||||
|
|
|
@ -124,7 +124,7 @@ func DefaultConfigEnv() map[string]string {
|
|||
"HEADSCALE_PRIVATE_KEY_PATH": "/tmp/private.key",
|
||||
"HEADSCALE_NOISE_PRIVATE_KEY_PATH": "/tmp/noise_private.key",
|
||||
"HEADSCALE_LISTEN_ADDR": "0.0.0.0:8080",
|
||||
"HEADSCALE_METRICS_LISTEN_ADDR": "127.0.0.1:9090",
|
||||
"HEADSCALE_METRICS_LISTEN_ADDR": "0.0.0.0:9090",
|
||||
"HEADSCALE_SERVER_URL": "http://headscale:8080",
|
||||
"HEADSCALE_DERP_URLS": "https://controlplane.tailscale.com/derpmap/default",
|
||||
"HEADSCALE_DERP_AUTO_UPDATE_ENABLED": "false",
|
||||
|
|
|
@ -260,7 +260,7 @@ func New(
|
|||
|
||||
runOptions := &dockertest.RunOptions{
|
||||
Name: hsic.hostname,
|
||||
ExposedPorts: append([]string{portProto}, hsic.extraPorts...),
|
||||
ExposedPorts: append([]string{portProto, "9090/tcp"}, hsic.extraPorts...),
|
||||
Networks: []*dockertest.Network{network},
|
||||
// Cmd: []string{"headscale", "serve"},
|
||||
// TODO(kradalby): Get rid of this hack, we currently need to give us some
|
||||
|
|
|
@ -252,7 +252,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||
|
||||
scenario, err := NewScenario()
|
||||
assertNoErrf(t, "failed to create scenario: %s", err)
|
||||
defer scenario.Shutdown()
|
||||
// defer scenario.Shutdown()
|
||||
|
||||
spec := map[string]int{
|
||||
user: 3,
|
||||
|
|
Loading…
Reference in New Issue