mirror of
https://github.com/juanfont/headscale.git
synced 2025-05-21 09:33:52 -04:00
users: harden, test, and add cleaner of identifier (#2593)
* users: harden, test, and add cleaner of identifier Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * db: migrate badly joined provider identifiers Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> --------- Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
d7a503a34e
commit
2dc2f3b3f0
@ -695,6 +695,29 @@ AND auth_key_id NOT IN (
|
||||
},
|
||||
Rollback: func(db *gorm.DB) error { return nil },
|
||||
},
|
||||
// Fix the provider identifier for users that have a double slash in the
|
||||
// provider identifier.
|
||||
{
|
||||
ID: "202505141324",
|
||||
Migrate: func(tx *gorm.DB) error {
|
||||
users, err := ListUsers(tx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listing users: %w", err)
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
user.ProviderIdentifier.String = types.CleanIdentifier(user.ProviderIdentifier.String)
|
||||
|
||||
err := tx.Save(user).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("saving user: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
Rollback: func(db *gorm.DB) error { return nil },
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -194,13 +194,110 @@ type OIDCClaims struct {
|
||||
Username string `json:"preferred_username,omitempty"`
|
||||
}
|
||||
|
||||
// Identifier returns a unique identifier string combining the Iss and Sub claims.
|
||||
// The format depends on whether Iss is a URL or not:
|
||||
// - For URLs: Joins the URL and sub path (e.g., "https://example.com/sub")
|
||||
// - For non-URLs: Joins with a slash (e.g., "oidc/sub")
|
||||
// - For empty Iss: Returns just "sub"
|
||||
// - For empty Sub: Returns just the Issuer
|
||||
// - For both empty: Returns empty string
|
||||
//
|
||||
// The result is cleaned using CleanIdentifier() to ensure consistent formatting.
|
||||
func (c *OIDCClaims) Identifier() string {
|
||||
if strings.HasPrefix(c.Iss, "http") {
|
||||
if i, err := url.JoinPath(c.Iss, c.Sub); err == nil {
|
||||
return i
|
||||
// Handle empty components special cases
|
||||
if c.Iss == "" && c.Sub == "" {
|
||||
return ""
|
||||
}
|
||||
if c.Iss == "" {
|
||||
return CleanIdentifier(c.Sub)
|
||||
}
|
||||
if c.Sub == "" {
|
||||
return CleanIdentifier(c.Iss)
|
||||
}
|
||||
|
||||
// We'll use the raw values and let CleanIdentifier handle all the whitespace
|
||||
issuer := c.Iss
|
||||
subject := c.Sub
|
||||
|
||||
var result string
|
||||
// Try to parse as URL to handle URL joining correctly
|
||||
if u, err := url.Parse(issuer); err == nil && u.Scheme != "" {
|
||||
// For URLs, use proper URL path joining
|
||||
if joined, err := url.JoinPath(issuer, subject); err == nil {
|
||||
result = joined
|
||||
}
|
||||
}
|
||||
return c.Iss + "/" + c.Sub
|
||||
|
||||
// If URL joining failed or issuer wasn't a URL, do simple string join
|
||||
if result == "" {
|
||||
// Default case: simple string joining with slash
|
||||
issuer = strings.TrimSuffix(issuer, "/")
|
||||
subject = strings.TrimPrefix(subject, "/")
|
||||
result = issuer + "/" + subject
|
||||
}
|
||||
|
||||
// Clean the result and return it
|
||||
return CleanIdentifier(result)
|
||||
}
|
||||
|
||||
// CleanIdentifier cleans a potentially malformed identifier by removing double slashes
|
||||
// while preserving protocol specifications like http://. This function will:
|
||||
// - Trim all whitespace from the beginning and end of the identifier
|
||||
// - Remove whitespace within path segments
|
||||
// - Preserve the scheme (http://, https://, etc.) for URLs
|
||||
// - Remove any duplicate slashes in the path
|
||||
// - Remove empty path segments
|
||||
// - For non-URL identifiers, it joins non-empty segments with a single slash
|
||||
// - Returns empty string for identifiers with only slashes
|
||||
// - Normalize URL schemes to lowercase
|
||||
func CleanIdentifier(identifier string) string {
|
||||
if identifier == "" {
|
||||
return identifier
|
||||
}
|
||||
|
||||
// Trim leading/trailing whitespace
|
||||
identifier = strings.TrimSpace(identifier)
|
||||
|
||||
// Handle URLs with schemes
|
||||
u, err := url.Parse(identifier)
|
||||
if err == nil && u.Scheme != "" {
|
||||
// Clean path by removing empty segments and whitespace within segments
|
||||
parts := strings.FieldsFunc(u.Path, func(c rune) bool { return c == '/' })
|
||||
for i, part := range parts {
|
||||
parts[i] = strings.TrimSpace(part)
|
||||
}
|
||||
// Remove empty parts after trimming
|
||||
cleanParts := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
if part != "" {
|
||||
cleanParts = append(cleanParts, part)
|
||||
}
|
||||
}
|
||||
|
||||
if len(cleanParts) == 0 {
|
||||
u.Path = ""
|
||||
} else {
|
||||
u.Path = "/" + strings.Join(cleanParts, "/")
|
||||
}
|
||||
// Ensure scheme is lowercase
|
||||
u.Scheme = strings.ToLower(u.Scheme)
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// Handle non-URL identifiers
|
||||
parts := strings.FieldsFunc(identifier, func(c rune) bool { return c == '/' })
|
||||
// Clean whitespace from each part
|
||||
cleanParts := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
trimmed := strings.TrimSpace(part)
|
||||
if trimmed != "" {
|
||||
cleanParts = append(cleanParts, trimmed)
|
||||
}
|
||||
}
|
||||
if len(cleanParts) == 0 {
|
||||
return ""
|
||||
}
|
||||
return strings.Join(cleanParts, "/")
|
||||
}
|
||||
|
||||
type OIDCUserInfo struct {
|
||||
@ -231,7 +328,13 @@ func (u *User) FromClaim(claims *OIDCClaims) {
|
||||
}
|
||||
}
|
||||
|
||||
u.ProviderIdentifier = sql.NullString{String: claims.Identifier(), Valid: true}
|
||||
// Get provider identifier
|
||||
identifier := claims.Identifier()
|
||||
// Ensure provider identifier always has a leading slash for backward compatibility
|
||||
if claims.Iss == "" && !strings.HasPrefix(identifier, "/") {
|
||||
identifier = "/" + identifier
|
||||
}
|
||||
u.ProviderIdentifier = sql.NullString{String: identifier, Valid: true}
|
||||
u.DisplayName = claims.Name
|
||||
u.ProfilePicURL = claims.ProfilePictureURL
|
||||
u.Provider = util.RegisterMethodOIDC
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestUnmarshallOIDCClaims(t *testing.T) {
|
||||
@ -76,6 +77,218 @@ func TestUnmarshallOIDCClaims(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCClaimsIdentifier(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
iss string
|
||||
sub string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "standard URL with trailing slash",
|
||||
iss: "https://oidc.example.com/",
|
||||
sub: "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx",
|
||||
expected: "https://oidc.example.com/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx",
|
||||
},
|
||||
{
|
||||
name: "standard URL without trailing slash",
|
||||
iss: "https://oidc.example.com",
|
||||
sub: "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx",
|
||||
expected: "https://oidc.example.com/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx",
|
||||
},
|
||||
{
|
||||
name: "standard URL with uppercase protocol",
|
||||
iss: "HTTPS://oidc.example.com/",
|
||||
sub: "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx",
|
||||
expected: "https://oidc.example.com/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx",
|
||||
},
|
||||
{
|
||||
name: "standard URL with path and trailing slash",
|
||||
iss: "https://login.microsoftonline.com/v2.0/",
|
||||
sub: "I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ",
|
||||
expected: "https://login.microsoftonline.com/v2.0/I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ",
|
||||
},
|
||||
{
|
||||
name: "standard URL with path without trailing slash",
|
||||
iss: "https://login.microsoftonline.com/v2.0",
|
||||
sub: "I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ",
|
||||
expected: "https://login.microsoftonline.com/v2.0/I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ",
|
||||
},
|
||||
{
|
||||
name: "non-URL identifier with slash",
|
||||
iss: "oidc",
|
||||
sub: "sub",
|
||||
expected: "oidc/sub",
|
||||
},
|
||||
{
|
||||
name: "non-URL identifier with trailing slash",
|
||||
iss: "oidc/",
|
||||
sub: "sub",
|
||||
expected: "oidc/sub",
|
||||
},
|
||||
{
|
||||
name: "subject with slash",
|
||||
iss: "oidc/",
|
||||
sub: "sub/",
|
||||
expected: "oidc/sub",
|
||||
},
|
||||
{
|
||||
name: "whitespace",
|
||||
iss: " oidc/ ",
|
||||
sub: " sub ",
|
||||
expected: "oidc/sub",
|
||||
},
|
||||
{
|
||||
name: "newline",
|
||||
iss: "\noidc/\n",
|
||||
sub: "\nsub\n",
|
||||
expected: "oidc/sub",
|
||||
},
|
||||
{
|
||||
name: "tab",
|
||||
iss: "\toidc/\t",
|
||||
sub: "\tsub\t",
|
||||
expected: "oidc/sub",
|
||||
},
|
||||
{
|
||||
name: "empty issuer",
|
||||
iss: "",
|
||||
sub: "sub",
|
||||
expected: "sub",
|
||||
},
|
||||
{
|
||||
name: "empty subject",
|
||||
iss: "https://oidc.example.com",
|
||||
sub: "",
|
||||
expected: "https://oidc.example.com",
|
||||
},
|
||||
{
|
||||
name: "both empty",
|
||||
iss: "",
|
||||
sub: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "URL with double slash",
|
||||
iss: "https://login.microsoftonline.com//v2.0",
|
||||
sub: "I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ",
|
||||
expected: "https://login.microsoftonline.com/v2.0/I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ",
|
||||
},
|
||||
{
|
||||
name: "FTP URL protocol",
|
||||
iss: "ftp://example.com/directory",
|
||||
sub: "resource",
|
||||
expected: "ftp://example.com/directory/resource",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
claims := OIDCClaims{
|
||||
Iss: tt.iss,
|
||||
Sub: tt.sub,
|
||||
}
|
||||
result := claims.Identifier()
|
||||
assert.Equal(t, tt.expected, result)
|
||||
if diff := cmp.Diff(tt.expected, result); diff != "" {
|
||||
t.Errorf("Identifier() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
// Now clean the identifier and verify it's still the same
|
||||
cleaned := CleanIdentifier(result)
|
||||
|
||||
// Double-check with cmp.Diff for better error messages
|
||||
if diff := cmp.Diff(tt.expected, cleaned); diff != "" {
|
||||
t.Errorf("CleanIdentifier(Identifier()) mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanIdentifier(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
identifier string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "empty identifier",
|
||||
identifier: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "simple identifier",
|
||||
identifier: "oidc/sub",
|
||||
expected: "oidc/sub",
|
||||
},
|
||||
{
|
||||
name: "double slashes in the middle",
|
||||
identifier: "oidc//sub",
|
||||
expected: "oidc/sub",
|
||||
},
|
||||
{
|
||||
name: "trailing slash",
|
||||
identifier: "oidc/sub/",
|
||||
expected: "oidc/sub",
|
||||
},
|
||||
{
|
||||
name: "multiple double slashes",
|
||||
identifier: "oidc//sub///id//",
|
||||
expected: "oidc/sub/id",
|
||||
},
|
||||
{
|
||||
name: "HTTP URL with proper scheme",
|
||||
identifier: "http://example.com/path",
|
||||
expected: "http://example.com/path",
|
||||
},
|
||||
{
|
||||
name: "HTTP URL with double slashes in path",
|
||||
identifier: "http://example.com//path///resource",
|
||||
expected: "http://example.com/path/resource",
|
||||
},
|
||||
{
|
||||
name: "HTTPS URL with empty segments",
|
||||
identifier: "https://example.com///path//",
|
||||
expected: "https://example.com/path",
|
||||
},
|
||||
{
|
||||
name: "URL with double slashes in domain",
|
||||
identifier: "https://login.microsoftonline.com//v2.0/I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ",
|
||||
expected: "https://login.microsoftonline.com/v2.0/I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ",
|
||||
},
|
||||
{
|
||||
name: "FTP URL with double slashes",
|
||||
identifier: "ftp://example.com//resource//",
|
||||
expected: "ftp://example.com/resource",
|
||||
},
|
||||
{
|
||||
name: "Just slashes",
|
||||
identifier: "///",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Leading slash without URL",
|
||||
identifier: "/path//to///resource",
|
||||
expected: "path/to/resource",
|
||||
},
|
||||
{
|
||||
name: "Non-standard protocol",
|
||||
identifier: "ldap://example.org//path//to//resource",
|
||||
expected: "ldap://example.org/path/to/resource",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := CleanIdentifier(tt.identifier)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
if diff := cmp.Diff(tt.expected, result); diff != "" {
|
||||
t.Errorf("CleanIdentifier() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCClaimsJSONToUser(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
Loading…
x
Reference in New Issue
Block a user