From e7a2501fe865dc04cb999be4d50332dc4e0f57f0 Mon Sep 17 00:00:00 2001 From: Raal Goff Date: Sun, 26 Sep 2021 16:53:05 +0800 Subject: [PATCH 01/15] initial work on OIDC (SSO) integration --- api.go | 17 +- app.go | 6 + cmd/headscale/cli/utils.go | 4 + go.mod | 3 + go.sum | 11 ++ oidc.go | 310 +++++++++++++++++++++++++++++++++++++ 6 files changed, 347 insertions(+), 4 deletions(-) create mode 100644 oidc.go diff --git a/api.go b/api.go index e2a56185..2c5a1321 100644 --- a/api.go +++ b/api.go @@ -133,8 +133,13 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { Str("handler", "Registration"). Str("machine", m.Name). Msg("Not registered and not NodeKey rotation. Sending a authurl to register") - resp.AuthURL = fmt.Sprintf("%s/register?key=%s", - h.cfg.ServerURL, mKey.HexString()) + + if h.cfg.OIDCEndpoint != "" { + resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", h.cfg.ServerURL, mKey.HexString()) + } else { + resp.AuthURL = fmt.Sprintf("%s/register?key=%s", + h.cfg.ServerURL, mKey.HexString()) + } respBody, err := encode(resp, &mKey, h.privateKey) if err != nil { log.Error(). @@ -199,8 +204,12 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { Str("handler", "Registration"). Str("machine", m.Name). Msg("The node is sending us a new NodeKey, sending auth url") - resp.AuthURL = fmt.Sprintf("%s/register?key=%s", - h.cfg.ServerURL, mKey.HexString()) + if h.cfg.OIDCEndpoint != "" { + resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", h.cfg.ServerURL, mKey.HexString()) + } else { + resp.AuthURL = fmt.Sprintf("%s/register?key=%s", + h.cfg.ServerURL, mKey.HexString()) + } respBody, err := encode(resp, &mKey, h.privateKey) if err != nil { log.Error(). diff --git a/app.go b/app.go index c903d83f..81871a87 100644 --- a/app.go +++ b/app.go @@ -45,6 +45,10 @@ type Config struct { TLSKeyPath string DNSConfig *tailcfg.DNSConfig + + OIDCEndpoint string + OIDCClientID string + OIDCClientSecret string } // Headscale represents the base app of the service @@ -168,6 +172,8 @@ func (h *Headscale) Serve() error { r.GET("/register", h.RegisterWebAPI) r.POST("/machine/:id/map", h.PollNetMapHandler) r.POST("/machine/:id", h.RegistrationHandler) + r.GET("/oidc/register/:mKey", h.RegisterOIDC) + r.GET("/oidc/callback", h.OIDCCallback) var err error timeout := 30 * time.Second diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go index 7ada6693..b7faad57 100644 --- a/cmd/headscale/cli/utils.go +++ b/cmd/headscale/cli/utils.go @@ -170,6 +170,10 @@ func getHeadscaleApp() (*headscale.Headscale, error) { TLSKeyPath: absPath(viper.GetString("tls_key_path")), DNSConfig: GetDNSConfig(), + + OIDCEndpoint: viper.GetString("oidc_endpoint"), + OIDCClientID: viper.GetString("oidc_client_id"), + OIDCClientSecret: viper.GetString("oidc_client_secret"), } h, err := headscale.NewHeadscale(cfg) diff --git a/go.mod b/go.mod index 8709119b..031460e8 100644 --- a/go.mod +++ b/go.mod @@ -17,8 +17,10 @@ require ( github.com/moby/term v0.0.0-20210619224110-3f7ff695adc6 // indirect github.com/opencontainers/runc v1.0.2 // indirect github.com/ory/dockertest/v3 v3.7.0 + github.com/patrickmn/go-cache v2.1.0+incompatible github.com/pterm/pterm v0.12.30 github.com/rs/zerolog v1.25.0 + github.com/s12v/go-jwks v0.2.1 github.com/spf13/cobra v1.2.1 github.com/spf13/viper v1.8.1 github.com/stretchr/testify v1.7.0 @@ -28,6 +30,7 @@ require ( golang.org/x/net v0.0.0-20210913180222-943fd674d43e // indirect golang.org/x/sys v0.0.0-20210910150752-751e447fb3d0 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c + gopkg.in/square/go-jose.v2 v2.3.1 gopkg.in/yaml.v2 v2.4.0 gorm.io/datatypes v1.0.2 gorm.io/driver/postgres v1.1.1 diff --git a/go.sum b/go.sum index ac934dbe..195fb21d 100644 --- a/go.sum +++ b/go.sum @@ -711,6 +711,8 @@ github.com/ory/dockertest/v3 v3.7.0 h1:Bijzonc69Ont3OU0a3TWKJ1Rzlh3TsDXP1JrTAkSm github.com/ory/dockertest/v3 v3.7.0/go.mod h1:PvCCgnP7AfBZeVrzwiUTjZx/IUXlGLC1zQlUQrLIlUE= github.com/pact-foundation/pact-go v1.0.4/go.mod h1:uExwJY4kCzNPcHRj+hCR/HBbOOIwwtUjcrb0b5/5kLM= github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= +github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= +github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/pborman/getopt v1.1.0/go.mod h1:FxXoW1Re00sQG/+KIkuSqRL/LwQgSkv7uyac+STFsbk= github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k= github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= @@ -786,6 +788,8 @@ github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQD github.com/ryancurrah/gomodguard v1.1.0/go.mod h1:4O8tr7hBODaGE6VIhfJDHcwzh5GUccKSJBU0UMXJFVM= github.com/ryanrolds/sqlclosecheck v0.3.0/go.mod h1:1gREqxyTGR3lVtpngyFo3hZAgk0KCtEdgEkHwDbigdA= github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= +github.com/s12v/go-jwks v0.2.1 h1:2zShofKJoSXztWyh5ASPfpzuQrE+b+Sum9JJdif05Po= +github.com/s12v/go-jwks v0.2.1/go.mod h1:DmmtP4Etd59Y90j8zmTS4z61MKu0QPvgioAXv+mqyjQ= github.com/samuel/go-zookeeper v0.0.0-20190923202752-2cc03de413da/go.mod h1:gi+0XIa01GRL2eRQVjQkKGqKF3SF9vZR/HnPullcV2E= github.com/sassoftware/go-rpmutils v0.0.0-20190420191620-a8f1baeba37b/go.mod h1:am+Fp8Bt506lA3Rk3QCmSqmYmLMnPDhdDUcosQCAx+I= github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= @@ -845,6 +849,8 @@ github.com/spf13/viper v1.7.0/go.mod h1:8WkrPz2fc9jxqZNCJI/76HCieCp4Q8HaLFoCha5q github.com/spf13/viper v1.7.1/go.mod h1:8WkrPz2fc9jxqZNCJI/76HCieCp4Q8HaLFoCha5qpdg= github.com/spf13/viper v1.8.1 h1:Kq1fyeebqsBfbjZj4EL7gj2IO0mMaiyjYUWcUsl2O44= github.com/spf13/viper v1.8.1/go.mod h1:o0Pch8wJ9BVSWGQMbra6iw0oQ5oktSIBaujf1rJH9Ns= +github.com/square/go-jose v2.5.1+incompatible h1:FC+BwI9FzJZWpKaE0yUhFNbp/CyFHndARzuGVME/LGk= +github.com/square/go-jose v2.5.1+incompatible/go.mod h1:7MxpAF/1WTVUu8Am+T5kNy+t0902CaLWM4Z745MkOa8= github.com/ssgreg/nlreturn/v2 v2.1.0/go.mod h1:E/iiPB78hV7Szg2YfRgyIrk1AD6JVMTRkkxBiELzh2I= github.com/streadway/amqp v0.0.0-20190404075320-75d898a42a94/go.mod h1:AZpEONHx3DKn8O/DFsRAY58/XVQiIPMTMB1SddzLXVw= github.com/streadway/amqp v0.0.0-20190827072141-edfb9018d271/go.mod h1:AZpEONHx3DKn8O/DFsRAY58/XVQiIPMTMB1SddzLXVw= @@ -958,6 +964,7 @@ go4.org/mem v0.0.0-20201119185036-c04c5a6ff174/go.mod h1:reUoABIJ9ikfM5sgtSF3Wus go4.org/unsafe/assume-no-moving-gc v0.0.0-20201222175341-b30ae309168e/go.mod h1:FftLjUGFEDu5k8lt0ddY+HcrH/qU/0qk+H8j9/nTl3E= go4.org/unsafe/assume-no-moving-gc v0.0.0-20201222180813-1025295fd063 h1:1tk03FUNpulq2cuWpXZWj649rwJpk0d20rxWiopKRmc= go4.org/unsafe/assume-no-moving-gc v0.0.0-20201222180813-1025295fd063/go.mod h1:FftLjUGFEDu5k8lt0ddY+HcrH/qU/0qk+H8j9/nTl3E= +golang.org/x/crypto v0.0.0-20180621125126-a49355c7e3f8/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20181029021203-45a5f77698d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190219172222-a4c6cb3142f2/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= @@ -1021,6 +1028,7 @@ golang.org/x/mod v0.4.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180729183719-c4299a1a0d85/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181023162649-9b4f9f5ad519/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -1097,6 +1105,7 @@ golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -1436,6 +1445,8 @@ gopkg.in/ini.v1 v1.51.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/ini.v1 v1.62.0 h1:duBzk771uxoUuOlyRLkHsygud9+5lrlGjdFBb4mSKDU= gopkg.in/ini.v1 v1.62.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo= +gopkg.in/square/go-jose.v2 v2.3.1 h1:SK5KegNXmKmqE342YYN2qPHEnUYeoMiXXl1poUlI+o4= +gopkg.in/square/go-jose.v2 v2.3.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI= gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74= diff --git a/oidc.go b/oidc.go new file mode 100644 index 00000000..0006cc32 --- /dev/null +++ b/oidc.go @@ -0,0 +1,310 @@ +package headscale + +import ( + "crypto/rand" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "github.com/gin-gonic/gin" + "github.com/patrickmn/go-cache" + "github.com/rs/zerolog/log" + "github.com/s12v/go-jwks" + "gopkg.in/square/go-jose.v2/jwt" + "gorm.io/gorm" + "io" + "net/http" + "net/url" + "strings" + "time" +) + +type OpenIDConfiguration struct { + Issuer string `json:"issuer"` + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + JWKSURI string `json:"jwks_uri"` +} + +type OpenIDTokens struct { + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` + IdToken string `json:"id_token"` + NotBeforePolicy int `json:"not-before-policy,omitempty"` + RefreshExpiresIn int `json:"refresh_expires_in"` + RefreshToken string `json:"refresh_token"` + Scope string `json:"scope"` + SessionState string `json:"session_state,omitempty"` + TokenType string `json:"token_type,omitempty"` +} + +type AccessToken struct { + jwt.Claims + Name string `json:"name,omitempty"` + Groups []string `json:"groups,omitempty"` + Email string `json:"email"` + Username string `json:"preferred_username,omitempty"` +} + +var oidcConfig *OpenIDConfiguration +var stateCache *cache.Cache +var jwksSource *jwks.WebSource +var jwksClient jwks.JWKSClient + +func verifyToken(token string) (*AccessToken, error) { + + if jwksClient == nil { + jwksSource = jwks.NewWebSource(oidcConfig.JWKSURI) + jwksClient = jwks.NewDefaultClient( + jwksSource, + time.Hour, // Refresh keys every 1 hour + 12*time.Hour, // Expire keys after 12 hours + ) + } + + //decode jwt + tok, err := jwt.ParseSigned(token) + if err != nil { + return nil, err + } + + if tok.Headers[0].KeyID != "" { + log.Debug().Msgf("Checking KID %s\n", tok.Headers[0].KeyID) + + jwk, err := jwksClient.GetSignatureKey(tok.Headers[0].KeyID) + if err != nil { + return nil, err + } + + claims := AccessToken{} + + err = tok.Claims(jwk.Certificates[0].PublicKey, &claims) + if err != nil { + return nil, err + } else { + + err = claims.Validate(jwt.Expected{ + Time: time.Now(), + }) + if err != nil { + return nil, err + } + + return &claims, nil + } + + } else { + return nil, err + } +} + +func getOIDCConfig(oidcConfigURL string) (*OpenIDConfiguration, error) { + client := &http.Client{} + req, err := http.NewRequest("GET", oidcConfigURL, nil) + if err != nil { + log.Error().Msgf("%v", err) + return nil, err + } + + log.Debug().Msgf("Requesting OIDC Config from %s", oidcConfigURL) + + oidcConfigResp, err := client.Do(req) + if err != nil { + log.Error().Msgf("%v", err) + return nil, err + } + defer oidcConfigResp.Body.Close() + + var oidcConfig OpenIDConfiguration + + err = json.NewDecoder(oidcConfigResp.Body).Decode(&oidcConfig) + if err != nil { + log.Error().Msgf("%v", err) + return nil, err + } + return &oidcConfig, nil +} + +func (h *Headscale) exchangeCodeForTokens(code string, redirectURI string) (*OpenIDTokens, error) { + var err error + + if oidcConfig == nil { + oidcConfig, err = getOIDCConfig(fmt.Sprintf("%s.well-known/openid-configuration", h.cfg.OIDCEndpoint)) + if err != nil { + return nil, err + } + } + + params := url.Values{} + params.Add("grant_type", "authorization_code") + params.Add("code", code) + params.Add("client_id", h.cfg.OIDCClientID) + params.Add("client_secret", h.cfg.OIDCClientSecret) + params.Add("redirect_uri", redirectURI) + + client := &http.Client{} + req, err := http.NewRequest("POST", oidcConfig.TokenEndpoint, strings.NewReader(params.Encode())) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + + if err != nil { + log.Error().Msgf("%v", err) + return nil, err + } + + tokenResp, err := client.Do(req) + if err != nil { + log.Error().Msgf("%v", err) + return nil, err + } + defer tokenResp.Body.Close() + + if tokenResp.StatusCode != 200 { + b, _ := io.ReadAll(tokenResp.Body) + log.Error().Msgf("%s", b) + } + + var tokens OpenIDTokens + + err = json.NewDecoder(tokenResp.Body).Decode(&tokens) + if err != nil { + log.Error().Msgf("%v", err) + return nil, err + } + + log.Info().Msg("Successfully exchanged code for tokens") + + return &tokens, nil +} + +// RegisterOIDC redirects to the OIDC provider for authentication +// Puts machine key in cache so the callback can retrieve it using the oidc state param +// Listens in /oidc/register/:mKey +func (h *Headscale) RegisterOIDC(c *gin.Context) { + mKeyStr := c.Param("mKey") + if mKeyStr == "" { + c.String(http.StatusBadRequest, "Wrong params") + return + } + + var err error + + // grab oidc config if it hasn't been already + if oidcConfig == nil { + oidcConfig, err = getOIDCConfig(fmt.Sprintf("%s.well-known/openid-configuration", h.cfg.OIDCEndpoint)) + + if err != nil { + c.String(http.StatusInternalServerError, "Could not retrieve OIDC Config") + return + } + } + + b := make([]byte, 16) + _, err = rand.Read(b) + stateStr := hex.EncodeToString(b)[:32] + + // init the state cache if it hasn't been already + if stateCache == nil { + stateCache = cache.New(time.Minute*5, time.Minute*10) + } + + // place the machine key into the state cache, so it can be retrieved later + stateCache.Set(stateStr, mKeyStr, time.Minute*5) + + params := url.Values{} + params.Add("response_type", "code") + params.Add("client_id", h.cfg.OIDCClientID) + params.Add("redirect_uri", fmt.Sprintf("%s/oidc/callback", h.cfg.ServerURL)) + params.Add("scope", "openid") + params.Add("state", stateStr) + + authUrl := fmt.Sprintf("%s?%s", oidcConfig.AuthorizationEndpoint, params.Encode()) + log.Debug().Msg(authUrl) + + c.Redirect(http.StatusFound, authUrl) +} + +// OIDCCallback handles the callback from the OIDC endpoint +// Retrieves the mkey from the state cache, if the machine is not registered, presents a confirmation +// Listens in /oidc/callback +func (h *Headscale) OIDCCallback(c *gin.Context) { + + code := c.Query("code") + state := c.Query("state") + + if code == "" || state == "" { + c.String(http.StatusBadRequest, "Wrong params") + return + } + + redirectURI := fmt.Sprintf("%s/oidc/callback", h.cfg.ServerURL) + + tokens, err := h.exchangeCodeForTokens(code, redirectURI) + + if err != nil { + c.String(http.StatusBadRequest, "Could not exchange code for token") + return + } + + //verify tokens + claims, err := verifyToken(tokens.AccessToken) + + if err != nil { + c.String(http.StatusBadRequest, "invalid tokens") + return + } + + //retrieve machinekey from state cache + mKeyIf, mKeyFound := stateCache.Get(state) + + if !mKeyFound { + c.String(http.StatusBadRequest, "state has expired") + return + } + mKeyStr, mKeyOK := mKeyIf.(string) + + if !mKeyOK { + c.String(http.StatusInternalServerError, "could not get machine key from cache") + return + } + + // retrieve machine information + var m Machine + if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKeyStr); errors.Is(result.Error, gorm.ErrRecordNotFound) { + log.Error().Msg("machine key not found in database") + c.String(http.StatusInternalServerError, "could not get machine info from database") + return + } + + //look for a namespace of the users email for now + if !m.Registered { + + ns, err := h.GetNamespace(claims.Email) + if err != nil { + ns, err = h.CreateNamespace(claims.Email) + } + + ip, err := h.getAvailableIP() + if err != nil { + c.String(http.StatusInternalServerError, "could not get an IP from the pool") + return + } + + m.IPAddress = ip.String() + m.NamespaceID = ns.ID + m.Registered = true + m.RegisterMethod = "oidc" + h.db.Save(&m) + } + + c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(` + + +

headscale

+

+ Authenticated, you can now close this window. +

+ + + +`))) +} From b22a9781a22a41834dcb87b96b8ae2f87df17d55 Mon Sep 17 00:00:00 2001 From: Raal Goff Date: Sun, 26 Sep 2021 21:12:36 +0800 Subject: [PATCH 02/15] fix linter errors, error out if jwt does not contain a key id --- oidc.go | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/oidc.go b/oidc.go index 0006cc32..dabd8b03 100644 --- a/oidc.go +++ b/oidc.go @@ -94,7 +94,7 @@ func verifyToken(token string) (*AccessToken, error) { } } else { - return nil, err + return nil, errors.New("JWT does not contain a key id") } } @@ -200,6 +200,13 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) { b := make([]byte, 16) _, err = rand.Read(b) + + if err != nil { + log.Error().Msg("could not read 16 bytes from rand") + c.String(http.StatusInternalServerError, "could not read 16 bytes from rand") + return + } + stateStr := hex.EncodeToString(b)[:32] // init the state cache if it hasn't been already @@ -281,6 +288,13 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { ns, err := h.GetNamespace(claims.Email) if err != nil { ns, err = h.CreateNamespace(claims.Email) + + if err != nil { + log.Error().Msgf("could not create new namespace '%s'", claims.Email) + c.String(http.StatusInternalServerError, "could not create new namespace") + return + } + } ip, err := h.getAvailableIP() @@ -301,10 +315,10 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {

headscale

- Authenticated, you can now close this window. + Authenticated as %s, you can now close this window.

-`))) +`, claims.Email))) } From c487591437afb292701c8a905dc7ce9ad0477562 Mon Sep 17 00:00:00 2001 From: Raal Goff Date: Wed, 6 Oct 2021 17:19:15 +0800 Subject: [PATCH 03/15] use go-oidc instead of verifying and extracting tokens ourselves, rename oidc_endpoint to oidc_issuer to be more inline with spec --- api.go | 4 +- app.go | 6 +- cmd/headscale/cli/utils.go | 2 +- go.mod | 7 +- go.sum | 8 +- machine.go | 5 + oidc.go | 222 ++++++++----------------------------- 7 files changed, 69 insertions(+), 185 deletions(-) diff --git a/api.go b/api.go index 2c5a1321..fb54f3cf 100644 --- a/api.go +++ b/api.go @@ -134,7 +134,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { Str("machine", m.Name). Msg("Not registered and not NodeKey rotation. Sending a authurl to register") - if h.cfg.OIDCEndpoint != "" { + if h.cfg.OIDCIssuer != "" { resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", h.cfg.ServerURL, mKey.HexString()) } else { resp.AuthURL = fmt.Sprintf("%s/register?key=%s", @@ -204,7 +204,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { Str("handler", "Registration"). Str("machine", m.Name). Msg("The node is sending us a new NodeKey, sending auth url") - if h.cfg.OIDCEndpoint != "" { + if h.cfg.OIDCIssuer != "" { resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", h.cfg.ServerURL, mKey.HexString()) } else { resp.AuthURL = fmt.Sprintf("%s/register?key=%s", diff --git a/app.go b/app.go index 6f24c82b..3c4b307a 100644 --- a/app.go +++ b/app.go @@ -46,7 +46,7 @@ type Config struct { DNSConfig *tailcfg.DNSConfig - OIDCEndpoint string + OIDCIssuer string OIDCClientID string OIDCClientSecret string } @@ -172,11 +172,11 @@ func (h *Headscale) Serve() error { r.GET("/register", h.RegisterWebAPI) r.POST("/machine/:id/map", h.PollNetMapHandler) r.POST("/machine/:id", h.RegistrationHandler) - r.GET("/oidc/register/:mKey", h.RegisterOIDC) + r.GET("/oidc/register/:mkey", h.RegisterOIDC) r.GET("/oidc/callback", h.OIDCCallback) r.GET("/apple", h.AppleMobileConfig) r.GET("/apple/:platform", h.ApplePlatformConfig) - + var err error timeout := 30 * time.Second diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go index b7faad57..6ccdcdee 100644 --- a/cmd/headscale/cli/utils.go +++ b/cmd/headscale/cli/utils.go @@ -171,7 +171,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) { DNSConfig: GetDNSConfig(), - OIDCEndpoint: viper.GetString("oidc_endpoint"), + OIDCIssuer: viper.GetString("oidc_issuer"), OIDCClientID: viper.GetString("oidc_client_id"), OIDCClientSecret: viper.GetString("oidc_client_secret"), } diff --git a/go.mod b/go.mod index c1d4561b..a770338b 100644 --- a/go.mod +++ b/go.mod @@ -7,11 +7,12 @@ require ( github.com/Microsoft/go-winio v0.5.0 // indirect github.com/cenkalti/backoff/v4 v4.1.1 // indirect github.com/containerd/continuity v0.1.0 // indirect + github.com/coreos/go-oidc/v3 v3.1.0 github.com/docker/cli v20.10.8+incompatible // indirect github.com/docker/docker v20.10.8+incompatible // indirect github.com/efekarakus/termcolor v1.0.1 github.com/gin-gonic/gin v1.7.4 - github.com/gofrs/uuid v4.0.0+incompatible // indirect + github.com/gofrs/uuid v4.0.0+incompatible github.com/google/go-github v17.0.0+incompatible // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/hako/durafmt v0.0.0-20210608085754-5c1018a4e16b @@ -28,13 +29,13 @@ require ( github.com/spf13/viper v1.8.1 github.com/stretchr/testify v1.7.0 github.com/tailscale/hujson v0.0.0-20210818175511-7360507a6e88 - github.com/tcnksm/go-latest v0.0.0-20170313132115-e3007ae9052e // indirect + github.com/tcnksm/go-latest v0.0.0-20170313132115-e3007ae9052e github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 golang.org/x/net v0.0.0-20210913180222-943fd674d43e // indirect + golang.org/x/oauth2 v0.0.0-20210402161424-2e8d93401602 golang.org/x/sys v0.0.0-20210910150752-751e447fb3d0 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c - gopkg.in/square/go-jose.v2 v2.3.1 gopkg.in/yaml.v2 v2.4.0 gorm.io/datatypes v1.0.2 gorm.io/driver/postgres v1.1.1 diff --git a/go.sum b/go.sum index 9a76d176..fc498e75 100644 --- a/go.sum +++ b/go.sum @@ -143,6 +143,8 @@ github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkE github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= github.com/coreos/etcd v3.3.13+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= github.com/coreos/go-iptables v0.6.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q= +github.com/coreos/go-oidc/v3 v3.1.0 h1:6avEvcdvTa1qYsOZ6I5PRkSYHzpTNWgKYmaJfaYbrRw= +github.com/coreos/go-oidc/v3 v3.1.0/go.mod h1:rEJ/idjfUyfkBit1eI1fvyr+64/g9dcKpAm8MJMesvo= github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd v0.0.0-20180511133405-39ca1b05acc7/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= @@ -1067,6 +1069,7 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200301022130-244492dfa37a/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200501053045-e0ff5e5a1de5/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/net v0.0.0-20200505041828-1ed23360d12c/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200506145744-7e3656a0809f/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200513185701-a91f0712d120/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= @@ -1101,6 +1104,7 @@ golang.org/x/oauth2 v0.0.0-20201208152858-08078c50e5b5/go.mod h1:KelEdhl1UZF7XfJ golang.org/x/oauth2 v0.0.0-20210218202405-ba52d332ba99/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20210220000619-9bb904979d93/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20210313182246-cd4f82c27b84/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= +golang.org/x/oauth2 v0.0.0-20210402161424-2e8d93401602 h1:0Ja1LBD+yisY6RWM/BH7TJVXWsSjs2VwBSmvSX4HdBc= golang.org/x/oauth2 v0.0.0-20210402161424-2e8d93401602/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -1355,6 +1359,7 @@ google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7 google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0= google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= +google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c= google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= @@ -1452,8 +1457,9 @@ gopkg.in/ini.v1 v1.51.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/ini.v1 v1.62.0 h1:duBzk771uxoUuOlyRLkHsygud9+5lrlGjdFBb4mSKDU= gopkg.in/ini.v1 v1.62.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo= -gopkg.in/square/go-jose.v2 v2.3.1 h1:SK5KegNXmKmqE342YYN2qPHEnUYeoMiXXl1poUlI+o4= gopkg.in/square/go-jose.v2 v2.3.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= +gopkg.in/square/go-jose.v2 v2.5.1 h1:7odma5RETjNHWJnR32wx8t+Io4djHE1PqxCFx3iiZ2w= +gopkg.in/square/go-jose.v2 v2.5.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI= gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74= diff --git a/machine.go b/machine.go index 1d4939c1..b5c821f5 100644 --- a/machine.go +++ b/machine.go @@ -50,6 +50,11 @@ func (m Machine) isAlreadyRegistered() bool { return m.Registered } +// isExpired returns whether the machine registration has expired +func (m Machine) isExpired() bool { + return time.Now().UTC().After(*m.Expiry) +} + // toNode converts a Machine into a Tailscale Node. includeRoutes is false for shared nodes // as per the expected behaviour in the official SaaS func (m Machine) toNode(includeRoutes bool) (*tailcfg.Node, error) { diff --git a/oidc.go b/oidc.go index dabd8b03..aa80911b 100644 --- a/oidc.go +++ b/oidc.go @@ -1,186 +1,37 @@ package headscale import ( + "context" "crypto/rand" "encoding/hex" - "encoding/json" "errors" "fmt" + "github.com/coreos/go-oidc/v3/oidc" "github.com/gin-gonic/gin" "github.com/patrickmn/go-cache" "github.com/rs/zerolog/log" - "github.com/s12v/go-jwks" - "gopkg.in/square/go-jose.v2/jwt" + "golang.org/x/oauth2" "gorm.io/gorm" - "io" "net/http" - "net/url" - "strings" "time" ) -type OpenIDConfiguration struct { - Issuer string `json:"issuer"` - AuthorizationEndpoint string `json:"authorization_endpoint"` - TokenEndpoint string `json:"token_endpoint"` - JWKSURI string `json:"jwks_uri"` -} - -type OpenIDTokens struct { - AccessToken string `json:"access_token"` - ExpiresIn int `json:"expires_in"` - IdToken string `json:"id_token"` - NotBeforePolicy int `json:"not-before-policy,omitempty"` - RefreshExpiresIn int `json:"refresh_expires_in"` - RefreshToken string `json:"refresh_token"` - Scope string `json:"scope"` - SessionState string `json:"session_state,omitempty"` - TokenType string `json:"token_type,omitempty"` -} - -type AccessToken struct { - jwt.Claims +type IDTokenClaims struct { Name string `json:"name,omitempty"` Groups []string `json:"groups,omitempty"` Email string `json:"email"` Username string `json:"preferred_username,omitempty"` } -var oidcConfig *OpenIDConfiguration +var oidcProvider *oidc.Provider +var oauth2Config *oauth2.Config var stateCache *cache.Cache -var jwksSource *jwks.WebSource -var jwksClient jwks.JWKSClient - -func verifyToken(token string) (*AccessToken, error) { - - if jwksClient == nil { - jwksSource = jwks.NewWebSource(oidcConfig.JWKSURI) - jwksClient = jwks.NewDefaultClient( - jwksSource, - time.Hour, // Refresh keys every 1 hour - 12*time.Hour, // Expire keys after 12 hours - ) - } - - //decode jwt - tok, err := jwt.ParseSigned(token) - if err != nil { - return nil, err - } - - if tok.Headers[0].KeyID != "" { - log.Debug().Msgf("Checking KID %s\n", tok.Headers[0].KeyID) - - jwk, err := jwksClient.GetSignatureKey(tok.Headers[0].KeyID) - if err != nil { - return nil, err - } - - claims := AccessToken{} - - err = tok.Claims(jwk.Certificates[0].PublicKey, &claims) - if err != nil { - return nil, err - } else { - - err = claims.Validate(jwt.Expected{ - Time: time.Now(), - }) - if err != nil { - return nil, err - } - - return &claims, nil - } - - } else { - return nil, errors.New("JWT does not contain a key id") - } -} - -func getOIDCConfig(oidcConfigURL string) (*OpenIDConfiguration, error) { - client := &http.Client{} - req, err := http.NewRequest("GET", oidcConfigURL, nil) - if err != nil { - log.Error().Msgf("%v", err) - return nil, err - } - - log.Debug().Msgf("Requesting OIDC Config from %s", oidcConfigURL) - - oidcConfigResp, err := client.Do(req) - if err != nil { - log.Error().Msgf("%v", err) - return nil, err - } - defer oidcConfigResp.Body.Close() - - var oidcConfig OpenIDConfiguration - - err = json.NewDecoder(oidcConfigResp.Body).Decode(&oidcConfig) - if err != nil { - log.Error().Msgf("%v", err) - return nil, err - } - return &oidcConfig, nil -} - -func (h *Headscale) exchangeCodeForTokens(code string, redirectURI string) (*OpenIDTokens, error) { - var err error - - if oidcConfig == nil { - oidcConfig, err = getOIDCConfig(fmt.Sprintf("%s.well-known/openid-configuration", h.cfg.OIDCEndpoint)) - if err != nil { - return nil, err - } - } - - params := url.Values{} - params.Add("grant_type", "authorization_code") - params.Add("code", code) - params.Add("client_id", h.cfg.OIDCClientID) - params.Add("client_secret", h.cfg.OIDCClientSecret) - params.Add("redirect_uri", redirectURI) - - client := &http.Client{} - req, err := http.NewRequest("POST", oidcConfig.TokenEndpoint, strings.NewReader(params.Encode())) - req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - - if err != nil { - log.Error().Msgf("%v", err) - return nil, err - } - - tokenResp, err := client.Do(req) - if err != nil { - log.Error().Msgf("%v", err) - return nil, err - } - defer tokenResp.Body.Close() - - if tokenResp.StatusCode != 200 { - b, _ := io.ReadAll(tokenResp.Body) - log.Error().Msgf("%s", b) - } - - var tokens OpenIDTokens - - err = json.NewDecoder(tokenResp.Body).Decode(&tokens) - if err != nil { - log.Error().Msgf("%v", err) - return nil, err - } - - log.Info().Msg("Successfully exchanged code for tokens") - - return &tokens, nil -} // RegisterOIDC redirects to the OIDC provider for authentication // Puts machine key in cache so the callback can retrieve it using the oidc state param // Listens in /oidc/register/:mKey func (h *Headscale) RegisterOIDC(c *gin.Context) { - mKeyStr := c.Param("mKey") + mKeyStr := c.Param("mkey") if mKeyStr == "" { c.String(http.StatusBadRequest, "Wrong params") return @@ -189,13 +40,23 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) { var err error // grab oidc config if it hasn't been already - if oidcConfig == nil { - oidcConfig, err = getOIDCConfig(fmt.Sprintf("%s.well-known/openid-configuration", h.cfg.OIDCEndpoint)) + if oauth2Config == nil { + oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDCIssuer) if err != nil { + log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error()) c.String(http.StatusInternalServerError, "Could not retrieve OIDC Config") return } + + oauth2Config = &oauth2.Config{ + ClientID: h.cfg.OIDCClientID, + ClientSecret: h.cfg.OIDCClientSecret, + Endpoint: oidcProvider.Endpoint(), + RedirectURL: fmt.Sprintf("%s/oidc/callback", h.cfg.ServerURL), + Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, + } + } b := make([]byte, 16) @@ -217,21 +78,16 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) { // place the machine key into the state cache, so it can be retrieved later stateCache.Set(stateStr, mKeyStr, time.Minute*5) - params := url.Values{} - params.Add("response_type", "code") - params.Add("client_id", h.cfg.OIDCClientID) - params.Add("redirect_uri", fmt.Sprintf("%s/oidc/callback", h.cfg.ServerURL)) - params.Add("scope", "openid") - params.Add("state", stateStr) - - authUrl := fmt.Sprintf("%s?%s", oidcConfig.AuthorizationEndpoint, params.Encode()) - log.Debug().Msg(authUrl) + authUrl := oauth2Config.AuthCodeURL(stateStr) + log.Debug().Msgf("Redirecting to %s for authentication", authUrl) c.Redirect(http.StatusFound, authUrl) } // OIDCCallback handles the callback from the OIDC endpoint -// Retrieves the mkey from the state cache, if the machine is not registered, presents a confirmation +// Retrieves the mkey from the state cache and adds the machine to the users email namespace +// TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities +// TODO: Add groups information from OIDC tokens into machine HostInfo // Listens in /oidc/callback func (h *Headscale) OIDCCallback(c *gin.Context) { @@ -243,20 +99,36 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { return } - redirectURI := fmt.Sprintf("%s/oidc/callback", h.cfg.ServerURL) - - tokens, err := h.exchangeCodeForTokens(code, redirectURI) - + oauth2Token, err := oauth2Config.Exchange(context.Background(), code) if err != nil { c.String(http.StatusBadRequest, "Could not exchange code for token") return } - //verify tokens - claims, err := verifyToken(tokens.AccessToken) + rawIDToken, rawIDTokenOK := oauth2Token.Extra("id_token").(string) + if !rawIDTokenOK { + c.String(http.StatusBadRequest, "Could not extract ID Token") + return + } + verifier := oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDCClientID}) + + idToken, err := verifier.Verify(context.Background(), rawIDToken) if err != nil { - c.String(http.StatusBadRequest, "invalid tokens") + c.String(http.StatusBadRequest, "Failed to verify id token: %s", err.Error()) + return + } + + //userInfo, err := oidcProvider.UserInfo(context.Background(), oauth2.StaticTokenSource(oauth2Token)) + //if err != nil { + // c.String(http.StatusBadRequest, "Failed to retrieve userinfo: "+err.Error()) + // return + //} + + // Extract custom claims + var claims IDTokenClaims + if err = idToken.Claims(&claims); err != nil { + c.String(http.StatusBadRequest, "Failed to decode id token claims: "+err.Error()) return } From 35795c79c367c80590deefe51415bc364f82a024 Mon Sep 17 00:00:00 2001 From: unreality Date: Fri, 8 Oct 2021 15:26:31 +0800 Subject: [PATCH 04/15] Handle trailing slash on uris Co-authored-by: Kristoffer Dalby --- api.go | 8 ++++---- oidc.go | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/api.go b/api.go index fb54f3cf..ddabd937 100644 --- a/api.go +++ b/api.go @@ -135,10 +135,10 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { Msg("Not registered and not NodeKey rotation. Sending a authurl to register") if h.cfg.OIDCIssuer != "" { - resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", h.cfg.ServerURL, mKey.HexString()) + resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) } else { resp.AuthURL = fmt.Sprintf("%s/register?key=%s", - h.cfg.ServerURL, mKey.HexString()) + strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) } respBody, err := encode(resp, &mKey, h.privateKey) if err != nil { @@ -205,10 +205,10 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { Str("machine", m.Name). Msg("The node is sending us a new NodeKey, sending auth url") if h.cfg.OIDCIssuer != "" { - resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", h.cfg.ServerURL, mKey.HexString()) + resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) } else { resp.AuthURL = fmt.Sprintf("%s/register?key=%s", - h.cfg.ServerURL, mKey.HexString()) + strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) } respBody, err := encode(resp, &mKey, h.privateKey) if err != nil { diff --git a/oidc.go b/oidc.go index aa80911b..328731ec 100644 --- a/oidc.go +++ b/oidc.go @@ -53,7 +53,7 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) { ClientID: h.cfg.OIDCClientID, ClientSecret: h.cfg.OIDCClientSecret, Endpoint: oidcProvider.Endpoint(), - RedirectURL: fmt.Sprintf("%s/oidc/callback", h.cfg.ServerURL), + RedirectURL: fmt.Sprintf("%s/oidc/callback", strings.TrimSuffix(h.cfg.ServerURL, "/")), Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, } From e407d423d44c208d0989f9ec8e94ca29909a6f16 Mon Sep 17 00:00:00 2001 From: Raal Goff Date: Fri, 8 Oct 2021 17:43:52 +0800 Subject: [PATCH 05/15] updates from code review --- api.go | 56 +++++++++++++++++++++--- app.go | 17 ++++++++ cmd/headscale/cli/utils.go | 13 ++++++ oidc.go | 88 ++++++++++++++++++++++---------------- 4 files changed, 131 insertions(+), 43 deletions(-) diff --git a/api.go b/api.go index ddabd937..02f28919 100644 --- a/api.go +++ b/api.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net/http" + "strings" "time" "github.com/rs/zerolog/log" @@ -83,7 +84,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) { log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine") m = Machine{ - Expiry: &req.Expiry, + Expiry: &time.Time{}, MachineKey: mKey.HexString(), Name: req.Hostinfo.Hostname, NodeKey: wgkey.Key(req.NodeKey).HexString(), @@ -107,7 +108,33 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { // We have the updated key! if m.NodeKey == wgkey.Key(req.NodeKey).HexString() { - if m.Registered { + + if !req.Expiry.IsZero() && req.Expiry.UTC().Before(now) { + log.Debug(). + Str("handler", "Registration"). + Str("machine", m.Name). + Msg("Client requested logout") + + m.Expiry = &req.Expiry + h.db.Save(&m) + + resp.AuthURL = "" + resp.MachineAuthorized = false + resp.User = *m.Namespace.toUser() + respBody, err := encode(resp, &mKey, h.privateKey) + if err != nil { + log.Error(). + Str("handler", "Registration"). + Err(err). + Msg("Cannot encode message") + c.String(http.StatusInternalServerError, "") + return + } + c.Data(200, "application/json; charset=utf-8", respBody) + return + } + + if m.Registered && m.Expiry.UTC().After(now) { log.Debug(). Str("handler", "Registration"). Str("machine", m.Name). @@ -132,14 +159,19 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { log.Debug(). Str("handler", "Registration"). Str("machine", m.Name). - Msg("Not registered and not NodeKey rotation. Sending a authurl to register") + Msg("Not registered (or expired) and not NodeKey rotation. Sending a authurl to register") if h.cfg.OIDCIssuer != "" { - resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) + resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", + strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) } else { resp.AuthURL = fmt.Sprintf("%s/register?key=%s", strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) } + + m.Expiry = &req.Expiry // save the requested expiry time for retrieval later + h.db.Save(&m) + respBody, err := encode(resp, &mKey, h.privateKey) if err != nil { log.Error(). @@ -153,8 +185,8 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { return } - // The NodeKey we have matches OldNodeKey, which means this is a refresh after an key expiration - if m.NodeKey == wgkey.Key(req.OldNodeKey).HexString() { + // The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration + if m.NodeKey == wgkey.Key(req.OldNodeKey).HexString() && m.Expiry.UTC().After(now) { log.Debug(). Str("handler", "Registration"). Str("machine", m.Name). @@ -179,14 +211,19 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { // We arrive here after a client is restarted without finalizing the authentication flow or // when headscale is stopped in the middle of the auth process. - if m.Registered { + if m.Registered && m.Expiry.UTC().After(now) { log.Debug(). Str("handler", "Registration"). Str("machine", m.Name). Msg("The node is sending us a new NodeKey, but machine is registered. All clear for /map") + + m.NodeKey = wgkey.Key(req.NodeKey).HexString() + h.db.Save(&m) + resp.AuthURL = "" resp.MachineAuthorized = true resp.User = *m.Namespace.toUser() + respBody, err := encode(resp, &mKey, h.privateKey) if err != nil { log.Error(). @@ -210,6 +247,11 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { resp.AuthURL = fmt.Sprintf("%s/register?key=%s", strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) } + + m.Expiry = &req.Expiry // save the requested expiry time for retrieval later + m.NodeKey = wgkey.Key(req.NodeKey).HexString() // save the new nodekey + h.db.Save(&m) + respBody, err := encode(resp, &mKey, h.privateKey) if err != nil { log.Error(). diff --git a/app.go b/app.go index 3c4b307a..2ad72154 100644 --- a/app.go +++ b/app.go @@ -3,6 +3,9 @@ package headscale import ( "errors" "fmt" + "github.com/coreos/go-oidc/v3/oidc" + "github.com/patrickmn/go-cache" + "golang.org/x/oauth2" "net/http" "os" "strings" @@ -49,6 +52,9 @@ type Config struct { OIDCIssuer string OIDCClientID string OIDCClientSecret string + + MaxMachineExpiry time.Duration + DefaultMachineExpiry time.Duration } // Headscale represents the base app of the service @@ -68,6 +74,10 @@ type Headscale struct { clientsUpdateChannelMutex sync.Mutex lastStateChange sync.Map + + oidcProvider *oidc.Provider + oauth2Config *oauth2.Config + oidcStateCache *cache.Cache } // NewHeadscale returns the Headscale app @@ -107,6 +117,13 @@ func NewHeadscale(cfg Config) (*Headscale, error) { return nil, err } + if cfg.OIDCIssuer != "" { + err = h.initOIDC() + if err != nil { + return nil, err + } + } + return &h, nil } diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go index 6ccdcdee..67017aa0 100644 --- a/cmd/headscale/cli/utils.go +++ b/cmd/headscale/cli/utils.go @@ -144,6 +144,16 @@ func getHeadscaleApp() (*headscale.Headscale, error) { return nil, err } + maxMachineExpiry, _ := time.ParseDuration("8h") + if viper.GetDuration("max_machine_expiry") >= time.Second { + maxMachineExpiry = viper.GetDuration("max_machine_expiry") + } + + defaultMachineExpiry, _ := time.ParseDuration("8h") + if viper.GetDuration("default_machine_expiry") >= time.Second { + defaultMachineExpiry = viper.GetDuration("default_machine_expiry") + } + cfg := headscale.Config{ ServerURL: viper.GetString("server_url"), Addr: viper.GetString("listen_addr"), @@ -174,6 +184,9 @@ func getHeadscaleApp() (*headscale.Headscale, error) { OIDCIssuer: viper.GetString("oidc_issuer"), OIDCClientID: viper.GetString("oidc_client_id"), OIDCClientSecret: viper.GetString("oidc_client_secret"), + + MaxMachineExpiry: maxMachineExpiry, + DefaultMachineExpiry: defaultMachineExpiry, } h, err := headscale.NewHeadscale(cfg) diff --git a/oidc.go b/oidc.go index 328731ec..1220098c 100644 --- a/oidc.go +++ b/oidc.go @@ -13,6 +13,7 @@ import ( "golang.org/x/oauth2" "gorm.io/gorm" "net/http" + "strings" "time" ) @@ -23,9 +24,33 @@ type IDTokenClaims struct { Username string `json:"preferred_username,omitempty"` } -var oidcProvider *oidc.Provider -var oauth2Config *oauth2.Config -var stateCache *cache.Cache +func (h *Headscale) initOIDC() error { + var err error + // grab oidc config if it hasn't been already + if h.oauth2Config == nil { + h.oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDCIssuer) + + if err != nil { + log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error()) + return err + } + + h.oauth2Config = &oauth2.Config{ + ClientID: h.cfg.OIDCClientID, + ClientSecret: h.cfg.OIDCClientSecret, + Endpoint: h.oidcProvider.Endpoint(), + RedirectURL: fmt.Sprintf("%s/oidc/callback", strings.TrimSuffix(h.cfg.ServerURL, "/")), + Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, + } + } + + // init the state cache if it hasn't been already + if h.oidcStateCache == nil { + h.oidcStateCache = cache.New(time.Minute*5, time.Minute*10) + } + + return nil +} // RegisterOIDC redirects to the OIDC provider for authentication // Puts machine key in cache so the callback can retrieve it using the oidc state param @@ -37,30 +62,8 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) { return } - var err error - - // grab oidc config if it hasn't been already - if oauth2Config == nil { - oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDCIssuer) - - if err != nil { - log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error()) - c.String(http.StatusInternalServerError, "Could not retrieve OIDC Config") - return - } - - oauth2Config = &oauth2.Config{ - ClientID: h.cfg.OIDCClientID, - ClientSecret: h.cfg.OIDCClientSecret, - Endpoint: oidcProvider.Endpoint(), - RedirectURL: fmt.Sprintf("%s/oidc/callback", strings.TrimSuffix(h.cfg.ServerURL, "/")), - Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, - } - - } - b := make([]byte, 16) - _, err = rand.Read(b) + _, err := rand.Read(b) if err != nil { log.Error().Msg("could not read 16 bytes from rand") @@ -70,15 +73,10 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) { stateStr := hex.EncodeToString(b)[:32] - // init the state cache if it hasn't been already - if stateCache == nil { - stateCache = cache.New(time.Minute*5, time.Minute*10) - } - // place the machine key into the state cache, so it can be retrieved later - stateCache.Set(stateStr, mKeyStr, time.Minute*5) + h.oidcStateCache.Set(stateStr, mKeyStr, time.Minute*5) - authUrl := oauth2Config.AuthCodeURL(stateStr) + authUrl := h.oauth2Config.AuthCodeURL(stateStr) log.Debug().Msgf("Redirecting to %s for authentication", authUrl) c.Redirect(http.StatusFound, authUrl) @@ -99,7 +97,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { return } - oauth2Token, err := oauth2Config.Exchange(context.Background(), code) + oauth2Token, err := h.oauth2Config.Exchange(context.Background(), code) if err != nil { c.String(http.StatusBadRequest, "Could not exchange code for token") return @@ -111,7 +109,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { return } - verifier := oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDCClientID}) + verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDCClientID}) idToken, err := verifier.Verify(context.Background(), rawIDToken) if err != nil { @@ -133,7 +131,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { } //retrieve machinekey from state cache - mKeyIf, mKeyFound := stateCache.Get(state) + mKeyIf, mKeyFound := h.oidcStateCache.Get(state) if !mKeyFound { c.String(http.StatusBadRequest, "state has expired") @@ -157,6 +155,8 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { //look for a namespace of the users email for now if !m.Registered { + log.Debug().Msg("Registering new machine after successful callback") + ns, err := h.GetNamespace(claims.Email) if err != nil { ns, err = h.CreateNamespace(claims.Email) @@ -182,6 +182,22 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { h.db.Save(&m) } + if m.isExpired() { + maxExpiry := time.Now().UTC().Add(h.cfg.MaxMachineExpiry) + + // use the maximum expiry if it's sooner than the requested expiry + if maxExpiry.Before(*m.Expiry) { + log.Debug().Msgf("Clamping expiry time to maximum: %v (%v)", maxExpiry, h.cfg.MaxMachineExpiry) + m.Expiry = &maxExpiry + h.db.Save(&m) + } else if m.Expiry.IsZero() { + log.Debug().Msgf("Using default machine expiry time: %v (%v)", maxExpiry, h.cfg.MaxMachineExpiry) + defaultExpiry := time.Now().UTC().Add(h.cfg.DefaultMachineExpiry) + m.Expiry = &defaultExpiry + h.db.Save(&m) + } + } + c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(` From 74e6c1479e64ea13e49fbb4ca87f668dd14068ab Mon Sep 17 00:00:00 2001 From: Raal Goff Date: Sun, 10 Oct 2021 17:22:42 +0800 Subject: [PATCH 06/15] updates from code review --- api.go | 71 +++++++++++++------------------------- app.go | 4 +-- cli.go | 3 ++ cmd/headscale/cli/utils.go | 18 +++++----- go.mod | 6 ++-- machine.go | 30 ++++++++++++++-- oidc.go | 43 ++++++++++------------- 7 files changed, 88 insertions(+), 87 deletions(-) diff --git a/api.go b/api.go index a70df5b3..bda9d9bd 100644 --- a/api.go +++ b/api.go @@ -65,7 +65,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { Str("handler", "Registration"). Err(err). Msg("Cannot parse machine key") - machineRegistrations.WithLabelValues("unkown", "web", "error", "unknown").Inc() + machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc() c.String(http.StatusInternalServerError, "Sad!") return } @@ -76,34 +76,33 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { Str("handler", "Registration"). Err(err). Msg("Cannot decode message") - machineRegistrations.WithLabelValues("unkown", "web", "error", "unknown").Inc() + machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc() c.String(http.StatusInternalServerError, "Very sad!") return } now := time.Now().UTC() - var m Machine - if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) { + m, err := h.GetMachineByMachineKey(mKey.HexString()) + if errors.Is(err, gorm.ErrRecordNotFound) { log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine") - m = Machine{ - Expiry: &time.Time{}, - MachineKey: mKey.HexString(), - Name: req.Hostinfo.Hostname, - NodeKey: wgkey.Key(req.NodeKey).HexString(), - LastSuccessfulUpdate: &now, + newMachine := Machine{ + Expiry: &time.Time{}, + MachineKey: mKey.HexString(), + Name: req.Hostinfo.Hostname, } - if err := h.db.Create(&m).Error; err != nil { + if err := h.db.Create(&newMachine).Error; err != nil { log.Error(). Str("handler", "Registration"). Err(err). Msg("Could not create row") - machineRegistrations.WithLabelValues("unkown", "web", "error", m.Namespace.Name).Inc() + machineRegistrations.WithLabelValues("unknown", "web", "error", m.Namespace.Name).Inc() return } + m = &newMachine } if !m.Registered && req.Auth.AuthKey != "" { - h.handleAuthKey(c, h.db, mKey, req, m) + h.handleAuthKey(c, h.db, mKey, req, *m) return } @@ -112,13 +111,14 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { // We have the updated key! if m.NodeKey == wgkey.Key(req.NodeKey).HexString() { + // The client sends an Expiry in the past if the client is requesting a logout if !req.Expiry.IsZero() && req.Expiry.UTC().Before(now) { - log.Debug(). + log.Info(). Str("handler", "Registration"). Str("machine", m.Name). Msg("Client requested logout") - m.Expiry = &req.Expiry + m.Expiry = &req.Expiry // save the expiry so that the machine is marked as expired h.db.Save(&m) resp.AuthURL = "" @@ -138,6 +138,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { } if m.Registered && m.Expiry.UTC().After(now) { + // The machine registration is valid, respond with redirect to /map log.Debug(). Str("handler", "Registration"). Str("machine", m.Name). @@ -161,10 +162,11 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { return } + // The client has registered before, but has expired log.Debug(). Str("handler", "Registration"). Str("machine", m.Name). - Msg("Not registered (or expired) and not NodeKey rotation. Sending a authurl to register") + Msg("Machine registration has expired. Sending a authurl to register") if h.cfg.OIDCIssuer != "" { resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", @@ -174,7 +176,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) } - m.Expiry = &req.Expiry // save the requested expiry time for retrieval later + m.RequestedExpiry = &req.Expiry // save the requested expiry time for retrieval later in the authentication flow h.db.Save(&m) respBody, err := encode(resp, &mKey, h.privateKey) @@ -216,34 +218,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { return } - // We arrive here after a client is restarted without finalizing the authentication flow or - // when headscale is stopped in the middle of the auth process. - if m.Registered && m.Expiry.UTC().After(now) { - log.Debug(). - Str("handler", "Registration"). - Str("machine", m.Name). - Msg("The node is sending us a new NodeKey, but machine is registered. All clear for /map") - - m.NodeKey = wgkey.Key(req.NodeKey).HexString() - h.db.Save(&m) - - resp.AuthURL = "" - resp.MachineAuthorized = true - resp.User = *m.Namespace.toUser() - - respBody, err := encode(resp, &mKey, h.privateKey) - if err != nil { - log.Error(). - Str("handler", "Registration"). - Err(err). - Msg("Cannot encode message") - c.String(http.StatusInternalServerError, "") - return - } - c.Data(200, "application/json; charset=utf-8", respBody) - return - } - + // The machine registration is new, redirect the client to the registration URL log.Debug(). Str("handler", "Registration"). Str("machine", m.Name). @@ -255,8 +230,8 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) } - m.Expiry = &req.Expiry // save the requested expiry time for retrieval later - m.NodeKey = wgkey.Key(req.NodeKey).HexString() // save the new nodekey + m.RequestedExpiry = &req.Expiry // save the requested expiry time for retrieval later in the authentication flow + m.NodeKey = wgkey.Key(req.NodeKey).HexString() // save the NodeKey h.db.Save(&m) respBody, err := encode(resp, &mKey, h.privateKey) @@ -436,6 +411,8 @@ func (h *Headscale) handleAuthKey(c *gin.Context, db *gorm.DB, idKey wgkey.Key, m.RegisterMethod = "authKey" db.Save(&m) + h.updateMachineExpiry(&m) // TODO: do we want to do different expiry times for AuthKeys? + resp.MachineAuthorized = true resp.User = *pak.Namespace.toUser() respBody, err := encode(resp, &idKey, h.privateKey) diff --git a/app.go b/app.go index 9e688fe1..239998c2 100644 --- a/app.go +++ b/app.go @@ -59,8 +59,8 @@ type Config struct { OIDCClientID string OIDCClientSecret string - MaxMachineExpiry time.Duration - DefaultMachineExpiry time.Duration + MaxMachineRegistrationDuration time.Duration + DefaultMachineRegistrationDuration time.Duration } // Headscale represents the base app of the service diff --git a/cli.go b/cli.go index 9c5b66e5..8610b334 100644 --- a/cli.go +++ b/cli.go @@ -23,6 +23,8 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err return nil, errors.New("Machine not found") } + h.updateMachineExpiry(&m) // update the machine's expiry before bailing if its already registered + if m.isAlreadyRegistered() { return nil, errors.New("Machine already registered") } @@ -36,5 +38,6 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err m.Registered = true m.RegisterMethod = "cli" h.db.Save(&m) + return &m, nil } diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go index 17bc37e7..366e9597 100644 --- a/cmd/headscale/cli/utils.go +++ b/cmd/headscale/cli/utils.go @@ -144,14 +144,16 @@ func getHeadscaleApp() (*headscale.Headscale, error) { return nil, err } - maxMachineExpiry, _ := time.ParseDuration("8h") - if viper.GetDuration("max_machine_expiry") >= time.Second { - maxMachineExpiry = viper.GetDuration("max_machine_expiry") + // maxMachineRegistrationDuration is the maximum time a client can request for a client registration + maxMachineRegistrationDuration, _ := time.ParseDuration("10h") + if viper.GetDuration("max_machine_registration_duration") >= time.Second { + maxMachineRegistrationDuration = viper.GetDuration("max_machine_registration_duration") } - defaultMachineExpiry, _ := time.ParseDuration("8h") - if viper.GetDuration("default_machine_expiry") >= time.Second { - defaultMachineExpiry = viper.GetDuration("default_machine_expiry") + // defaultMachineRegistrationDuration is the default time assigned to a client registration if one is not specified by the client + defaultMachineRegistrationDuration, _ := time.ParseDuration("8h") + if viper.GetDuration("default_machine_registration_duration") >= time.Second { + defaultMachineRegistrationDuration = viper.GetDuration("default_machine_registration_duration") } cfg := headscale.Config{ @@ -188,8 +190,8 @@ func getHeadscaleApp() (*headscale.Headscale, error) { OIDCClientID: viper.GetString("oidc_client_id"), OIDCClientSecret: viper.GetString("oidc_client_secret"), - MaxMachineExpiry: maxMachineExpiry, - DefaultMachineExpiry: defaultMachineExpiry, + MaxMachineRegistrationDuration: maxMachineRegistrationDuration, // the maximum duration a client may request for expiry time + DefaultMachineRegistrationDuration: defaultMachineRegistrationDuration, // if a client does not request a specific expiry time, use this duration } h, err := headscale.NewHeadscale(cfg) diff --git a/go.mod b/go.mod index 7e137e19..5a116bb4 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,7 @@ require ( github.com/docker/cli v20.10.8+incompatible // indirect github.com/docker/docker v20.10.8+incompatible // indirect github.com/efekarakus/termcolor v1.0.1 - github.com/fatih/set v0.2.1 // indirect + github.com/fatih/set v0.2.1 github.com/gin-gonic/gin v1.7.4 github.com/gofrs/uuid v4.0.0+incompatible github.com/google/go-github v17.0.0+incompatible // indirect @@ -23,7 +23,7 @@ require ( github.com/opencontainers/runc v1.0.2 // indirect github.com/ory/dockertest/v3 v3.7.0 github.com/patrickmn/go-cache v2.1.0+incompatible - github.com/prometheus/client_golang v1.11.0 // indirect + github.com/prometheus/client_golang v1.11.0 github.com/pterm/pterm v0.12.30 github.com/rs/zerolog v1.25.0 github.com/s12v/go-jwks v0.2.1 @@ -33,7 +33,7 @@ require ( github.com/tailscale/hujson v0.0.0-20210818175511-7360507a6e88 github.com/tcnksm/go-latest v0.0.0-20170313132115-e3007ae9052e github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect - github.com/zsais/go-gin-prometheus v0.1.0 // indirect + github.com/zsais/go-gin-prometheus v0.1.0 golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 golang.org/x/net v0.0.0-20210913180222-943fd674d43e // indirect golang.org/x/oauth2 v0.0.0-20210402161424-2e8d93401602 diff --git a/machine.go b/machine.go index bd5caf0d..6eecbc6f 100644 --- a/machine.go +++ b/machine.go @@ -36,6 +36,7 @@ type Machine struct { LastSeen *time.Time LastSuccessfulUpdate *time.Time Expiry *time.Time + RequestedExpiry *time.Time // when a client connects, it may request a specific expiry time, use this field to store it HostInfo datatypes.JSON Endpoints datatypes.JSON @@ -59,8 +60,33 @@ func (m Machine) isAlreadyRegistered() bool { // isExpired returns whether the machine registration has expired func (m Machine) isExpired() bool { return time.Now().UTC().After(*m.Expiry) -} - +} + +// If the Machine is expired, updateMachineExpiry updates the Machine Expiry time to the maximum allowed duration, +// or the default duration if no Expiry time was requested by the client +func (h *Headscale) updateMachineExpiry(m *Machine) { + + if m.isExpired() { + now := time.Now().UTC() + maxExpiry := now.Add(h.cfg.MaxMachineRegistrationDuration) // calculate the maximum expiry + defaultExpiry := now.Add(h.cfg.DefaultMachineRegistrationDuration) // calculate the default expiry + + // clamp the expiry time of the machine registration to the maximum allowed, or use the default if none supplied + if maxExpiry.Before(*m.RequestedExpiry) { + log.Debug().Msgf("Clamping registration expiry time to maximum: %v (%v)", maxExpiry, h.cfg.MaxMachineRegistrationDuration) + m.Expiry = &maxExpiry + } else if m.RequestedExpiry.IsZero() { + log.Debug().Msgf("Using default machine registration expiry time: %v (%v)", defaultExpiry, h.cfg.DefaultMachineRegistrationDuration) + m.Expiry = &defaultExpiry + } else { + log.Debug().Msgf("Using requested machine registration expiry time: %v", m.RequestedExpiry) + m.Expiry = m.RequestedExpiry + } + + h.db.Save(&m) + } +} + func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) { log.Trace(). Str("func", "getDirectPeers"). diff --git a/oidc.go b/oidc.go index 1220098c..01c54b44 100644 --- a/oidc.go +++ b/oidc.go @@ -4,14 +4,12 @@ import ( "context" "crypto/rand" "encoding/hex" - "errors" "fmt" "github.com/coreos/go-oidc/v3/oidc" "github.com/gin-gonic/gin" "github.com/patrickmn/go-cache" "github.com/rs/zerolog/log" "golang.org/x/oauth2" - "gorm.io/gorm" "net/http" "strings" "time" @@ -103,6 +101,8 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { return } + log.Debug().Msgf("AccessToken: %v", oauth2Token.AccessToken) + rawIDToken, rawIDTokenOK := oauth2Token.Extra("id_token").(string) if !rawIDTokenOK { c.String(http.StatusBadRequest, "Could not extract ID Token") @@ -117,16 +117,17 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { return } + // TODO: we can use userinfo at some point to grab additional information about the user (groups membership, etc) //userInfo, err := oidcProvider.UserInfo(context.Background(), oauth2.StaticTokenSource(oauth2Token)) //if err != nil { - // c.String(http.StatusBadRequest, "Failed to retrieve userinfo: "+err.Error()) + // c.String(http.StatusBadRequest, fmt.Sprintf("Failed to retrieve userinfo: %s", err)) // return //} // Extract custom claims var claims IDTokenClaims if err = idToken.Claims(&claims); err != nil { - c.String(http.StatusBadRequest, "Failed to decode id token claims: "+err.Error()) + c.String(http.StatusBadRequest, fmt.Sprintf("Failed to decode id token claims: %s", err)) return } @@ -134,39 +135,44 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { mKeyIf, mKeyFound := h.oidcStateCache.Get(state) if !mKeyFound { + log.Error().Msg("requested machine state key expired before authorisation completed") c.String(http.StatusBadRequest, "state has expired") return } mKeyStr, mKeyOK := mKeyIf.(string) if !mKeyOK { + log.Error().Msg("could not get machine key from cache") c.String(http.StatusInternalServerError, "could not get machine key from cache") return } // retrieve machine information - var m Machine - if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKeyStr); errors.Is(result.Error, gorm.ErrRecordNotFound) { + m, err := h.GetMachineByMachineKey(mKeyStr) + + if err != nil { log.Error().Msg("machine key not found in database") c.String(http.StatusInternalServerError, "could not get machine info from database") return } - //look for a namespace of the users email for now + now := time.Now().UTC() + + // register the machine if it's new if !m.Registered { + nsName := strings.ReplaceAll(claims.Email, "@", "-") // TODO: Implement a better email sanitisation log.Debug().Msg("Registering new machine after successful callback") - ns, err := h.GetNamespace(claims.Email) + ns, err := h.GetNamespace(nsName) if err != nil { - ns, err = h.CreateNamespace(claims.Email) + ns, err = h.CreateNamespace(nsName) if err != nil { log.Error().Msgf("could not create new namespace '%s'", claims.Email) c.String(http.StatusInternalServerError, "could not create new namespace") return } - } ip, err := h.getAvailableIP() @@ -179,24 +185,11 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { m.NamespaceID = ns.ID m.Registered = true m.RegisterMethod = "oidc" + m.LastSuccessfulUpdate = &now h.db.Save(&m) } - if m.isExpired() { - maxExpiry := time.Now().UTC().Add(h.cfg.MaxMachineExpiry) - - // use the maximum expiry if it's sooner than the requested expiry - if maxExpiry.Before(*m.Expiry) { - log.Debug().Msgf("Clamping expiry time to maximum: %v (%v)", maxExpiry, h.cfg.MaxMachineExpiry) - m.Expiry = &maxExpiry - h.db.Save(&m) - } else if m.Expiry.IsZero() { - log.Debug().Msgf("Using default machine expiry time: %v (%v)", maxExpiry, h.cfg.MaxMachineExpiry) - defaultExpiry := time.Now().UTC().Add(h.cfg.DefaultMachineExpiry) - m.Expiry = &defaultExpiry - h.db.Save(&m) - } - } + h.updateMachineExpiry(m) c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(` From 8843188b8448cbcac4d603ff903a85144ab829a1 Mon Sep 17 00:00:00 2001 From: Raal Goff Date: Sun, 10 Oct 2021 22:52:30 +0800 Subject: [PATCH 07/15] add notes to README.md about OIDC --- README.md | 55 +++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 39 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 5f691a6c..9d2ec159 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,7 @@ Headscale implements this coordination server. - [x] Support for alternative IP ranges in the tailnets (default Tailscale's 100.64.0.0/10) - [x] DNS (passing DNS servers to nodes) - [x] Share nodes between ~~users~~ namespaces +- [x] SSO (via OIDC) - [ ] MagicDNS / Smart DNS ## Client OS support @@ -100,7 +101,21 @@ Suggestions/PRs welcomed! docker exec headscale create myfirstnamespace ``` -5. Run the server +5. (Optional) Configure an OIDC Issuer + + You can optionally configure an OIDC endpoint to which your users will be redirected to authenticate with headscale. In config.json set the following parameters: + + ```json + { + "oidc_issuer": "https://your-oidc.issuer.com/path", + "oidc_client_id": "your-oidc-client-id", + "oidc_client_secret": "your-oidc-client-secret" + } + ``` + + If `oidc_issuer` is set, headscale will attempt to send your users to the OIDC server for authentication, otherwise it will give instructions on how to authorise clients via the CLI. + +6. Run the server ```shell headscale serve @@ -114,7 +129,7 @@ Suggestions/PRs welcomed! docker run -v $(pwd)/private.key:/private.key -v $(pwd)/config.json:/config.json -v $(pwd)/derp.yaml:/derp.yaml -v $(pwd)/db.sqlite:/db.sqlite -p 127.0.0.1:8080:8080 headscale/headscale:x.x.x headscale serve ``` -6. If you used tailscale.com before in your nodes, make sure you clear the tailscald data folder +7. If you used tailscale.com before in your nodes, make sure you clear the tailscald data folder ```shell systemctl stop tailscaled @@ -122,26 +137,26 @@ Suggestions/PRs welcomed! systemctl start tailscaled ``` -7. Add your first machine +8. Add your first machine ```shell tailscale up --login-server YOUR_HEADSCALE_URL ``` -8. Navigate to the URL you will get with `tailscale up`, where you'll find your machine key. +9. Navigate to the URL you will get with `tailscale up`, where you'll find your machine key. If OIDC is configured, once you login your user will be added to a namespace automatically, and you can skip step 10. -9. In the server, register your machine to a namespace with the CLI - ```shell - headscale -n myfirstnamespace nodes register YOURMACHINEKEY - ``` - or docker: - ```shell - docker run -v $(pwd)/private.key:/private.key -v $(pwd)/config.json:/config.json -v $(pwd)/derp.yaml:/derp.yaml headscale/headscale:x.x.x headscale -n myfirstnamespace nodes register YOURMACHINEKEY - ``` - or if your server is already running in docker: - ```shell - docker exec headscale -n myfirstnamespace nodes register YOURMACHINEKEY - ``` +10. In the server, register your machine to a namespace with the CLI + ```shell + headscale -n myfirstnamespace nodes register YOURMACHINEKEY + ``` + or docker: + ```shell + docker run -v $(pwd)/private.key:/private.key -v $(pwd)/config.json:/config.json -v $(pwd)/derp.yaml:/derp.yaml headscale/headscale:x.x.x headscale -n myfirstnamespace nodes register YOURMACHINEKEY + ``` + or if your server is already running in docker: + ```shell + docker exec headscale -n myfirstnamespace nodes register YOURMACHINEKEY + ``` Alternatively, you can use Auth Keys to register your machines: @@ -218,6 +233,14 @@ Headscale's configuration file is named `config.json` or `config.yaml`. Headscal The fields starting with `db_` are used for the PostgreSQL connection information. +OpenID Connect settings: +``` + "oidc_issuer": "https://your-oidc.issuer.com/path", + "oidc_client_id": "your-oidc-client-id", + "oidc_client_secret": "your-oidc-client-secret" +``` + + ### Running the service via TLS (optional) ``` From 0603e29c46143b4de0d74a375d13cd7bfec4ab3c Mon Sep 17 00:00:00 2001 From: Raal Goff Date: Fri, 15 Oct 2021 23:09:55 +0800 Subject: [PATCH 08/15] add login details to RegisterResponse so GUI clients show login display name --- api.go | 2 ++ namespaces.go | 11 +++++++++++ 2 files changed, 13 insertions(+) diff --git a/api.go b/api.go index bda9d9bd..d85221bc 100644 --- a/api.go +++ b/api.go @@ -147,6 +147,8 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { resp.AuthURL = "" resp.MachineAuthorized = true resp.User = *m.Namespace.toUser() + resp.Login = *m.Namespace.toLogin() + respBody, err := encode(resp, &mKey, h.privateKey) if err != nil { log.Error(). diff --git a/namespaces.go b/namespaces.go index 2bf62bb3..212df9a6 100644 --- a/namespaces.go +++ b/namespaces.go @@ -216,3 +216,14 @@ func (n *Namespace) toUser() *tailcfg.User { } return &u } + +func (n *Namespace) toLogin() *tailcfg.Login { + l := tailcfg.Login{ + ID: tailcfg.LoginID(n.ID), + LoginName: n.Name, + DisplayName: n.Name, + ProfilePicURL: "", + Domain: "headscale.net", + } + return &l +} From d0cd5af419d2fee7b1ae6a7b8510b27eb25d4c8d Mon Sep 17 00:00:00 2001 From: Raal Goff Date: Sat, 16 Oct 2021 22:34:11 +0800 Subject: [PATCH 09/15] fix incorrect merge --- cmd/headscale/cli/utils.go | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go index ba8d34ad..f29c389d 100644 --- a/cmd/headscale/cli/utils.go +++ b/cmd/headscale/cli/utils.go @@ -161,8 +161,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) { return nil, err } - - // maxMachineRegistrationDuration is the maximum time a client can request for a client registration + // maxMachineRegistrationDuration is the maximum time a client can request for a client registration maxMachineRegistrationDuration, _ := time.ParseDuration("10h") if viper.GetDuration("max_machine_registration_duration") >= time.Second { maxMachineRegistrationDuration = viper.GetDuration("max_machine_registration_duration") @@ -174,7 +173,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) { defaultMachineRegistrationDuration = viper.GetDuration("default_machine_registration_duration") } - dnsConfig, baseDomain := GetDNSConfig() + dnsConfig, baseDomain := GetDNSConfig() cfg := headscale.Config{ ServerURL: viper.GetString("server_url"), @@ -207,8 +206,6 @@ func getHeadscaleApp() (*headscale.Headscale, error) { ACMEEmail: viper.GetString("acme_email"), ACMEURL: viper.GetString("acme_url"), - DNSConfig: GetDNSConfig(), - OIDCIssuer: viper.GetString("oidc_issuer"), OIDCClientID: viper.GetString("oidc_client_id"), OIDCClientSecret: viper.GetString("oidc_client_secret"), From a347d276bd650223109d9f12cc9829e99651b6b9 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Mon, 18 Oct 2021 19:26:43 +0000 Subject: [PATCH 10/15] Fix broken machine test --- cli_test.go | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/cli_test.go b/cli_test.go index 528a115e..291b5df1 100644 --- a/cli_test.go +++ b/cli_test.go @@ -1,6 +1,8 @@ package headscale import ( + "time" + "gopkg.in/check.v1" ) @@ -8,14 +10,18 @@ func (s *Suite) TestRegisterMachine(c *check.C) { n, err := h.CreateNamespace("test") c.Assert(err, check.IsNil) + now := time.Now().UTC() + m := Machine{ - ID: 0, - MachineKey: "8ce002a935f8c394e55e78fbbb410576575ff8ec5cfa2e627e4b807f1be15b0e", - NodeKey: "bar", - DiscoKey: "faa", - Name: "testmachine", - NamespaceID: n.ID, - IPAddress: "10.0.0.1", + ID: 0, + MachineKey: "8ce002a935f8c394e55e78fbbb410576575ff8ec5cfa2e627e4b807f1be15b0e", + NodeKey: "bar", + DiscoKey: "faa", + Name: "testmachine", + NamespaceID: n.ID, + IPAddress: "10.0.0.1", + Expiry: &now, + RequestedExpiry: &now, } h.db.Save(&m) From 677bd9b657d0ca229cb5e97875e58b6f43571bed Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Mon, 18 Oct 2021 19:27:52 +0000 Subject: [PATCH 11/15] Implement namespace matching --- api.go | 6 +- app.go | 26 +++--- cmd/headscale/cli/utils.go | 24 ++++- oidc.go | 93 ++++++++++++-------- oidc_test.go | 173 +++++++++++++++++++++++++++++++++++++ 5 files changed, 267 insertions(+), 55 deletions(-) create mode 100644 oidc_test.go diff --git a/api.go b/api.go index cbe48072..c542b3aa 100644 --- a/api.go +++ b/api.go @@ -170,7 +170,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { Str("machine", m.Name). Msg("Machine registration has expired. Sending a authurl to register") - if h.cfg.OIDCIssuer != "" { + if h.cfg.OIDC.Issuer != "" { resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) } else { @@ -225,7 +225,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { Str("handler", "Registration"). Str("machine", m.Name). Msg("The node is sending us a new NodeKey, sending auth url") - if h.cfg.OIDCIssuer != "" { + if h.cfg.OIDC.Issuer != "" { resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) } else { resp.AuthURL = fmt.Sprintf("%s/register?key=%s", @@ -424,7 +424,7 @@ func (h *Headscale) handleAuthKey(c *gin.Context, db *gorm.DB, idKey wgkey.Key, db.Save(&m) h.updateMachineExpiry(&m) // TODO: do we want to do different expiry times for AuthKeys? - + pak.Used = true db.Save(&pak) diff --git a/app.go b/app.go index 89c43589..c89856f8 100644 --- a/app.go +++ b/app.go @@ -3,9 +3,6 @@ package headscale import ( "errors" "fmt" - "github.com/coreos/go-oidc/v3/oidc" - "github.com/patrickmn/go-cache" - "golang.org/x/oauth2" "net/http" "os" "sort" @@ -13,6 +10,10 @@ import ( "sync" "time" + "github.com/coreos/go-oidc/v3/oidc" + "github.com/patrickmn/go-cache" + "golang.org/x/oauth2" + "github.com/rs/zerolog/log" "github.com/gin-gonic/gin" @@ -57,14 +58,19 @@ type Config struct { DNSConfig *tailcfg.DNSConfig - OIDCIssuer string - OIDCClientID string - OIDCClientSecret string + OIDC OIDCConfig MaxMachineRegistrationDuration time.Duration DefaultMachineRegistrationDuration time.Duration } +type OIDCConfig struct { + Issuer string + ClientID string + ClientSecret string + MatchMap map[string]string +} + // Headscale represents the base app of the service type Headscale struct { cfg Config @@ -122,14 +128,14 @@ func NewHeadscale(cfg Config) (*Headscale, error) { return nil, err } - if cfg.OIDCIssuer != "" { + if cfg.OIDC.Issuer != "" { err = h.initOIDC() if err != nil { return nil, err } - } + } - if h.cfg.DNSConfig != nil && h.cfg.DNSConfig.Proxied { // if MagicDNS + if h.cfg.DNSConfig != nil && h.cfg.DNSConfig.Proxied { // if MagicDNS magicDNSDomains, err := generateMagicDNSRootDomains(h.cfg.IPPrefix, h.cfg.BaseDomain) if err != nil { return nil, err @@ -294,7 +300,6 @@ func (h *Headscale) getLastStateChange(namespaces ...string) time.Time { times = append(times, lastChange) } - } sort.Slice(times, func(i, j int) bool { @@ -305,7 +310,6 @@ func (h *Headscale) getLastStateChange(namespaces ...string) time.Time { if len(times) == 0 { return time.Now().UTC() - } else { return times[0] } diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go index f29c389d..4a598e77 100644 --- a/cmd/headscale/cli/utils.go +++ b/cmd/headscale/cli/utils.go @@ -7,6 +7,7 @@ import ( "io" "os" "path/filepath" + "regexp" "strings" "time" @@ -73,7 +74,6 @@ func LoadConfig(path string) error { } else { return nil } - } func GetDNSConfig() (*tailcfg.DNSConfig, string) { @@ -206,15 +206,19 @@ func getHeadscaleApp() (*headscale.Headscale, error) { ACMEEmail: viper.GetString("acme_email"), ACMEURL: viper.GetString("acme_url"), - OIDCIssuer: viper.GetString("oidc_issuer"), - OIDCClientID: viper.GetString("oidc_client_id"), - OIDCClientSecret: viper.GetString("oidc_client_secret"), + OIDC: headscale.OIDCConfig{ + Issuer: viper.GetString("oidc.issuer"), + ClientID: viper.GetString("oidc.client_id"), + ClientSecret: viper.GetString("oidc.client_secret"), + }, MaxMachineRegistrationDuration: maxMachineRegistrationDuration, // the maximum duration a client may request for expiry time DefaultMachineRegistrationDuration: defaultMachineRegistrationDuration, // if a client does not request a specific expiry time, use this duration } + cfg.OIDC.MatchMap = loadOIDCMatchMap() + h, err := headscale.NewHeadscale(cfg) if err != nil { return nil, err @@ -291,3 +295,15 @@ func HasJsonOutputFlag() bool { } return false } + +// loadOIDCMatchMap is a wrapper around viper to verifies that the keys in +// the match map is valid regex strings. +func loadOIDCMatchMap() map[string]string { + strMap := viper.GetStringMapString("oidc.domain_map") + + for oidcMatcher := range strMap { + _ = regexp.MustCompile(oidcMatcher) + } + + return strMap +} diff --git a/oidc.go b/oidc.go index 01c54b44..1b13963c 100644 --- a/oidc.go +++ b/oidc.go @@ -5,14 +5,16 @@ import ( "crypto/rand" "encoding/hex" "fmt" + "net/http" + "regexp" + "strings" + "time" + "github.com/coreos/go-oidc/v3/oidc" "github.com/gin-gonic/gin" "github.com/patrickmn/go-cache" "github.com/rs/zerolog/log" "golang.org/x/oauth2" - "net/http" - "strings" - "time" ) type IDTokenClaims struct { @@ -26,7 +28,7 @@ func (h *Headscale) initOIDC() error { var err error // grab oidc config if it hasn't been already if h.oauth2Config == nil { - h.oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDCIssuer) + h.oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDC.Issuer) if err != nil { log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error()) @@ -34,8 +36,8 @@ func (h *Headscale) initOIDC() error { } h.oauth2Config = &oauth2.Config{ - ClientID: h.cfg.OIDCClientID, - ClientSecret: h.cfg.OIDCClientSecret, + ClientID: h.cfg.OIDC.ClientID, + ClientSecret: h.cfg.OIDC.ClientSecret, Endpoint: h.oidcProvider.Endpoint(), RedirectURL: fmt.Sprintf("%s/oidc/callback", strings.TrimSuffix(h.cfg.ServerURL, "/")), Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, @@ -62,7 +64,6 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) { b := make([]byte, 16) _, err := rand.Read(b) - if err != nil { log.Error().Msg("could not read 16 bytes from rand") c.String(http.StatusInternalServerError, "could not read 16 bytes from rand") @@ -86,7 +87,6 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) { // TODO: Add groups information from OIDC tokens into machine HostInfo // Listens in /oidc/callback func (h *Headscale) OIDCCallback(c *gin.Context) { - code := c.Query("code") state := c.Query("state") @@ -109,7 +109,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { return } - verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDCClientID}) + verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDC.ClientID}) idToken, err := verifier.Verify(context.Background(), rawIDToken) if err != nil { @@ -131,7 +131,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { return } - //retrieve machinekey from state cache + // retrieve machinekey from state cache mKeyIf, mKeyFound := h.oidcStateCache.Get(state) if !mKeyFound { @@ -149,7 +149,6 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { // retrieve machine information m, err := h.GetMachineByMachineKey(mKeyStr) - if err != nil { log.Error().Msg("machine key not found in database") c.String(http.StatusInternalServerError, "could not get machine info from database") @@ -158,40 +157,40 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { now := time.Now().UTC() - // register the machine if it's new - if !m.Registered { - nsName := strings.ReplaceAll(claims.Email, "@", "-") // TODO: Implement a better email sanitisation + if nsName, ok := h.getNamespaceFromEmail(claims.Email); ok { + // register the machine if it's new + if !m.Registered { - log.Debug().Msg("Registering new machine after successful callback") - - ns, err := h.GetNamespace(nsName) - if err != nil { - ns, err = h.CreateNamespace(nsName) + log.Debug().Msg("Registering new machine after successful callback") + ns, err := h.GetNamespace(nsName) if err != nil { - log.Error().Msgf("could not create new namespace '%s'", claims.Email) - c.String(http.StatusInternalServerError, "could not create new namespace") + ns, err = h.CreateNamespace(nsName) + + if err != nil { + log.Error().Msgf("could not create new namespace '%s'", claims.Email) + c.String(http.StatusInternalServerError, "could not create new namespace") + return + } + } + + ip, err := h.getAvailableIP() + if err != nil { + c.String(http.StatusInternalServerError, "could not get an IP from the pool") return } + + m.IPAddress = ip.String() + m.NamespaceID = ns.ID + m.Registered = true + m.RegisterMethod = "oidc" + m.LastSuccessfulUpdate = &now + h.db.Save(&m) } - ip, err := h.getAvailableIP() - if err != nil { - c.String(http.StatusInternalServerError, "could not get an IP from the pool") - return - } + h.updateMachineExpiry(m) - m.IPAddress = ip.String() - m.NamespaceID = ns.ID - m.Registered = true - m.RegisterMethod = "oidc" - m.LastSuccessfulUpdate = &now - h.db.Save(&m) - } - - h.updateMachineExpiry(m) - - c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(` + c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`

