From 5cd15c36568acc562ca6639176b38969c64308a7 Mon Sep 17 00:00:00 2001 From: Andrey Bobelev Date: Tue, 4 Nov 2025 07:18:51 +0200 Subject: [PATCH] fix: make state cookies valid when client uses multiple login URLs On Windows, if the user clicks the Tailscale icon in the system tray, it opens a login URL in the browser. When the login URL is opened, `state/nonce` cookies are set for that particular URL. If the user clicks the icon again, a new login URL is opened in the browser, and new cookies are set. If the user proceeds with auth in the first tab, the redirect results in a "state did not match" error. This patch ensures that each opened login URL sets an individual cookie that remains valid on the `/oidc/callback` page. `TestOIDCMultipleOpenedLoginUrls` illustrates and tests this behavior. --- .github/workflows/test-integration.yaml | 1 + hscontrol/oidc.go | 17 ++- integration/auth_oidc_test.go | 113 +++++++++++++++ integration/scenario.go | 182 +++++++++++++++++++++--- 4 files changed, 287 insertions(+), 26 deletions(-) diff --git a/.github/workflows/test-integration.yaml b/.github/workflows/test-integration.yaml index 735c50bf..fe934aab 100644 --- a/.github/workflows/test-integration.yaml +++ b/.github/workflows/test-integration.yaml @@ -38,6 +38,7 @@ jobs: - TestOIDCAuthenticationWithPKCE - TestOIDCReloginSameNodeNewUser - TestOIDCFollowUpUrl + - TestOIDCMultipleOpenedLoginUrls - TestOIDCReloginSameNodeSameUser - TestAuthWebFlowAuthenticationPingAll - TestAuthWebFlowLogoutAndReloginSameUser diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 84d00712..7c7895c6 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -213,7 +213,8 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( return } - cookieState, err := req.Cookie("state") + stateCookieName := getCookieName("state", state) + cookieState, err := req.Cookie(stateCookieName) if err != nil { httpError(writer, NewHTTPError(http.StatusBadRequest, "state not found", err)) return @@ -235,8 +236,13 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( httpError(writer, err) return } + if idToken.Nonce == "" { + httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found in IDToken", err)) + return + } - nonce, err := req.Cookie("nonce") + nonceCookieName := getCookieName("nonce", idToken.Nonce) + nonce, err := req.Cookie(nonceCookieName) if err != nil { httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found", err)) return @@ -584,6 +590,11 @@ func renderOIDCCallbackTemplate( return &content, nil } +// getCookieName generates a unique cookie name based on a cookie value. +func getCookieName(baseName, value string) string { + return fmt.Sprintf("%s_%s", baseName, value[:6]) +} + func setCSRFCookie(w http.ResponseWriter, r *http.Request, name string) (string, error) { val, err := util.GenerateRandomStringURLSafe(64) if err != nil { @@ -592,7 +603,7 @@ func setCSRFCookie(w http.ResponseWriter, r *http.Request, name string) (string, c := &http.Cookie{ Path: "/oidc/callback", - Name: name, + Name: getCookieName(name, val), Value: val, MaxAge: int(time.Hour.Seconds()), Secure: r.TLS != nil, diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index 0a0b5b95..eebb8165 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -953,6 +953,119 @@ func TestOIDCFollowUpUrl(t *testing.T) { }, 10*time.Second, 200*time.Millisecond, "Waiting for expected node list after OIDC login") } +// TestOIDCMultipleOpenedLoginUrls tests the scenario: +// - client (mostly Windows) opens multiple browser tabs with different login URLs +// - client performs auth on the first opened browser tab +// +// This test makes sure that cookies are still valid for the first browser tab. +func TestOIDCMultipleOpenedLoginUrls(t *testing.T) { + IntegrationSkip(t) + + scenario, err := NewScenario( + ScenarioSpec{ + OIDCUsers: []mockoidc.MockUser{ + oidcMockUser("user1", true), + }, + }, + ) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + oidcMap := map[string]string{ + "HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(), + "HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(), + "CREDENTIALS_DIRECTORY_TEST": "/tmp", + "HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", + } + + err = scenario.CreateHeadscaleEnvWithLoginURL( + nil, + hsic.WithTestName("oidcauthrelog"), + hsic.WithConfigEnv(oidcMap), + hsic.WithTLS(), + hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())), + hsic.WithEmbeddedDERPServerOnly(), + ) + require.NoError(t, err) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + listUsers, err := headscale.ListUsers() + require.NoError(t, err) + assert.Empty(t, listUsers) + + ts, err := scenario.CreateTailscaleNode( + "unstable", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + ) + require.NoError(t, err) + + u1, err := ts.LoginWithURL(headscale.GetEndpoint()) + require.NoError(t, err) + + u2, err := ts.LoginWithURL(headscale.GetEndpoint()) + require.NoError(t, err) + + // make sure login URLs are different + require.NotEqual(t, u1.String(), u2.String()) + + loginClient, err := newLoginHTTPClient(ts.Hostname()) + require.NoError(t, err) + + // open the first login URL "in browser" + _, redirect1, err := doLoginURLWithClient(ts.Hostname(), u1, loginClient, false) + require.NoError(t, err) + // open the second login URL "in browser" + _, redirect2, err := doLoginURLWithClient(ts.Hostname(), u2, loginClient, false) + require.NoError(t, err) + + // two valid redirects with different state/nonce params + require.NotEqual(t, redirect1.String(), redirect2.String()) + + // complete auth with the first opened "browser tab" + _, redirect1, err = doLoginURLWithClient(ts.Hostname(), redirect1, loginClient, true) + require.NoError(t, err) + + listUsers, err = headscale.ListUsers() + require.NoError(t, err) + assert.Len(t, listUsers, 1) + + wantUsers := []*v1.User{ + { + Id: 1, + Name: "user1", + Email: "user1@headscale.net", + Provider: "oidc", + ProviderId: scenario.mockOIDC.Issuer() + "/user1", + }, + } + + sort.Slice( + listUsers, func(i, j int) bool { + return listUsers[i].GetId() < listUsers[j].GetId() + }, + ) + + if diff := cmp.Diff( + wantUsers, + listUsers, + cmpopts.IgnoreUnexported(v1.User{}), + cmpopts.IgnoreFields(v1.User{}, "CreatedAt"), + ); diff != "" { + t.Fatalf("unexpected users: %s", diff) + } + + assert.EventuallyWithT( + t, func(c *assert.CollectT) { + listNodes, err := headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, listNodes, 1) + }, 10*time.Second, 200*time.Millisecond, "Waiting for expected node list after OIDC login", + ) +} + // TestOIDCReloginSameNodeSameUser tests the scenario where a single Tailscale client // authenticates using OIDC (OpenID Connect), logs out, and then logs back in as the same user. // diff --git a/integration/scenario.go b/integration/scenario.go index aa844a7e..c3b5549c 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -860,47 +860,183 @@ func (s *Scenario) RunTailscaleUpWithURL(userStr, loginServer string) error { return fmt.Errorf("failed to up tailscale node: %w", errNoUserAvailable) } -// doLoginURL visits the given login URL and returns the body as a -// string. -func doLoginURL(hostname string, loginURL *url.URL) (string, error) { - log.Printf("%s login url: %s\n", hostname, loginURL.String()) +type debugJar struct { + inner *cookiejar.Jar + mu sync.RWMutex + store map[string]map[string]map[string]*http.Cookie // domain -> path -> name -> cookie +} - var err error +func newDebugJar() (*debugJar, error) { + jar, err := cookiejar.New(nil) + if err != nil { + return nil, err + } + return &debugJar{ + inner: jar, + store: make(map[string]map[string]map[string]*http.Cookie), + }, nil +} + +func (j *debugJar) SetCookies(u *url.URL, cookies []*http.Cookie) { + j.inner.SetCookies(u, cookies) + + j.mu.Lock() + defer j.mu.Unlock() + + for _, c := range cookies { + if c == nil || c.Name == "" { + continue + } + domain := c.Domain + if domain == "" { + domain = u.Hostname() + } + path := c.Path + if path == "" { + path = "/" + } + if _, ok := j.store[domain]; !ok { + j.store[domain] = make(map[string]map[string]*http.Cookie) + } + if _, ok := j.store[domain][path]; !ok { + j.store[domain][path] = make(map[string]*http.Cookie) + } + j.store[domain][path][c.Name] = copyCookie(c) + } +} + +func (j *debugJar) Cookies(u *url.URL) []*http.Cookie { + return j.inner.Cookies(u) +} + +func (j *debugJar) Dump(w io.Writer) { + j.mu.RLock() + defer j.mu.RUnlock() + + for domain, paths := range j.store { + fmt.Fprintf(w, "Domain: %s\n", domain) + for path, byName := range paths { + fmt.Fprintf(w, " Path: %s\n", path) + for _, c := range byName { + fmt.Fprintf( + w, " %s=%s; Expires=%v; Secure=%v; HttpOnly=%v; SameSite=%v\n", + c.Name, c.Value, c.Expires, c.Secure, c.HttpOnly, c.SameSite, + ) + } + } + } +} + +func copyCookie(c *http.Cookie) *http.Cookie { + cc := *c + return &cc +} + +func newLoginHTTPClient(hostname string) (*http.Client, error) { hc := &http.Client{ Transport: LoggingRoundTripper{Hostname: hostname}, } - hc.Jar, err = cookiejar.New(nil) + + jar, err := newDebugJar() if err != nil { - return "", fmt.Errorf("%s failed to create cookiejar : %w", hostname, err) + return nil, fmt.Errorf("%s failed to create cookiejar: %w", hostname, err) + } + + hc.Jar = jar + + return hc, nil +} + +// doLoginURL visits the given login URL and returns the body as a string. +func doLoginURL(hostname string, loginURL *url.URL) (string, error) { + log.Printf("%s login url: %s\n", hostname, loginURL.String()) + + hc, err := newLoginHTTPClient(hostname) + if err != nil { + return "", err + } + + body, _, err := doLoginURLWithClient(hostname, loginURL, hc, true) + if err != nil { + return "", err + } + + return body, nil +} + +// doLoginURLWithClient performs the login request using the provided HTTP client. +// When followRedirects is false, it will return the first redirect without following it. +func doLoginURLWithClient(hostname string, loginURL *url.URL, hc *http.Client, followRedirects bool) ( + string, + *url.URL, + error, +) { + if hc == nil { + return "", nil, fmt.Errorf("%s http client is nil", hostname) + } + + if loginURL == nil { + return "", nil, fmt.Errorf("%s login url is nil", hostname) } log.Printf("%s logging in with url: %s", hostname, loginURL.String()) ctx := context.Background() - req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil) + if err != nil { + return "", nil, fmt.Errorf("%s failed to create http request: %w", hostname, err) + } + + originalRedirect := hc.CheckRedirect + if !followRedirects { + hc.CheckRedirect = func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + } + } + defer func() { + hc.CheckRedirect = originalRedirect + }() + resp, err := hc.Do(req) if err != nil { - return "", fmt.Errorf("%s failed to send http request: %w", hostname, err) + return "", nil, fmt.Errorf("%s failed to send http request: %w", hostname, err) } - - log.Printf("cookies: %+v", hc.Jar.Cookies(loginURL)) - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - log.Printf("body: %s", body) - - return "", fmt.Errorf("%s response code of login request was %w", hostname, err) - } - defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) + bodyBytes, err := io.ReadAll(resp.Body) if err != nil { - log.Printf("%s failed to read response body: %s", hostname, err) + return "", nil, fmt.Errorf("%s failed to read response body: %w", hostname, err) + } + body := string(bodyBytes) - return "", fmt.Errorf("%s failed to read response body: %w", hostname, err) + var redirectURL *url.URL + if resp.StatusCode >= http.StatusMultipleChoices && resp.StatusCode < http.StatusBadRequest { + redirectURL, err = resp.Location() + if err != nil { + return body, nil, fmt.Errorf("%s failed to resolve redirect location: %w", hostname, err) + } } - return string(body), nil + if followRedirects && resp.StatusCode != http.StatusOK { + log.Printf("body: %s", body) + + return body, redirectURL, fmt.Errorf("%s unexpected status code %d", hostname, resp.StatusCode) + } + + if resp.StatusCode >= http.StatusBadRequest { + log.Printf("body: %s", body) + + return body, redirectURL, fmt.Errorf("%s unexpected status code %d", hostname, resp.StatusCode) + } + + if hc.Jar != nil { + if jar, ok := hc.Jar.(*debugJar); ok { + jar.Dump(os.Stdout) + } else { + log.Printf("cookies: %+v", hc.Jar.Cookies(loginURL)) + } + } + + return body, redirectURL, nil } var errParseAuthPage = errors.New("failed to parse auth page")