diff --git a/README.md b/README.md index 060824b9..59997d8c 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,7 @@ headscale implements this coordination server. - [x] Taildrop (File Sharing) - [x] Support for alternative IP ranges in the tailnets (default Tailscale's 100.64.0.0/10) - [x] DNS (passing DNS servers to nodes) +- [x] Single-Sign-On (via Open ID Connect) - [x] Share nodes between namespaces - [x] MagicDNS (see `docs/`) @@ -49,7 +50,6 @@ headscale implements this coordination server. Suggestions/PRs welcomed! - ## Running headscale Please have a look at the documentation under [`docs/`](docs/). diff --git a/api.go b/api.go index 0c1f3d41..ad87a7e9 100644 --- a/api.go +++ b/api.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net/http" + "strings" "time" "github.com/rs/zerolog/log" @@ -64,7 +65,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { Str("handler", "Registration"). Err(err). Msg("Cannot parse machine key") - machineRegistrations.WithLabelValues("unkown", "web", "error", "unknown").Inc() + machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc() c.String(http.StatusInternalServerError, "Sad!") return } @@ -75,37 +76,33 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { Str("handler", "Registration"). Err(err). Msg("Cannot decode message") - machineRegistrations.WithLabelValues("unkown", "web", "error", "unknown").Inc() + machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc() c.String(http.StatusInternalServerError, "Very sad!") return } now := time.Now().UTC() - var m Machine - if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is( - result.Error, - gorm.ErrRecordNotFound, - ) { + m, err := h.GetMachineByMachineKey(mKey.HexString()) + if errors.Is(err, gorm.ErrRecordNotFound) { log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine") - m = Machine{ - Expiry: &req.Expiry, - MachineKey: mKey.HexString(), - Name: req.Hostinfo.Hostname, - NodeKey: wgkey.Key(req.NodeKey).HexString(), - LastSuccessfulUpdate: &now, + newMachine := Machine{ + Expiry: &time.Time{}, + MachineKey: mKey.HexString(), + Name: req.Hostinfo.Hostname, } - if err := h.db.Create(&m).Error; err != nil { + if err := h.db.Create(&newMachine).Error; err != nil { log.Error(). Str("handler", "Registration"). Err(err). Msg("Could not create row") - machineRegistrations.WithLabelValues("unkown", "web", "error", m.Namespace.Name).Inc() + machineRegistrations.WithLabelValues("unknown", "web", "error", m.Namespace.Name).Inc() return } + m = &newMachine } if !m.Registered && req.Auth.AuthKey != "" { - h.handleAuthKey(c, h.db, mKey, req, m) + h.handleAuthKey(c, h.db, mKey, req, *m) return } @@ -113,7 +110,36 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { // We have the updated key! if m.NodeKey == wgkey.Key(req.NodeKey).HexString() { - if m.Registered { + + // The client sends an Expiry in the past if the client is requesting to expire the key (aka logout) + // https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648 + if !req.Expiry.IsZero() && req.Expiry.UTC().Before(now) { + log.Info(). + Str("handler", "Registration"). + Str("machine", m.Name). + Msg("Client requested logout") + + m.Expiry = &req.Expiry // save the expiry so that the machine is marked as expired + h.db.Save(&m) + + resp.AuthURL = "" + resp.MachineAuthorized = false + resp.User = *m.Namespace.toUser() + respBody, err := encode(resp, &mKey, h.privateKey) + if err != nil { + log.Error(). + Str("handler", "Registration"). + Err(err). + Msg("Cannot encode message") + c.String(http.StatusInternalServerError, "") + return + } + c.Data(200, "application/json; charset=utf-8", respBody) + return + } + + if m.Registered && m.Expiry.UTC().After(now) { + // The machine registration is valid, respond with redirect to /map log.Debug(). Str("handler", "Registration"). Str("machine", m.Name). @@ -122,6 +148,8 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { resp.AuthURL = "" resp.MachineAuthorized = true resp.User = *m.Namespace.toUser() + resp.Login = *m.Namespace.toLogin() + respBody, err := encode(resp, &mKey, h.privateKey) if err != nil { log.Error(). @@ -137,12 +165,30 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { return } + // The client has registered before, but has expired log.Debug(). Str("handler", "Registration"). Str("machine", m.Name). - Msg("Not registered and not NodeKey rotation. Sending a authurl to register") - resp.AuthURL = fmt.Sprintf("%s/register?key=%s", - h.cfg.ServerURL, mKey.HexString()) + Msg("Machine registration has expired. Sending a authurl to register") + + if h.cfg.OIDC.Issuer != "" { + resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", + strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) + } else { + resp.AuthURL = fmt.Sprintf("%s/register?key=%s", + strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) + } + + // When a client connects, it may request a specific expiry time in its + // RegisterRequest (https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L634) + // RequestedExpiry is used to store the clients requested expiry time since the authentication flow is broken + // into two steps (which cant pass arbitrary data between them easily) and needs to be + // retrieved again after the user has authenticated. After the authentication flow + // completes, RequestedExpiry is copied into Expiry. + m.RequestedExpiry = &req.Expiry + + h.db.Save(&m) + respBody, err := encode(resp, &mKey, h.privateKey) if err != nil { log.Error(). @@ -158,8 +204,8 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { return } - // The NodeKey we have matches OldNodeKey, which means this is a refresh after an key expiration - if m.NodeKey == wgkey.Key(req.OldNodeKey).HexString() { + // The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration + if m.NodeKey == wgkey.Key(req.OldNodeKey).HexString() && m.Expiry.UTC().After(now) { log.Debug(). Str("handler", "Registration"). Str("machine", m.Name). @@ -182,35 +228,23 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { return } - // We arrive here after a client is restarted without finalizing the authentication flow or - // when headscale is stopped in the middle of the auth process. - if m.Registered { - log.Debug(). - Str("handler", "Registration"). - Str("machine", m.Name). - Msg("The node is sending us a new NodeKey, but machine is registered. All clear for /map") - resp.AuthURL = "" - resp.MachineAuthorized = true - resp.User = *m.Namespace.toUser() - respBody, err := encode(resp, &mKey, h.privateKey) - if err != nil { - log.Error(). - Str("handler", "Registration"). - Err(err). - Msg("Cannot encode message") - c.String(http.StatusInternalServerError, "") - return - } - c.Data(200, "application/json; charset=utf-8", respBody) - return - } - + // The machine registration is new, redirect the client to the registration URL log.Debug(). Str("handler", "Registration"). Str("machine", m.Name). Msg("The node is sending us a new NodeKey, sending auth url") - resp.AuthURL = fmt.Sprintf("%s/register?key=%s", - h.cfg.ServerURL, mKey.HexString()) + if h.cfg.OIDC.Issuer != "" { + resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) + } else { + resp.AuthURL = fmt.Sprintf("%s/register?key=%s", + strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) + } + + // save the requested expiry time for retrieval later in the authentication flow + m.RequestedExpiry = &req.Expiry + m.NodeKey = wgkey.Key(req.NodeKey).HexString() // save the NodeKey + h.db.Save(&m) + respBody, err := encode(resp, &mKey, h.privateKey) if err != nil { log.Error(). diff --git a/app.go b/app.go index 50343d3c..683d3a72 100644 --- a/app.go +++ b/app.go @@ -15,6 +15,10 @@ import ( "sync" "time" + "github.com/coreos/go-oidc/v3/oidc" + "github.com/patrickmn/go-cache" + "golang.org/x/oauth2" + "github.com/gin-gonic/gin" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" apiV1 "github.com/juanfont/headscale/gen/go/headscale/v1" @@ -75,6 +79,18 @@ type Config struct { DNSConfig *tailcfg.DNSConfig UnixSocket string + + OIDC OIDCConfig + + MaxMachineRegistrationDuration time.Duration + DefaultMachineRegistrationDuration time.Duration +} + +type OIDCConfig struct { + Issuer string + ClientID string + ClientSecret string + MatchMap map[string]string } type DERPConfig struct { @@ -100,6 +116,10 @@ type Headscale struct { aclRules *[]tailcfg.FilterRule lastStateChange sync.Map + + oidcProvider *oidc.Provider + oauth2Config *oauth2.Config + oidcStateCache *cache.Cache } // NewHeadscale returns the Headscale app. @@ -140,6 +160,13 @@ func NewHeadscale(cfg Config) (*Headscale, error) { return nil, err } + if cfg.OIDC.Issuer != "" { + err = h.initOIDC() + if err != nil { + return nil, err + } + } + if h.cfg.DNSConfig != nil && h.cfg.DNSConfig.Proxied { // if MagicDNS magicDNSDomains, err := generateMagicDNSRootDomains(h.cfg.IPPrefix, h.cfg.BaseDomain) if err != nil { @@ -379,6 +406,8 @@ func (h *Headscale) Serve() error { r.GET("/register", h.RegisterWebAPI) r.POST("/machine/:id/map", h.PollNetMapHandler) r.POST("/machine/:id", h.RegistrationHandler) + r.GET("/oidc/register/:mkey", h.RegisterOIDC) + r.GET("/oidc/callback", h.OIDCCallback) r.GET("/apple", h.AppleMobileConfig) r.GET("/apple/:platform", h.ApplePlatformConfig) r.GET("/swagger", SwaggerUI) diff --git a/cli.go b/cli.go index 9c5b66e5..8610b334 100644 --- a/cli.go +++ b/cli.go @@ -23,6 +23,8 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err return nil, errors.New("Machine not found") } + h.updateMachineExpiry(&m) // update the machine's expiry before bailing if its already registered + if m.isAlreadyRegistered() { return nil, errors.New("Machine already registered") } @@ -36,5 +38,6 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err m.Registered = true m.RegisterMethod = "cli" h.db.Save(&m) + return &m, nil } diff --git a/cli_test.go b/cli_test.go index 528a115e..291b5df1 100644 --- a/cli_test.go +++ b/cli_test.go @@ -1,6 +1,8 @@ package headscale import ( + "time" + "gopkg.in/check.v1" ) @@ -8,14 +10,18 @@ func (s *Suite) TestRegisterMachine(c *check.C) { n, err := h.CreateNamespace("test") c.Assert(err, check.IsNil) + now := time.Now().UTC() + m := Machine{ - ID: 0, - MachineKey: "8ce002a935f8c394e55e78fbbb410576575ff8ec5cfa2e627e4b807f1be15b0e", - NodeKey: "bar", - DiscoKey: "faa", - Name: "testmachine", - NamespaceID: n.ID, - IPAddress: "10.0.0.1", + ID: 0, + MachineKey: "8ce002a935f8c394e55e78fbbb410576575ff8ec5cfa2e627e4b807f1be15b0e", + NodeKey: "bar", + DiscoKey: "faa", + Name: "testmachine", + NamespaceID: n.ID, + IPAddress: "10.0.0.1", + Expiry: &now, + RequestedExpiry: &now, } h.db.Save(&m) diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go index 8e044bf5..035f508e 100644 --- a/cmd/headscale/cli/utils.go +++ b/cmd/headscale/cli/utils.go @@ -8,6 +8,7 @@ import ( "net/url" "os" "path/filepath" + "regexp" "strconv" "strings" "time" @@ -209,6 +210,26 @@ func absPath(path string) string { } func getHeadscaleConfig() headscale.Config { + // maxMachineRegistrationDuration is the maximum time headscale will allow a client to (optionally) request for + // the machine key expiry time. RegisterRequests with Expiry times that are more than + // maxMachineRegistrationDuration in the future will be clamped to (now + maxMachineRegistrationDuration) + maxMachineRegistrationDuration, _ := time.ParseDuration( + "10h", + ) // use 10h here because it is the length of a standard business day plus a small amount of leeway + if viper.GetDuration("max_machine_registration_duration") >= time.Second { + maxMachineRegistrationDuration = viper.GetDuration("max_machine_registration_duration") + } + + // defaultMachineRegistrationDuration is the default time assigned to a machine registration if one is not + // specified by the tailscale client. It is the default amount of time a machine registration is valid for + // (ie the amount of time before the user has to re-authenticate when requesting a connection) + defaultMachineRegistrationDuration, _ := time.ParseDuration( + "8h", + ) // use 8h here because it's the length of a standard business day + if viper.GetDuration("default_machine_registration_duration") >= time.Second { + defaultMachineRegistrationDuration = viper.GetDuration("default_machine_registration_duration") + } + dnsConfig, baseDomain := GetDNSConfig() derpConfig := GetDERPConfig() @@ -245,6 +266,15 @@ func getHeadscaleConfig() headscale.Config { ACMEURL: viper.GetString("acme_url"), UnixSocket: viper.GetString("unix_socket"), + + OIDC: headscale.OIDCConfig{ + Issuer: viper.GetString("oidc.issuer"), + ClientID: viper.GetString("oidc.client_id"), + ClientSecret: viper.GetString("oidc.client_secret"), + }, + + MaxMachineRegistrationDuration: maxMachineRegistrationDuration, + DefaultMachineRegistrationDuration: defaultMachineRegistrationDuration, } } @@ -263,6 +293,8 @@ func getHeadscaleApp() (*headscale.Headscale, error) { cfg := getHeadscaleConfig() + cfg.OIDC.MatchMap = loadOIDCMatchMap() + h, err := headscale.NewHeadscale(cfg) if err != nil { return nil, err @@ -402,3 +434,15 @@ func (t tokenAuth) GetRequestMetadata(ctx context.Context, in ...string) (map[st func (tokenAuth) RequireTransportSecurity() bool { return true } + +// loadOIDCMatchMap is a wrapper around viper to verifies that the keys in +// the match map is valid regex strings. +func loadOIDCMatchMap() map[string]string { + strMap := viper.GetStringMapString("oidc.domain_map") + + for oidcMatcher := range strMap { + _ = regexp.MustCompile(oidcMatcher) + } + + return strMap +} diff --git a/config-example.yaml b/config-example.yaml index 0eaf4c2d..60369306 100644 --- a/config-example.yaml +++ b/config-example.yaml @@ -69,3 +69,16 @@ dns_config: # Note: for local development, you probably want to change this to: # unix_socket: ./headscale.sock unix_socket: /var/run/headscale.sock +# headscale supports experimental OpenID connect support, +# it is still being tested and might have some bugs, please +# help us test it. +# OpenID Connect +# oidc: +# issuer: "https://your-oidc.issuer.com/path" +# client_id: "your-oidc-client-id" +# client_secret: "your-oidc-client-secret" +# +# # Domain map is used to map incomming users (by their email) to +# # a namespace. The key can be a string, or regex. +# domain_map: +# ".*": default-namespace diff --git a/go.mod b/go.mod index 002e3ede..296def85 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/Microsoft/go-winio v0.5.0 // indirect github.com/cenkalti/backoff/v4 v4.1.1 // indirect github.com/containerd/continuity v0.1.0 // indirect + github.com/coreos/go-oidc/v3 v3.1.0 github.com/docker/cli v20.10.8+incompatible // indirect github.com/docker/docker v20.10.8+incompatible // indirect github.com/efekarakus/termcolor v1.0.1 @@ -23,6 +24,7 @@ require ( github.com/moby/term v0.0.0-20210619224110-3f7ff695adc6 // indirect github.com/opencontainers/runc v1.0.2 // indirect github.com/ory/dockertest/v3 v3.7.0 + github.com/patrickmn/go-cache v2.1.0+incompatible github.com/prometheus/client_golang v1.11.0 github.com/pterm/pterm v0.12.30 github.com/rs/zerolog v1.25.0 @@ -36,6 +38,7 @@ require ( github.com/zsais/go-gin-prometheus v0.1.0 golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 golang.org/x/net v0.0.0-20210913180222-943fd674d43e // indirect + golang.org/x/oauth2 v0.0.0-20210819190943-2bc19b11175f golang.org/x/sync v0.0.0-20210220032951-036812b2e83c golang.org/x/sys v0.0.0-20210910150752-751e447fb3d0 // indirect google.golang.org/genproto v0.0.0-20210903162649-d08c68adba83 diff --git a/go.sum b/go.sum index 9e31f2fe..96eb50bd 100644 --- a/go.sum +++ b/go.sum @@ -153,6 +153,8 @@ github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkE github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= github.com/coreos/etcd v3.3.13+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= github.com/coreos/go-iptables v0.6.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q= +github.com/coreos/go-oidc/v3 v3.1.0 h1:6avEvcdvTa1qYsOZ6I5PRkSYHzpTNWgKYmaJfaYbrRw= +github.com/coreos/go-oidc/v3 v3.1.0/go.mod h1:rEJ/idjfUyfkBit1eI1fvyr+64/g9dcKpAm8MJMesvo= github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd v0.0.0-20180511133405-39ca1b05acc7/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= @@ -767,6 +769,8 @@ github.com/ory/dockertest/v3 v3.7.0 h1:Bijzonc69Ont3OU0a3TWKJ1Rzlh3TsDXP1JrTAkSm github.com/ory/dockertest/v3 v3.7.0/go.mod h1:PvCCgnP7AfBZeVrzwiUTjZx/IUXlGLC1zQlUQrLIlUE= github.com/pact-foundation/pact-go v1.0.4/go.mod h1:uExwJY4kCzNPcHRj+hCR/HBbOOIwwtUjcrb0b5/5kLM= github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= +github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= +github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/pborman/getopt v1.1.0/go.mod h1:FxXoW1Re00sQG/+KIkuSqRL/LwQgSkv7uyac+STFsbk= github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k= github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= @@ -1137,6 +1141,7 @@ golang.org/x/net v0.0.0-20200301022130-244492dfa37a/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200421231249-e086a090c8fd/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200501053045-e0ff5e5a1de5/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/net v0.0.0-20200505041828-1ed23360d12c/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200506145744-7e3656a0809f/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200513185701-a91f0712d120/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= @@ -1174,6 +1179,7 @@ golang.org/x/oauth2 v0.0.0-20210220000619-9bb904979d93/go.mod h1:KelEdhl1UZF7XfJ golang.org/x/oauth2 v0.0.0-20210313182246-cd4f82c27b84/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20210402161424-2e8d93401602/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20210427180440-81ed05c6b58c/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= +golang.org/x/oauth2 v0.0.0-20210819190943-2bc19b11175f h1:Qmd2pbz05z7z6lm0DrgQVVPuBm92jqujBKMHMOlOQEw= golang.org/x/oauth2 v0.0.0-20210819190943-2bc19b11175f/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -1436,6 +1442,7 @@ google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7 google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0= google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= +google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c= google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= @@ -1553,6 +1560,8 @@ gopkg.in/ini.v1 v1.51.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/ini.v1 v1.62.0 h1:duBzk771uxoUuOlyRLkHsygud9+5lrlGjdFBb4mSKDU= gopkg.in/ini.v1 v1.62.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo= +gopkg.in/square/go-jose.v2 v2.5.1 h1:7odma5RETjNHWJnR32wx8t+Io4djHE1PqxCFx3iiZ2w= +gopkg.in/square/go-jose.v2 v2.5.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI= gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74= diff --git a/machine.go b/machine.go index 8986ac92..ccd30e3e 100644 --- a/machine.go +++ b/machine.go @@ -36,6 +36,7 @@ type Machine struct { LastSeen *time.Time LastSuccessfulUpdate *time.Time Expiry *time.Time + RequestedExpiry *time.Time HostInfo datatypes.JSON Endpoints datatypes.JSON @@ -56,6 +57,38 @@ func (m Machine) isAlreadyRegistered() bool { return m.Registered } +// isExpired returns whether the machine registration has expired +func (m Machine) isExpired() bool { + return time.Now().UTC().After(*m.Expiry) +} + +// If the Machine is expired, updateMachineExpiry updates the Machine Expiry time to the maximum allowed duration, +// or the default duration if no Expiry time was requested by the client. The expiry time here does not (yet) cause +// a client to be disconnected, however they will have to re-auth the machine if they attempt to reconnect after the +// expiry time. +func (h *Headscale) updateMachineExpiry(m *Machine) { + if m.isExpired() { + now := time.Now().UTC() + maxExpiry := now.Add(h.cfg.MaxMachineRegistrationDuration) // calculate the maximum expiry + defaultExpiry := now.Add(h.cfg.DefaultMachineRegistrationDuration) // calculate the default expiry + + // clamp the expiry time of the machine registration to the maximum allowed, or use the default if none supplied + if maxExpiry.Before(*m.RequestedExpiry) { + log.Debug(). + Msgf("Clamping registration expiry time to maximum: %v (%v)", maxExpiry, h.cfg.MaxMachineRegistrationDuration) + m.Expiry = &maxExpiry + } else if m.RequestedExpiry.IsZero() { + log.Debug().Msgf("Using default machine registration expiry time: %v (%v)", defaultExpiry, h.cfg.DefaultMachineRegistrationDuration) + m.Expiry = &defaultExpiry + } else { + log.Debug().Msgf("Using requested machine registration expiry time: %v", m.RequestedExpiry) + m.Expiry = m.RequestedExpiry + } + + h.db.Save(&m) + } +} + func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) { log.Trace(). Str("func", "getDirectPeers"). @@ -326,7 +359,11 @@ func (ms MachinesP) String() string { return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp)) } -func (ms Machines) toNodes(baseDomain string, dnsConfig *tailcfg.DNSConfig, includeRoutes bool) ([]*tailcfg.Node, error) { +func (ms Machines) toNodes( + baseDomain string, + dnsConfig *tailcfg.DNSConfig, + includeRoutes bool, +) ([]*tailcfg.Node, error) { nodes := make([]*tailcfg.Node, len(ms)) for index, machine := range ms { @@ -446,8 +483,10 @@ func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, include } n := tailcfg.Node{ - ID: tailcfg.NodeID(m.ID), // this is the actual ID - StableID: tailcfg.StableNodeID(strconv.FormatUint(m.ID, 10)), // in headscale, unlike tailcontrol server, IDs are permanent + ID: tailcfg.NodeID(m.ID), // this is the actual ID + StableID: tailcfg.StableNodeID( + strconv.FormatUint(m.ID, 10), + ), // in headscale, unlike tailcontrol server, IDs are permanent Name: hostname, User: tailcfg.UserID(m.NamespaceID), Key: tailcfg.NodeKey(nKey), diff --git a/namespaces.go b/namespaces.go index c350e8c8..d7c1e035 100644 --- a/namespaces.go +++ b/namespaces.go @@ -246,6 +246,17 @@ func (n *Namespace) toUser() *tailcfg.User { return &u } +func (n *Namespace) toLogin() *tailcfg.Login { + l := tailcfg.Login{ + ID: tailcfg.LoginID(n.ID), + LoginName: n.Name, + DisplayName: n.Name, + ProfilePicURL: "", + Domain: "headscale.net", + } + return &l +} + func getMapResponseUserProfiles(m Machine, peers Machines) []tailcfg.UserProfile { namespaceMap := make(map[string]Namespace) namespaceMap[m.Namespace.Name] = m.Namespace diff --git a/oidc.go b/oidc.go new file mode 100644 index 00000000..51c443db --- /dev/null +++ b/oidc.go @@ -0,0 +1,228 @@ +package headscale + +import ( + "context" + "crypto/rand" + "encoding/hex" + "fmt" + "net/http" + "regexp" + "strings" + "time" + + "github.com/coreos/go-oidc/v3/oidc" + "github.com/gin-gonic/gin" + "github.com/patrickmn/go-cache" + "github.com/rs/zerolog/log" + "golang.org/x/oauth2" +) + +type IDTokenClaims struct { + Name string `json:"name,omitempty"` + Groups []string `json:"groups,omitempty"` + Email string `json:"email"` + Username string `json:"preferred_username,omitempty"` +} + +func (h *Headscale) initOIDC() error { + var err error + // grab oidc config if it hasn't been already + if h.oauth2Config == nil { + h.oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDC.Issuer) + + if err != nil { + log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error()) + return err + } + + h.oauth2Config = &oauth2.Config{ + ClientID: h.cfg.OIDC.ClientID, + ClientSecret: h.cfg.OIDC.ClientSecret, + Endpoint: h.oidcProvider.Endpoint(), + RedirectURL: fmt.Sprintf("%s/oidc/callback", strings.TrimSuffix(h.cfg.ServerURL, "/")), + Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, + } + } + + // init the state cache if it hasn't been already + if h.oidcStateCache == nil { + h.oidcStateCache = cache.New(time.Minute*5, time.Minute*10) + } + + return nil +} + +// RegisterOIDC redirects to the OIDC provider for authentication +// Puts machine key in cache so the callback can retrieve it using the oidc state param +// Listens in /oidc/register/:mKey +func (h *Headscale) RegisterOIDC(c *gin.Context) { + mKeyStr := c.Param("mkey") + if mKeyStr == "" { + c.String(http.StatusBadRequest, "Wrong params") + return + } + + b := make([]byte, 16) + _, err := rand.Read(b) + if err != nil { + log.Error().Msg("could not read 16 bytes from rand") + c.String(http.StatusInternalServerError, "could not read 16 bytes from rand") + return + } + + stateStr := hex.EncodeToString(b)[:32] + + // place the machine key into the state cache, so it can be retrieved later + h.oidcStateCache.Set(stateStr, mKeyStr, time.Minute*5) + + authUrl := h.oauth2Config.AuthCodeURL(stateStr) + log.Debug().Msgf("Redirecting to %s for authentication", authUrl) + + c.Redirect(http.StatusFound, authUrl) +} + +// OIDCCallback handles the callback from the OIDC endpoint +// Retrieves the mkey from the state cache and adds the machine to the users email namespace +// TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities +// TODO: Add groups information from OIDC tokens into machine HostInfo +// Listens in /oidc/callback +func (h *Headscale) OIDCCallback(c *gin.Context) { + code := c.Query("code") + state := c.Query("state") + + if code == "" || state == "" { + c.String(http.StatusBadRequest, "Wrong params") + return + } + + oauth2Token, err := h.oauth2Config.Exchange(context.Background(), code) + if err != nil { + c.String(http.StatusBadRequest, "Could not exchange code for token") + return + } + + log.Debug().Msgf("AccessToken: %v", oauth2Token.AccessToken) + + rawIDToken, rawIDTokenOK := oauth2Token.Extra("id_token").(string) + if !rawIDTokenOK { + c.String(http.StatusBadRequest, "Could not extract ID Token") + return + } + + verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDC.ClientID}) + + idToken, err := verifier.Verify(context.Background(), rawIDToken) + if err != nil { + c.String(http.StatusBadRequest, "Failed to verify id token: %s", err.Error()) + return + } + + // TODO: we can use userinfo at some point to grab additional information about the user (groups membership, etc) + //userInfo, err := oidcProvider.UserInfo(context.Background(), oauth2.StaticTokenSource(oauth2Token)) + //if err != nil { + // c.String(http.StatusBadRequest, fmt.Sprintf("Failed to retrieve userinfo: %s", err)) + // return + //} + + // Extract custom claims + var claims IDTokenClaims + if err = idToken.Claims(&claims); err != nil { + c.String(http.StatusBadRequest, fmt.Sprintf("Failed to decode id token claims: %s", err)) + return + } + + // retrieve machinekey from state cache + mKeyIf, mKeyFound := h.oidcStateCache.Get(state) + + if !mKeyFound { + log.Error().Msg("requested machine state key expired before authorisation completed") + c.String(http.StatusBadRequest, "state has expired") + return + } + mKeyStr, mKeyOK := mKeyIf.(string) + + if !mKeyOK { + log.Error().Msg("could not get machine key from cache") + c.String(http.StatusInternalServerError, "could not get machine key from cache") + return + } + + // retrieve machine information + m, err := h.GetMachineByMachineKey(mKeyStr) + if err != nil { + log.Error().Msg("machine key not found in database") + c.String(http.StatusInternalServerError, "could not get machine info from database") + return + } + + now := time.Now().UTC() + + if nsName, ok := h.getNamespaceFromEmail(claims.Email); ok { + // register the machine if it's new + if !m.Registered { + + log.Debug().Msg("Registering new machine after successful callback") + + ns, err := h.GetNamespace(nsName) + if err != nil { + ns, err = h.CreateNamespace(nsName) + + if err != nil { + log.Error().Msgf("could not create new namespace '%s'", claims.Email) + c.String(http.StatusInternalServerError, "could not create new namespace") + return + } + } + + ip, err := h.getAvailableIP() + if err != nil { + c.String(http.StatusInternalServerError, "could not get an IP from the pool") + return + } + + m.IPAddress = ip.String() + m.NamespaceID = ns.ID + m.Registered = true + m.RegisterMethod = "oidc" + m.LastSuccessfulUpdate = &now + h.db.Save(&m) + } + + h.updateMachineExpiry(m) + + c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(` + + +

headscale

+

+ Authenticated as %s, you can now close this window. +

+ + + +`, claims.Email))) + + } + + log.Error(). + Str("email", claims.Email). + Str("username", claims.Username). + Str("machine", m.Name). + Msg("Email could not be mapped to a namespace") + c.String(http.StatusBadRequest, "email from claim could not be mapped to a namespace") +} + +// getNamespaceFromEmail passes the users email through a list of "matchers" +// and iterates through them until it matches and returns a namespace. +// If no match is found, an empty string will be returned. +// TODO(kradalby): golang Maps key order is not stable, so this list is _not_ deterministic. Find a way to make the list of keys stable, preferably in the order presented in a users configuration. +func (h *Headscale) getNamespaceFromEmail(email string) (string, bool) { + for match, namespace := range h.cfg.OIDC.MatchMap { + regex := regexp.MustCompile(match) + if regex.MatchString(email) { + return namespace, true + } + } + + return "", false +} diff --git a/oidc_test.go b/oidc_test.go new file mode 100644 index 00000000..b501ff14 --- /dev/null +++ b/oidc_test.go @@ -0,0 +1,174 @@ +package headscale + +import ( + "sync" + "testing" + + "github.com/coreos/go-oidc/v3/oidc" + "github.com/patrickmn/go-cache" + "golang.org/x/oauth2" + "gorm.io/gorm" + "tailscale.com/tailcfg" + "tailscale.com/types/wgkey" +) + +func TestHeadscale_getNamespaceFromEmail(t *testing.T) { + type fields struct { + cfg Config + db *gorm.DB + dbString string + dbType string + dbDebug bool + publicKey *wgkey.Key + privateKey *wgkey.Private + aclPolicy *ACLPolicy + aclRules *[]tailcfg.FilterRule + lastStateChange sync.Map + oidcProvider *oidc.Provider + oauth2Config *oauth2.Config + oidcStateCache *cache.Cache + } + type args struct { + email string + } + tests := []struct { + name string + fields fields + args args + want string + want1 bool + }{ + { + name: "match all", + fields: fields{ + cfg: Config{ + OIDC: OIDCConfig{ + MatchMap: map[string]string{ + ".*": "space", + }, + }, + }, + }, + args: args{ + email: "test@example.no", + }, + want: "space", + want1: true, + }, + { + name: "match user", + fields: fields{ + cfg: Config{ + OIDC: OIDCConfig{ + MatchMap: map[string]string{ + "specific@user\\.no": "user-namespace", + }, + }, + }, + }, + args: args{ + email: "specific@user.no", + }, + want: "user-namespace", + want1: true, + }, + { + name: "match domain", + fields: fields{ + cfg: Config{ + OIDC: OIDCConfig{ + MatchMap: map[string]string{ + ".*@example\\.no": "example", + }, + }, + }, + }, + args: args{ + email: "test@example.no", + }, + want: "example", + want1: true, + }, + { + name: "multi match domain", + fields: fields{ + cfg: Config{ + OIDC: OIDCConfig{ + MatchMap: map[string]string{ + ".*@example\\.no": "exammple", + ".*@gmail\\.com": "gmail", + }, + }, + }, + }, + args: args{ + email: "someuser@gmail.com", + }, + want: "gmail", + want1: true, + }, + { + name: "no match domain", + fields: fields{ + cfg: Config{ + OIDC: OIDCConfig{ + MatchMap: map[string]string{ + ".*@dontknow.no": "never", + }, + }, + }, + }, + args: args{ + email: "test@wedontknow.no", + }, + want: "", + want1: false, + }, + { + name: "multi no match domain", + fields: fields{ + cfg: Config{ + OIDC: OIDCConfig{ + MatchMap: map[string]string{ + ".*@dontknow.no": "never", + ".*@wedontknow.no": "other", + ".*\\.no": "stuffy", + }, + }, + }, + }, + args: args{ + email: "tasy@nonofthem.com", + }, + want: "", + want1: false, + }, + } + //nolint + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := &Headscale{ + cfg: tt.fields.cfg, + db: tt.fields.db, + dbString: tt.fields.dbString, + dbType: tt.fields.dbType, + dbDebug: tt.fields.dbDebug, + publicKey: tt.fields.publicKey, + privateKey: tt.fields.privateKey, + aclPolicy: tt.fields.aclPolicy, + aclRules: tt.fields.aclRules, + lastStateChange: tt.fields.lastStateChange, + oidcProvider: tt.fields.oidcProvider, + oauth2Config: tt.fields.oauth2Config, + oidcStateCache: tt.fields.oidcStateCache, + } + got, got1 := h.getNamespaceFromEmail(tt.args.email) + if got != tt.want { + t.Errorf("Headscale.getNamespaceFromEmail() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("Headscale.getNamespaceFromEmail() got1 = %v, want %v", got1, tt.want1) + } + }) + } +}