fix: make state cookies valid when client uses multiple login URLs

On Windows, if the user clicks the Tailscale icon in the system tray,
it opens a login URL in the browser.

When the login URL is opened, `state/nonce` cookies are set for that particular URL.

If the user clicks the icon again, a new login URL is opened in the browser,
and new cookies are set.

If the user proceeds with auth in the first tab,
the redirect results in a "state did not match" error.

This patch ensures that each opened login URL sets an individual cookie
that remains valid on the `/oidc/callback` page.

`TestOIDCMultipleOpenedLoginUrls` illustrates and tests this behavior.
This commit is contained in:
Andrey Bobelev
2025-11-04 07:18:51 +02:00
committed by Kristoffer Dalby
parent 2024219bd1
commit 5cd15c3656
4 changed files with 287 additions and 26 deletions

View File

@@ -38,6 +38,7 @@ jobs:
- TestOIDCAuthenticationWithPKCE
- TestOIDCReloginSameNodeNewUser
- TestOIDCFollowUpUrl
- TestOIDCMultipleOpenedLoginUrls
- TestOIDCReloginSameNodeSameUser
- TestAuthWebFlowAuthenticationPingAll
- TestAuthWebFlowLogoutAndReloginSameUser

View File

@@ -213,7 +213,8 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
return
}
cookieState, err := req.Cookie("state")
stateCookieName := getCookieName("state", state)
cookieState, err := req.Cookie(stateCookieName)
if err != nil {
httpError(writer, NewHTTPError(http.StatusBadRequest, "state not found", err))
return
@@ -235,8 +236,13 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
httpError(writer, err)
return
}
if idToken.Nonce == "" {
httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found in IDToken", err))
return
}
nonce, err := req.Cookie("nonce")
nonceCookieName := getCookieName("nonce", idToken.Nonce)
nonce, err := req.Cookie(nonceCookieName)
if err != nil {
httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found", err))
return
@@ -584,6 +590,11 @@ func renderOIDCCallbackTemplate(
return &content, nil
}
// getCookieName generates a unique cookie name based on a cookie value.
func getCookieName(baseName, value string) string {
return fmt.Sprintf("%s_%s", baseName, value[:6])
}
func setCSRFCookie(w http.ResponseWriter, r *http.Request, name string) (string, error) {
val, err := util.GenerateRandomStringURLSafe(64)
if err != nil {
@@ -592,7 +603,7 @@ func setCSRFCookie(w http.ResponseWriter, r *http.Request, name string) (string,
c := &http.Cookie{
Path: "/oidc/callback",
Name: name,
Name: getCookieName(name, val),
Value: val,
MaxAge: int(time.Hour.Seconds()),
Secure: r.TLS != nil,

View File

@@ -953,6 +953,119 @@ func TestOIDCFollowUpUrl(t *testing.T) {
}, 10*time.Second, 200*time.Millisecond, "Waiting for expected node list after OIDC login")
}
// TestOIDCMultipleOpenedLoginUrls tests the scenario:
// - client (mostly Windows) opens multiple browser tabs with different login URLs
// - client performs auth on the first opened browser tab
//
// This test makes sure that cookies are still valid for the first browser tab.
func TestOIDCMultipleOpenedLoginUrls(t *testing.T) {
IntegrationSkip(t)
scenario, err := NewScenario(
ScenarioSpec{
OIDCUsers: []mockoidc.MockUser{
oidcMockUser("user1", true),
},
},
)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
oidcMap := map[string]string{
"HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(),
"HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(),
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
}
err = scenario.CreateHeadscaleEnvWithLoginURL(
nil,
hsic.WithTestName("oidcauthrelog"),
hsic.WithConfigEnv(oidcMap),
hsic.WithTLS(),
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())),
hsic.WithEmbeddedDERPServerOnly(),
)
require.NoError(t, err)
headscale, err := scenario.Headscale()
require.NoError(t, err)
listUsers, err := headscale.ListUsers()
require.NoError(t, err)
assert.Empty(t, listUsers)
ts, err := scenario.CreateTailscaleNode(
"unstable",
tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]),
)
require.NoError(t, err)
u1, err := ts.LoginWithURL(headscale.GetEndpoint())
require.NoError(t, err)
u2, err := ts.LoginWithURL(headscale.GetEndpoint())
require.NoError(t, err)
// make sure login URLs are different
require.NotEqual(t, u1.String(), u2.String())
loginClient, err := newLoginHTTPClient(ts.Hostname())
require.NoError(t, err)
// open the first login URL "in browser"
_, redirect1, err := doLoginURLWithClient(ts.Hostname(), u1, loginClient, false)
require.NoError(t, err)
// open the second login URL "in browser"
_, redirect2, err := doLoginURLWithClient(ts.Hostname(), u2, loginClient, false)
require.NoError(t, err)
// two valid redirects with different state/nonce params
require.NotEqual(t, redirect1.String(), redirect2.String())
// complete auth with the first opened "browser tab"
_, redirect1, err = doLoginURLWithClient(ts.Hostname(), redirect1, loginClient, true)
require.NoError(t, err)
listUsers, err = headscale.ListUsers()
require.NoError(t, err)
assert.Len(t, listUsers, 1)
wantUsers := []*v1.User{
{
Id: 1,
Name: "user1",
Email: "user1@headscale.net",
Provider: "oidc",
ProviderId: scenario.mockOIDC.Issuer() + "/user1",
},
}
sort.Slice(
listUsers, func(i, j int) bool {
return listUsers[i].GetId() < listUsers[j].GetId()
},
)
if diff := cmp.Diff(
wantUsers,
listUsers,
cmpopts.IgnoreUnexported(v1.User{}),
cmpopts.IgnoreFields(v1.User{}, "CreatedAt"),
); diff != "" {
t.Fatalf("unexpected users: %s", diff)
}
assert.EventuallyWithT(
t, func(c *assert.CollectT) {
listNodes, err := headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, listNodes, 1)
}, 10*time.Second, 200*time.Millisecond, "Waiting for expected node list after OIDC login",
)
}
// TestOIDCReloginSameNodeSameUser tests the scenario where a single Tailscale client
// authenticates using OIDC (OpenID Connect), logs out, and then logs back in as the same user.
//

