mirror of
https://github.com/juanfont/headscale.git
synced 2025-01-11 20:23:18 -05:00
Set CSRF cookies for OIDC (#2328)
* set state and nounce in oidc to prevent csrf Fixes #2276 * try to fix new postgres issue Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> --------- Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
41bad2b9fd
commit
fa641e38b8
6
.github/workflows/test.yml
vendored
6
.github/workflows/test.yml
vendored
@ -34,4 +34,10 @@ jobs:
|
||||
|
||||
- name: Run tests
|
||||
if: steps.changed-files.outputs.files == 'true'
|
||||
env:
|
||||
# As of 2025-01-06, these env vars was not automatically
|
||||
# set anymore which breaks the initdb for postgres on
|
||||
# some of the database migration tests.
|
||||
LC_ALL: "en_US.UTF-8"
|
||||
LC_CTYPE: "en_US.UTF-8"
|
||||
run: nix develop --command -- gotestsum
|
||||
|
@ -3,9 +3,7 @@ package hscontrol
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
_ "embed"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"html/template"
|
||||
@ -157,13 +155,19 @@ func (a *AuthProviderOIDC) RegisterHandler(
|
||||
return
|
||||
}
|
||||
|
||||
randomBlob := make([]byte, randomByteSize)
|
||||
if _, err := rand.Read(randomBlob); err != nil {
|
||||
// Set the state and nonce cookies to protect against CSRF attacks
|
||||
state, err := setCSRFCookie(writer, req, "state")
|
||||
if err != nil {
|
||||
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
stateStr := hex.EncodeToString(randomBlob)[:32]
|
||||
// Set the state and nonce cookies to protect against CSRF attacks
|
||||
nonce, err := setCSRFCookie(writer, req, "nonce")
|
||||
if err != nil {
|
||||
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Initialize registration info with machine key
|
||||
registrationInfo := RegistrationInfo{
|
||||
@ -191,11 +195,12 @@ func (a *AuthProviderOIDC) RegisterHandler(
|
||||
for k, v := range a.cfg.ExtraParams {
|
||||
extras = append(extras, oauth2.SetAuthURLParam(k, v))
|
||||
}
|
||||
extras = append(extras, oidc.Nonce(nonce))
|
||||
|
||||
// Cache the registration info
|
||||
a.registrationCache.Set(stateStr, registrationInfo)
|
||||
a.registrationCache.Set(state, registrationInfo)
|
||||
|
||||
authURL := a.oauth2Config.AuthCodeURL(stateStr, extras...)
|
||||
authURL := a.oauth2Config.AuthCodeURL(state, extras...)
|
||||
log.Debug().Msgf("Redirecting to %s for authentication", authURL)
|
||||
|
||||
http.Redirect(writer, req, authURL, http.StatusFound)
|
||||
@ -228,11 +233,34 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug().Interface("cookies", req.Cookies()).Msg("Received oidc callback")
|
||||
cookieState, err := req.Cookie("state")
|
||||
if err != nil {
|
||||
http.Error(writer, "state not found", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if state != cookieState.Value {
|
||||
http.Error(writer, "state did not match", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
idToken, err := a.extractIDToken(req.Context(), code, state)
|
||||
if err != nil {
|
||||
http.Error(writer, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
nonce, err := req.Cookie("nonce")
|
||||
if err != nil {
|
||||
http.Error(writer, "nonce not found", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if idToken.Nonce != nonce.Value {
|
||||
http.Error(writer, "nonce did not match", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
nodeExpiry := a.determineNodeExpiry(idToken.Expiry)
|
||||
|
||||
var claims types.OIDCClaims
|
||||
@ -592,3 +620,22 @@ func getUserName(
|
||||
|
||||
return userName, nil
|
||||
}
|
||||
|
||||
func setCSRFCookie(w http.ResponseWriter, r *http.Request, name string) (string, error) {
|
||||
val, err := util.GenerateRandomStringURLSafe(64)
|
||||
if err != nil {
|
||||
return val, err
|
||||
}
|
||||
|
||||
c := &http.Cookie{
|
||||
Path: "/oidc/callback",
|
||||
Name: name,
|
||||
Value: val,
|
||||
MaxAge: int(time.Hour.Seconds()),
|
||||
Secure: r.TLS != nil,
|
||||
HttpOnly: true,
|
||||
}
|
||||
http.SetCookie(w, c)
|
||||
|
||||
return val, nil
|
||||
}
|
||||
|
@ -10,6 +10,8 @@ import (
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"sort"
|
||||
"strconv"
|
||||
@ -747,6 +749,24 @@ func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration, users []mockoidc
|
||||
}, nil
|
||||
}
|
||||
|
||||
type LoggingRoundTripper struct{}
|
||||
|
||||
func (t LoggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
noTls := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint
|
||||
}
|
||||
resp, err := noTls.RoundTrip(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Printf("---")
|
||||
log.Printf("method: %s | url: %s", resp.Request.Method, resp.Request.URL.String())
|
||||
log.Printf("status: %d | cookies: %+v", resp.StatusCode, resp.Cookies())
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (s *AuthOIDCScenario) runTailscaleUp(
|
||||
userStr, loginServer string,
|
||||
) error {
|
||||
@ -758,35 +778,39 @@ func (s *AuthOIDCScenario) runTailscaleUp(
|
||||
log.Printf("running tailscale up for user %s", userStr)
|
||||
if user, ok := s.users[userStr]; ok {
|
||||
for _, client := range user.Clients {
|
||||
c := client
|
||||
tsc := client
|
||||
user.joinWaitGroup.Go(func() error {
|
||||
loginURL, err := c.LoginWithURL(loginServer)
|
||||
loginURL, err := tsc.LoginWithURL(loginServer)
|
||||
if err != nil {
|
||||
log.Printf("%s failed to run tailscale up: %s", c.Hostname(), err)
|
||||
log.Printf("%s failed to run tailscale up: %s", tsc.Hostname(), err)
|
||||
}
|
||||
|
||||
loginURL.Host = fmt.Sprintf("%s:8080", headscale.GetIP())
|
||||
loginURL.Host = fmt.Sprintf("%s:8080", headscale.GetHostname())
|
||||
loginURL.Scheme = "http"
|
||||
|
||||
if len(headscale.GetCert()) > 0 {
|
||||
loginURL.Scheme = "https"
|
||||
}
|
||||
|
||||
insecureTransport := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint
|
||||
httptest.NewRecorder()
|
||||
hc := &http.Client{
|
||||
Transport: LoggingRoundTripper{},
|
||||
}
|
||||
hc.Jar, err = cookiejar.New(nil)
|
||||
if err != nil {
|
||||
log.Printf("failed to create cookie jar: %s", err)
|
||||
}
|
||||
|
||||
log.Printf("%s login url: %s\n", c.Hostname(), loginURL.String())
|
||||
log.Printf("%s login url: %s\n", tsc.Hostname(), loginURL.String())
|
||||
|
||||
log.Printf("%s logging in with url", c.Hostname())
|
||||
httpClient := &http.Client{Transport: insecureTransport}
|
||||
log.Printf("%s logging in with url", tsc.Hostname())
|
||||
ctx := context.Background()
|
||||
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil)
|
||||
resp, err := httpClient.Do(req)
|
||||
resp, err := hc.Do(req)
|
||||
if err != nil {
|
||||
log.Printf(
|
||||
"%s failed to login using url %s: %s",
|
||||
c.Hostname(),
|
||||
tsc.Hostname(),
|
||||
loginURL,
|
||||
err,
|
||||
)
|
||||
@ -794,8 +818,10 @@ func (s *AuthOIDCScenario) runTailscaleUp(
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("cookies: %+v", hc.Jar.Cookies(loginURL))
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Printf("%s response code of oidc login request was %s", c.Hostname(), resp.Status)
|
||||
log.Printf("%s response code of oidc login request was %s", tsc.Hostname(), resp.Status)
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
log.Printf("body: %s", body)
|
||||
|
||||
@ -806,12 +832,12 @@ func (s *AuthOIDCScenario) runTailscaleUp(
|
||||
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
log.Printf("%s failed to read response body: %s", c.Hostname(), err)
|
||||
log.Printf("%s failed to read response body: %s", tsc.Hostname(), err)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("Finished request for %s to join tailnet", c.Hostname())
|
||||
log.Printf("Finished request for %s to join tailnet", tsc.Hostname())
|
||||
return nil
|
||||
})
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user