Merge pull request #668 from GrigoriyMikhalkin/graceful-shutdown

graceful shutdown fix
This commit is contained in:
Juan Font 2022-07-22 09:35:40 +02:00 committed by GitHub
commit 581d1f3bfa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 42 additions and 19 deletions

45
app.go
View File

@ -94,7 +94,8 @@ type Headscale struct {
ipAllocationMutex sync.Mutex ipAllocationMutex sync.Mutex
shutdownChan chan struct{} shutdownChan chan struct{}
pollNetMapStreamWG sync.WaitGroup
} }
// Look up the TLS constant relative to user-supplied TLS client // Look up the TLS constant relative to user-supplied TLS client
@ -147,12 +148,13 @@ func NewHeadscale(cfg *Config) (*Headscale, error) {
) )
app := Headscale{ app := Headscale{
cfg: cfg, cfg: cfg,
dbType: cfg.DBtype, dbType: cfg.DBtype,
dbString: dbString, dbString: dbString,
privateKey: privKey, privateKey: privKey,
aclRules: tailcfg.FilterAllowAll, // default allowall aclRules: tailcfg.FilterAllowAll, // default allowall
registrationCache: registrationCache, registrationCache: registrationCache,
pollNetMapStreamWG: sync.WaitGroup{},
} }
err = app.initDB() err = app.initDB()
@ -565,6 +567,8 @@ func (h *Headscale) Serve() error {
// https://github.com/soheilhy/cmux/issues/68 // https://github.com/soheilhy/cmux/issues/68
// https://github.com/soheilhy/cmux/issues/91 // https://github.com/soheilhy/cmux/issues/91
var grpcServer *grpc.Server
var grpcListener net.Listener
if tlsConfig != nil || h.cfg.GRPCAllowInsecure { if tlsConfig != nil || h.cfg.GRPCAllowInsecure {
log.Info().Msgf("Enabling remote gRPC at %s", h.cfg.GRPCAddr) log.Info().Msgf("Enabling remote gRPC at %s", h.cfg.GRPCAddr)
@ -585,12 +589,12 @@ func (h *Headscale) Serve() error {
log.Warn().Msg("gRPC is running without security") log.Warn().Msg("gRPC is running without security")
} }
grpcServer := grpc.NewServer(grpcOptions...) grpcServer = grpc.NewServer(grpcOptions...)
v1.RegisterHeadscaleServiceServer(grpcServer, newHeadscaleV1APIServer(h)) v1.RegisterHeadscaleServiceServer(grpcServer, newHeadscaleV1APIServer(h))
reflection.Register(grpcServer) reflection.Register(grpcServer)
grpcListener, err := net.Listen("tcp", h.cfg.GRPCAddr) grpcListener, err = net.Listen("tcp", h.cfg.GRPCAddr)
if err != nil { if err != nil {
return fmt.Errorf("failed to bind to TCP address: %w", err) return fmt.Errorf("failed to bind to TCP address: %w", err)
} }
@ -666,7 +670,7 @@ func (h *Headscale) Serve() error {
syscall.SIGTERM, syscall.SIGTERM,
syscall.SIGQUIT, syscall.SIGQUIT,
syscall.SIGHUP) syscall.SIGHUP)
go func(c chan os.Signal) { sigFunc := func(c chan os.Signal) {
// Wait for a SIGINT or SIGKILL: // Wait for a SIGINT or SIGKILL:
for { for {
sig := <-c sig := <-c
@ -676,7 +680,7 @@ func (h *Headscale) Serve() error {
Str("signal", sig.String()). Str("signal", sig.String()).
Msg("Received SIGHUP, reloading ACL and Config") Msg("Received SIGHUP, reloading ACL and Config")
// TODO(kradalby): Reload config on SIGHUP // TODO(kradalby): Reload config on SIGHUP
if h.cfg.ACL.PolicyPath != "" { if h.cfg.ACL.PolicyPath != "" {
aclPath := AbsolutePathFromConfigPath(h.cfg.ACL.PolicyPath) aclPath := AbsolutePathFromConfigPath(h.cfg.ACL.PolicyPath)
@ -696,7 +700,8 @@ func (h *Headscale) Serve() error {
Str("signal", sig.String()). Str("signal", sig.String()).
Msg("Received signal to stop, shutting down gracefully") Msg("Received signal to stop, shutting down gracefully")
h.shutdownChan <- struct{}{} close(h.shutdownChan)
h.pollNetMapStreamWG.Wait()
// Gracefully shut down servers // Gracefully shut down servers
ctx, cancel := context.WithTimeout(context.Background(), HTTPShutdownTimeout) ctx, cancel := context.WithTimeout(context.Background(), HTTPShutdownTimeout)
@ -708,6 +713,11 @@ func (h *Headscale) Serve() error {
} }
grpcSocket.GracefulStop() grpcSocket.GracefulStop()
if grpcServer != nil {
grpcServer.GracefulStop()
grpcListener.Close()
}
// Close network listeners // Close network listeners
promHTTPListener.Close() promHTTPListener.Close()
httpListener.Close() httpListener.Close()
@ -734,7 +744,12 @@ func (h *Headscale) Serve() error {
os.Exit(0) os.Exit(0)
} }
} }
}(sigc) }
errorGroup.Go(func() error {
sigFunc(sigc)
return nil
})
return errorGroup.Wait() return errorGroup.Wait()
} }
@ -758,13 +773,13 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
} }
switch h.cfg.TLS.LetsEncrypt.ChallengeType { switch h.cfg.TLS.LetsEncrypt.ChallengeType {
case "TLS-ALPN-01": case tlsALPN01ChallengeType:
// Configuration via autocert with TLS-ALPN-01 (https://tools.ietf.org/html/rfc8737) // Configuration via autocert with TLS-ALPN-01 (https://tools.ietf.org/html/rfc8737)
// The RFC requires that the validation is done on port 443; in other words, headscale // The RFC requires that the validation is done on port 443; in other words, headscale
// must be reachable on port 443. // must be reachable on port 443.
return certManager.TLSConfig(), nil return certManager.TLSConfig(), nil
case "HTTP-01": case http01ChallengeType:
// Configuration via autocert with HTTP-01. This requires listening on // Configuration via autocert with HTTP-01. This requires listening on
// port 80 for the certificate validation in addition to the headscale // port 80 for the certificate validation in addition to the headscale
// service, which can be configured to run on any other port. // service, which can be configured to run on any other port.

View File

@ -18,6 +18,11 @@ import (
"tailscale.com/types/dnstype" "tailscale.com/types/dnstype"
) )
const (
tlsALPN01ChallengeType = "TLS-ALPN-01"
http01ChallengeType = "HTTP-01"
)
// Config contains the initial Headscale configuration. // Config contains the initial Headscale configuration.
type Config struct { type Config struct {
ServerURL string ServerURL string
@ -136,7 +141,7 @@ func LoadConfig(path string, isFile bool) error {
viper.AutomaticEnv() viper.AutomaticEnv()
viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache") viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache")
viper.SetDefault("tls_letsencrypt_challenge_type", "HTTP-01") viper.SetDefault("tls_letsencrypt_challenge_type", http01ChallengeType)
viper.SetDefault("tls_client_auth_mode", "relaxed") viper.SetDefault("tls_client_auth_mode", "relaxed")
viper.SetDefault("log_level", "info") viper.SetDefault("log_level", "info")
@ -179,15 +184,15 @@ func LoadConfig(path string, isFile bool) error {
} }
if (viper.GetString("tls_letsencrypt_hostname") != "") && if (viper.GetString("tls_letsencrypt_hostname") != "") &&
(viper.GetString("tls_letsencrypt_challenge_type") == "TLS-ALPN-01") && (viper.GetString("tls_letsencrypt_challenge_type") == tlsALPN01ChallengeType) &&
(!strings.HasSuffix(viper.GetString("listen_addr"), ":443")) { (!strings.HasSuffix(viper.GetString("listen_addr"), ":443")) {
// this is only a warning because there could be something sitting in front of headscale that redirects the traffic (e.g. an iptables rule) // this is only a warning because there could be something sitting in front of headscale that redirects the traffic (e.g. an iptables rule)
log.Warn(). log.Warn().
Msg("Warning: when using tls_letsencrypt_hostname with TLS-ALPN-01 as challenge type, headscale must be reachable on port 443, i.e. listen_addr should probably end in :443") Msg("Warning: when using tls_letsencrypt_hostname with TLS-ALPN-01 as challenge type, headscale must be reachable on port 443, i.e. listen_addr should probably end in :443")
} }
if (viper.GetString("tls_letsencrypt_challenge_type") != "HTTP-01") && if (viper.GetString("tls_letsencrypt_challenge_type") != http01ChallengeType) &&
(viper.GetString("tls_letsencrypt_challenge_type") != "TLS-ALPN-01") { (viper.GetString("tls_letsencrypt_challenge_type") != tlsALPN01ChallengeType) {
errorText += "Fatal config error: the only supported values for tls_letsencrypt_challenge_type are HTTP-01 and TLS-ALPN-01\n" errorText += "Fatal config error: the only supported values for tls_letsencrypt_challenge_type are HTTP-01 and TLS-ALPN-01\n"
} }

View File

@ -290,6 +290,9 @@ func (h *Headscale) PollNetMapStream(
keepAliveChan chan []byte, keepAliveChan chan []byte,
updateChan chan struct{}, updateChan chan struct{},
) { ) {
h.pollNetMapStreamWG.Add(1)
defer h.pollNetMapStreamWG.Done()
ctx := context.WithValue(req.Context(), machineNameContextKey, machine.Hostname) ctx := context.WithValue(req.Context(), machineNameContextKey, machine.Hostname)
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)