View File

@@ -860,47 +860,183 @@ func (s *Scenario) RunTailscaleUpWithURL(userStr, loginServer string) error {
return fmt.Errorf("failed to up tailscale node: %w", errNoUserAvailable)
}
// doLoginURL visits the given login URL and returns the body as a
// string.
func doLoginURL(hostname string, loginURL *url.URL) (string, error) {
log.Printf("%s login url: %s\n", hostname, loginURL.String())
type debugJar struct {
inner *cookiejar.Jar
mu sync.RWMutex
store map[string]map[string]map[string]*http.Cookie // domain -> path -> name -> cookie
}
var err error
func newDebugJar() (*debugJar, error) {
jar, err := cookiejar.New(nil)
if err != nil {
return nil, err
}
return &debugJar{
inner: jar,
store: make(map[string]map[string]map[string]*http.Cookie),
}, nil
}
func (j *debugJar) SetCookies(u *url.URL, cookies []*http.Cookie) {
j.inner.SetCookies(u, cookies)
j.mu.Lock()
defer j.mu.Unlock()
for _, c := range cookies {
if c == nil || c.Name == "" {
continue
}
domain := c.Domain
if domain == "" {
domain = u.Hostname()
}
path := c.Path
if path == "" {
path = "/"
}
if _, ok := j.store[domain]; !ok {
j.store[domain] = make(map[string]map[string]*http.Cookie)
}
if _, ok := j.store[domain][path]; !ok {
j.store[domain][path] = make(map[string]*http.Cookie)
}
j.store[domain][path][c.Name] = copyCookie(c)
}
}
func (j *debugJar) Cookies(u *url.URL) []*http.Cookie {
return j.inner.Cookies(u)
}
func (j *debugJar) Dump(w io.Writer) {
j.mu.RLock()
defer j.mu.RUnlock()
for domain, paths := range j.store {
fmt.Fprintf(w, "Domain: %s\n", domain)
for path, byName := range paths {
fmt.Fprintf(w, " Path: %s\n", path)
for _, c := range byName {
fmt.Fprintf(
w, " %s=%s; Expires=%v; Secure=%v; HttpOnly=%v; SameSite=%v\n",
c.Name, c.Value, c.Expires, c.Secure, c.HttpOnly, c.SameSite,
)
}
}
}
}
func copyCookie(c *http.Cookie) *http.Cookie {
cc := *c
return &cc
}
func newLoginHTTPClient(hostname string) (*http.Client, error) {
hc := &http.Client{
Transport: LoggingRoundTripper{Hostname: hostname},
}
hc.Jar, err = cookiejar.New(nil)
jar, err := newDebugJar()
if err != nil {
return "", fmt.Errorf("%s failed to create cookiejar : %w", hostname, err)
return nil, fmt.Errorf("%s failed to create cookiejar: %w", hostname, err)
}
hc.Jar = jar
return hc, nil
}
// doLoginURL visits the given login URL and returns the body as a string.
func doLoginURL(hostname string, loginURL *url.URL) (string, error) {
log.Printf("%s login url: %s\n", hostname, loginURL.String())
hc, err := newLoginHTTPClient(hostname)
if err != nil {
return "", err
}
body, _, err := doLoginURLWithClient(hostname, loginURL, hc, true)
if err != nil {
return "", err
}
return body, nil
}
// doLoginURLWithClient performs the login request using the provided HTTP client.
// When followRedirects is false, it will return the first redirect without following it.
func doLoginURLWithClient(hostname string, loginURL *url.URL, hc *http.Client, followRedirects bool) (
string,
*url.URL,
error,
) {
if hc == nil {
return "", nil, fmt.Errorf("%s http client is nil", hostname)
}
if loginURL == nil {
return "", nil, fmt.Errorf("%s login url is nil", hostname)
}
log.Printf("%s logging in with url: %s", hostname, loginURL.String())
ctx := context.Background()
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil)
if err != nil {
return "", nil, fmt.Errorf("%s failed to create http request: %w", hostname, err)
}
originalRedirect := hc.CheckRedirect
if !followRedirects {
hc.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
}
defer func() {
hc.CheckRedirect = originalRedirect
}()
resp, err := hc.Do(req)
if err != nil {
return "", fmt.Errorf("%s failed to send http request: %w", hostname, err)
return "", nil, fmt.Errorf("%s failed to send http request: %w", hostname, err)
}
log.Printf("cookies: %+v", hc.Jar.Cookies(loginURL))
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
log.Printf("body: %s", body)
return "", fmt.Errorf("%s response code of login request was %w", hostname, err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
log.Printf("%s failed to read response body: %s", hostname, err)
return "", nil, fmt.Errorf("%s failed to read response body: %w", hostname, err)
}
body := string(bodyBytes)
return "", fmt.Errorf("%s failed to read response body: %w", hostname, err)
var redirectURL *url.URL
if resp.StatusCode >= http.StatusMultipleChoices && resp.StatusCode < http.StatusBadRequest {
redirectURL, err = resp.Location()
if err != nil {
return body, nil, fmt.Errorf("%s failed to resolve redirect location: %w", hostname, err)
}
}
return string(body), nil
if followRedirects && resp.StatusCode != http.StatusOK {
log.Printf("body: %s", body)
return body, redirectURL, fmt.Errorf("%s unexpected status code %d", hostname, resp.StatusCode)
}
if resp.StatusCode >= http.StatusBadRequest {
log.Printf("body: %s", body)
return body, redirectURL, fmt.Errorf("%s unexpected status code %d", hostname, resp.StatusCode)
}
if hc.Jar != nil {
if jar, ok := hc.Jar.(*debugJar); ok {
jar.Dump(os.Stdout)
} else {
log.Printf("cookies: %+v", hc.Jar.Cookies(loginURL))
}
}
return body, redirectURL, nil
}
var errParseAuthPage = errors.New("failed to parse auth page")