diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index 74f51ddc..bab0061e 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -695,6 +695,29 @@ AND auth_key_id NOT IN ( }, Rollback: func(db *gorm.DB) error { return nil }, }, + // Fix the provider identifier for users that have a double slash in the + // provider identifier. + { + ID: "202505141324", + Migrate: func(tx *gorm.DB) error { + users, err := ListUsers(tx) + if err != nil { + return fmt.Errorf("listing users: %w", err) + } + + for _, user := range users { + user.ProviderIdentifier.String = types.CleanIdentifier(user.ProviderIdentifier.String) + + err := tx.Save(user).Error + if err != nil { + return fmt.Errorf("saving user: %w", err) + } + } + + return nil + }, + Rollback: func(db *gorm.DB) error { return nil }, + }, }, ) diff --git a/hscontrol/types/users.go b/hscontrol/types/users.go index 471cb1e5..6cd2c41a 100644 --- a/hscontrol/types/users.go +++ b/hscontrol/types/users.go @@ -194,13 +194,110 @@ type OIDCClaims struct { Username string `json:"preferred_username,omitempty"` } +// Identifier returns a unique identifier string combining the Iss and Sub claims. +// The format depends on whether Iss is a URL or not: +// - For URLs: Joins the URL and sub path (e.g., "https://example.com/sub") +// - For non-URLs: Joins with a slash (e.g., "oidc/sub") +// - For empty Iss: Returns just "sub" +// - For empty Sub: Returns just the Issuer +// - For both empty: Returns empty string +// +// The result is cleaned using CleanIdentifier() to ensure consistent formatting. func (c *OIDCClaims) Identifier() string { - if strings.HasPrefix(c.Iss, "http") { - if i, err := url.JoinPath(c.Iss, c.Sub); err == nil { - return i + // Handle empty components special cases + if c.Iss == "" && c.Sub == "" { + return "" + } + if c.Iss == "" { + return CleanIdentifier(c.Sub) + } + if c.Sub == "" { + return CleanIdentifier(c.Iss) + } + + // We'll use the raw values and let CleanIdentifier handle all the whitespace + issuer := c.Iss + subject := c.Sub + + var result string + // Try to parse as URL to handle URL joining correctly + if u, err := url.Parse(issuer); err == nil && u.Scheme != "" { + // For URLs, use proper URL path joining + if joined, err := url.JoinPath(issuer, subject); err == nil { + result = joined } } - return c.Iss + "/" + c.Sub + + // If URL joining failed or issuer wasn't a URL, do simple string join + if result == "" { + // Default case: simple string joining with slash + issuer = strings.TrimSuffix(issuer, "/") + subject = strings.TrimPrefix(subject, "/") + result = issuer + "/" + subject + } + + // Clean the result and return it + return CleanIdentifier(result) +} + +// CleanIdentifier cleans a potentially malformed identifier by removing double slashes +// while preserving protocol specifications like http://. This function will: +// - Trim all whitespace from the beginning and end of the identifier +// - Remove whitespace within path segments +// - Preserve the scheme (http://, https://, etc.) for URLs +// - Remove any duplicate slashes in the path +// - Remove empty path segments +// - For non-URL identifiers, it joins non-empty segments with a single slash +// - Returns empty string for identifiers with only slashes +// - Normalize URL schemes to lowercase +func CleanIdentifier(identifier string) string { + if identifier == "" { + return identifier + } + + // Trim leading/trailing whitespace + identifier = strings.TrimSpace(identifier) + + // Handle URLs with schemes + u, err := url.Parse(identifier) + if err == nil && u.Scheme != "" { + // Clean path by removing empty segments and whitespace within segments + parts := strings.FieldsFunc(u.Path, func(c rune) bool { return c == '/' }) + for i, part := range parts { + parts[i] = strings.TrimSpace(part) + } + // Remove empty parts after trimming + cleanParts := make([]string, 0, len(parts)) + for _, part := range parts { + if part != "" { + cleanParts = append(cleanParts, part) + } + } + + if len(cleanParts) == 0 { + u.Path = "" + } else { + u.Path = "/" + strings.Join(cleanParts, "/") + } + // Ensure scheme is lowercase + u.Scheme = strings.ToLower(u.Scheme) + return u.String() + } + + // Handle non-URL identifiers + parts := strings.FieldsFunc(identifier, func(c rune) bool { return c == '/' }) + // Clean whitespace from each part + cleanParts := make([]string, 0, len(parts)) + for _, part := range parts { + trimmed := strings.TrimSpace(part) + if trimmed != "" { + cleanParts = append(cleanParts, trimmed) + } + } + if len(cleanParts) == 0 { + return "" + } + return strings.Join(cleanParts, "/") } type OIDCUserInfo struct { @@ -231,7 +328,13 @@ func (u *User) FromClaim(claims *OIDCClaims) { } } - u.ProviderIdentifier = sql.NullString{String: claims.Identifier(), Valid: true} + // Get provider identifier + identifier := claims.Identifier() + // Ensure provider identifier always has a leading slash for backward compatibility + if claims.Iss == "" && !strings.HasPrefix(identifier, "/") { + identifier = "/" + identifier + } + u.ProviderIdentifier = sql.NullString{String: identifier, Valid: true} u.DisplayName = claims.Name u.ProfilePicURL = claims.ProfilePictureURL u.Provider = util.RegisterMethodOIDC diff --git a/hscontrol/types/users_test.go b/hscontrol/types/users_test.go index 12029701..f36489a3 100644 --- a/hscontrol/types/users_test.go +++ b/hscontrol/types/users_test.go @@ -7,6 +7,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/juanfont/headscale/hscontrol/util" + "github.com/stretchr/testify/assert" ) func TestUnmarshallOIDCClaims(t *testing.T) { @@ -76,6 +77,218 @@ func TestUnmarshallOIDCClaims(t *testing.T) { } } +func TestOIDCClaimsIdentifier(t *testing.T) { + tests := []struct { + name string + iss string + sub string + expected string + }{ + { + name: "standard URL with trailing slash", + iss: "https://oidc.example.com/", + sub: "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", + expected: "https://oidc.example.com/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", + }, + { + name: "standard URL without trailing slash", + iss: "https://oidc.example.com", + sub: "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", + expected: "https://oidc.example.com/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", + }, + { + name: "standard URL with uppercase protocol", + iss: "HTTPS://oidc.example.com/", + sub: "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", + expected: "https://oidc.example.com/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", + }, + { + name: "standard URL with path and trailing slash", + iss: "https://login.microsoftonline.com/v2.0/", + sub: "I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ", + expected: "https://login.microsoftonline.com/v2.0/I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ", + }, + { + name: "standard URL with path without trailing slash", + iss: "https://login.microsoftonline.com/v2.0", + sub: "I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ", + expected: "https://login.microsoftonline.com/v2.0/I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ", + }, + { + name: "non-URL identifier with slash", + iss: "oidc", + sub: "sub", + expected: "oidc/sub", + }, + { + name: "non-URL identifier with trailing slash", + iss: "oidc/", + sub: "sub", + expected: "oidc/sub", + }, + { + name: "subject with slash", + iss: "oidc/", + sub: "sub/", + expected: "oidc/sub", + }, + { + name: "whitespace", + iss: " oidc/ ", + sub: " sub ", + expected: "oidc/sub", + }, + { + name: "newline", + iss: "\noidc/\n", + sub: "\nsub\n", + expected: "oidc/sub", + }, + { + name: "tab", + iss: "\toidc/\t", + sub: "\tsub\t", + expected: "oidc/sub", + }, + { + name: "empty issuer", + iss: "", + sub: "sub", + expected: "sub", + }, + { + name: "empty subject", + iss: "https://oidc.example.com", + sub: "", + expected: "https://oidc.example.com", + }, + { + name: "both empty", + iss: "", + sub: "", + expected: "", + }, + { + name: "URL with double slash", + iss: "https://login.microsoftonline.com//v2.0", + sub: "I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ", + expected: "https://login.microsoftonline.com/v2.0/I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ", + }, + { + name: "FTP URL protocol", + iss: "ftp://example.com/directory", + sub: "resource", + expected: "ftp://example.com/directory/resource", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + claims := OIDCClaims{ + Iss: tt.iss, + Sub: tt.sub, + } + result := claims.Identifier() + assert.Equal(t, tt.expected, result) + if diff := cmp.Diff(tt.expected, result); diff != "" { + t.Errorf("Identifier() mismatch (-want +got):\n%s", diff) + } + + // Now clean the identifier and verify it's still the same + cleaned := CleanIdentifier(result) + + // Double-check with cmp.Diff for better error messages + if diff := cmp.Diff(tt.expected, cleaned); diff != "" { + t.Errorf("CleanIdentifier(Identifier()) mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestCleanIdentifier(t *testing.T) { + tests := []struct { + name string + identifier string + expected string + }{ + { + name: "empty identifier", + identifier: "", + expected: "", + }, + { + name: "simple identifier", + identifier: "oidc/sub", + expected: "oidc/sub", + }, + { + name: "double slashes in the middle", + identifier: "oidc//sub", + expected: "oidc/sub", + }, + { + name: "trailing slash", + identifier: "oidc/sub/", + expected: "oidc/sub", + }, + { + name: "multiple double slashes", + identifier: "oidc//sub///id//", + expected: "oidc/sub/id", + }, + { + name: "HTTP URL with proper scheme", + identifier: "http://example.com/path", + expected: "http://example.com/path", + }, + { + name: "HTTP URL with double slashes in path", + identifier: "http://example.com//path///resource", + expected: "http://example.com/path/resource", + }, + { + name: "HTTPS URL with empty segments", + identifier: "https://example.com///path//", + expected: "https://example.com/path", + }, + { + name: "URL with double slashes in domain", + identifier: "https://login.microsoftonline.com//v2.0/I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ", + expected: "https://login.microsoftonline.com/v2.0/I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ", + }, + { + name: "FTP URL with double slashes", + identifier: "ftp://example.com//resource//", + expected: "ftp://example.com/resource", + }, + { + name: "Just slashes", + identifier: "///", + expected: "", + }, + { + name: "Leading slash without URL", + identifier: "/path//to///resource", + expected: "path/to/resource", + }, + { + name: "Non-standard protocol", + identifier: "ldap://example.org//path//to//resource", + expected: "ldap://example.org/path/to/resource", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := CleanIdentifier(tt.identifier) + assert.Equal(t, tt.expected, result) + if diff := cmp.Diff(tt.expected, result); diff != "" { + t.Errorf("CleanIdentifier() mismatch (-want +got):\n%s", diff) + } + }) + } +} + func TestOIDCClaimsJSONToUser(t *testing.T) { tests := []struct { name string