headscale

@@ -202,4 +201,24 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { `, claims.Email))) + + } + + log.Error(). + Str("email", claims.Email). + Str("username", claims.Username). + Str("machine", m.Name). + Msg("Email could not be mapped to a namespace") + c.String(http.StatusBadRequest, "email from claim could not be mapped to a namespace") +} + +func (h *Headscale) getNamespaceFromEmail(email string) (string, bool) { + for match, namespace := range h.cfg.OIDC.MatchMap { + regex := regexp.MustCompile(match) + if regex.MatchString(email) { + return namespace, true + } + } + + return "", false } diff --git a/oidc_test.go b/oidc_test.go new file mode 100644 index 00000000..ddb44e4c --- /dev/null +++ b/oidc_test.go @@ -0,0 +1,173 @@ +package headscale + +import ( + "sync" + "testing" + + "github.com/coreos/go-oidc/v3/oidc" + "github.com/patrickmn/go-cache" + "golang.org/x/oauth2" + "gorm.io/gorm" + "tailscale.com/tailcfg" + "tailscale.com/types/wgkey" +) + +func TestHeadscale_getNamespaceFromEmail(t *testing.T) { + type fields struct { + cfg Config + db *gorm.DB + dbString string + dbType string + dbDebug bool + publicKey *wgkey.Key + privateKey *wgkey.Private + aclPolicy *ACLPolicy + aclRules *[]tailcfg.FilterRule + lastStateChange sync.Map + oidcProvider *oidc.Provider + oauth2Config *oauth2.Config + oidcStateCache *cache.Cache + } + type args struct { + email string + } + tests := []struct { + name string + fields fields + args args + want string + want1 bool + }{ + { + name: "match all", + fields: fields{ + cfg: Config{ + OIDC: OIDCConfig{ + MatchMap: map[string]string{ + ".*": "space", + }, + }, + }, + }, + args: args{ + email: "test@example.no", + }, + want: "space", + want1: true, + }, + { + name: "match user", + fields: fields{ + cfg: Config{ + OIDC: OIDCConfig{ + MatchMap: map[string]string{ + "specific@user\\.no": "user-namespace", + }, + }, + }, + }, + args: args{ + email: "specific@user.no", + }, + want: "user-namespace", + want1: true, + }, + { + name: "match domain", + fields: fields{ + cfg: Config{ + OIDC: OIDCConfig{ + MatchMap: map[string]string{ + ".*@example\\.no": "example", + }, + }, + }, + }, + args: args{ + email: "test@example.no", + }, + want: "example", + want1: true, + }, + { + name: "multi match domain", + fields: fields{ + cfg: Config{ + OIDC: OIDCConfig{ + MatchMap: map[string]string{ + ".*@example\\.no": "exammple", + ".*@gmail\\.com": "gmail", + }, + }, + }, + }, + args: args{ + email: "someuser@gmail.com", + }, + want: "gmail", + want1: true, + }, + { + name: "no match domain", + fields: fields{ + cfg: Config{ + OIDC: OIDCConfig{ + MatchMap: map[string]string{ + ".*@dontknow.no": "never", + }, + }, + }, + }, + args: args{ + email: "test@wedontknow.no", + }, + want: "", + want1: false, + }, + { + name: "multi no match domain", + fields: fields{ + cfg: Config{ + OIDC: OIDCConfig{ + MatchMap: map[string]string{ + ".*@dontknow.no": "never", + ".*@wedontknow.no": "other", + ".*\\.no": "stuffy", + }, + }, + }, + }, + args: args{ + email: "tasy@nonofthem.com", + }, + want: "", + want1: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := &Headscale{ + cfg: tt.fields.cfg, + db: tt.fields.db, + dbString: tt.fields.dbString, + dbType: tt.fields.dbType, + dbDebug: tt.fields.dbDebug, + publicKey: tt.fields.publicKey, + privateKey: tt.fields.privateKey, + aclPolicy: tt.fields.aclPolicy, + aclRules: tt.fields.aclRules, + lastStateChange: tt.fields.lastStateChange, + oidcProvider: tt.fields.oidcProvider, + oauth2Config: tt.fields.oauth2Config, + oidcStateCache: tt.fields.oidcStateCache, + } + got, got1 := h.getNamespaceFromEmail(tt.args.email) + if got != tt.want { + t.Errorf("Headscale.getNamespaceFromEmail() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("Headscale.getNamespaceFromEmail() got1 = %v, want %v", got1, tt.want1) + } + }) + } +} From dbe193ad1783ccc323f4227d1fd849e3290f9454 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 19 Oct 2021 18:25:59 +0100 Subject: [PATCH 12/15] Fix up leftovers from kradalby PR --- README.md | 24 ++++++++++++++---------- oidc.go | 4 ++++ 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 0cb41bcb..16b6c67f 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ Headscale implements this coordination server. - [x] Support for alternative IP ranges in the tailnets (default Tailscale's 100.64.0.0/10) - [x] DNS (passing DNS servers to nodes) - [x] Share nodes between ~~users~~ namespaces -- [x] SSO (via OIDC) +- [x] Single-Sign-On (via Open ID Connect) - [x] MagicDNS (see `docs/`) ## Client OS support @@ -109,13 +109,14 @@ Suggestions/PRs welcomed! ```json { - "oidc_issuer": "https://your-oidc.issuer.com/path", - "oidc_client_id": "your-oidc-client-id", - "oidc_client_secret": "your-oidc-client-secret" + "oidc": { + "issuer": "https://your-oidc.issuer.com/path", + "client_id": "your-oidc-client-id", + "client_secret": "your-oidc-client-secret", + "domain_map": { + ".*": "default-namespace" + } } - ``` - - If `oidc_issuer` is set, headscale will attempt to send your users to the OIDC server for authentication, otherwise it will give instructions on how to authorise clients via the CLI. 6. Run the server @@ -237,9 +238,12 @@ The fields starting with `db_` are used for the PostgreSQL connection informatio OpenID Connect settings: ``` - "oidc_issuer": "https://your-oidc.issuer.com/path", - "oidc_client_id": "your-oidc-client-id", - "oidc_client_secret": "your-oidc-client-secret" + oidc: + issuer: "https://your-oidc.issuer.com/path" + client_id: "your-oidc-client-id" + client_secret: "your-oidc-client-secret" + domain_map: + ".*": default-namespace ``` diff --git a/oidc.go b/oidc.go index 1b13963c..51c443db 100644 --- a/oidc.go +++ b/oidc.go @@ -212,6 +212,10 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { c.String(http.StatusBadRequest, "email from claim could not be mapped to a namespace") } +// getNamespaceFromEmail passes the users email through a list of "matchers" +// and iterates through them until it matches and returns a namespace. +// If no match is found, an empty string will be returned. +// TODO(kradalby): golang Maps key order is not stable, so this list is _not_ deterministic. Find a way to make the list of keys stable, preferably in the order presented in a users configuration. func (h *Headscale) getNamespaceFromEmail(email string) (string, bool) { for match, namespace := range h.cfg.OIDC.MatchMap { regex := regexp.MustCompile(match) From 2d252da221b3db0bb2d605f618d2d685f0035158 Mon Sep 17 00:00:00 2001 From: Raal Goff Date: Fri, 29 Oct 2021 21:35:07 +0800 Subject: [PATCH 13/15] suggested documentation and comments --- api.go | 11 +++++++++-- cmd/headscale/cli/utils.go | 17 ++++++++++------- machine.go | 6 ++++-- 3 files changed, 23 insertions(+), 11 deletions(-) diff --git a/api.go b/api.go index c7ae122c..36af5a06 100644 --- a/api.go +++ b/api.go @@ -111,7 +111,8 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { // We have the updated key! if m.NodeKey == wgkey.Key(req.NodeKey).HexString() { - // The client sends an Expiry in the past if the client is requesting a logout + // The client sends an Expiry in the past if the client is requesting to expire the key (aka logout) + // https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648 if !req.Expiry.IsZero() && req.Expiry.UTC().Before(now) { log.Info(). Str("handler", "Registration"). @@ -178,7 +179,13 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) } - m.RequestedExpiry = &req.Expiry // save the requested expiry time for retrieval later in the authentication flow + // When a client connects, it may request a specific expiry time in its + // RegisterRequest (https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L634) + m.RequestedExpiry = &req.Expiry // RequestedExpiry is used to store the clients requested expiry time since the authentication flow is broken + // into two steps (which cant pass arbitrary data between them easily) and needs to be + // retrieved again after the user has authenticated. After the authentication flow + // completes, RequestedExpiry is copied into Expiry. + h.db.Save(&m) respBody, err := encode(resp, &mKey, h.privateKey) diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go index 4a598e77..0ba43b28 100644 --- a/cmd/headscale/cli/utils.go +++ b/cmd/headscale/cli/utils.go @@ -161,14 +161,18 @@ func getHeadscaleApp() (*headscale.Headscale, error) { return nil, err } - // maxMachineRegistrationDuration is the maximum time a client can request for a client registration - maxMachineRegistrationDuration, _ := time.ParseDuration("10h") + // maxMachineRegistrationDuration is the maximum time headscale will allow a client to (optionally) request for + // the machine key expiry time. RegisterRequests with Expiry times that are more than + // maxMachineRegistrationDuration in the future will be clamped to (now + maxMachineRegistrationDuration) + maxMachineRegistrationDuration, _ := time.ParseDuration("10h") // use 10h here because it is the length of a standard business day plus a small amount of leeway if viper.GetDuration("max_machine_registration_duration") >= time.Second { maxMachineRegistrationDuration = viper.GetDuration("max_machine_registration_duration") } - // defaultMachineRegistrationDuration is the default time assigned to a client registration if one is not specified by the client - defaultMachineRegistrationDuration, _ := time.ParseDuration("8h") + // defaultMachineRegistrationDuration is the default time assigned to a machine registration if one is not + // specified by the tailscale client. It is the default amount of time a machine registration is valid for + // (ie the amount of time before the user has to re-authenticate when requesting a connection) + defaultMachineRegistrationDuration, _ := time.ParseDuration("8h") // use 8h here because it's the length of a standard business day if viper.GetDuration("default_machine_registration_duration") >= time.Second { defaultMachineRegistrationDuration = viper.GetDuration("default_machine_registration_duration") } @@ -212,9 +216,8 @@ func getHeadscaleApp() (*headscale.Headscale, error) { ClientSecret: viper.GetString("oidc.client_secret"), }, - MaxMachineRegistrationDuration: maxMachineRegistrationDuration, // the maximum duration a client may request for expiry time - DefaultMachineRegistrationDuration: defaultMachineRegistrationDuration, // if a client does not request a specific expiry time, use this duration - + MaxMachineRegistrationDuration: maxMachineRegistrationDuration, + DefaultMachineRegistrationDuration: defaultMachineRegistrationDuration, } cfg.OIDC.MatchMap = loadOIDCMatchMap() diff --git a/machine.go b/machine.go index a43fa8de..f4ce0afb 100644 --- a/machine.go +++ b/machine.go @@ -36,7 +36,7 @@ type Machine struct { LastSeen *time.Time LastSuccessfulUpdate *time.Time Expiry *time.Time - RequestedExpiry *time.Time // when a client connects, it may request a specific expiry time, use this field to store it + RequestedExpiry *time.Time HostInfo datatypes.JSON Endpoints datatypes.JSON @@ -63,7 +63,9 @@ func (m Machine) isExpired() bool { } // If the Machine is expired, updateMachineExpiry updates the Machine Expiry time to the maximum allowed duration, -// or the default duration if no Expiry time was requested by the client +// or the default duration if no Expiry time was requested by the client. The expiry time here does not (yet) cause +// a client to be disconnected, however they will have to re-auth the machine if they attempt to reconnect after the +// expiry time. func (h *Headscale) updateMachineExpiry(m *Machine) { if m.isExpired() { From cd2914dbc9c1af4c929b363f56d0c3b6d5edfd3e Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Sat, 30 Oct 2021 15:35:58 +0000 Subject: [PATCH 14/15] Make note about oidc being experimental --- config-example.yaml | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/config-example.yaml b/config-example.yaml index f28b4191..d4aa7815 100644 --- a/config-example.yaml +++ b/config-example.yaml @@ -66,10 +66,16 @@ dns_config: base_domain: example.com -# Experimental: OpenID Connect +# headscale supports experimental OpenID connect support, +# it is still being tested and might have some bugs, please +# help us test it. +# OpenID Connect # oidc: # issuer: "https://your-oidc.issuer.com/path" # client_id: "your-oidc-client-id" # client_secret: "your-oidc-client-secret" +# +# # Domain map is used to map incomming users (by their email) to +# # a namespace. The key can be a string, or regex. # domain_map: # ".*": default-namespace From bac81176b25b97d1cb0fb9b48657bf1aaf8d8fdb Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Sat, 30 Oct 2021 15:39:05 +0000 Subject: [PATCH 15/15] Remove lint from generated testcode --- oidc_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/oidc_test.go b/oidc_test.go index ddb44e4c..b501ff14 100644 --- a/oidc_test.go +++ b/oidc_test.go @@ -144,6 +144,7 @@ func TestHeadscale_getNamespaceFromEmail(t *testing.T) { want1: false, }, } + //nolint for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := &Headscale{