mirror of
https://github.com/juanfont/headscale.git
synced 2025-11-20 17:56:02 -05:00
fix: make state cookies valid when client uses multiple login URLs
On Windows, if the user clicks the Tailscale icon in the system tray, it opens a login URL in the browser. When the login URL is opened, `state/nonce` cookies are set for that particular URL. If the user clicks the icon again, a new login URL is opened in the browser, and new cookies are set. If the user proceeds with auth in the first tab, the redirect results in a "state did not match" error. This patch ensures that each opened login URL sets an individual cookie that remains valid on the `/oidc/callback` page. `TestOIDCMultipleOpenedLoginUrls` illustrates and tests this behavior.
This commit is contained in:
committed by
Kristoffer Dalby
parent
2024219bd1
commit
5cd15c3656
1
.github/workflows/test-integration.yaml
vendored
1
.github/workflows/test-integration.yaml
vendored
@@ -38,6 +38,7 @@ jobs:
|
|||||||
- TestOIDCAuthenticationWithPKCE
|
- TestOIDCAuthenticationWithPKCE
|
||||||
- TestOIDCReloginSameNodeNewUser
|
- TestOIDCReloginSameNodeNewUser
|
||||||
- TestOIDCFollowUpUrl
|
- TestOIDCFollowUpUrl
|
||||||
|
- TestOIDCMultipleOpenedLoginUrls
|
||||||
- TestOIDCReloginSameNodeSameUser
|
- TestOIDCReloginSameNodeSameUser
|
||||||
- TestAuthWebFlowAuthenticationPingAll
|
- TestAuthWebFlowAuthenticationPingAll
|
||||||
- TestAuthWebFlowLogoutAndReloginSameUser
|
- TestAuthWebFlowLogoutAndReloginSameUser
|
||||||
|
|||||||
@@ -213,7 +213,8 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
cookieState, err := req.Cookie("state")
|
stateCookieName := getCookieName("state", state)
|
||||||
|
cookieState, err := req.Cookie(stateCookieName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(writer, NewHTTPError(http.StatusBadRequest, "state not found", err))
|
httpError(writer, NewHTTPError(http.StatusBadRequest, "state not found", err))
|
||||||
return
|
return
|
||||||
@@ -235,8 +236,13 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||||||
httpError(writer, err)
|
httpError(writer, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if idToken.Nonce == "" {
|
||||||
|
httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found in IDToken", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
nonce, err := req.Cookie("nonce")
|
nonceCookieName := getCookieName("nonce", idToken.Nonce)
|
||||||
|
nonce, err := req.Cookie(nonceCookieName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found", err))
|
httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found", err))
|
||||||
return
|
return
|
||||||
@@ -584,6 +590,11 @@ func renderOIDCCallbackTemplate(
|
|||||||
return &content, nil
|
return &content, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getCookieName generates a unique cookie name based on a cookie value.
|
||||||
|
func getCookieName(baseName, value string) string {
|
||||||
|
return fmt.Sprintf("%s_%s", baseName, value[:6])
|
||||||
|
}
|
||||||
|
|
||||||
func setCSRFCookie(w http.ResponseWriter, r *http.Request, name string) (string, error) {
|
func setCSRFCookie(w http.ResponseWriter, r *http.Request, name string) (string, error) {
|
||||||
val, err := util.GenerateRandomStringURLSafe(64)
|
val, err := util.GenerateRandomStringURLSafe(64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -592,7 +603,7 @@ func setCSRFCookie(w http.ResponseWriter, r *http.Request, name string) (string,
|
|||||||
|
|
||||||
c := &http.Cookie{
|
c := &http.Cookie{
|
||||||
Path: "/oidc/callback",
|
Path: "/oidc/callback",
|
||||||
Name: name,
|
Name: getCookieName(name, val),
|
||||||
Value: val,
|
Value: val,
|
||||||
MaxAge: int(time.Hour.Seconds()),
|
MaxAge: int(time.Hour.Seconds()),
|
||||||
Secure: r.TLS != nil,
|
Secure: r.TLS != nil,
|
||||||
|
|||||||
@@ -953,6 +953,119 @@ func TestOIDCFollowUpUrl(t *testing.T) {
|
|||||||
}, 10*time.Second, 200*time.Millisecond, "Waiting for expected node list after OIDC login")
|
}, 10*time.Second, 200*time.Millisecond, "Waiting for expected node list after OIDC login")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestOIDCMultipleOpenedLoginUrls tests the scenario:
|
||||||
|
// - client (mostly Windows) opens multiple browser tabs with different login URLs
|
||||||
|
// - client performs auth on the first opened browser tab
|
||||||
|
//
|
||||||
|
// This test makes sure that cookies are still valid for the first browser tab.
|
||||||
|
func TestOIDCMultipleOpenedLoginUrls(t *testing.T) {
|
||||||
|
IntegrationSkip(t)
|
||||||
|
|
||||||
|
scenario, err := NewScenario(
|
||||||
|
ScenarioSpec{
|
||||||
|
OIDCUsers: []mockoidc.MockUser{
|
||||||
|
oidcMockUser("user1", true),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer scenario.ShutdownAssertNoPanics(t)
|
||||||
|
|
||||||
|
oidcMap := map[string]string{
|
||||||
|
"HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(),
|
||||||
|
"HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(),
|
||||||
|
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
|
||||||
|
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
|
||||||
|
}
|
||||||
|
|
||||||
|
err = scenario.CreateHeadscaleEnvWithLoginURL(
|
||||||
|
nil,
|
||||||
|
hsic.WithTestName("oidcauthrelog"),
|
||||||
|
hsic.WithConfigEnv(oidcMap),
|
||||||
|
hsic.WithTLS(),
|
||||||
|
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())),
|
||||||
|
hsic.WithEmbeddedDERPServerOnly(),
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
headscale, err := scenario.Headscale()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
listUsers, err := headscale.ListUsers()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Empty(t, listUsers)
|
||||||
|
|
||||||
|
ts, err := scenario.CreateTailscaleNode(
|
||||||
|
"unstable",
|
||||||
|
tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]),
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
u1, err := ts.LoginWithURL(headscale.GetEndpoint())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
u2, err := ts.LoginWithURL(headscale.GetEndpoint())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// make sure login URLs are different
|
||||||
|
require.NotEqual(t, u1.String(), u2.String())
|
||||||
|
|
||||||
|
loginClient, err := newLoginHTTPClient(ts.Hostname())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// open the first login URL "in browser"
|
||||||
|
_, redirect1, err := doLoginURLWithClient(ts.Hostname(), u1, loginClient, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
// open the second login URL "in browser"
|
||||||
|
_, redirect2, err := doLoginURLWithClient(ts.Hostname(), u2, loginClient, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// two valid redirects with different state/nonce params
|
||||||
|
require.NotEqual(t, redirect1.String(), redirect2.String())
|
||||||
|
|
||||||
|
// complete auth with the first opened "browser tab"
|
||||||
|
_, redirect1, err = doLoginURLWithClient(ts.Hostname(), redirect1, loginClient, true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
listUsers, err = headscale.ListUsers()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, listUsers, 1)
|
||||||
|
|
||||||
|
wantUsers := []*v1.User{
|
||||||
|
{
|
||||||
|
Id: 1,
|
||||||
|
Name: "user1",
|
||||||
|
Email: "user1@headscale.net",
|
||||||
|
Provider: "oidc",
|
||||||
|
ProviderId: scenario.mockOIDC.Issuer() + "/user1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Slice(
|
||||||
|
listUsers, func(i, j int) bool {
|
||||||
|
return listUsers[i].GetId() < listUsers[j].GetId()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if diff := cmp.Diff(
|
||||||
|
wantUsers,
|
||||||
|
listUsers,
|
||||||
|
cmpopts.IgnoreUnexported(v1.User{}),
|
||||||
|
cmpopts.IgnoreFields(v1.User{}, "CreatedAt"),
|
||||||
|
); diff != "" {
|
||||||
|
t.Fatalf("unexpected users: %s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.EventuallyWithT(
|
||||||
|
t, func(c *assert.CollectT) {
|
||||||
|
listNodes, err := headscale.ListNodes()
|
||||||
|
assert.NoError(c, err)
|
||||||
|
assert.Len(c, listNodes, 1)
|
||||||
|
}, 10*time.Second, 200*time.Millisecond, "Waiting for expected node list after OIDC login",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
// TestOIDCReloginSameNodeSameUser tests the scenario where a single Tailscale client
|
// TestOIDCReloginSameNodeSameUser tests the scenario where a single Tailscale client
|
||||||
// authenticates using OIDC (OpenID Connect), logs out, and then logs back in as the same user.
|
// authenticates using OIDC (OpenID Connect), logs out, and then logs back in as the same user.
|
||||||
//
|
//
|
||||||
|
|||||||
@@ -860,47 +860,183 @@ func (s *Scenario) RunTailscaleUpWithURL(userStr, loginServer string) error {
|
|||||||
return fmt.Errorf("failed to up tailscale node: %w", errNoUserAvailable)
|
return fmt.Errorf("failed to up tailscale node: %w", errNoUserAvailable)
|
||||||
}
|
}
|
||||||
|
|
||||||
// doLoginURL visits the given login URL and returns the body as a
|
type debugJar struct {
|
||||||
// string.
|
inner *cookiejar.Jar
|
||||||
func doLoginURL(hostname string, loginURL *url.URL) (string, error) {
|
mu sync.RWMutex
|
||||||
log.Printf("%s login url: %s\n", hostname, loginURL.String())
|
store map[string]map[string]map[string]*http.Cookie // domain -> path -> name -> cookie
|
||||||
|
}
|
||||||
|
|
||||||
var err error
|
func newDebugJar() (*debugJar, error) {
|
||||||
|
jar, err := cookiejar.New(nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &debugJar{
|
||||||
|
inner: jar,
|
||||||
|
store: make(map[string]map[string]map[string]*http.Cookie),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (j *debugJar) SetCookies(u *url.URL, cookies []*http.Cookie) {
|
||||||
|
j.inner.SetCookies(u, cookies)
|
||||||
|
|
||||||
|
j.mu.Lock()
|
||||||
|
defer j.mu.Unlock()
|
||||||
|
|
||||||
|
for _, c := range cookies {
|
||||||
|
if c == nil || c.Name == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
domain := c.Domain
|
||||||
|
if domain == "" {
|
||||||
|
domain = u.Hostname()
|
||||||
|
}
|
||||||
|
path := c.Path
|
||||||
|
if path == "" {
|
||||||
|
path = "/"
|
||||||
|
}
|
||||||
|
if _, ok := j.store[domain]; !ok {
|
||||||
|
j.store[domain] = make(map[string]map[string]*http.Cookie)
|
||||||
|
}
|
||||||
|
if _, ok := j.store[domain][path]; !ok {
|
||||||
|
j.store[domain][path] = make(map[string]*http.Cookie)
|
||||||
|
}
|
||||||
|
j.store[domain][path][c.Name] = copyCookie(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (j *debugJar) Cookies(u *url.URL) []*http.Cookie {
|
||||||
|
return j.inner.Cookies(u)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (j *debugJar) Dump(w io.Writer) {
|
||||||
|
j.mu.RLock()
|
||||||
|
defer j.mu.RUnlock()
|
||||||
|
|
||||||
|
for domain, paths := range j.store {
|
||||||
|
fmt.Fprintf(w, "Domain: %s\n", domain)
|
||||||
|
for path, byName := range paths {
|
||||||
|
fmt.Fprintf(w, " Path: %s\n", path)
|
||||||
|
for _, c := range byName {
|
||||||
|
fmt.Fprintf(
|
||||||
|
w, " %s=%s; Expires=%v; Secure=%v; HttpOnly=%v; SameSite=%v\n",
|
||||||
|
c.Name, c.Value, c.Expires, c.Secure, c.HttpOnly, c.SameSite,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func copyCookie(c *http.Cookie) *http.Cookie {
|
||||||
|
cc := *c
|
||||||
|
return &cc
|
||||||
|
}
|
||||||
|
|
||||||
|
func newLoginHTTPClient(hostname string) (*http.Client, error) {
|
||||||
hc := &http.Client{
|
hc := &http.Client{
|
||||||
Transport: LoggingRoundTripper{Hostname: hostname},
|
Transport: LoggingRoundTripper{Hostname: hostname},
|
||||||
}
|
}
|
||||||
hc.Jar, err = cookiejar.New(nil)
|
|
||||||
|
jar, err := newDebugJar()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("%s failed to create cookiejar : %w", hostname, err)
|
return nil, fmt.Errorf("%s failed to create cookiejar: %w", hostname, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
hc.Jar = jar
|
||||||
|
|
||||||
|
return hc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// doLoginURL visits the given login URL and returns the body as a string.
|
||||||
|
func doLoginURL(hostname string, loginURL *url.URL) (string, error) {
|
||||||
|
log.Printf("%s login url: %s\n", hostname, loginURL.String())
|
||||||
|
|
||||||
|
hc, err := newLoginHTTPClient(hostname)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
body, _, err := doLoginURLWithClient(hostname, loginURL, hc, true)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// doLoginURLWithClient performs the login request using the provided HTTP client.
|
||||||
|
// When followRedirects is false, it will return the first redirect without following it.
|
||||||
|
func doLoginURLWithClient(hostname string, loginURL *url.URL, hc *http.Client, followRedirects bool) (
|
||||||
|
string,
|
||||||
|
*url.URL,
|
||||||
|
error,
|
||||||
|
) {
|
||||||
|
if hc == nil {
|
||||||
|
return "", nil, fmt.Errorf("%s http client is nil", hostname)
|
||||||
|
}
|
||||||
|
|
||||||
|
if loginURL == nil {
|
||||||
|
return "", nil, fmt.Errorf("%s login url is nil", hostname)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("%s logging in with url: %s", hostname, loginURL.String())
|
log.Printf("%s logging in with url: %s", hostname, loginURL.String())
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, fmt.Errorf("%s failed to create http request: %w", hostname, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
originalRedirect := hc.CheckRedirect
|
||||||
|
if !followRedirects {
|
||||||
|
hc.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||||
|
return http.ErrUseLastResponse
|
||||||
|
}
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
hc.CheckRedirect = originalRedirect
|
||||||
|
}()
|
||||||
|
|
||||||
resp, err := hc.Do(req)
|
resp, err := hc.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("%s failed to send http request: %w", hostname, err)
|
return "", nil, fmt.Errorf("%s failed to send http request: %w", hostname, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("cookies: %+v", hc.Jar.Cookies(loginURL))
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
body, _ := io.ReadAll(resp.Body)
|
|
||||||
log.Printf("body: %s", body)
|
|
||||||
|
|
||||||
return "", fmt.Errorf("%s response code of login request was %w", hostname, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
bodyBytes, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("%s failed to read response body: %s", hostname, err)
|
return "", nil, fmt.Errorf("%s failed to read response body: %w", hostname, err)
|
||||||
|
}
|
||||||
|
body := string(bodyBytes)
|
||||||
|
|
||||||
return "", fmt.Errorf("%s failed to read response body: %w", hostname, err)
|
var redirectURL *url.URL
|
||||||
|
if resp.StatusCode >= http.StatusMultipleChoices && resp.StatusCode < http.StatusBadRequest {
|
||||||
|
redirectURL, err = resp.Location()
|
||||||
|
if err != nil {
|
||||||
|
return body, nil, fmt.Errorf("%s failed to resolve redirect location: %w", hostname, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return string(body), nil
|
if followRedirects && resp.StatusCode != http.StatusOK {
|
||||||
|
log.Printf("body: %s", body)
|
||||||
|
|
||||||
|
return body, redirectURL, fmt.Errorf("%s unexpected status code %d", hostname, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode >= http.StatusBadRequest {
|
||||||
|
log.Printf("body: %s", body)
|
||||||
|
|
||||||
|
return body, redirectURL, fmt.Errorf("%s unexpected status code %d", hostname, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
if hc.Jar != nil {
|
||||||
|
if jar, ok := hc.Jar.(*debugJar); ok {
|
||||||
|
jar.Dump(os.Stdout)
|
||||||
|
} else {
|
||||||
|
log.Printf("cookies: %+v", hc.Jar.Cookies(loginURL))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return body, redirectURL, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var errParseAuthPage = errors.New("failed to parse auth page")
|
var errParseAuthPage = errors.New("failed to parse auth page")
|
||||||
|
|||||||
Reference in New Issue
Block a user