Set CSRF cookies for OIDC (#2328)

* set state and nounce in oidc to prevent csrf

Fixes #2276

* try to fix new postgres issue

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

---------

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2025-01-08 16:29:37 +01:00 committed by GitHub
parent 41bad2b9fd
commit fa641e38b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 100 additions and 21 deletions

View File

@ -34,4 +34,10 @@ jobs:
- name: Run tests - name: Run tests
if: steps.changed-files.outputs.files == 'true' if: steps.changed-files.outputs.files == 'true'
env:
# As of 2025-01-06, these env vars was not automatically
# set anymore which breaks the initdb for postgres on
# some of the database migration tests.
LC_ALL: "en_US.UTF-8"
LC_CTYPE: "en_US.UTF-8"
run: nix develop --command -- gotestsum run: nix develop --command -- gotestsum

View File

@ -3,9 +3,7 @@ package hscontrol
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/rand"
_ "embed" _ "embed"
"encoding/hex"
"errors" "errors"
"fmt" "fmt"
"html/template" "html/template"
@ -157,13 +155,19 @@ func (a *AuthProviderOIDC) RegisterHandler(
return return
} }
randomBlob := make([]byte, randomByteSize) // Set the state and nonce cookies to protect against CSRF attacks
if _, err := rand.Read(randomBlob); err != nil { state, err := setCSRFCookie(writer, req, "state")
if err != nil {
http.Error(writer, "Internal server error", http.StatusInternalServerError) http.Error(writer, "Internal server error", http.StatusInternalServerError)
return return
} }
stateStr := hex.EncodeToString(randomBlob)[:32] // Set the state and nonce cookies to protect against CSRF attacks
nonce, err := setCSRFCookie(writer, req, "nonce")
if err != nil {
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
// Initialize registration info with machine key // Initialize registration info with machine key
registrationInfo := RegistrationInfo{ registrationInfo := RegistrationInfo{
@ -191,11 +195,12 @@ func (a *AuthProviderOIDC) RegisterHandler(
for k, v := range a.cfg.ExtraParams { for k, v := range a.cfg.ExtraParams {
extras = append(extras, oauth2.SetAuthURLParam(k, v)) extras = append(extras, oauth2.SetAuthURLParam(k, v))
} }
extras = append(extras, oidc.Nonce(nonce))
// Cache the registration info // Cache the registration info
a.registrationCache.Set(stateStr, registrationInfo) a.registrationCache.Set(state, registrationInfo)
authURL := a.oauth2Config.AuthCodeURL(stateStr, extras...) authURL := a.oauth2Config.AuthCodeURL(state, extras...)
log.Debug().Msgf("Redirecting to %s for authentication", authURL) log.Debug().Msgf("Redirecting to %s for authentication", authURL)
http.Redirect(writer, req, authURL, http.StatusFound) http.Redirect(writer, req, authURL, http.StatusFound)
@ -228,11 +233,34 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
return return
} }
log.Debug().Interface("cookies", req.Cookies()).Msg("Received oidc callback")
cookieState, err := req.Cookie("state")
if err != nil {
http.Error(writer, "state not found", http.StatusBadRequest)
return
}
if state != cookieState.Value {
http.Error(writer, "state did not match", http.StatusBadRequest)
return
}
idToken, err := a.extractIDToken(req.Context(), code, state) idToken, err := a.extractIDToken(req.Context(), code, state)
if err != nil { if err != nil {
http.Error(writer, err.Error(), http.StatusBadRequest) http.Error(writer, err.Error(), http.StatusBadRequest)
return return
} }
nonce, err := req.Cookie("nonce")
if err != nil {
http.Error(writer, "nonce not found", http.StatusBadRequest)
return
}
if idToken.Nonce != nonce.Value {
http.Error(writer, "nonce did not match", http.StatusBadRequest)
return
}
nodeExpiry := a.determineNodeExpiry(idToken.Expiry) nodeExpiry := a.determineNodeExpiry(idToken.Expiry)
var claims types.OIDCClaims var claims types.OIDCClaims
@ -592,3 +620,22 @@ func getUserName(
return userName, nil return userName, nil
} }
func setCSRFCookie(w http.ResponseWriter, r *http.Request, name string) (string, error) {
val, err := util.GenerateRandomStringURLSafe(64)
if err != nil {
return val, err
}
c := &http.Cookie{
Path: "/oidc/callback",
Name: name,
Value: val,
MaxAge: int(time.Hour.Seconds()),
Secure: r.TLS != nil,
HttpOnly: true,
}
http.SetCookie(w, c)
return val, nil
}

View File

@ -10,6 +10,8 @@ import (
"log" "log"
"net" "net"
"net/http" "net/http"
"net/http/cookiejar"
"net/http/httptest"
"net/netip" "net/netip"
"sort" "sort"
"strconv" "strconv"
@ -747,6 +749,24 @@ func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration, users []mockoidc
}, nil }, nil
} }
type LoggingRoundTripper struct{}
func (t LoggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
noTls := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint
}
resp, err := noTls.RoundTrip(req)
if err != nil {
return nil, err
}
log.Printf("---")
log.Printf("method: %s | url: %s", resp.Request.Method, resp.Request.URL.String())
log.Printf("status: %d | cookies: %+v", resp.StatusCode, resp.Cookies())
return resp, nil
}
func (s *AuthOIDCScenario) runTailscaleUp( func (s *AuthOIDCScenario) runTailscaleUp(
userStr, loginServer string, userStr, loginServer string,
) error { ) error {
@ -758,35 +778,39 @@ func (s *AuthOIDCScenario) runTailscaleUp(
log.Printf("running tailscale up for user %s", userStr) log.Printf("running tailscale up for user %s", userStr)
if user, ok := s.users[userStr]; ok { if user, ok := s.users[userStr]; ok {
for _, client := range user.Clients { for _, client := range user.Clients {
c := client tsc := client
user.joinWaitGroup.Go(func() error { user.joinWaitGroup.Go(func() error {
loginURL, err := c.LoginWithURL(loginServer) loginURL, err := tsc.LoginWithURL(loginServer)
if err != nil { if err != nil {
log.Printf("%s failed to run tailscale up: %s", c.Hostname(), err) log.Printf("%s failed to run tailscale up: %s", tsc.Hostname(), err)
} }
loginURL.Host = fmt.Sprintf("%s:8080", headscale.GetIP()) loginURL.Host = fmt.Sprintf("%s:8080", headscale.GetHostname())
loginURL.Scheme = "http" loginURL.Scheme = "http"
if len(headscale.GetCert()) > 0 { if len(headscale.GetCert()) > 0 {
loginURL.Scheme = "https" loginURL.Scheme = "https"
} }
insecureTransport := &http.Transport{ httptest.NewRecorder()
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint hc := &http.Client{
Transport: LoggingRoundTripper{},
}
hc.Jar, err = cookiejar.New(nil)
if err != nil {
log.Printf("failed to create cookie jar: %s", err)
} }
log.Printf("%s login url: %s\n", c.Hostname(), loginURL.String()) log.Printf("%s login url: %s\n", tsc.Hostname(), loginURL.String())
log.Printf("%s logging in with url", c.Hostname()) log.Printf("%s logging in with url", tsc.Hostname())
httpClient := &http.Client{Transport: insecureTransport}
ctx := context.Background() ctx := context.Background()
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil) req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil)
resp, err := httpClient.Do(req) resp, err := hc.Do(req)
if err != nil { if err != nil {
log.Printf( log.Printf(
"%s failed to login using url %s: %s", "%s failed to login using url %s: %s",
c.Hostname(), tsc.Hostname(),
loginURL, loginURL,
err, err,
) )
@ -794,8 +818,10 @@ func (s *AuthOIDCScenario) runTailscaleUp(
return err return err
} }
log.Printf("cookies: %+v", hc.Jar.Cookies(loginURL))
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
log.Printf("%s response code of oidc login request was %s", c.Hostname(), resp.Status) log.Printf("%s response code of oidc login request was %s", tsc.Hostname(), resp.Status)
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
log.Printf("body: %s", body) log.Printf("body: %s", body)
@ -806,12 +832,12 @@ func (s *AuthOIDCScenario) runTailscaleUp(
_, err = io.ReadAll(resp.Body) _, err = io.ReadAll(resp.Body)
if err != nil { if err != nil {
log.Printf("%s failed to read response body: %s", c.Hostname(), err) log.Printf("%s failed to read response body: %s", tsc.Hostname(), err)
return err return err
} }
log.Printf("Finished request for %s to join tailnet", c.Hostname()) log.Printf("Finished request for %s to join tailnet", tsc.Hostname())
return nil return nil
}) })