Redo OIDC configuration (#2020)
expand user, add claims to user This commit expands the user table with additional fields that can be retrieved from OIDC providers (and other places) and uses this data in various tailscale response objects if it is available. This is the beginning of implementing https://docs.google.com/document/d/1X85PMxIaVWDF6T_UPji3OeeUqVBcGj_uHRM5CI-AwlY/edit trying to make OIDC more coherant and maintainable in addition to giving the user a better experience and integration with a provider. remove usernames in magic dns, normalisation of emails this commit removes the option to have usernames as part of MagicDNS domains and headscale will now align with Tailscale, where there is a root domain, and the machine name. In addition, the various normalisation functions for dns names has been made lighter not caring about username and special character that wont occur. Email are no longer normalised as part of the policy processing. untagle oidc and regcache, use typed cache This commits stops reusing the registration cache for oidc purposes and switches the cache to be types and not use any allowing the removal of a bunch of casting. try to make reauth/register branches clearer in oidc Currently there was a function that did a bunch of stuff, finding the machine key, trying to find the node, reauthing the node, returning some status, and it was called validate which was very confusing. This commit tries to split this into what to do if the node exists, if it needs to register etc. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
bc9e83b52e
commit
218138afee
14
CHANGELOG.md
14
CHANGELOG.md
|
@ -2,11 +2,25 @@
|
||||||
|
|
||||||
## Next
|
## Next
|
||||||
|
|
||||||
|
### BREAKING
|
||||||
|
|
||||||
|
- Remove `dns.use_username_in_magic_dns` configuration option [#2020](https://github.com/juanfont/headscale/pull/2020)
|
||||||
|
- Having usernames in magic DNS is no longer possible.
|
||||||
|
- Redo OpenID Connect configuration [#2020](https://github.com/juanfont/headscale/pull/2020)
|
||||||
|
- `strip_email_domain` has been removed, domain is _always_ part of the username for OIDC.
|
||||||
|
- Users are now identified by `sub` claim in the ID token instead of username, allowing the username, name and email to be updated.
|
||||||
|
- User has been extended to store username, display name, profile picture url and email.
|
||||||
|
- These fields are forwarded to the client, and shows up nicely in the user switcher.
|
||||||
|
- These fields can be made available via the API/CLI for non-OIDC users in the future.
|
||||||
- Remove versions older than 1.56 [#2149](https://github.com/juanfont/headscale/pull/2149)
|
- Remove versions older than 1.56 [#2149](https://github.com/juanfont/headscale/pull/2149)
|
||||||
- Clean up old code required by old versions
|
- Clean up old code required by old versions
|
||||||
|
|
||||||
|
### Changes
|
||||||
|
|
||||||
- Improved compatibilty of built-in DERP server with clients connecting over WebSocket.
|
- Improved compatibilty of built-in DERP server with clients connecting over WebSocket.
|
||||||
- Allow nodes to use SSH agent forwarding [#2145](https://github.com/juanfont/headscale/pull/2145)
|
- Allow nodes to use SSH agent forwarding [#2145](https://github.com/juanfont/headscale/pull/2145)
|
||||||
|
|
||||||
|
|
||||||
## 0.23.0 (2024-09-18)
|
## 0.23.0 (2024-09-18)
|
||||||
|
|
||||||
This release was intended to be mainly a code reorganisation and refactoring, significantly improving the maintainability of the codebase. This should allow us to improve further and make it easier for the maintainers to keep on top of the project.
|
This release was intended to be mainly a code reorganisation and refactoring, significantly improving the maintainability of the codebase. This should allow us to improve further and make it easier for the maintainers to keep on top of the project.
|
||||||
|
|
|
@ -32,7 +32,7 @@
|
||||||
|
|
||||||
# When updating go.mod or go.sum, a new sha will need to be calculated,
|
# When updating go.mod or go.sum, a new sha will need to be calculated,
|
||||||
# update this if you have a mismatch after doing a change to thos files.
|
# update this if you have a mismatch after doing a change to thos files.
|
||||||
vendorHash = "sha256-/CPUkLLCwNKK3z3UZyF+AY0ArMnLaDmH0HV3/RYHo4c=";
|
vendorHash = "sha256-SDJSFji6498WI9bJLmY62VGt21TtD2GxrxRAWyYyr0c=";
|
||||||
|
|
||||||
subPackages = ["cmd/headscale"];
|
subPackages = ["cmd/headscale"];
|
||||||
|
|
||||||
|
|
3
go.mod
3
go.mod
|
@ -7,7 +7,6 @@ require (
|
||||||
github.com/coder/websocket v1.8.12
|
github.com/coder/websocket v1.8.12
|
||||||
github.com/coreos/go-oidc/v3 v3.11.0
|
github.com/coreos/go-oidc/v3 v3.11.0
|
||||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc
|
||||||
github.com/deckarep/golang-set/v2 v2.6.0
|
|
||||||
github.com/glebarez/sqlite v1.11.0
|
github.com/glebarez/sqlite v1.11.0
|
||||||
github.com/go-gormigrate/gormigrate/v2 v2.1.2
|
github.com/go-gormigrate/gormigrate/v2 v2.1.2
|
||||||
github.com/gofrs/uuid/v5 v5.3.0
|
github.com/gofrs/uuid/v5 v5.3.0
|
||||||
|
@ -19,7 +18,6 @@ require (
|
||||||
github.com/klauspost/compress v1.17.9
|
github.com/klauspost/compress v1.17.9
|
||||||
github.com/oauth2-proxy/mockoidc v0.0.0-20240214162133-caebfff84d25
|
github.com/oauth2-proxy/mockoidc v0.0.0-20240214162133-caebfff84d25
|
||||||
github.com/ory/dockertest/v3 v3.11.0
|
github.com/ory/dockertest/v3 v3.11.0
|
||||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
|
||||||
github.com/philip-bui/grpc-zerolog v1.0.1
|
github.com/philip-bui/grpc-zerolog v1.0.1
|
||||||
github.com/pkg/profile v1.7.0
|
github.com/pkg/profile v1.7.0
|
||||||
github.com/prometheus/client_golang v1.20.2
|
github.com/prometheus/client_golang v1.20.2
|
||||||
|
@ -49,6 +47,7 @@ require (
|
||||||
gorm.io/driver/postgres v1.5.9
|
gorm.io/driver/postgres v1.5.9
|
||||||
gorm.io/gorm v1.25.11
|
gorm.io/gorm v1.25.11
|
||||||
tailscale.com v1.75.0-pre.0.20240926101731-7d1160ddaab7
|
tailscale.com v1.75.0-pre.0.20240926101731-7d1160ddaab7
|
||||||
|
zgo.at/zcache/v2 v2.1.0
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
|
12
go.sum
12
go.sum
|
@ -128,8 +128,6 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1
|
||||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/dblohm7/wingoes v0.0.0-20240123200102-b75a8a7d7eb0 h1:vrC07UZcgPzu/OjWsmQKMGg3LoPSz9jh/pQXIrHjUj4=
|
github.com/dblohm7/wingoes v0.0.0-20240123200102-b75a8a7d7eb0 h1:vrC07UZcgPzu/OjWsmQKMGg3LoPSz9jh/pQXIrHjUj4=
|
||||||
github.com/dblohm7/wingoes v0.0.0-20240123200102-b75a8a7d7eb0/go.mod h1:Nx87SkVqTKd8UtT+xu7sM/l+LgXs6c0aHrlKusR+2EQ=
|
github.com/dblohm7/wingoes v0.0.0-20240123200102-b75a8a7d7eb0/go.mod h1:Nx87SkVqTKd8UtT+xu7sM/l+LgXs6c0aHrlKusR+2EQ=
|
||||||
github.com/deckarep/golang-set/v2 v2.6.0 h1:XfcQbWM1LlMB8BsJ8N9vW5ehnnPVIw0je80NsVHagjM=
|
|
||||||
github.com/deckarep/golang-set/v2 v2.6.0/go.mod h1:VAky9rY/yGXJOLEDv3OMci+7wtDpOF4IN+y82NBOac4=
|
|
||||||
github.com/digitalocean/go-smbios v0.0.0-20180907143718-390a4f403a8e h1:vUmf0yezR0y7jJ5pceLHthLaYf4bA5T14B6q39S4q2Q=
|
github.com/digitalocean/go-smbios v0.0.0-20180907143718-390a4f403a8e h1:vUmf0yezR0y7jJ5pceLHthLaYf4bA5T14B6q39S4q2Q=
|
||||||
github.com/digitalocean/go-smbios v0.0.0-20180907143718-390a4f403a8e/go.mod h1:YTIHhz/QFSYnu/EhlF2SpU2Uk+32abacUYA5ZPljz1A=
|
github.com/digitalocean/go-smbios v0.0.0-20180907143718-390a4f403a8e/go.mod h1:YTIHhz/QFSYnu/EhlF2SpU2Uk+32abacUYA5ZPljz1A=
|
||||||
github.com/djherbis/times v1.6.0 h1:w2ctJ92J8fBvWPxugmXIv7Nz7Q3iDMKNx9v5ocVH20c=
|
github.com/djherbis/times v1.6.0 h1:w2ctJ92J8fBvWPxugmXIv7Nz7Q3iDMKNx9v5ocVH20c=
|
||||||
|
@ -364,8 +362,6 @@ github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFSt
|
||||||
github.com/orisano/pixelmatch v0.0.0-20220722002657-fb0b55479cde/go.mod h1:nZgzbfBr3hhjoZnS66nKrHmduYNpc34ny7RK4z5/HM0=
|
github.com/orisano/pixelmatch v0.0.0-20220722002657-fb0b55479cde/go.mod h1:nZgzbfBr3hhjoZnS66nKrHmduYNpc34ny7RK4z5/HM0=
|
||||||
github.com/ory/dockertest/v3 v3.11.0 h1:OiHcxKAvSDUwsEVh2BjxQQc/5EHz9n0va9awCtNGuyA=
|
github.com/ory/dockertest/v3 v3.11.0 h1:OiHcxKAvSDUwsEVh2BjxQQc/5EHz9n0va9awCtNGuyA=
|
||||||
github.com/ory/dockertest/v3 v3.11.0/go.mod h1:VIPxS1gwT9NpPOrfD3rACs8Y9Z7yhzO4SB194iUDnUI=
|
github.com/ory/dockertest/v3 v3.11.0/go.mod h1:VIPxS1gwT9NpPOrfD3rACs8Y9Z7yhzO4SB194iUDnUI=
|
||||||
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
|
|
||||||
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
|
|
||||||
github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M=
|
github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M=
|
||||||
github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc=
|
github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc=
|
||||||
github.com/petermattis/goid v0.0.0-20240813172612-4fcff4a6cae7 h1:Dx7Ovyv/SFnMFw3fD4oEoeorXc6saIiQ23LrGLth0Gw=
|
github.com/petermattis/goid v0.0.0-20240813172612-4fcff4a6cae7 h1:Dx7Ovyv/SFnMFw3fD4oEoeorXc6saIiQ23LrGLth0Gw=
|
||||||
|
@ -729,11 +725,7 @@ modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||||
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
|
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
|
||||||
software.sslmate.com/src/go-pkcs12 v0.4.0 h1:H2g08FrTvSFKUj+D309j1DPfk5APnIdAQAB8aEykJ5k=
|
software.sslmate.com/src/go-pkcs12 v0.4.0 h1:H2g08FrTvSFKUj+D309j1DPfk5APnIdAQAB8aEykJ5k=
|
||||||
software.sslmate.com/src/go-pkcs12 v0.4.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI=
|
software.sslmate.com/src/go-pkcs12 v0.4.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI=
|
||||||
tailscale.com v1.75.0-pre.0.20240925091311-031f291c98fe h1:3+E/vlEsZa2FpWBz2Ly6/L4zh4utVO8z54Ms75HitrQ=
|
|
||||||
tailscale.com v1.75.0-pre.0.20240925091311-031f291c98fe/go.mod h1:G4R9objdXe2zAcLaLkDOcHfqN9XnspBifyBHGNwTzKg=
|
|
||||||
tailscale.com v1.75.0-pre.0.20240925102642-c17c476c0d59 h1:GSuB+bmPiVfBLRqVyLOFSU+9V00lXBz9HakAewevYZA=
|
|
||||||
tailscale.com v1.75.0-pre.0.20240925102642-c17c476c0d59/go.mod h1:G4R9objdXe2zAcLaLkDOcHfqN9XnspBifyBHGNwTzKg=
|
|
||||||
tailscale.com v1.75.0-pre.0.20240926030905-c90c9938c8a2 h1:ivZ1GEXMzCNI1VRp2TjUWmLuOtno7TqW26lZf7MlF4k=
|
|
||||||
tailscale.com v1.75.0-pre.0.20240926030905-c90c9938c8a2/go.mod h1:xKxYf3B3PuezFlRaMT+VhuVu8XTFUTLy+VCzLPMJVmg=
|
|
||||||
tailscale.com v1.75.0-pre.0.20240926101731-7d1160ddaab7 h1:nfRWV6ECxwNvvXKtbqSVstjlEi1BWktzv3FuxWpyyx0=
|
tailscale.com v1.75.0-pre.0.20240926101731-7d1160ddaab7 h1:nfRWV6ECxwNvvXKtbqSVstjlEi1BWktzv3FuxWpyyx0=
|
||||||
tailscale.com v1.75.0-pre.0.20240926101731-7d1160ddaab7/go.mod h1:xKxYf3B3PuezFlRaMT+VhuVu8XTFUTLy+VCzLPMJVmg=
|
tailscale.com v1.75.0-pre.0.20240926101731-7d1160ddaab7/go.mod h1:xKxYf3B3PuezFlRaMT+VhuVu8XTFUTLy+VCzLPMJVmg=
|
||||||
|
zgo.at/zcache/v2 v2.1.0 h1:USo+ubK+R4vtjw4viGzTe/zjXyPw6R7SK/RL3epBBxs=
|
||||||
|
zgo.at/zcache/v2 v2.1.0/go.mod h1:gyCeoLVo01QjDZynjime8xUGHHMbsLiPyUTBpDGd4Gk=
|
||||||
|
|
|
@ -18,7 +18,6 @@ import (
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/coreos/go-oidc/v3/oidc"
|
|
||||||
"github.com/davecgh/go-spew/spew"
|
"github.com/davecgh/go-spew/spew"
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware"
|
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware"
|
||||||
|
@ -33,7 +32,6 @@ import (
|
||||||
"github.com/juanfont/headscale/hscontrol/policy"
|
"github.com/juanfont/headscale/hscontrol/policy"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/patrickmn/go-cache"
|
|
||||||
zerolog "github.com/philip-bui/grpc-zerolog"
|
zerolog "github.com/philip-bui/grpc-zerolog"
|
||||||
"github.com/pkg/profile"
|
"github.com/pkg/profile"
|
||||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||||
|
@ -41,7 +39,6 @@ import (
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"golang.org/x/crypto/acme"
|
"golang.org/x/crypto/acme"
|
||||||
"golang.org/x/crypto/acme/autocert"
|
"golang.org/x/crypto/acme/autocert"
|
||||||
"golang.org/x/oauth2"
|
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
|
@ -57,6 +54,7 @@ import (
|
||||||
"tailscale.com/types/dnstype"
|
"tailscale.com/types/dnstype"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
"tailscale.com/util/dnsname"
|
"tailscale.com/util/dnsname"
|
||||||
|
zcache "zgo.at/zcache/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -95,10 +93,9 @@ type Headscale struct {
|
||||||
mapper *mapper.Mapper
|
mapper *mapper.Mapper
|
||||||
nodeNotifier *notifier.Notifier
|
nodeNotifier *notifier.Notifier
|
||||||
|
|
||||||
oidcProvider *oidc.Provider
|
registrationCache *zcache.Cache[string, types.Node]
|
||||||
oauth2Config *oauth2.Config
|
|
||||||
|
|
||||||
registrationCache *cache.Cache
|
authProvider AuthProvider
|
||||||
|
|
||||||
pollNetMapStreamWG sync.WaitGroup
|
pollNetMapStreamWG sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
@ -123,7 +120,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||||
return nil, fmt.Errorf("failed to read or create Noise protocol private key: %w", err)
|
return nil, fmt.Errorf("failed to read or create Noise protocol private key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
registrationCache := cache.New(
|
registrationCache := zcache.New[string, types.Node](
|
||||||
registerCacheExpiration,
|
registerCacheExpiration,
|
||||||
registerCacheCleanup,
|
registerCacheCleanup,
|
||||||
)
|
)
|
||||||
|
@ -138,7 +135,9 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||||
|
|
||||||
app.db, err = db.NewHeadscaleDatabase(
|
app.db, err = db.NewHeadscaleDatabase(
|
||||||
cfg.Database,
|
cfg.Database,
|
||||||
cfg.BaseDomain)
|
cfg.BaseDomain,
|
||||||
|
registrationCache,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -154,16 +153,30 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
var authProvider AuthProvider
|
||||||
|
authProvider = NewAuthProviderWeb(cfg.ServerURL)
|
||||||
if cfg.OIDC.Issuer != "" {
|
if cfg.OIDC.Issuer != "" {
|
||||||
err = app.initOIDC()
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
oidcProvider, err := NewAuthProviderOIDC(
|
||||||
|
ctx,
|
||||||
|
cfg.ServerURL,
|
||||||
|
&cfg.OIDC,
|
||||||
|
app.db,
|
||||||
|
app.nodeNotifier,
|
||||||
|
app.ipAlloc,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if cfg.OIDC.OnlyStartIfOIDCIsAvailable {
|
if cfg.OIDC.OnlyStartIfOIDCIsAvailable {
|
||||||
return nil, err
|
return nil, err
|
||||||
} else {
|
} else {
|
||||||
log.Warn().Err(err).Msg("failed to set up OIDC provider, falling back to CLI based authentication")
|
log.Warn().Err(err).Msg("failed to set up OIDC provider, falling back to CLI based authentication")
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
authProvider = oidcProvider
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
app.authProvider = authProvider
|
||||||
|
|
||||||
if app.cfg.DNSConfig != nil && app.cfg.DNSConfig.Proxied { // if MagicDNS
|
if app.cfg.DNSConfig != nil && app.cfg.DNSConfig.Proxied { // if MagicDNS
|
||||||
// TODO(kradalby): revisit why this takes a list.
|
// TODO(kradalby): revisit why this takes a list.
|
||||||
|
@ -429,10 +442,11 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
|
||||||
|
|
||||||
router.HandleFunc("/health", h.HealthHandler).Methods(http.MethodGet)
|
router.HandleFunc("/health", h.HealthHandler).Methods(http.MethodGet)
|
||||||
router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet)
|
router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet)
|
||||||
router.HandleFunc("/register/{mkey}", h.RegisterWebAPI).Methods(http.MethodGet)
|
router.HandleFunc("/register/{mkey}", h.authProvider.RegisterHandler).Methods(http.MethodGet)
|
||||||
|
|
||||||
router.HandleFunc("/oidc/register/{mkey}", h.RegisterOIDC).Methods(http.MethodGet)
|
if provider, ok := h.authProvider.(*AuthProviderOIDC); ok {
|
||||||
router.HandleFunc("/oidc/callback", h.OIDCCallback).Methods(http.MethodGet)
|
router.HandleFunc("/oidc/callback", provider.OIDCCallbackHandler).Methods(http.MethodGet)
|
||||||
|
}
|
||||||
router.HandleFunc("/apple", h.AppleConfigMessage).Methods(http.MethodGet)
|
router.HandleFunc("/apple", h.AppleConfigMessage).Methods(http.MethodGet)
|
||||||
router.HandleFunc("/apple/{platform}", h.ApplePlatformConfig).
|
router.HandleFunc("/apple/{platform}", h.ApplePlatformConfig).
|
||||||
Methods(http.MethodGet)
|
Methods(http.MethodGet)
|
||||||
|
|
|
@ -6,7 +6,6 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol/db"
|
"github.com/juanfont/headscale/hscontrol/db"
|
||||||
|
@ -19,6 +18,11 @@ import (
|
||||||
"tailscale.com/types/ptr"
|
"tailscale.com/types/ptr"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type AuthProvider interface {
|
||||||
|
RegisterHandler(http.ResponseWriter, *http.Request)
|
||||||
|
AuthURL(key.MachinePublic) string
|
||||||
|
}
|
||||||
|
|
||||||
func logAuthFunc(
|
func logAuthFunc(
|
||||||
registerRequest tailcfg.RegisterRequest,
|
registerRequest tailcfg.RegisterRequest,
|
||||||
machineKey key.MachinePublic,
|
machineKey key.MachinePublic,
|
||||||
|
@ -125,7 +129,6 @@ func (h *Headscale) handleRegister(
|
||||||
h.registrationCache.Set(
|
h.registrationCache.Set(
|
||||||
machineKey.String(),
|
machineKey.String(),
|
||||||
newNode,
|
newNode,
|
||||||
registerCacheExpiration,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
h.handleNewNode(writer, regReq, machineKey)
|
h.handleNewNode(writer, regReq, machineKey)
|
||||||
|
@ -164,7 +167,7 @@ func (h *Headscale) handleRegister(
|
||||||
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
|
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
|
||||||
if !regReq.Expiry.IsZero() &&
|
if !regReq.Expiry.IsZero() &&
|
||||||
regReq.Expiry.UTC().Before(now) {
|
regReq.Expiry.UTC().Before(now) {
|
||||||
h.handleNodeLogOut(writer, *node, machineKey)
|
h.handleNodeLogOut(writer, *node)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -172,7 +175,7 @@ func (h *Headscale) handleRegister(
|
||||||
// If node is not expired, and it is register, we have a already accepted this node,
|
// If node is not expired, and it is register, we have a already accepted this node,
|
||||||
// let it proceed with a valid registration
|
// let it proceed with a valid registration
|
||||||
if !node.IsExpired() {
|
if !node.IsExpired() {
|
||||||
h.handleNodeWithValidRegistration(writer, *node, machineKey)
|
h.handleNodeWithValidRegistration(writer, *node)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -185,7 +188,6 @@ func (h *Headscale) handleRegister(
|
||||||
writer,
|
writer,
|
||||||
regReq,
|
regReq,
|
||||||
*node,
|
*node,
|
||||||
machineKey,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
@ -198,7 +200,6 @@ func (h *Headscale) handleRegister(
|
||||||
writer,
|
writer,
|
||||||
regReq,
|
regReq,
|
||||||
*node,
|
*node,
|
||||||
machineKey,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
@ -226,7 +227,6 @@ func (h *Headscale) handleRegister(
|
||||||
h.registrationCache.Set(
|
h.registrationCache.Set(
|
||||||
machineKey.String(),
|
machineKey.String(),
|
||||||
*node,
|
*node,
|
||||||
registerCacheExpiration,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
@ -386,7 +386,7 @@ func (h *Headscale) handleAuthKey(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
h.db.Write(func(tx *gorm.DB) error {
|
err = h.db.Write(func(tx *gorm.DB) error {
|
||||||
return db.UsePreAuthKey(tx, pak)
|
return db.UsePreAuthKey(tx, pak)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -447,17 +447,7 @@ func (h *Headscale) handleNewNode(
|
||||||
// The node registration is new, redirect the client to the registration URL
|
// The node registration is new, redirect the client to the registration URL
|
||||||
logTrace("The node seems to be new, sending auth url")
|
logTrace("The node seems to be new, sending auth url")
|
||||||
|
|
||||||
if h.oauth2Config != nil {
|
resp.AuthURL = h.authProvider.AuthURL(machineKey)
|
||||||
resp.AuthURL = fmt.Sprintf(
|
|
||||||
"%s/oidc/register/%s",
|
|
||||||
strings.TrimSuffix(h.cfg.ServerURL, "/"),
|
|
||||||
machineKey.String(),
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
resp.AuthURL = fmt.Sprintf("%s/register/%s",
|
|
||||||
strings.TrimSuffix(h.cfg.ServerURL, "/"),
|
|
||||||
machineKey.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
respBody, err := json.Marshal(resp)
|
respBody, err := json.Marshal(resp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -480,7 +470,6 @@ func (h *Headscale) handleNewNode(
|
||||||
func (h *Headscale) handleNodeLogOut(
|
func (h *Headscale) handleNodeLogOut(
|
||||||
writer http.ResponseWriter,
|
writer http.ResponseWriter,
|
||||||
node types.Node,
|
node types.Node,
|
||||||
machineKey key.MachinePublic,
|
|
||||||
) {
|
) {
|
||||||
resp := tailcfg.RegisterResponse{}
|
resp := tailcfg.RegisterResponse{}
|
||||||
|
|
||||||
|
@ -563,7 +552,6 @@ func (h *Headscale) handleNodeLogOut(
|
||||||
func (h *Headscale) handleNodeWithValidRegistration(
|
func (h *Headscale) handleNodeWithValidRegistration(
|
||||||
writer http.ResponseWriter,
|
writer http.ResponseWriter,
|
||||||
node types.Node,
|
node types.Node,
|
||||||
machineKey key.MachinePublic,
|
|
||||||
) {
|
) {
|
||||||
resp := tailcfg.RegisterResponse{}
|
resp := tailcfg.RegisterResponse{}
|
||||||
|
|
||||||
|
@ -609,7 +597,6 @@ func (h *Headscale) handleNodeKeyRefresh(
|
||||||
writer http.ResponseWriter,
|
writer http.ResponseWriter,
|
||||||
registerRequest tailcfg.RegisterRequest,
|
registerRequest tailcfg.RegisterRequest,
|
||||||
node types.Node,
|
node types.Node,
|
||||||
machineKey key.MachinePublic,
|
|
||||||
) {
|
) {
|
||||||
resp := tailcfg.RegisterResponse{}
|
resp := tailcfg.RegisterResponse{}
|
||||||
|
|
||||||
|
@ -685,15 +672,7 @@ func (h *Headscale) handleNodeExpiredOrLoggedOut(
|
||||||
Str("node_key_old", regReq.OldNodeKey.ShortString()).
|
Str("node_key_old", regReq.OldNodeKey.ShortString()).
|
||||||
Msg("Node registration has expired or logged out. Sending a auth url to register")
|
Msg("Node registration has expired or logged out. Sending a auth url to register")
|
||||||
|
|
||||||
if h.oauth2Config != nil {
|
resp.AuthURL = h.authProvider.AuthURL(machineKey)
|
||||||
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
|
|
||||||
strings.TrimSuffix(h.cfg.ServerURL, "/"),
|
|
||||||
machineKey.String())
|
|
||||||
} else {
|
|
||||||
resp.AuthURL = fmt.Sprintf("%s/register/%s",
|
|
||||||
strings.TrimSuffix(h.cfg.ServerURL, "/"),
|
|
||||||
machineKey.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
respBody, err := json.Marshal(resp)
|
respBody, err := json.Marshal(resp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -22,6 +22,7 @@ import (
|
||||||
"gorm.io/gorm/logger"
|
"gorm.io/gorm/logger"
|
||||||
"gorm.io/gorm/schema"
|
"gorm.io/gorm/schema"
|
||||||
"tailscale.com/util/set"
|
"tailscale.com/util/set"
|
||||||
|
"zgo.at/zcache/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
@ -38,8 +39,9 @@ type KV struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type HSDatabase struct {
|
type HSDatabase struct {
|
||||||
DB *gorm.DB
|
DB *gorm.DB
|
||||||
cfg *types.DatabaseConfig
|
cfg *types.DatabaseConfig
|
||||||
|
regCache *zcache.Cache[string, types.Node]
|
||||||
|
|
||||||
baseDomain string
|
baseDomain string
|
||||||
}
|
}
|
||||||
|
@ -49,6 +51,7 @@ type HSDatabase struct {
|
||||||
func NewHeadscaleDatabase(
|
func NewHeadscaleDatabase(
|
||||||
cfg types.DatabaseConfig,
|
cfg types.DatabaseConfig,
|
||||||
baseDomain string,
|
baseDomain string,
|
||||||
|
regCache *zcache.Cache[string, types.Node],
|
||||||
) (*HSDatabase, error) {
|
) (*HSDatabase, error) {
|
||||||
dbConn, err := openDB(cfg)
|
dbConn, err := openDB(cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -264,9 +267,6 @@ func NewHeadscaleDatabase(
|
||||||
|
|
||||||
for item, node := range nodes {
|
for item, node := range nodes {
|
||||||
if node.GivenName == "" {
|
if node.GivenName == "" {
|
||||||
normalizedHostname, err := util.NormalizeToFQDNRulesConfigFromViper(
|
|
||||||
node.Hostname,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -276,7 +276,7 @@ func NewHeadscaleDatabase(
|
||||||
}
|
}
|
||||||
|
|
||||||
err = tx.Model(nodes[item]).Updates(types.Node{
|
err = tx.Model(nodes[item]).Updates(types.Node{
|
||||||
GivenName: normalizedHostname,
|
GivenName: node.Hostname,
|
||||||
}).Error
|
}).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
|
@ -469,6 +469,17 @@ func NewHeadscaleDatabase(
|
||||||
|
|
||||||
// Drop the old table.
|
// Drop the old table.
|
||||||
_ = tx.Migrator().DropTable(&preAuthKeyACLTag{})
|
_ = tx.Migrator().DropTable(&preAuthKeyACLTag{})
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
Rollback: func(db *gorm.DB) error { return nil },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "202407191627",
|
||||||
|
Migrate: func(tx *gorm.DB) error {
|
||||||
|
err := tx.AutoMigrate(&types.User{})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
|
@ -482,8 +493,9 @@ func NewHeadscaleDatabase(
|
||||||
}
|
}
|
||||||
|
|
||||||
db := HSDatabase{
|
db := HSDatabase{
|
||||||
DB: dbConn,
|
DB: dbConn,
|
||||||
cfg: &cfg,
|
cfg: &cfg,
|
||||||
|
regCache: regCache,
|
||||||
|
|
||||||
baseDomain: baseDomain,
|
baseDomain: baseDomain,
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"slices"
|
"slices"
|
||||||
"sort"
|
"sort"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/google/go-cmp/cmp/cmpopts"
|
"github.com/google/go-cmp/cmp/cmpopts"
|
||||||
|
@ -16,6 +17,7 @@ import (
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
"zgo.at/zcache/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMigrations(t *testing.T) {
|
func TestMigrations(t *testing.T) {
|
||||||
|
@ -206,7 +208,7 @@ func TestMigrations(t *testing.T) {
|
||||||
Sqlite: types.SqliteConfig{
|
Sqlite: types.SqliteConfig{
|
||||||
Path: dbPath,
|
Path: dbPath,
|
||||||
},
|
},
|
||||||
}, "")
|
}, "", emptyCache())
|
||||||
if err != nil && tt.wantErr != err.Error() {
|
if err != nil && tt.wantErr != err.Error() {
|
||||||
t.Errorf("TestMigrations() unexpected error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("TestMigrations() unexpected error = %v, wantErr %v", err, tt.wantErr)
|
||||||
}
|
}
|
||||||
|
@ -250,3 +252,7 @@ func testCopyOfDatabase(src string) (string, error) {
|
||||||
_, err = io.Copy(destination, source)
|
_, err = io.Copy(destination, source)
|
||||||
return dst, err
|
return dst, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func emptyCache() *zcache.Cache[string, types.Node] {
|
||||||
|
return zcache.New[string, types.Node](time.Minute, time.Hour)
|
||||||
|
}
|
||||||
|
|
|
@ -12,7 +12,6 @@ import (
|
||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/patrickmn/go-cache"
|
|
||||||
"github.com/puzpuzpuz/xsync/v3"
|
"github.com/puzpuzpuz/xsync/v3"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
@ -320,26 +319,17 @@ func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error {
|
||||||
return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("last_seen", lastSeen).Error
|
return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("last_seen", lastSeen).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func RegisterNodeFromAuthCallback(
|
func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
|
||||||
tx *gorm.DB,
|
|
||||||
cache *cache.Cache,
|
|
||||||
mkey key.MachinePublic,
|
mkey key.MachinePublic,
|
||||||
userName string,
|
userID types.UserID,
|
||||||
nodeExpiry *time.Time,
|
nodeExpiry *time.Time,
|
||||||
registrationMethod string,
|
registrationMethod string,
|
||||||
ipv4 *netip.Addr,
|
ipv4 *netip.Addr,
|
||||||
ipv6 *netip.Addr,
|
ipv6 *netip.Addr,
|
||||||
) (*types.Node, error) {
|
) (*types.Node, error) {
|
||||||
log.Debug().
|
return Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||||
Str("machine_key", mkey.ShortString()).
|
if node, ok := hsdb.regCache.Get(mkey.String()); ok {
|
||||||
Str("userName", userName).
|
user, err := GetUserByID(tx, userID)
|
||||||
Str("registrationMethod", registrationMethod).
|
|
||||||
Str("expiresAt", fmt.Sprintf("%v", nodeExpiry)).
|
|
||||||
Msg("Registering node from API/CLI or auth callback")
|
|
||||||
|
|
||||||
if nodeInterface, ok := cache.Get(mkey.String()); ok {
|
|
||||||
if registrationNode, ok := nodeInterface.(types.Node); ok {
|
|
||||||
user, err := GetUser(tx, userName)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf(
|
return nil, fmt.Errorf(
|
||||||
"failed to find user in register node from auth callback, %w",
|
"failed to find user in register node from auth callback, %w",
|
||||||
|
@ -347,37 +337,42 @@ func RegisterNodeFromAuthCallback(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Debug().
|
||||||
|
Str("machine_key", mkey.ShortString()).
|
||||||
|
Str("username", user.Username()).
|
||||||
|
Str("registrationMethod", registrationMethod).
|
||||||
|
Str("expiresAt", fmt.Sprintf("%v", nodeExpiry)).
|
||||||
|
Msg("Registering node from API/CLI or auth callback")
|
||||||
|
|
||||||
// Registration of expired node with different user
|
// Registration of expired node with different user
|
||||||
if registrationNode.ID != 0 &&
|
if node.ID != 0 &&
|
||||||
registrationNode.UserID != user.ID {
|
node.UserID != user.ID {
|
||||||
return nil, ErrDifferentRegisteredUser
|
return nil, ErrDifferentRegisteredUser
|
||||||
}
|
}
|
||||||
|
|
||||||
registrationNode.UserID = user.ID
|
node.UserID = user.ID
|
||||||
registrationNode.User = *user
|
node.User = *user
|
||||||
registrationNode.RegisterMethod = registrationMethod
|
node.RegisterMethod = registrationMethod
|
||||||
|
|
||||||
if nodeExpiry != nil {
|
if nodeExpiry != nil {
|
||||||
registrationNode.Expiry = nodeExpiry
|
node.Expiry = nodeExpiry
|
||||||
}
|
}
|
||||||
|
|
||||||
node, err := RegisterNode(
|
node, err := RegisterNode(
|
||||||
tx,
|
tx,
|
||||||
registrationNode,
|
node,
|
||||||
ipv4, ipv6,
|
ipv4, ipv6,
|
||||||
)
|
)
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
cache.Delete(mkey.String())
|
hsdb.regCache.Delete(mkey.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
return node, err
|
return node, err
|
||||||
} else {
|
|
||||||
return nil, ErrCouldNotConvertNodeInterface
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
return nil, ErrNodeNotFoundRegistrationCache
|
return nil, ErrNodeNotFoundRegistrationCache
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) RegisterNode(node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) {
|
func (hsdb *HSDatabase) RegisterNode(node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) {
|
||||||
|
@ -392,7 +387,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
|
||||||
Str("node", node.Hostname).
|
Str("node", node.Hostname).
|
||||||
Str("machine_key", node.MachineKey.ShortString()).
|
Str("machine_key", node.MachineKey.ShortString()).
|
||||||
Str("node_key", node.NodeKey.ShortString()).
|
Str("node_key", node.NodeKey.ShortString()).
|
||||||
Str("user", node.User.Name).
|
Str("user", node.User.Username()).
|
||||||
Msg("Registering node")
|
Msg("Registering node")
|
||||||
|
|
||||||
// If the node exists and it already has IP(s), we just save it
|
// If the node exists and it already has IP(s), we just save it
|
||||||
|
@ -408,7 +403,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
|
||||||
Str("node", node.Hostname).
|
Str("node", node.Hostname).
|
||||||
Str("machine_key", node.MachineKey.ShortString()).
|
Str("machine_key", node.MachineKey.ShortString()).
|
||||||
Str("node_key", node.NodeKey.ShortString()).
|
Str("node_key", node.NodeKey.ShortString()).
|
||||||
Str("user", node.User.Name).
|
Str("user", node.User.Username()).
|
||||||
Msg("Node authorized again")
|
Msg("Node authorized again")
|
||||||
|
|
||||||
return &node, nil
|
return &node, nil
|
||||||
|
@ -612,18 +607,15 @@ func enableRoutes(tx *gorm.DB,
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
|
func generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
|
||||||
normalizedHostname, err := util.NormalizeToFQDNRulesConfigFromViper(
|
if len(suppliedName) > util.LabelHostnameLength {
|
||||||
suppliedName,
|
return "", types.ErrHostnameTooLong
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if randomSuffix {
|
if randomSuffix {
|
||||||
// Trim if a hostname will be longer than 63 chars after adding the hash.
|
// Trim if a hostname will be longer than 63 chars after adding the hash.
|
||||||
trimmedHostnameLength := util.LabelHostnameLength - NodeGivenNameHashLength - NodeGivenNameTrimSize
|
trimmedHostnameLength := util.LabelHostnameLength - NodeGivenNameHashLength - NodeGivenNameTrimSize
|
||||||
if len(normalizedHostname) > trimmedHostnameLength {
|
if len(suppliedName) > trimmedHostnameLength {
|
||||||
normalizedHostname = normalizedHostname[:trimmedHostnameLength]
|
suppliedName = suppliedName[:trimmedHostnameLength]
|
||||||
}
|
}
|
||||||
|
|
||||||
suffix, err := util.GenerateRandomStringDNSSafe(NodeGivenNameHashLength)
|
suffix, err := util.GenerateRandomStringDNSSafe(NodeGivenNameHashLength)
|
||||||
|
@ -631,10 +623,10 @@ func generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
normalizedHostname += "-" + suffix
|
suppliedName += "-" + suffix
|
||||||
}
|
}
|
||||||
|
|
||||||
return normalizedHostname, nil
|
return suppliedName, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func isUnqiueName(tx *gorm.DB, name string) (bool, error) {
|
func isUnqiueName(tx *gorm.DB, name string) (bool, error) {
|
||||||
|
|
|
@ -23,6 +23,7 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (hsdb *HSDatabase) CreatePreAuthKey(
|
func (hsdb *HSDatabase) CreatePreAuthKey(
|
||||||
|
// TODO(kradalby): Should be ID, not name
|
||||||
userName string,
|
userName string,
|
||||||
reusable bool,
|
reusable bool,
|
||||||
ephemeral bool,
|
ephemeral bool,
|
||||||
|
@ -37,13 +38,14 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
|
||||||
// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it.
|
// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it.
|
||||||
func CreatePreAuthKey(
|
func CreatePreAuthKey(
|
||||||
tx *gorm.DB,
|
tx *gorm.DB,
|
||||||
|
// TODO(kradalby): Should be ID, not name
|
||||||
userName string,
|
userName string,
|
||||||
reusable bool,
|
reusable bool,
|
||||||
ephemeral bool,
|
ephemeral bool,
|
||||||
expiration *time.Time,
|
expiration *time.Time,
|
||||||
aclTags []string,
|
aclTags []string,
|
||||||
) (*types.PreAuthKey, error) {
|
) (*types.PreAuthKey, error) {
|
||||||
user, err := GetUser(tx, userName)
|
user, err := GetUserByUsername(tx, userName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -95,7 +97,7 @@ func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, er
|
||||||
|
|
||||||
// ListPreAuthKeys returns the list of PreAuthKeys for a user.
|
// ListPreAuthKeys returns the list of PreAuthKeys for a user.
|
||||||
func ListPreAuthKeys(tx *gorm.DB, userName string) ([]types.PreAuthKey, error) {
|
func ListPreAuthKeys(tx *gorm.DB, userName string) ([]types.PreAuthKey, error) {
|
||||||
user, err := GetUser(tx, userName)
|
user, err := GetUserByUsername(tx, userName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -645,7 +645,7 @@ func EnableAutoApprovedRoutes(
|
||||||
Msg("looking up route for autoapproving")
|
Msg("looking up route for autoapproving")
|
||||||
|
|
||||||
for _, approvedAlias := range routeApprovers {
|
for _, approvedAlias := range routeApprovers {
|
||||||
if approvedAlias == node.User.Name {
|
if approvedAlias == node.User.Username() {
|
||||||
approvedRoutes = append(approvedRoutes, advertisedRoute)
|
approvedRoutes = append(approvedRoutes, advertisedRoute)
|
||||||
} else {
|
} else {
|
||||||
// TODO(kradalby): figure out how to get this to depend on less stuff
|
// TODO(kradalby): figure out how to get this to depend on less stuff
|
||||||
|
|
|
@ -336,6 +336,7 @@ func dbForTest(t *testing.T, testName string) *HSDatabase {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"",
|
"",
|
||||||
|
emptyCache(),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("setting up database: %s", err)
|
t.Fatalf("setting up database: %s", err)
|
||||||
|
|
|
@ -59,6 +59,7 @@ func newTestDB() (*HSDatabase, error) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"",
|
"",
|
||||||
|
emptyCache(),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -49,7 +49,7 @@ func (hsdb *HSDatabase) DestroyUser(name string) error {
|
||||||
// DestroyUser destroys a User. Returns error if the User does
|
// DestroyUser destroys a User. Returns error if the User does
|
||||||
// not exist or if there are nodes associated with it.
|
// not exist or if there are nodes associated with it.
|
||||||
func DestroyUser(tx *gorm.DB, name string) error {
|
func DestroyUser(tx *gorm.DB, name string) error {
|
||||||
user, err := GetUser(tx, name)
|
user, err := GetUserByUsername(tx, name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ErrUserNotFound
|
return ErrUserNotFound
|
||||||
}
|
}
|
||||||
|
@ -90,7 +90,7 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
|
||||||
// not exist or if another User exists with the new name.
|
// not exist or if another User exists with the new name.
|
||||||
func RenameUser(tx *gorm.DB, oldName, newName string) error {
|
func RenameUser(tx *gorm.DB, oldName, newName string) error {
|
||||||
var err error
|
var err error
|
||||||
oldUser, err := GetUser(tx, oldName)
|
oldUser, err := GetUserByUsername(tx, oldName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -98,7 +98,7 @@ func RenameUser(tx *gorm.DB, oldName, newName string) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = GetUser(tx, newName)
|
_, err = GetUserByUsername(tx, newName)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return ErrUserExists
|
return ErrUserExists
|
||||||
}
|
}
|
||||||
|
@ -115,13 +115,13 @@ func RenameUser(tx *gorm.DB, oldName, newName string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) GetUser(name string) (*types.User, error) {
|
func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) {
|
||||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) {
|
return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) {
|
||||||
return GetUser(rx, name)
|
return GetUserByUsername(rx, name)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUser(tx *gorm.DB, name string) (*types.User, error) {
|
func GetUserByUsername(tx *gorm.DB, name string) (*types.User, error) {
|
||||||
user := types.User{}
|
user := types.User{}
|
||||||
if result := tx.First(&user, "name = ?", name); errors.Is(
|
if result := tx.First(&user, "name = ?", name); errors.Is(
|
||||||
result.Error,
|
result.Error,
|
||||||
|
@ -133,6 +133,42 @@ func GetUser(tx *gorm.DB, name string) (*types.User, error) {
|
||||||
return &user, nil
|
return &user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (hsdb *HSDatabase) GetUserByID(id types.UserID) (*types.User, error) {
|
||||||
|
return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) {
|
||||||
|
return GetUserByID(rx, id)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetUserByID(tx *gorm.DB, id types.UserID) (*types.User, error) {
|
||||||
|
user := types.User{}
|
||||||
|
if result := tx.First(&user, "id = ?", id); errors.Is(
|
||||||
|
result.Error,
|
||||||
|
gorm.ErrRecordNotFound,
|
||||||
|
) {
|
||||||
|
return nil, ErrUserNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return &user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hsdb *HSDatabase) GetUserByOIDCIdentifier(id string) (*types.User, error) {
|
||||||
|
return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) {
|
||||||
|
return GetUserByOIDCIdentifier(rx, id)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetUserByOIDCIdentifier(tx *gorm.DB, id string) (*types.User, error) {
|
||||||
|
user := types.User{}
|
||||||
|
if result := tx.First(&user, "provider_identifier = ?", id); errors.Is(
|
||||||
|
result.Error,
|
||||||
|
gorm.ErrRecordNotFound,
|
||||||
|
) {
|
||||||
|
return nil, ErrUserNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return &user, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) ListUsers() ([]types.User, error) {
|
func (hsdb *HSDatabase) ListUsers() ([]types.User, error) {
|
||||||
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.User, error) {
|
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.User, error) {
|
||||||
return ListUsers(rx)
|
return ListUsers(rx)
|
||||||
|
@ -155,7 +191,7 @@ func ListNodesByUser(tx *gorm.DB, name string) (types.Nodes, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
user, err := GetUser(tx, name)
|
user, err := GetUserByUsername(tx, name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -180,7 +216,7 @@ func AssignNodeToUser(tx *gorm.DB, node *types.Node, username string) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
user, err := GetUser(tx, username)
|
user, err := GetUserByUsername(tx, username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,7 +20,7 @@ func (s *Suite) TestCreateAndDestroyUser(c *check.C) {
|
||||||
err = db.DestroyUser("test")
|
err = db.DestroyUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.GetUser("test")
|
_, err = db.GetUserByName("test")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -73,10 +73,10 @@ func (s *Suite) TestRenameUser(c *check.C) {
|
||||||
err = db.RenameUser("test", "test-renamed")
|
err = db.RenameUser("test", "test-renamed")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.GetUser("test")
|
_, err = db.GetUserByName("test")
|
||||||
c.Assert(err, check.Equals, ErrUserNotFound)
|
c.Assert(err, check.Equals, ErrUserNotFound)
|
||||||
|
|
||||||
_, err = db.GetUser("test-renamed")
|
_, err = db.GetUserByName("test-renamed")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = db.RenameUser("test-does-not-exit", "test")
|
err = db.RenameUser("test-does-not-exit", "test")
|
||||||
|
|
|
@ -41,7 +41,7 @@ func (api headscaleV1APIServer) GetUser(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.GetUserRequest,
|
request *v1.GetUserRequest,
|
||||||
) (*v1.GetUserResponse, error) {
|
) (*v1.GetUserResponse, error) {
|
||||||
user, err := api.h.db.GetUser(request.GetName())
|
user, err := api.h.db.GetUserByName(request.GetName())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -70,7 +70,7 @@ func (api headscaleV1APIServer) RenameUser(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := api.h.db.GetUser(request.GetNewName())
|
user, err := api.h.db.GetUserByName(request.GetNewName())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -205,17 +205,18 @@ func (api headscaleV1APIServer) RegisterNode(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
user, err := api.h.db.GetUserByName(request.GetUser())
|
||||||
return db.RegisterNodeFromAuthCallback(
|
if err != nil {
|
||||||
tx,
|
return nil, fmt.Errorf("looking up user: %w", err)
|
||||||
api.h.registrationCache,
|
}
|
||||||
mkey,
|
|
||||||
request.GetUser(),
|
node, err := api.h.db.RegisterNodeFromAuthCallback(
|
||||||
nil,
|
mkey,
|
||||||
util.RegisterMethodCLI,
|
types.UserID(user.ID),
|
||||||
ipv4, ipv6,
|
nil,
|
||||||
)
|
util.RegisterMethodCLI,
|
||||||
})
|
ipv4, ipv6,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -774,7 +775,7 @@ func (api headscaleV1APIServer) DebugCreateNode(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.DebugCreateNodeRequest,
|
request *v1.DebugCreateNodeRequest,
|
||||||
) (*v1.DebugCreateNodeResponse, error) {
|
) (*v1.DebugCreateNodeResponse, error) {
|
||||||
user, err := api.h.db.GetUser(request.GetUser())
|
user, err := api.h.db.GetUserByName(request.GetUser())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -823,7 +824,6 @@ func (api headscaleV1APIServer) DebugCreateNode(
|
||||||
api.h.registrationCache.Set(
|
api.h.registrationCache.Set(
|
||||||
mkey.String(),
|
mkey.String(),
|
||||||
newNode,
|
newNode,
|
||||||
registerCacheExpiration,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return &v1.DebugCreateNodeResponse{Node: newNode.Proto()}, nil
|
return &v1.DebugCreateNodeResponse{Node: newNode.Proto()}, nil
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"html/template"
|
"html/template"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
@ -167,12 +168,29 @@ var registerWebAPITemplate = template.Must(
|
||||||
</html>
|
</html>
|
||||||
`))
|
`))
|
||||||
|
|
||||||
|
type AuthProviderWeb struct {
|
||||||
|
serverURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAuthProviderWeb(serverURL string) *AuthProviderWeb {
|
||||||
|
return &AuthProviderWeb{
|
||||||
|
serverURL: serverURL,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthProviderWeb) AuthURL(mKey key.MachinePublic) string {
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"%s/register/%s",
|
||||||
|
strings.TrimSuffix(a.serverURL, "/"),
|
||||||
|
mKey.String())
|
||||||
|
}
|
||||||
|
|
||||||
// RegisterWebAPI shows a simple message in the browser to point to the CLI
|
// RegisterWebAPI shows a simple message in the browser to point to the CLI
|
||||||
// Listens in /register/:nkey.
|
// Listens in /register/:nkey.
|
||||||
//
|
//
|
||||||
// This is not part of the Tailscale control API, as we could send whatever URL
|
// This is not part of the Tailscale control API, as we could send whatever URL
|
||||||
// in the RegisterResponse.AuthURL field.
|
// in the RegisterResponse.AuthURL field.
|
||||||
func (h *Headscale) RegisterWebAPI(
|
func (a *AuthProviderWeb) RegisterHandler(
|
||||||
writer http.ResponseWriter,
|
writer http.ResponseWriter,
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
) {
|
) {
|
||||||
|
@ -187,7 +205,7 @@ func (h *Headscale) RegisterWebAPI(
|
||||||
[]byte(machineKeyStr),
|
[]byte(machineKeyStr),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn().Err(err).Msg("Failed to parse incoming nodekey")
|
log.Warn().Err(err).Msg("Failed to parse incoming machinekey")
|
||||||
|
|
||||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
writer.WriteHeader(http.StatusBadRequest)
|
writer.WriteHeader(http.StatusBadRequest)
|
||||||
|
|
|
@ -15,7 +15,6 @@ import (
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
mapset "github.com/deckarep/golang-set/v2"
|
|
||||||
"github.com/juanfont/headscale/hscontrol/db"
|
"github.com/juanfont/headscale/hscontrol/db"
|
||||||
"github.com/juanfont/headscale/hscontrol/notifier"
|
"github.com/juanfont/headscale/hscontrol/notifier"
|
||||||
"github.com/juanfont/headscale/hscontrol/policy"
|
"github.com/juanfont/headscale/hscontrol/policy"
|
||||||
|
@ -95,10 +94,10 @@ func generateUserProfiles(
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
peers types.Nodes,
|
peers types.Nodes,
|
||||||
) []tailcfg.UserProfile {
|
) []tailcfg.UserProfile {
|
||||||
userMap := make(map[string]types.User)
|
userMap := make(map[uint]types.User)
|
||||||
userMap[node.User.Name] = node.User
|
userMap[node.User.ID] = node.User
|
||||||
for _, peer := range peers {
|
for _, peer := range peers {
|
||||||
userMap[peer.User.Name] = peer.User // not worth checking if already is there
|
userMap[peer.User.ID] = peer.User // not worth checking if already is there
|
||||||
}
|
}
|
||||||
|
|
||||||
var profiles []tailcfg.UserProfile
|
var profiles []tailcfg.UserProfile
|
||||||
|
@ -122,32 +121,6 @@ func generateDNSConfig(
|
||||||
|
|
||||||
dnsConfig := cfg.DNSConfig.Clone()
|
dnsConfig := cfg.DNSConfig.Clone()
|
||||||
|
|
||||||
// if MagicDNS is enabled
|
|
||||||
if dnsConfig.Proxied {
|
|
||||||
if cfg.DNSUserNameInMagicDNS {
|
|
||||||
// Only inject the Search Domain of the current user
|
|
||||||
// shared nodes should use their full FQDN
|
|
||||||
dnsConfig.Domains = append(
|
|
||||||
dnsConfig.Domains,
|
|
||||||
fmt.Sprintf(
|
|
||||||
"%s.%s",
|
|
||||||
node.User.Name,
|
|
||||||
baseDomain,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
userSet := mapset.NewSet[types.User]()
|
|
||||||
userSet.Add(node.User)
|
|
||||||
for _, p := range peers {
|
|
||||||
userSet.Add(p.User)
|
|
||||||
}
|
|
||||||
for _, user := range userSet.ToSlice() {
|
|
||||||
dnsRoute := fmt.Sprintf("%v.%v", user.Name, baseDomain)
|
|
||||||
dnsConfig.Routes[dnsRoute] = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
addNextDNSMetadata(dnsConfig.Resolvers, node)
|
addNextDNSMetadata(dnsConfig.Resolvers, node)
|
||||||
|
|
||||||
return dnsConfig
|
return dnsConfig
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
"github.com/juanfont/headscale/hscontrol/policy"
|
"github.com/juanfont/headscale/hscontrol/policy"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"gopkg.in/check.v1"
|
"gopkg.in/check.v1"
|
||||||
|
"gorm.io/gorm"
|
||||||
"tailscale.com/net/tsaddr"
|
"tailscale.com/net/tsaddr"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/dnstype"
|
"tailscale.com/types/dnstype"
|
||||||
|
@ -29,6 +30,9 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
|
||||||
Hostname: hostname,
|
Hostname: hostname,
|
||||||
UserID: userid,
|
UserID: userid,
|
||||||
User: types.User{
|
User: types.User{
|
||||||
|
Model: gorm.Model{
|
||||||
|
ID: userid,
|
||||||
|
},
|
||||||
Name: username,
|
Name: username,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -73,14 +77,9 @@ func TestDNSConfigMapResponse(t *testing.T) {
|
||||||
{
|
{
|
||||||
magicDNS: true,
|
magicDNS: true,
|
||||||
want: &tailcfg.DNSConfig{
|
want: &tailcfg.DNSConfig{
|
||||||
Routes: map[string][]*dnstype.Resolver{
|
Routes: map[string][]*dnstype.Resolver{},
|
||||||
"shared1.foobar.headscale.net": {},
|
|
||||||
"shared2.foobar.headscale.net": {},
|
|
||||||
"shared3.foobar.headscale.net": {},
|
|
||||||
},
|
|
||||||
Domains: []string{
|
Domains: []string{
|
||||||
"foobar.headscale.net",
|
"foobar.headscale.net",
|
||||||
"shared1.foobar.headscale.net",
|
|
||||||
},
|
},
|
||||||
Proxied: true,
|
Proxied: true,
|
||||||
},
|
},
|
||||||
|
@ -128,8 +127,7 @@ func TestDNSConfigMapResponse(t *testing.T) {
|
||||||
|
|
||||||
got := generateDNSConfig(
|
got := generateDNSConfig(
|
||||||
&types.Config{
|
&types.Config{
|
||||||
DNSConfig: &dnsConfigOrig,
|
DNSConfig: &dnsConfigOrig,
|
||||||
DNSUserNameInMagicDNS: true,
|
|
||||||
},
|
},
|
||||||
baseDomain,
|
baseDomain,
|
||||||
nodeInShared1,
|
nodeInShared1,
|
||||||
|
|
|
@ -76,7 +76,7 @@ func tailNode(
|
||||||
keyExpiry = time.Time{}
|
keyExpiry = time.Time{}
|
||||||
}
|
}
|
||||||
|
|
||||||
hostname, err := node.GetFQDN(cfg, cfg.BaseDomain)
|
hostname, err := node.GetFQDN(cfg.BaseDomain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("tailNode, failed to create FQDN: %s", err)
|
return nil, fmt.Errorf("tailNode, failed to create FQDN: %s", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,12 +17,13 @@ import (
|
||||||
"github.com/coreos/go-oidc/v3/oidc"
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/juanfont/headscale/hscontrol/db"
|
"github.com/juanfont/headscale/hscontrol/db"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/notifier"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"gorm.io/gorm"
|
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
|
"zgo.at/zcache/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -45,49 +46,81 @@ var (
|
||||||
errOIDCNodeKeyMissing = errors.New("could not get node key from cache")
|
errOIDCNodeKeyMissing = errors.New("could not get node key from cache")
|
||||||
)
|
)
|
||||||
|
|
||||||
type IDTokenClaims struct {
|
type AuthProviderOIDC struct {
|
||||||
Name string `json:"name,omitempty"`
|
serverURL string
|
||||||
Groups []string `json:"groups,omitempty"`
|
cfg *types.OIDCConfig
|
||||||
Email string `json:"email"`
|
db *db.HSDatabase
|
||||||
Username string `json:"preferred_username,omitempty"`
|
registrationCache *zcache.Cache[string, key.MachinePublic]
|
||||||
|
notifier *notifier.Notifier
|
||||||
|
ipAlloc *db.IPAllocator
|
||||||
|
|
||||||
|
oidcProvider *oidc.Provider
|
||||||
|
oauth2Config *oauth2.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) initOIDC() error {
|
func NewAuthProviderOIDC(
|
||||||
|
ctx context.Context,
|
||||||
|
serverURL string,
|
||||||
|
cfg *types.OIDCConfig,
|
||||||
|
db *db.HSDatabase,
|
||||||
|
notif *notifier.Notifier,
|
||||||
|
ipAlloc *db.IPAllocator,
|
||||||
|
) (*AuthProviderOIDC, error) {
|
||||||
var err error
|
var err error
|
||||||
// grab oidc config if it hasn't been already
|
// grab oidc config if it hasn't been already
|
||||||
if h.oauth2Config == nil {
|
oidcProvider, err := oidc.NewProvider(context.Background(), cfg.Issuer)
|
||||||
h.oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDC.Issuer)
|
if err != nil {
|
||||||
if err != nil {
|
return nil, fmt.Errorf("creating OIDC provider from issuer config: %w", err)
|
||||||
return fmt.Errorf("creating OIDC provider from issuer config: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
h.oauth2Config = &oauth2.Config{
|
|
||||||
ClientID: h.cfg.OIDC.ClientID,
|
|
||||||
ClientSecret: h.cfg.OIDC.ClientSecret,
|
|
||||||
Endpoint: h.oidcProvider.Endpoint(),
|
|
||||||
RedirectURL: fmt.Sprintf(
|
|
||||||
"%s/oidc/callback",
|
|
||||||
strings.TrimSuffix(h.cfg.ServerURL, "/"),
|
|
||||||
),
|
|
||||||
Scopes: h.cfg.OIDC.Scope,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
oauth2Config := &oauth2.Config{
|
||||||
|
ClientID: cfg.ClientID,
|
||||||
|
ClientSecret: cfg.ClientSecret,
|
||||||
|
Endpoint: oidcProvider.Endpoint(),
|
||||||
|
RedirectURL: fmt.Sprintf(
|
||||||
|
"%s/oidc/callback",
|
||||||
|
strings.TrimSuffix(serverURL, "/"),
|
||||||
|
),
|
||||||
|
Scopes: cfg.Scope,
|
||||||
|
}
|
||||||
|
|
||||||
|
registrationCache := zcache.New[string, key.MachinePublic](
|
||||||
|
registerCacheExpiration,
|
||||||
|
registerCacheCleanup,
|
||||||
|
)
|
||||||
|
|
||||||
|
return &AuthProviderOIDC{
|
||||||
|
serverURL: serverURL,
|
||||||
|
cfg: cfg,
|
||||||
|
db: db,
|
||||||
|
registrationCache: registrationCache,
|
||||||
|
notifier: notif,
|
||||||
|
ipAlloc: ipAlloc,
|
||||||
|
|
||||||
|
oidcProvider: oidcProvider,
|
||||||
|
oauth2Config: oauth2Config,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) determineTokenExpiration(idTokenExpiration time.Time) time.Time {
|
func (a *AuthProviderOIDC) AuthURL(mKey key.MachinePublic) string {
|
||||||
if h.cfg.OIDC.UseExpiryFromToken {
|
return fmt.Sprintf(
|
||||||
|
"%s/register/%s",
|
||||||
|
strings.TrimSuffix(a.serverURL, "/"),
|
||||||
|
mKey.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthProviderOIDC) determineNodeExpiry(idTokenExpiration time.Time) time.Time {
|
||||||
|
if a.cfg.UseExpiryFromToken {
|
||||||
return idTokenExpiration
|
return idTokenExpiration
|
||||||
}
|
}
|
||||||
|
|
||||||
return time.Now().Add(h.cfg.OIDC.Expiry)
|
return time.Now().Add(a.cfg.Expiry)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterOIDC redirects to the OIDC provider for authentication
|
// RegisterOIDC redirects to the OIDC provider for authentication
|
||||||
// Puts NodeKey in cache so the callback can retrieve it using the oidc state param
|
// Puts NodeKey in cache so the callback can retrieve it using the oidc state param
|
||||||
// Listens in /oidc/register/:mKey.
|
// Listens in /register/:mKey.
|
||||||
func (h *Headscale) RegisterOIDC(
|
func (a *AuthProviderOIDC) RegisterHandler(
|
||||||
writer http.ResponseWriter,
|
writer http.ResponseWriter,
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
) {
|
) {
|
||||||
|
@ -108,46 +141,32 @@ func (h *Headscale) RegisterOIDC(
|
||||||
[]byte(machineKeyStr),
|
[]byte(machineKeyStr),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn().
|
http.Error(writer, err.Error(), http.StatusBadRequest)
|
||||||
Err(err).
|
|
||||||
Msg("Failed to parse incoming nodekey in OIDC registration")
|
|
||||||
|
|
||||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
|
||||||
writer.WriteHeader(http.StatusBadRequest)
|
|
||||||
_, err := writer.Write([]byte("Wrong params"))
|
|
||||||
if err != nil {
|
|
||||||
util.LogErr(err, "Failed to write response")
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
randomBlob := make([]byte, randomByteSize)
|
randomBlob := make([]byte, randomByteSize)
|
||||||
if _, err := rand.Read(randomBlob); err != nil {
|
if _, err := rand.Read(randomBlob); err != nil {
|
||||||
util.LogErr(err, "could not read 16 bytes from rand")
|
|
||||||
|
|
||||||
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
stateStr := hex.EncodeToString(randomBlob)[:32]
|
stateStr := hex.EncodeToString(randomBlob)[:32]
|
||||||
|
|
||||||
// place the node key into the state cache, so it can be retrieved later
|
// place the node key into the state cache, so it can be retrieved later
|
||||||
h.registrationCache.Set(
|
a.registrationCache.Set(
|
||||||
stateStr,
|
stateStr,
|
||||||
machineKey,
|
machineKey,
|
||||||
registerCacheExpiration,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Add any extra parameter provided in the configuration to the Authorize Endpoint request
|
// Add any extra parameter provided in the configuration to the Authorize Endpoint request
|
||||||
extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams))
|
extras := make([]oauth2.AuthCodeOption, 0, len(a.cfg.ExtraParams))
|
||||||
|
|
||||||
for k, v := range h.cfg.OIDC.ExtraParams {
|
for k, v := range a.cfg.ExtraParams {
|
||||||
extras = append(extras, oauth2.SetAuthURLParam(k, v))
|
extras = append(extras, oauth2.SetAuthURLParam(k, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
authURL := h.oauth2Config.AuthCodeURL(stateStr, extras...)
|
authURL := a.oauth2Config.AuthCodeURL(stateStr, extras...)
|
||||||
log.Debug().Msgf("Redirecting to %s for authentication", authURL)
|
log.Debug().Msgf("Redirecting to %s for authentication", authURL)
|
||||||
|
|
||||||
http.Redirect(writer, req, authURL, http.StatusFound)
|
http.Redirect(writer, req, authURL, http.StatusFound)
|
||||||
|
@ -165,216 +184,165 @@ var oidcCallbackTemplate = template.Must(
|
||||||
template.New("oidccallback").Parse(oidcCallbackTemplateContent),
|
template.New("oidccallback").Parse(oidcCallbackTemplateContent),
|
||||||
)
|
)
|
||||||
|
|
||||||
// OIDCCallback handles the callback from the OIDC endpoint
|
// OIDCCallbackHandler handles the callback from the OIDC endpoint
|
||||||
// Retrieves the nkey from the state cache and adds the node to the users email user
|
// Retrieves the nkey from the state cache and adds the node to the users email user
|
||||||
// TODO: A confirmation page for new nodes should be added to avoid phishing vulnerabilities
|
// TODO: A confirmation page for new nodes should be added to avoid phishing vulnerabilities
|
||||||
// TODO: Add groups information from OIDC tokens into node HostInfo
|
// TODO: Add groups information from OIDC tokens into node HostInfo
|
||||||
// Listens in /oidc/callback.
|
// Listens in /oidc/callback.
|
||||||
func (h *Headscale) OIDCCallback(
|
func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||||
writer http.ResponseWriter,
|
writer http.ResponseWriter,
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
) {
|
) {
|
||||||
code, state, err := validateOIDCCallbackParams(writer, req)
|
code, state, err := extractCodeAndStateParamFromRequest(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
http.Error(writer, err.Error(), http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
rawIDToken, err := h.getIDTokenForOIDCCallback(req.Context(), writer, code, state)
|
idToken, err := a.extractIDToken(req.Context(), code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
http.Error(writer, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
nodeExpiry := a.determineNodeExpiry(idToken.Expiry)
|
||||||
|
|
||||||
|
var claims types.OIDCClaims
|
||||||
|
if err := idToken.Claims(&claims); err != nil {
|
||||||
|
http.Error(writer, fmt.Errorf("failed to decode ID token claims: %w", err).Error(), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
idToken, err := h.verifyIDTokenForOIDCCallback(req.Context(), writer, rawIDToken)
|
if err := validateOIDCAllowedDomains(a.cfg.AllowedDomains, &claims); err != nil {
|
||||||
if err != nil {
|
http.Error(writer, err.Error(), http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
idTokenExpiry := h.determineTokenExpiration(idToken.Expiry)
|
|
||||||
|
if err := validateOIDCAllowedGroups(a.cfg.AllowedGroups, &claims); err != nil {
|
||||||
// TODO: we can use userinfo at some point to grab additional information about the user (groups membership, etc)
|
http.Error(writer, err.Error(), http.StatusUnauthorized)
|
||||||
// userInfo, err := oidcProvider.UserInfo(context.Background(), oauth2.StaticTokenSource(oauth2Token))
|
return
|
||||||
// if err != nil {
|
}
|
||||||
// c.String(http.StatusBadRequest, fmt.Sprintf("Failed to retrieve userinfo"))
|
|
||||||
// return
|
if err := validateOIDCAllowedUsers(a.cfg.AllowedUsers, &claims); err != nil {
|
||||||
// }
|
http.Error(writer, err.Error(), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
claims, err := extractIDTokenClaims(writer, idToken)
|
}
|
||||||
|
|
||||||
|
user, err := a.createOrUpdateUserFromClaim(&claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
http.Error(writer, err.Error(), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := validateOIDCAllowedDomains(writer, h.cfg.OIDC.AllowedDomains, claims); err != nil {
|
// Retrieve the node and the machine key from the state cache and
|
||||||
|
// database.
|
||||||
|
// If the node exists, then the node should be reauthenticated,
|
||||||
|
// if the node does not exist, and the machine key exists, then
|
||||||
|
// this is a new node that should be registered.
|
||||||
|
node, mKey := a.getMachineKeyFromState(state)
|
||||||
|
|
||||||
|
// Reauthenticate the node if it does exists.
|
||||||
|
if node != nil {
|
||||||
|
err := a.reauthenticateNode(node, nodeExpiry)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(writer, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(kradalby): replace with go-elem
|
||||||
|
var content bytes.Buffer
|
||||||
|
if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
|
||||||
|
User: user.DisplayNameOrUsername(),
|
||||||
|
Verb: "Reauthenticated",
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(writer, fmt.Errorf("rendering OIDC callback template: %w", err).Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||||
|
writer.WriteHeader(http.StatusOK)
|
||||||
|
_, err = writer.Write(content.Bytes())
|
||||||
|
if err != nil {
|
||||||
|
util.LogErr(err, "Failed to write response")
|
||||||
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := validateOIDCAllowedGroups(writer, h.cfg.OIDC.AllowedGroups, claims); err != nil {
|
// Register the node if it does not exist.
|
||||||
|
if mKey != nil {
|
||||||
|
if err := a.registerNode(user, mKey, nodeExpiry); err != nil {
|
||||||
|
http.Error(writer, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
content, err := renderOIDCCallbackTemplate(user)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(writer, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||||
|
writer.WriteHeader(http.StatusOK)
|
||||||
|
if _, err := writer.Write(content.Bytes()); err != nil {
|
||||||
|
util.LogErr(err, "Failed to write response")
|
||||||
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := validateOIDCAllowedUsers(writer, h.cfg.OIDC.AllowedUsers, claims); err != nil {
|
// Neither node nor machine key was found in the state cache meaning
|
||||||
return
|
// that we could not reauth nor register the node.
|
||||||
}
|
http.Error(writer, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
machineKey, nodeExists, err := h.validateNodeForOIDCCallback(
|
|
||||||
writer,
|
|
||||||
state,
|
|
||||||
claims,
|
|
||||||
idTokenExpiry,
|
|
||||||
)
|
|
||||||
if err != nil || nodeExists {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
userName, err := getUserName(writer, claims, h.cfg.OIDC.StripEmaildomain)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// register the node if it's new
|
|
||||||
log.Debug().Msg("Registering new node after successful callback")
|
|
||||||
|
|
||||||
user, err := h.findOrCreateNewUserForOIDCCallback(writer, userName)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := h.registerNodeForOIDCCallback(writer, user, machineKey, idTokenExpiry); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
content, err := renderOIDCCallbackTemplate(writer, claims)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
|
||||||
writer.WriteHeader(http.StatusOK)
|
|
||||||
if _, err := writer.Write(content.Bytes()); err != nil {
|
|
||||||
util.LogErr(err, "Failed to write response")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateOIDCCallbackParams(
|
func extractCodeAndStateParamFromRequest(
|
||||||
writer http.ResponseWriter,
|
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
) (string, string, error) {
|
) (string, string, error) {
|
||||||
code := req.URL.Query().Get("code")
|
code := req.URL.Query().Get("code")
|
||||||
state := req.URL.Query().Get("state")
|
state := req.URL.Query().Get("state")
|
||||||
|
|
||||||
if code == "" || state == "" {
|
if code == "" || state == "" {
|
||||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
|
||||||
writer.WriteHeader(http.StatusBadRequest)
|
|
||||||
_, err := writer.Write([]byte("Wrong params"))
|
|
||||||
if err != nil {
|
|
||||||
util.LogErr(err, "Failed to write response")
|
|
||||||
}
|
|
||||||
|
|
||||||
return "", "", errEmptyOIDCCallbackParams
|
return "", "", errEmptyOIDCCallbackParams
|
||||||
}
|
}
|
||||||
|
|
||||||
return code, state, nil
|
return code, state, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) getIDTokenForOIDCCallback(
|
// extractIDToken takes the code parameter from the callback
|
||||||
|
// and extracts the ID token from the oauth2 token.
|
||||||
|
func (a *AuthProviderOIDC) extractIDToken(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
writer http.ResponseWriter,
|
code string,
|
||||||
code, state string,
|
|
||||||
) (string, error) {
|
|
||||||
oauth2Token, err := h.oauth2Config.Exchange(ctx, code)
|
|
||||||
if err != nil {
|
|
||||||
util.LogErr(err, "Could not exchange code for token")
|
|
||||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
|
||||||
writer.WriteHeader(http.StatusBadRequest)
|
|
||||||
_, werr := writer.Write([]byte("Could not exchange code for token"))
|
|
||||||
if werr != nil {
|
|
||||||
util.LogErr(err, "Failed to write response")
|
|
||||||
}
|
|
||||||
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Trace().
|
|
||||||
Caller().
|
|
||||||
Str("code", code).
|
|
||||||
Str("state", state).
|
|
||||||
Msg("Got oidc callback")
|
|
||||||
|
|
||||||
rawIDToken, rawIDTokenOK := oauth2Token.Extra("id_token").(string)
|
|
||||||
if !rawIDTokenOK {
|
|
||||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
|
||||||
writer.WriteHeader(http.StatusBadRequest)
|
|
||||||
_, err := writer.Write([]byte("Could not extract ID Token"))
|
|
||||||
if err != nil {
|
|
||||||
util.LogErr(err, "Failed to write response")
|
|
||||||
}
|
|
||||||
|
|
||||||
return "", errNoOIDCIDToken
|
|
||||||
}
|
|
||||||
|
|
||||||
return rawIDToken, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Headscale) verifyIDTokenForOIDCCallback(
|
|
||||||
ctx context.Context,
|
|
||||||
writer http.ResponseWriter,
|
|
||||||
rawIDToken string,
|
|
||||||
) (*oidc.IDToken, error) {
|
) (*oidc.IDToken, error) {
|
||||||
verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDC.ClientID})
|
oauth2Token, err := a.oauth2Config.Exchange(ctx, code)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("could not exchange code for token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, errNoOIDCIDToken
|
||||||
|
}
|
||||||
|
|
||||||
|
verifier := a.oidcProvider.Verifier(&oidc.Config{ClientID: a.cfg.ClientID})
|
||||||
idToken, err := verifier.Verify(ctx, rawIDToken)
|
idToken, err := verifier.Verify(ctx, rawIDToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.LogErr(err, "failed to verify id token")
|
return nil, fmt.Errorf("failed to verify ID token: %w", err)
|
||||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
|
||||||
writer.WriteHeader(http.StatusBadRequest)
|
|
||||||
_, werr := writer.Write([]byte("Failed to verify id token"))
|
|
||||||
if werr != nil {
|
|
||||||
util.LogErr(err, "Failed to write response")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return idToken, nil
|
return idToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func extractIDTokenClaims(
|
|
||||||
writer http.ResponseWriter,
|
|
||||||
idToken *oidc.IDToken,
|
|
||||||
) (*IDTokenClaims, error) {
|
|
||||||
var claims IDTokenClaims
|
|
||||||
if err := idToken.Claims(&claims); err != nil {
|
|
||||||
util.LogErr(err, "Failed to decode id token claims")
|
|
||||||
|
|
||||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
|
||||||
writer.WriteHeader(http.StatusBadRequest)
|
|
||||||
_, werr := writer.Write([]byte("Failed to decode id token claims"))
|
|
||||||
if werr != nil {
|
|
||||||
util.LogErr(err, "Failed to write response")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &claims, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateOIDCAllowedDomains checks that if AllowedDomains is provided,
|
// validateOIDCAllowedDomains checks that if AllowedDomains is provided,
|
||||||
// that the authenticated principal ends with @<alloweddomain>.
|
// that the authenticated principal ends with @<alloweddomain>.
|
||||||
func validateOIDCAllowedDomains(
|
func validateOIDCAllowedDomains(
|
||||||
writer http.ResponseWriter,
|
|
||||||
allowedDomains []string,
|
allowedDomains []string,
|
||||||
claims *IDTokenClaims,
|
claims *types.OIDCClaims,
|
||||||
) error {
|
) error {
|
||||||
if len(allowedDomains) > 0 {
|
if len(allowedDomains) > 0 {
|
||||||
if at := strings.LastIndex(claims.Email, "@"); at < 0 ||
|
if at := strings.LastIndex(claims.Email, "@"); at < 0 ||
|
||||||
!slices.Contains(allowedDomains, claims.Email[at+1:]) {
|
!slices.Contains(allowedDomains, claims.Email[at+1:]) {
|
||||||
log.Trace().Msg("authenticated principal does not match any allowed domain")
|
|
||||||
|
|
||||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
|
||||||
writer.WriteHeader(http.StatusBadRequest)
|
|
||||||
_, err := writer.Write([]byte("unauthorized principal (domain mismatch)"))
|
|
||||||
if err != nil {
|
|
||||||
util.LogErr(err, "Failed to write response")
|
|
||||||
}
|
|
||||||
|
|
||||||
return errOIDCAllowedDomains
|
return errOIDCAllowedDomains
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -387,9 +355,8 @@ func validateOIDCAllowedDomains(
|
||||||
// claims.Groups can be populated by adding a client scope named
|
// claims.Groups can be populated by adding a client scope named
|
||||||
// 'groups' that contains group membership.
|
// 'groups' that contains group membership.
|
||||||
func validateOIDCAllowedGroups(
|
func validateOIDCAllowedGroups(
|
||||||
writer http.ResponseWriter,
|
|
||||||
allowedGroups []string,
|
allowedGroups []string,
|
||||||
claims *IDTokenClaims,
|
claims *types.OIDCClaims,
|
||||||
) error {
|
) error {
|
||||||
if len(allowedGroups) > 0 {
|
if len(allowedGroups) > 0 {
|
||||||
for _, group := range allowedGroups {
|
for _, group := range allowedGroups {
|
||||||
|
@ -398,14 +365,6 @@ func validateOIDCAllowedGroups(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Trace().Msg("authenticated principal not in any allowed groups")
|
|
||||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
|
||||||
writer.WriteHeader(http.StatusBadRequest)
|
|
||||||
_, err := writer.Write([]byte("unauthorized principal (allowed groups)"))
|
|
||||||
if err != nil {
|
|
||||||
util.LogErr(err, "Failed to write response")
|
|
||||||
}
|
|
||||||
|
|
||||||
return errOIDCAllowedGroups
|
return errOIDCAllowedGroups
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -415,249 +374,129 @@ func validateOIDCAllowedGroups(
|
||||||
// validateOIDCAllowedUsers checks that if AllowedUsers is provided,
|
// validateOIDCAllowedUsers checks that if AllowedUsers is provided,
|
||||||
// that the authenticated principal is part of that list.
|
// that the authenticated principal is part of that list.
|
||||||
func validateOIDCAllowedUsers(
|
func validateOIDCAllowedUsers(
|
||||||
writer http.ResponseWriter,
|
|
||||||
allowedUsers []string,
|
allowedUsers []string,
|
||||||
claims *IDTokenClaims,
|
claims *types.OIDCClaims,
|
||||||
) error {
|
) error {
|
||||||
if len(allowedUsers) > 0 &&
|
if len(allowedUsers) > 0 &&
|
||||||
!slices.Contains(allowedUsers, claims.Email) {
|
!slices.Contains(allowedUsers, claims.Email) {
|
||||||
log.Trace().Msg("authenticated principal does not match any allowed user")
|
log.Trace().Msg("authenticated principal does not match any allowed user")
|
||||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
|
||||||
writer.WriteHeader(http.StatusBadRequest)
|
|
||||||
_, err := writer.Write([]byte("unauthorized principal (user mismatch)"))
|
|
||||||
if err != nil {
|
|
||||||
util.LogErr(err, "Failed to write response")
|
|
||||||
}
|
|
||||||
|
|
||||||
return errOIDCAllowedUsers
|
return errOIDCAllowedUsers
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// validateNode retrieves node information if it exist
|
// getMachineKeyFromState retrieves the machine key from the state
|
||||||
// The error is not important, because if it does not
|
// cache. If the machine key is found, it will try retrieve the
|
||||||
// exist, then this is a new node and we will move
|
// node information from the database.
|
||||||
// on to registration.
|
func (a *AuthProviderOIDC) getMachineKeyFromState(state string) (*types.Node, *key.MachinePublic) {
|
||||||
func (h *Headscale) validateNodeForOIDCCallback(
|
machineKey, ok := a.registrationCache.Get(state)
|
||||||
writer http.ResponseWriter,
|
if !ok {
|
||||||
state string,
|
return nil, nil
|
||||||
claims *IDTokenClaims,
|
|
||||||
expiry time.Time,
|
|
||||||
) (*key.MachinePublic, bool, error) {
|
|
||||||
// retrieve nodekey from state cache
|
|
||||||
machineKeyIf, machineKeyFound := h.registrationCache.Get(state)
|
|
||||||
if !machineKeyFound {
|
|
||||||
log.Trace().
|
|
||||||
Msg("requested node state key expired before authorisation completed")
|
|
||||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
|
||||||
writer.WriteHeader(http.StatusBadRequest)
|
|
||||||
_, err := writer.Write([]byte("state has expired"))
|
|
||||||
if err != nil {
|
|
||||||
util.LogErr(err, "Failed to write response")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, false, errOIDCNodeKeyMissing
|
|
||||||
}
|
|
||||||
|
|
||||||
var machineKey key.MachinePublic
|
|
||||||
machineKey, machineKeyOK := machineKeyIf.(key.MachinePublic)
|
|
||||||
if !machineKeyOK {
|
|
||||||
log.Trace().
|
|
||||||
Interface("got", machineKeyIf).
|
|
||||||
Msg("requested node state key is not a nodekey")
|
|
||||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
|
||||||
writer.WriteHeader(http.StatusBadRequest)
|
|
||||||
_, err := writer.Write([]byte("state is invalid"))
|
|
||||||
if err != nil {
|
|
||||||
util.LogErr(err, "Failed to write response")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, false, errOIDCInvalidNodeState
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// retrieve node information if it exist
|
// retrieve node information if it exist
|
||||||
// The error is not important, because if it does not
|
// The error is not important, because if it does not
|
||||||
// exist, then this is a new node and we will move
|
// exist, then this is a new node and we will move
|
||||||
// on to registration.
|
// on to registration.
|
||||||
node, _ := h.db.GetNodeByMachineKey(machineKey)
|
node, _ := a.db.GetNodeByMachineKey(machineKey)
|
||||||
|
|
||||||
if node != nil {
|
return node, &machineKey
|
||||||
log.Trace().
|
|
||||||
Caller().
|
|
||||||
Str("node", node.Hostname).
|
|
||||||
Msg("node already registered, reauthenticating")
|
|
||||||
|
|
||||||
err := h.db.NodeSetExpiry(node.ID, expiry)
|
|
||||||
if err != nil {
|
|
||||||
util.LogErr(err, "Failed to refresh node")
|
|
||||||
http.Error(
|
|
||||||
writer,
|
|
||||||
"Failed to refresh node",
|
|
||||||
http.StatusInternalServerError,
|
|
||||||
)
|
|
||||||
|
|
||||||
return nil, true, err
|
|
||||||
}
|
|
||||||
log.Debug().
|
|
||||||
Str("node", node.Hostname).
|
|
||||||
Str("expiresAt", fmt.Sprintf("%v", expiry)).
|
|
||||||
Msg("successfully refreshed node")
|
|
||||||
|
|
||||||
var content bytes.Buffer
|
|
||||||
if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
|
|
||||||
User: claims.Email,
|
|
||||||
Verb: "Reauthenticated",
|
|
||||||
}); err != nil {
|
|
||||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
|
||||||
writer.WriteHeader(http.StatusInternalServerError)
|
|
||||||
_, werr := writer.Write([]byte("Could not render OIDC callback template"))
|
|
||||||
if werr != nil {
|
|
||||||
util.LogErr(err, "Failed to write response")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, true, fmt.Errorf("rendering OIDC callback template: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
|
||||||
writer.WriteHeader(http.StatusOK)
|
|
||||||
_, err = writer.Write(content.Bytes())
|
|
||||||
if err != nil {
|
|
||||||
util.LogErr(err, "Failed to write response")
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := types.NotifyCtx(context.Background(), "oidc-expiry-self", node.Hostname)
|
|
||||||
h.nodeNotifier.NotifyByNodeID(
|
|
||||||
ctx,
|
|
||||||
types.StateUpdate{
|
|
||||||
Type: types.StateSelfUpdate,
|
|
||||||
ChangeNodes: []types.NodeID{node.ID},
|
|
||||||
},
|
|
||||||
node.ID,
|
|
||||||
)
|
|
||||||
|
|
||||||
ctx = types.NotifyCtx(context.Background(), "oidc-expiry-peers", node.Hostname)
|
|
||||||
h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, expiry), node.ID)
|
|
||||||
|
|
||||||
return nil, true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return &machineKey, false, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func getUserName(
|
// reauthenticateNode updates the node expiry in the database
|
||||||
writer http.ResponseWriter,
|
// and notifies the node and its peers about the change.
|
||||||
claims *IDTokenClaims,
|
func (a *AuthProviderOIDC) reauthenticateNode(
|
||||||
stripEmaildomain bool,
|
node *types.Node,
|
||||||
) (string, error) {
|
expiry time.Time,
|
||||||
userName, err := util.NormalizeToFQDNRules(
|
) error {
|
||||||
claims.Email,
|
err := a.db.NodeSetExpiry(node.ID, expiry)
|
||||||
stripEmaildomain,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.LogErr(err, "couldn't normalize email")
|
return err
|
||||||
|
|
||||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
|
||||||
writer.WriteHeader(http.StatusInternalServerError)
|
|
||||||
_, werr := writer.Write([]byte("couldn't normalize email"))
|
|
||||||
if werr != nil {
|
|
||||||
util.LogErr(err, "Failed to write response")
|
|
||||||
}
|
|
||||||
|
|
||||||
return "", err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return userName, nil
|
ctx := types.NotifyCtx(context.Background(), "oidc-expiry-self", node.Hostname)
|
||||||
|
a.notifier.NotifyByNodeID(
|
||||||
|
ctx,
|
||||||
|
types.StateUpdate{
|
||||||
|
Type: types.StateSelfUpdate,
|
||||||
|
ChangeNodes: []types.NodeID{node.ID},
|
||||||
|
},
|
||||||
|
node.ID,
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx = types.NotifyCtx(context.Background(), "oidc-expiry-peers", node.Hostname)
|
||||||
|
a.notifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, expiry), node.ID)
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) findOrCreateNewUserForOIDCCallback(
|
func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
|
||||||
writer http.ResponseWriter,
|
claims *types.OIDCClaims,
|
||||||
userName string,
|
|
||||||
) (*types.User, error) {
|
) (*types.User, error) {
|
||||||
user, err := h.db.GetUser(userName)
|
var user *types.User
|
||||||
if errors.Is(err, db.ErrUserNotFound) {
|
var err error
|
||||||
user, err = h.db.CreateUser(userName)
|
user, err = a.db.GetUserByOIDCIdentifier(claims.Sub)
|
||||||
if err != nil {
|
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
|
||||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
return nil, fmt.Errorf("creating or updating user: %w", err)
|
||||||
writer.WriteHeader(http.StatusInternalServerError)
|
}
|
||||||
_, werr := writer.Write([]byte("could not create user"))
|
|
||||||
if werr != nil {
|
|
||||||
util.LogErr(err, "Failed to write response")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, fmt.Errorf("creating new user: %w", err)
|
// This check is for legacy, if the user cannot be found by the OIDC identifier
|
||||||
}
|
// look it up by username. This should only be needed once.
|
||||||
} else if err != nil {
|
if user == nil {
|
||||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
user, err = a.db.GetUserByName(claims.Username)
|
||||||
writer.WriteHeader(http.StatusInternalServerError)
|
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
|
||||||
_, werr := writer.Write([]byte("could not find or create user"))
|
return nil, fmt.Errorf("creating or updating user: %w", err)
|
||||||
if werr != nil {
|
|
||||||
util.LogErr(err, "Failed to write response")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, fmt.Errorf("find or create user: %w", err)
|
// if the user is still not found, create a new empty user.
|
||||||
|
if user == nil {
|
||||||
|
user = &types.User{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
user.FromClaim(claims)
|
||||||
|
err = a.db.DB.Save(user).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("creating or updating user: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) registerNodeForOIDCCallback(
|
func (a *AuthProviderOIDC) registerNode(
|
||||||
writer http.ResponseWriter,
|
|
||||||
user *types.User,
|
user *types.User,
|
||||||
machineKey *key.MachinePublic,
|
machineKey *key.MachinePublic,
|
||||||
expiry time.Time,
|
expiry time.Time,
|
||||||
) error {
|
) error {
|
||||||
ipv4, ipv6, err := h.ipAlloc.Next()
|
ipv4, ipv6, err := a.ipAlloc.Next()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.db.Write(func(tx *gorm.DB) error {
|
if _, err := a.db.RegisterNodeFromAuthCallback(
|
||||||
if _, err := db.RegisterNodeFromAuthCallback(
|
*machineKey,
|
||||||
// TODO(kradalby): find a better way to use the cache across modules
|
types.UserID(user.ID),
|
||||||
tx,
|
&expiry,
|
||||||
h.registrationCache,
|
util.RegisterMethodOIDC,
|
||||||
*machineKey,
|
ipv4, ipv6,
|
||||||
user.Name,
|
); err != nil {
|
||||||
&expiry,
|
return fmt.Errorf("could not register node: %w", err)
|
||||||
util.RegisterMethodOIDC,
|
|
||||||
ipv4, ipv6,
|
|
||||||
); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}); err != nil {
|
|
||||||
util.LogErr(err, "could not register node")
|
|
||||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
|
||||||
writer.WriteHeader(http.StatusInternalServerError)
|
|
||||||
_, werr := writer.Write([]byte("could not register node"))
|
|
||||||
if werr != nil {
|
|
||||||
util.LogErr(err, "Failed to write response")
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(kradalby):
|
||||||
|
// Rewrite in elem-go
|
||||||
func renderOIDCCallbackTemplate(
|
func renderOIDCCallbackTemplate(
|
||||||
writer http.ResponseWriter,
|
user *types.User,
|
||||||
claims *IDTokenClaims,
|
|
||||||
) (*bytes.Buffer, error) {
|
) (*bytes.Buffer, error) {
|
||||||
var content bytes.Buffer
|
var content bytes.Buffer
|
||||||
if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
|
if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
|
||||||
User: claims.Email,
|
User: user.DisplayNameOrUsername(),
|
||||||
Verb: "Authenticated",
|
Verb: "Authenticated",
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
|
||||||
writer.WriteHeader(http.StatusInternalServerError)
|
|
||||||
_, werr := writer.Write([]byte("Could not render OIDC callback template"))
|
|
||||||
if werr != nil {
|
|
||||||
util.LogErr(err, "Failed to write response")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, fmt.Errorf("rendering OIDC callback template: %w", err)
|
return nil, fmt.Errorf("rendering OIDC callback template: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -743,15 +743,7 @@ func (pol *ACLPolicy) expandUsersFromGroup(
|
||||||
ErrInvalidGroup,
|
ErrInvalidGroup,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
grp, err := util.NormalizeToFQDNRulesConfigFromViper(group)
|
users = append(users, group)
|
||||||
if err != nil {
|
|
||||||
return []string{}, fmt.Errorf(
|
|
||||||
"failed to normalize group %q, err: %w",
|
|
||||||
group,
|
|
||||||
ErrInvalidGroup,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
users = append(users, grp)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return users, nil
|
return users, nil
|
||||||
|
@ -940,7 +932,7 @@ func (pol *ACLPolicy) TagsOfNode(
|
||||||
}
|
}
|
||||||
var found bool
|
var found bool
|
||||||
for _, owner := range owners {
|
for _, owner := range owners {
|
||||||
if node.User.Name == owner {
|
if node.User.Username() == owner {
|
||||||
found = true
|
found = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -964,7 +956,7 @@ func (pol *ACLPolicy) TagsOfNode(
|
||||||
func filterNodesByUser(nodes types.Nodes, user string) types.Nodes {
|
func filterNodesByUser(nodes types.Nodes, user string) types.Nodes {
|
||||||
var out types.Nodes
|
var out types.Nodes
|
||||||
for _, node := range nodes {
|
for _, node := range nodes {
|
||||||
if node.User.Name == user {
|
if node.User.Username() == user {
|
||||||
out = append(out, node)
|
out = append(out, node)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -635,25 +635,6 @@ func Test_expandGroup(t *testing.T) {
|
||||||
want: []string{},
|
want: []string{},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "Expand emails in group strip domains",
|
|
||||||
field: field{
|
|
||||||
pol: ACLPolicy{
|
|
||||||
Groups: Groups{
|
|
||||||
"group:admin": []string{
|
|
||||||
"joe.bar@gmail.com",
|
|
||||||
"john.doe@yahoo.fr",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
args: args{
|
|
||||||
group: "group:admin",
|
|
||||||
stripEmail: true,
|
|
||||||
},
|
|
||||||
want: []string{"joe.bar", "john.doe"},
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
name: "Expand emails in group",
|
name: "Expand emails in group",
|
||||||
field: field{
|
field: field{
|
||||||
|
@ -669,7 +650,7 @@ func Test_expandGroup(t *testing.T) {
|
||||||
args: args{
|
args: args{
|
||||||
group: "group:admin",
|
group: "group:admin",
|
||||||
},
|
},
|
||||||
want: []string{"joe.bar.gmail.com", "john.doe.yahoo.fr"},
|
want: []string{"joe.bar@gmail.com", "john.doe@yahoo.fr"},
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
|
@ -46,9 +46,7 @@ func (s *Suite) ResetDB(c *check.C) {
|
||||||
Path: tmpDir + "/headscale_test.db",
|
Path: tmpDir + "/headscale_test.db",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
OIDC: types.OIDCConfig{
|
OIDC: types.OIDCConfig{},
|
||||||
StripEmaildomain: false,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
app, err = NewHeadscale(&cfg)
|
app, err = NewHeadscale(&cfg)
|
||||||
|
|
|
@ -71,8 +71,7 @@ type Config struct {
|
||||||
ACMEURL string
|
ACMEURL string
|
||||||
ACMEEmail string
|
ACMEEmail string
|
||||||
|
|
||||||
DNSConfig *tailcfg.DNSConfig
|
DNSConfig *tailcfg.DNSConfig
|
||||||
DNSUserNameInMagicDNS bool
|
|
||||||
|
|
||||||
UnixSocket string
|
UnixSocket string
|
||||||
UnixSocketPermission fs.FileMode
|
UnixSocketPermission fs.FileMode
|
||||||
|
@ -90,12 +89,11 @@ type Config struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type DNSConfig struct {
|
type DNSConfig struct {
|
||||||
MagicDNS bool `mapstructure:"magic_dns"`
|
MagicDNS bool `mapstructure:"magic_dns"`
|
||||||
BaseDomain string `mapstructure:"base_domain"`
|
BaseDomain string `mapstructure:"base_domain"`
|
||||||
Nameservers Nameservers
|
Nameservers Nameservers
|
||||||
SearchDomains []string `mapstructure:"search_domains"`
|
SearchDomains []string `mapstructure:"search_domains"`
|
||||||
ExtraRecords []tailcfg.DNSRecord `mapstructure:"extra_records"`
|
ExtraRecords []tailcfg.DNSRecord `mapstructure:"extra_records"`
|
||||||
UserNameInMagicDNS bool `mapstructure:"use_username_in_magic_dns"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Nameservers struct {
|
type Nameservers struct {
|
||||||
|
@ -164,7 +162,6 @@ type OIDCConfig struct {
|
||||||
AllowedDomains []string
|
AllowedDomains []string
|
||||||
AllowedUsers []string
|
AllowedUsers []string
|
||||||
AllowedGroups []string
|
AllowedGroups []string
|
||||||
StripEmaildomain bool
|
|
||||||
Expiry time.Duration
|
Expiry time.Duration
|
||||||
UseExpiryFromToken bool
|
UseExpiryFromToken bool
|
||||||
}
|
}
|
||||||
|
@ -274,7 +271,6 @@ func LoadConfig(path string, isFile bool) error {
|
||||||
viper.SetDefault("database.sqlite.write_ahead_log", true)
|
viper.SetDefault("database.sqlite.write_ahead_log", true)
|
||||||
|
|
||||||
viper.SetDefault("oidc.scope", []string{oidc.ScopeOpenID, "profile", "email"})
|
viper.SetDefault("oidc.scope", []string{oidc.ScopeOpenID, "profile", "email"})
|
||||||
viper.SetDefault("oidc.strip_email_domain", true)
|
|
||||||
viper.SetDefault("oidc.only_start_if_oidc_is_available", true)
|
viper.SetDefault("oidc.only_start_if_oidc_is_available", true)
|
||||||
viper.SetDefault("oidc.expiry", "180d")
|
viper.SetDefault("oidc.expiry", "180d")
|
||||||
viper.SetDefault("oidc.use_expiry_from_token", false)
|
viper.SetDefault("oidc.use_expiry_from_token", false)
|
||||||
|
@ -321,8 +317,22 @@ func validateServerConfig() error {
|
||||||
depr.warn("dns_config.use_username_in_magic_dns")
|
depr.warn("dns_config.use_username_in_magic_dns")
|
||||||
depr.warn("dns.use_username_in_magic_dns")
|
depr.warn("dns.use_username_in_magic_dns")
|
||||||
|
|
||||||
|
depr.fatal("oidc.strip_email_domain")
|
||||||
|
depr.fatal("dns.use_username_in_musername_in_magic_dns")
|
||||||
|
depr.fatal("dns_config.use_username_in_musername_in_magic_dns")
|
||||||
|
|
||||||
depr.Log()
|
depr.Log()
|
||||||
|
|
||||||
|
for _, removed := range []string{
|
||||||
|
"oidc.strip_email_domain",
|
||||||
|
"dns_config.use_username_in_musername_in_magic_dns",
|
||||||
|
} {
|
||||||
|
if viper.IsSet(removed) {
|
||||||
|
log.Fatal().
|
||||||
|
Msgf("Fatal config error: %s has been removed. Please remove it from your config file", removed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Collect any validation errors and return them all at once
|
// Collect any validation errors and return them all at once
|
||||||
var errorText string
|
var errorText string
|
||||||
if (viper.GetString("tls_letsencrypt_hostname") != "") &&
|
if (viper.GetString("tls_letsencrypt_hostname") != "") &&
|
||||||
|
@ -572,12 +582,9 @@ func dns() (DNSConfig, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return DNSConfig{}, fmt.Errorf("unmarshaling dns extra records: %w", err)
|
return DNSConfig{}, fmt.Errorf("unmarshaling dns extra records: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dns.ExtraRecords = extraRecords
|
dns.ExtraRecords = extraRecords
|
||||||
}
|
}
|
||||||
|
|
||||||
dns.UserNameInMagicDNS = viper.GetBool("dns.use_username_in_magic_dns")
|
|
||||||
|
|
||||||
return dns, nil
|
return dns, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -780,7 +787,12 @@ func LoadServerConfig() (*Config, error) {
|
||||||
case string(IPAllocationStrategyRandom):
|
case string(IPAllocationStrategyRandom):
|
||||||
alloc = IPAllocationStrategyRandom
|
alloc = IPAllocationStrategyRandom
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("config error, prefixes.allocation is set to %s, which is not a valid strategy, allowed options: %s, %s", allocStr, IPAllocationStrategySequential, IPAllocationStrategyRandom)
|
return nil, fmt.Errorf(
|
||||||
|
"config error, prefixes.allocation is set to %s, which is not a valid strategy, allowed options: %s, %s",
|
||||||
|
allocStr,
|
||||||
|
IPAllocationStrategySequential,
|
||||||
|
IPAllocationStrategyRandom,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsConfig, err := dns()
|
dnsConfig, err := dns()
|
||||||
|
@ -814,10 +826,11 @@ func LoadServerConfig() (*Config, error) {
|
||||||
// - DERP run on their own domains
|
// - DERP run on their own domains
|
||||||
// - Control plane runs on login.tailscale.com/controlplane.tailscale.com
|
// - Control plane runs on login.tailscale.com/controlplane.tailscale.com
|
||||||
// - MagicDNS (BaseDomain) for users is on a *.ts.net domain per tailnet (e.g. tail-scale.ts.net)
|
// - MagicDNS (BaseDomain) for users is on a *.ts.net domain per tailnet (e.g. tail-scale.ts.net)
|
||||||
//
|
if dnsConfig.BaseDomain != "" &&
|
||||||
// TODO(kradalby): remove dnsConfig.UserNameInMagicDNS check when removed.
|
strings.Contains(serverURL, dnsConfig.BaseDomain) {
|
||||||
if !dnsConfig.UserNameInMagicDNS && dnsConfig.BaseDomain != "" && strings.Contains(serverURL, dnsConfig.BaseDomain) {
|
return nil, errors.New(
|
||||||
return nil, errors.New("server_url cannot contain the base_domain, this will cause the headscale server and embedded DERP to become unreachable from the Tailscale node.")
|
"server_url cannot contain the base_domain, this will cause the headscale server and embedded DERP to become unreachable from the Tailscale node.",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Config{
|
return &Config{
|
||||||
|
@ -847,8 +860,7 @@ func LoadServerConfig() (*Config, error) {
|
||||||
|
|
||||||
TLS: tlsConfig(),
|
TLS: tlsConfig(),
|
||||||
|
|
||||||
DNSConfig: dnsToTailcfgDNS(dnsConfig),
|
DNSConfig: dnsToTailcfgDNS(dnsConfig),
|
||||||
DNSUserNameInMagicDNS: dnsConfig.UserNameInMagicDNS,
|
|
||||||
|
|
||||||
ACMEEmail: viper.GetString("acme_email"),
|
ACMEEmail: viper.GetString("acme_email"),
|
||||||
ACMEURL: viper.GetString("acme_url"),
|
ACMEURL: viper.GetString("acme_url"),
|
||||||
|
@ -860,15 +872,14 @@ func LoadServerConfig() (*Config, error) {
|
||||||
OnlyStartIfOIDCIsAvailable: viper.GetBool(
|
OnlyStartIfOIDCIsAvailable: viper.GetBool(
|
||||||
"oidc.only_start_if_oidc_is_available",
|
"oidc.only_start_if_oidc_is_available",
|
||||||
),
|
),
|
||||||
Issuer: viper.GetString("oidc.issuer"),
|
Issuer: viper.GetString("oidc.issuer"),
|
||||||
ClientID: viper.GetString("oidc.client_id"),
|
ClientID: viper.GetString("oidc.client_id"),
|
||||||
ClientSecret: oidcClientSecret,
|
ClientSecret: oidcClientSecret,
|
||||||
Scope: viper.GetStringSlice("oidc.scope"),
|
Scope: viper.GetStringSlice("oidc.scope"),
|
||||||
ExtraParams: viper.GetStringMapString("oidc.extra_params"),
|
ExtraParams: viper.GetStringMapString("oidc.extra_params"),
|
||||||
AllowedDomains: viper.GetStringSlice("oidc.allowed_domains"),
|
AllowedDomains: viper.GetStringSlice("oidc.allowed_domains"),
|
||||||
AllowedUsers: viper.GetStringSlice("oidc.allowed_users"),
|
AllowedUsers: viper.GetStringSlice("oidc.allowed_users"),
|
||||||
AllowedGroups: viper.GetStringSlice("oidc.allowed_groups"),
|
AllowedGroups: viper.GetStringSlice("oidc.allowed_groups"),
|
||||||
StripEmaildomain: viper.GetBool("oidc.strip_email_domain"),
|
|
||||||
Expiry: func() time.Duration {
|
Expiry: func() time.Duration {
|
||||||
// if set to 0, we assume no expiry
|
// if set to 0, we assume no expiry
|
||||||
if value := viper.GetString("oidc.expiry"); value == "0" {
|
if value := viper.GetString("oidc.expiry"); value == "0" {
|
||||||
|
@ -903,9 +914,11 @@ func LoadServerConfig() (*Config, error) {
|
||||||
|
|
||||||
// TODO(kradalby): Document these settings when more stable
|
// TODO(kradalby): Document these settings when more stable
|
||||||
Tuning: Tuning{
|
Tuning: Tuning{
|
||||||
NotifierSendTimeout: viper.GetDuration("tuning.notifier_send_timeout"),
|
NotifierSendTimeout: viper.GetDuration("tuning.notifier_send_timeout"),
|
||||||
BatchChangeDelay: viper.GetDuration("tuning.batch_change_delay"),
|
BatchChangeDelay: viper.GetDuration("tuning.batch_change_delay"),
|
||||||
NodeMapSessionBufferedChanSize: viper.GetInt("tuning.node_mapsession_buffered_chan_size"),
|
NodeMapSessionBufferedChanSize: viper.GetInt(
|
||||||
|
"tuning.node_mapsession_buffered_chan_size",
|
||||||
|
),
|
||||||
},
|
},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
@ -921,14 +934,26 @@ func (d *deprecator) warnWithAlias(newKey, oldKey string) {
|
||||||
// NOTE: RegisterAlias is called with NEW KEY -> OLD KEY
|
// NOTE: RegisterAlias is called with NEW KEY -> OLD KEY
|
||||||
viper.RegisterAlias(newKey, oldKey)
|
viper.RegisterAlias(newKey, oldKey)
|
||||||
if viper.IsSet(oldKey) {
|
if viper.IsSet(oldKey) {
|
||||||
d.warns.Add(fmt.Sprintf("The %q configuration key is deprecated. Please use %q instead. %q will be removed in the future.", oldKey, newKey, oldKey))
|
d.warns.Add(
|
||||||
|
fmt.Sprintf(
|
||||||
|
"The %q configuration key is deprecated. Please use %q instead. %q will be removed in the future.",
|
||||||
|
oldKey,
|
||||||
|
newKey,
|
||||||
|
oldKey,
|
||||||
|
),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// fatal deprecates and adds an entry to the fatal list of options if the oldKey is set.
|
// fatal deprecates and adds an entry to the fatal list of options if the oldKey is set.
|
||||||
func (d *deprecator) fatal(newKey, oldKey string) {
|
func (d *deprecator) fatal(oldKey string) {
|
||||||
if viper.IsSet(oldKey) {
|
if viper.IsSet(oldKey) {
|
||||||
d.fatals.Add(fmt.Sprintf("The %q configuration key is deprecated. Please use %q instead. %q has been removed.", oldKey, newKey, oldKey))
|
d.fatals.Add(
|
||||||
|
fmt.Sprintf(
|
||||||
|
"The %q configuration key has been removed. Please see the changelog for more details.",
|
||||||
|
oldKey,
|
||||||
|
),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -936,7 +961,14 @@ func (d *deprecator) fatal(newKey, oldKey string) {
|
||||||
// If the new key is set, a warning is emitted instead.
|
// If the new key is set, a warning is emitted instead.
|
||||||
func (d *deprecator) fatalIfNewKeyIsNotUsed(newKey, oldKey string) {
|
func (d *deprecator) fatalIfNewKeyIsNotUsed(newKey, oldKey string) {
|
||||||
if viper.IsSet(oldKey) && !viper.IsSet(newKey) {
|
if viper.IsSet(oldKey) && !viper.IsSet(newKey) {
|
||||||
d.fatals.Add(fmt.Sprintf("The %q configuration key is deprecated. Please use %q instead. %q has been removed.", oldKey, newKey, oldKey))
|
d.fatals.Add(
|
||||||
|
fmt.Sprintf(
|
||||||
|
"The %q configuration key is deprecated. Please use %q instead. %q has been removed.",
|
||||||
|
oldKey,
|
||||||
|
newKey,
|
||||||
|
oldKey,
|
||||||
|
),
|
||||||
|
)
|
||||||
} else if viper.IsSet(oldKey) {
|
} else if viper.IsSet(oldKey) {
|
||||||
d.warns.Add(fmt.Sprintf("The %q configuration key is deprecated. Please use %q instead. %q has been removed.", oldKey, newKey, oldKey))
|
d.warns.Add(fmt.Sprintf("The %q configuration key is deprecated. Please use %q instead. %q has been removed.", oldKey, newKey, oldKey))
|
||||||
}
|
}
|
||||||
|
@ -945,14 +977,26 @@ func (d *deprecator) fatalIfNewKeyIsNotUsed(newKey, oldKey string) {
|
||||||
// warn deprecates and adds an option to log a warning if the oldKey is set.
|
// warn deprecates and adds an option to log a warning if the oldKey is set.
|
||||||
func (d *deprecator) warnNoAlias(newKey, oldKey string) {
|
func (d *deprecator) warnNoAlias(newKey, oldKey string) {
|
||||||
if viper.IsSet(oldKey) {
|
if viper.IsSet(oldKey) {
|
||||||
d.warns.Add(fmt.Sprintf("The %q configuration key is deprecated. Please use %q instead. %q has been removed.", oldKey, newKey, oldKey))
|
d.warns.Add(
|
||||||
|
fmt.Sprintf(
|
||||||
|
"The %q configuration key is deprecated. Please use %q instead. %q has been removed.",
|
||||||
|
oldKey,
|
||||||
|
newKey,
|
||||||
|
oldKey,
|
||||||
|
),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// warn deprecates and adds an entry to the warn list of options if the oldKey is set.
|
// warn deprecates and adds an entry to the warn list of options if the oldKey is set.
|
||||||
func (d *deprecator) warn(oldKey string) {
|
func (d *deprecator) warn(oldKey string) {
|
||||||
if viper.IsSet(oldKey) {
|
if viper.IsSet(oldKey) {
|
||||||
d.warns.Add(fmt.Sprintf("The %q configuration key is deprecated and has been removed. Please see the changelog for more details.", oldKey))
|
d.warns.Add(
|
||||||
|
fmt.Sprintf(
|
||||||
|
"The %q configuration key is deprecated and has been removed. Please see the changelog for more details.",
|
||||||
|
oldKey,
|
||||||
|
),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -42,8 +42,7 @@ func TestReadConfig(t *testing.T) {
|
||||||
{Name: "grafana.myvpn.example.com", Type: "A", Value: "100.64.0.3"},
|
{Name: "grafana.myvpn.example.com", Type: "A", Value: "100.64.0.3"},
|
||||||
{Name: "prometheus.myvpn.example.com", Type: "A", Value: "100.64.0.4"},
|
{Name: "prometheus.myvpn.example.com", Type: "A", Value: "100.64.0.4"},
|
||||||
},
|
},
|
||||||
SearchDomains: []string{"test.com", "bar.com"},
|
SearchDomains: []string{"test.com", "bar.com"},
|
||||||
UserNameInMagicDNS: true,
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -99,8 +98,7 @@ func TestReadConfig(t *testing.T) {
|
||||||
{Name: "grafana.myvpn.example.com", Type: "A", Value: "100.64.0.3"},
|
{Name: "grafana.myvpn.example.com", Type: "A", Value: "100.64.0.3"},
|
||||||
{Name: "prometheus.myvpn.example.com", Type: "A", Value: "100.64.0.4"},
|
{Name: "prometheus.myvpn.example.com", Type: "A", Value: "100.64.0.4"},
|
||||||
},
|
},
|
||||||
SearchDomains: []string{"test.com", "bar.com"},
|
SearchDomains: []string{"test.com", "bar.com"},
|
||||||
UserNameInMagicDNS: true,
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -234,11 +232,10 @@ func TestReadConfigFromEnv(t *testing.T) {
|
||||||
{
|
{
|
||||||
name: "unmarshal-dns-full-config",
|
name: "unmarshal-dns-full-config",
|
||||||
configEnv: map[string]string{
|
configEnv: map[string]string{
|
||||||
"HEADSCALE_DNS_MAGIC_DNS": "true",
|
"HEADSCALE_DNS_MAGIC_DNS": "true",
|
||||||
"HEADSCALE_DNS_BASE_DOMAIN": "example.com",
|
"HEADSCALE_DNS_BASE_DOMAIN": "example.com",
|
||||||
"HEADSCALE_DNS_NAMESERVERS_GLOBAL": `1.1.1.1 8.8.8.8`,
|
"HEADSCALE_DNS_NAMESERVERS_GLOBAL": `1.1.1.1 8.8.8.8`,
|
||||||
"HEADSCALE_DNS_SEARCH_DOMAINS": "test.com bar.com",
|
"HEADSCALE_DNS_SEARCH_DOMAINS": "test.com bar.com",
|
||||||
"HEADSCALE_DNS_USE_USERNAME_IN_MAGIC_DNS": "true",
|
|
||||||
|
|
||||||
// TODO(kradalby): Figure out how to pass these as env vars
|
// TODO(kradalby): Figure out how to pass these as env vars
|
||||||
// "HEADSCALE_DNS_NAMESERVERS_SPLIT": `{foo.bar.com: ["1.1.1.1"]}`,
|
// "HEADSCALE_DNS_NAMESERVERS_SPLIT": `{foo.bar.com: ["1.1.1.1"]}`,
|
||||||
|
@ -266,8 +263,7 @@ func TestReadConfigFromEnv(t *testing.T) {
|
||||||
ExtraRecords: []tailcfg.DNSRecord{
|
ExtraRecords: []tailcfg.DNSRecord{
|
||||||
// {Name: "prometheus.myvpn.example.com", Type: "A", Value: "100.64.0.4"},
|
// {Name: "prometheus.myvpn.example.com", Type: "A", Value: "100.64.0.4"},
|
||||||
},
|
},
|
||||||
SearchDomains: []string{"test.com", "bar.com"},
|
SearchDomains: []string{"test.com", "bar.com"},
|
||||||
UserNameInMagicDNS: true,
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
|
@ -253,7 +253,7 @@ func (node *Node) Proto() *v1.Node {
|
||||||
return nodeProto
|
return nodeProto
|
||||||
}
|
}
|
||||||
|
|
||||||
func (node *Node) GetFQDN(cfg *Config, baseDomain string) (string, error) {
|
func (node *Node) GetFQDN(baseDomain string) (string, error) {
|
||||||
if node.GivenName == "" {
|
if node.GivenName == "" {
|
||||||
return "", fmt.Errorf("failed to create valid FQDN: %w", ErrNodeHasNoGivenName)
|
return "", fmt.Errorf("failed to create valid FQDN: %w", ErrNodeHasNoGivenName)
|
||||||
}
|
}
|
||||||
|
@ -268,19 +268,6 @@ func (node *Node) GetFQDN(cfg *Config, baseDomain string) (string, error) {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.DNSUserNameInMagicDNS {
|
|
||||||
if node.User.Name == "" {
|
|
||||||
return "", fmt.Errorf("failed to create valid FQDN: %w", ErrNodeUserHasNoName)
|
|
||||||
}
|
|
||||||
|
|
||||||
hostname = fmt.Sprintf(
|
|
||||||
"%s.%s.%s",
|
|
||||||
node.GivenName,
|
|
||||||
node.User.Name,
|
|
||||||
baseDomain,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(hostname) > MaxHostnameLength {
|
if len(hostname) > MaxHostnameLength {
|
||||||
return "", fmt.Errorf(
|
return "", fmt.Errorf(
|
||||||
"failed to create valid FQDN (%s): %w",
|
"failed to create valid FQDN (%s): %w",
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
package types
|
package types
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
|
@ -127,76 +129,10 @@ func TestNodeFQDN(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
node Node
|
node Node
|
||||||
cfg Config
|
|
||||||
domain string
|
domain string
|
||||||
want string
|
want string
|
||||||
wantErr string
|
wantErr string
|
||||||
}{
|
}{
|
||||||
{
|
|
||||||
name: "all-set-with-username",
|
|
||||||
node: Node{
|
|
||||||
GivenName: "test",
|
|
||||||
User: User{
|
|
||||||
Name: "user",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
cfg: Config{
|
|
||||||
DNSConfig: &tailcfg.DNSConfig{
|
|
||||||
Proxied: true,
|
|
||||||
},
|
|
||||||
DNSUserNameInMagicDNS: true,
|
|
||||||
},
|
|
||||||
domain: "example.com",
|
|
||||||
want: "test.user.example.com",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "no-given-name-with-username",
|
|
||||||
node: Node{
|
|
||||||
User: User{
|
|
||||||
Name: "user",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
cfg: Config{
|
|
||||||
DNSConfig: &tailcfg.DNSConfig{
|
|
||||||
Proxied: true,
|
|
||||||
},
|
|
||||||
DNSUserNameInMagicDNS: true,
|
|
||||||
},
|
|
||||||
domain: "example.com",
|
|
||||||
wantErr: "failed to create valid FQDN: node has no given name",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "no-user-name-with-username",
|
|
||||||
node: Node{
|
|
||||||
GivenName: "test",
|
|
||||||
User: User{},
|
|
||||||
},
|
|
||||||
cfg: Config{
|
|
||||||
DNSConfig: &tailcfg.DNSConfig{
|
|
||||||
Proxied: true,
|
|
||||||
},
|
|
||||||
DNSUserNameInMagicDNS: true,
|
|
||||||
},
|
|
||||||
domain: "example.com",
|
|
||||||
wantErr: "failed to create valid FQDN: node user has no name",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "no-magic-dns-with-username",
|
|
||||||
node: Node{
|
|
||||||
GivenName: "test",
|
|
||||||
User: User{
|
|
||||||
Name: "user",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
cfg: Config{
|
|
||||||
DNSConfig: &tailcfg.DNSConfig{
|
|
||||||
Proxied: false,
|
|
||||||
},
|
|
||||||
DNSUserNameInMagicDNS: true,
|
|
||||||
},
|
|
||||||
domain: "example.com",
|
|
||||||
want: "test.user.example.com",
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
name: "no-dnsconfig-with-username",
|
name: "no-dnsconfig-with-username",
|
||||||
node: Node{
|
node: Node{
|
||||||
|
@ -216,12 +152,6 @@ func TestNodeFQDN(t *testing.T) {
|
||||||
Name: "user",
|
Name: "user",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
cfg: Config{
|
|
||||||
DNSConfig: &tailcfg.DNSConfig{
|
|
||||||
Proxied: true,
|
|
||||||
},
|
|
||||||
DNSUserNameInMagicDNS: false,
|
|
||||||
},
|
|
||||||
domain: "example.com",
|
domain: "example.com",
|
||||||
want: "test.example.com",
|
want: "test.example.com",
|
||||||
},
|
},
|
||||||
|
@ -232,46 +162,16 @@ func TestNodeFQDN(t *testing.T) {
|
||||||
Name: "user",
|
Name: "user",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
cfg: Config{
|
|
||||||
DNSConfig: &tailcfg.DNSConfig{
|
|
||||||
Proxied: true,
|
|
||||||
},
|
|
||||||
DNSUserNameInMagicDNS: false,
|
|
||||||
},
|
|
||||||
domain: "example.com",
|
domain: "example.com",
|
||||||
wantErr: "failed to create valid FQDN: node has no given name",
|
wantErr: "failed to create valid FQDN: node has no given name",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "no-user-name",
|
name: "too-long-username",
|
||||||
node: Node{
|
node: Node{
|
||||||
GivenName: "test",
|
GivenName: strings.Repeat("a", 256),
|
||||||
User: User{},
|
|
||||||
},
|
},
|
||||||
cfg: Config{
|
domain: "example.com",
|
||||||
DNSConfig: &tailcfg.DNSConfig{
|
wantErr: fmt.Sprintf("failed to create valid FQDN (%s.example.com): hostname too long, cannot except 255 ASCII chars", strings.Repeat("a", 256)),
|
||||||
Proxied: true,
|
|
||||||
},
|
|
||||||
DNSUserNameInMagicDNS: false,
|
|
||||||
},
|
|
||||||
domain: "example.com",
|
|
||||||
want: "test.example.com",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "no-magic-dns",
|
|
||||||
node: Node{
|
|
||||||
GivenName: "test",
|
|
||||||
User: User{
|
|
||||||
Name: "user",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
cfg: Config{
|
|
||||||
DNSConfig: &tailcfg.DNSConfig{
|
|
||||||
Proxied: false,
|
|
||||||
},
|
|
||||||
DNSUserNameInMagicDNS: false,
|
|
||||||
},
|
|
||||||
domain: "example.com",
|
|
||||||
want: "test.example.com",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "no-dnsconfig",
|
name: "no-dnsconfig",
|
||||||
|
@ -288,7 +188,9 @@ func TestNodeFQDN(t *testing.T) {
|
||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
got, err := tc.node.GetFQDN(&tc.cfg, tc.domain)
|
got, err := tc.node.GetFQDN(tc.domain)
|
||||||
|
|
||||||
|
t.Logf("GOT: %q, %q", got, tc.domain)
|
||||||
|
|
||||||
if (err != nil) && (err.Error() != tc.wantErr) {
|
if (err != nil) && (err.Error() != tc.wantErr) {
|
||||||
t.Errorf("GetFQDN() error = %s, wantErr %s", err, tc.wantErr)
|
t.Errorf("GetFQDN() error = %s, wantErr %s", err, tc.wantErr)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package types
|
package types
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"cmp"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
|
@ -10,25 +11,65 @@ import (
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type UserID uint64
|
||||||
|
|
||||||
// User is the way Headscale implements the concept of users in Tailscale
|
// User is the way Headscale implements the concept of users in Tailscale
|
||||||
//
|
//
|
||||||
// At the end of the day, users in Tailscale are some kind of 'bubbles' or users
|
// At the end of the day, users in Tailscale are some kind of 'bubbles' or users
|
||||||
// that contain our machines.
|
// that contain our machines.
|
||||||
type User struct {
|
type User struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
|
|
||||||
|
// Username for the user, is used if email is empty
|
||||||
|
// Should not be used, please use Username().
|
||||||
Name string `gorm:"unique"`
|
Name string `gorm:"unique"`
|
||||||
|
|
||||||
|
// Typically the full name of the user
|
||||||
|
DisplayName string
|
||||||
|
|
||||||
|
// Email of the user
|
||||||
|
// Should not be used, please use Username().
|
||||||
|
Email string
|
||||||
|
|
||||||
|
// Unique identifier of the user from OIDC,
|
||||||
|
// comes from `sub` claim in the OIDC token
|
||||||
|
// and is used to lookup the user.
|
||||||
|
ProviderIdentifier string `gorm:"index"`
|
||||||
|
|
||||||
|
// Provider is the origin of the user account,
|
||||||
|
// same as RegistrationMethod, without authkey.
|
||||||
|
Provider string
|
||||||
|
|
||||||
|
ProfilePicURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Username is the main way to get the username of a user,
|
||||||
|
// it will return the email if it exists, the name if it exists,
|
||||||
|
// the OIDCIdentifier if it exists, and the ID if nothing else exists.
|
||||||
|
// Email and OIDCIdentifier will be set when the user has headscale
|
||||||
|
// enabled with OIDC, which means that there is a domain involved which
|
||||||
|
// should be used throughout headscale, in information returned to the
|
||||||
|
// user and the Policy engine.
|
||||||
|
func (u *User) Username() string {
|
||||||
|
return cmp.Or(u.Email, u.Name, u.ProviderIdentifier, strconv.FormatUint(uint64(u.ID), 10))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisplayNameOrUsername returns the DisplayName if it exists, otherwise
|
||||||
|
// it will return the Username.
|
||||||
|
func (u *User) DisplayNameOrUsername() string {
|
||||||
|
return cmp.Or(u.DisplayName, u.Username())
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(kradalby): See if we can fill in Gravatar here.
|
// TODO(kradalby): See if we can fill in Gravatar here.
|
||||||
func (u *User) profilePicURL() string {
|
func (u *User) profilePicURL() string {
|
||||||
return ""
|
return u.ProfilePicURL
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *User) TailscaleUser() *tailcfg.User {
|
func (u *User) TailscaleUser() *tailcfg.User {
|
||||||
user := tailcfg.User{
|
user := tailcfg.User{
|
||||||
ID: tailcfg.UserID(u.ID),
|
ID: tailcfg.UserID(u.ID),
|
||||||
LoginName: u.Name,
|
LoginName: u.Username(),
|
||||||
DisplayName: u.Name,
|
DisplayName: u.DisplayNameOrUsername(),
|
||||||
ProfilePicURL: u.profilePicURL(),
|
ProfilePicURL: u.profilePicURL(),
|
||||||
Logins: []tailcfg.LoginID{},
|
Logins: []tailcfg.LoginID{},
|
||||||
Created: u.CreatedAt,
|
Created: u.CreatedAt,
|
||||||
|
@ -41,9 +82,9 @@ func (u *User) TailscaleLogin() *tailcfg.Login {
|
||||||
login := tailcfg.Login{
|
login := tailcfg.Login{
|
||||||
ID: tailcfg.LoginID(u.ID),
|
ID: tailcfg.LoginID(u.ID),
|
||||||
// TODO(kradalby): this should reflect registration method.
|
// TODO(kradalby): this should reflect registration method.
|
||||||
Provider: "",
|
Provider: u.Provider,
|
||||||
LoginName: u.Name,
|
LoginName: u.Username(),
|
||||||
DisplayName: u.Name,
|
DisplayName: u.DisplayNameOrUsername(),
|
||||||
ProfilePicURL: u.profilePicURL(),
|
ProfilePicURL: u.profilePicURL(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -53,8 +94,8 @@ func (u *User) TailscaleLogin() *tailcfg.Login {
|
||||||
func (u *User) TailscaleUserProfile() tailcfg.UserProfile {
|
func (u *User) TailscaleUserProfile() tailcfg.UserProfile {
|
||||||
return tailcfg.UserProfile{
|
return tailcfg.UserProfile{
|
||||||
ID: tailcfg.UserID(u.ID),
|
ID: tailcfg.UserID(u.ID),
|
||||||
LoginName: u.Name,
|
LoginName: u.Username(),
|
||||||
DisplayName: u.Name,
|
DisplayName: u.DisplayNameOrUsername(),
|
||||||
ProfilePicURL: u.profilePicURL(),
|
ProfilePicURL: u.profilePicURL(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -66,3 +107,27 @@ func (n *User) Proto() *v1.User {
|
||||||
CreatedAt: timestamppb.New(n.CreatedAt),
|
CreatedAt: timestamppb.New(n.CreatedAt),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type OIDCClaims struct {
|
||||||
|
// Sub is the user's unique identifier at the provider.
|
||||||
|
Sub string `json:"sub"`
|
||||||
|
|
||||||
|
// Name is the user's full name.
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
Groups []string `json:"groups,omitempty"`
|
||||||
|
Email string `json:"email,omitempty"`
|
||||||
|
EmailVerified bool `json:"email_verified,omitempty"`
|
||||||
|
ProfilePictureURL string `json:"picture,omitempty"`
|
||||||
|
Username string `json:"preferred_username,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromClaim overrides a User from OIDC claims.
|
||||||
|
// All fields will be updated, except for the ID.
|
||||||
|
func (u *User) FromClaim(claims *OIDCClaims) {
|
||||||
|
u.ProviderIdentifier = claims.Sub
|
||||||
|
u.DisplayName = claims.Name
|
||||||
|
u.Email = claims.Email
|
||||||
|
u.Name = claims.Username
|
||||||
|
u.ProfilePicURL = claims.ProfilePictureURL
|
||||||
|
u.Provider = util.RegisterMethodOIDC
|
||||||
|
}
|
||||||
|
|
|
@ -7,7 +7,6 @@ import (
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/spf13/viper"
|
|
||||||
"go4.org/netipx"
|
"go4.org/netipx"
|
||||||
"tailscale.com/util/dnsname"
|
"tailscale.com/util/dnsname"
|
||||||
)
|
)
|
||||||
|
@ -25,38 +24,6 @@ var invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+")
|
||||||
|
|
||||||
var ErrInvalidUserName = errors.New("invalid user name")
|
var ErrInvalidUserName = errors.New("invalid user name")
|
||||||
|
|
||||||
func NormalizeToFQDNRulesConfigFromViper(name string) (string, error) {
|
|
||||||
strip := viper.GetBool("oidc.strip_email_domain")
|
|
||||||
|
|
||||||
return NormalizeToFQDNRules(name, strip)
|
|
||||||
}
|
|
||||||
|
|
||||||
// NormalizeToFQDNRules will replace forbidden chars in user
|
|
||||||
// it can also return an error if the user doesn't respect RFC 952 and 1123.
|
|
||||||
func NormalizeToFQDNRules(name string, stripEmailDomain bool) (string, error) {
|
|
||||||
name = strings.ToLower(name)
|
|
||||||
name = strings.ReplaceAll(name, "'", "")
|
|
||||||
atIdx := strings.Index(name, "@")
|
|
||||||
if stripEmailDomain && atIdx > 0 {
|
|
||||||
name = name[:atIdx]
|
|
||||||
} else {
|
|
||||||
name = strings.ReplaceAll(name, "@", ".")
|
|
||||||
}
|
|
||||||
name = invalidCharsInUserRegex.ReplaceAllString(name, "-")
|
|
||||||
|
|
||||||
for _, elt := range strings.Split(name, ".") {
|
|
||||||
if len(elt) > LabelHostnameLength {
|
|
||||||
return "", fmt.Errorf(
|
|
||||||
"label %v is more than 63 chars: %w",
|
|
||||||
elt,
|
|
||||||
ErrInvalidUserName,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return name, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func CheckForFQDNRules(name string) error {
|
func CheckForFQDNRules(name string) error {
|
||||||
if len(name) > LabelHostnameLength {
|
if len(name) > LabelHostnameLength {
|
||||||
return fmt.Errorf(
|
return fmt.Errorf(
|
||||||
|
|
|
@ -7,100 +7,6 @@ import (
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNormalizeToFQDNRules(t *testing.T) {
|
|
||||||
type args struct {
|
|
||||||
name string
|
|
||||||
stripEmailDomain bool
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
want string
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "normalize simple name",
|
|
||||||
args: args{
|
|
||||||
name: "normalize-simple.name",
|
|
||||||
stripEmailDomain: false,
|
|
||||||
},
|
|
||||||
want: "normalize-simple.name",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "normalize an email",
|
|
||||||
args: args{
|
|
||||||
name: "foo.bar@example.com",
|
|
||||||
stripEmailDomain: false,
|
|
||||||
},
|
|
||||||
want: "foo.bar.example.com",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "normalize an email domain should be removed",
|
|
||||||
args: args{
|
|
||||||
name: "foo.bar@example.com",
|
|
||||||
stripEmailDomain: true,
|
|
||||||
},
|
|
||||||
want: "foo.bar",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "strip enabled no email passed as argument",
|
|
||||||
args: args{
|
|
||||||
name: "not-email-and-strip-enabled",
|
|
||||||
stripEmailDomain: true,
|
|
||||||
},
|
|
||||||
want: "not-email-and-strip-enabled",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "normalize complex email",
|
|
||||||
args: args{
|
|
||||||
name: "foo.bar+complex-email@example.com",
|
|
||||||
stripEmailDomain: false,
|
|
||||||
},
|
|
||||||
want: "foo.bar-complex-email.example.com",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "user name with space",
|
|
||||||
args: args{
|
|
||||||
name: "name space",
|
|
||||||
stripEmailDomain: false,
|
|
||||||
},
|
|
||||||
want: "name-space",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "user with quote",
|
|
||||||
args: args{
|
|
||||||
name: "Jamie's iPhone 5",
|
|
||||||
stripEmailDomain: false,
|
|
||||||
},
|
|
||||||
want: "jamies-iphone-5",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got, err := NormalizeToFQDNRules(tt.args.name, tt.args.stripEmailDomain)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf(
|
|
||||||
"NormalizeToFQDNRules() error = %v, wantErr %v",
|
|
||||||
err,
|
|
||||||
tt.wantErr,
|
|
||||||
)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if got != tt.want {
|
|
||||||
t.Errorf("NormalizeToFQDNRules() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCheckForFQDNRules(t *testing.T) {
|
func TestCheckForFQDNRules(t *testing.T) {
|
||||||
type args struct {
|
type args struct {
|
||||||
name string
|
name string
|
||||||
|
|
|
@ -62,7 +62,6 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
|
||||||
"HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID,
|
"HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID,
|
||||||
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
|
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
|
||||||
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
|
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
|
||||||
"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": fmt.Sprintf("%t", oidcConfig.StripEmaildomain),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = scenario.CreateHeadscaleEnv(
|
err = scenario.CreateHeadscaleEnv(
|
||||||
|
@ -121,7 +120,6 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
|
||||||
"HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer,
|
"HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer,
|
||||||
"HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID,
|
"HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID,
|
||||||
"HEADSCALE_OIDC_CLIENT_SECRET": oidcConfig.ClientSecret,
|
"HEADSCALE_OIDC_CLIENT_SECRET": oidcConfig.ClientSecret,
|
||||||
"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": fmt.Sprintf("%t", oidcConfig.StripEmaildomain),
|
|
||||||
"HEADSCALE_OIDC_USE_EXPIRY_FROM_TOKEN": "1",
|
"HEADSCALE_OIDC_USE_EXPIRY_FROM_TOKEN": "1",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -276,7 +274,6 @@ func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*types.OIDCConf
|
||||||
),
|
),
|
||||||
ClientID: "superclient",
|
ClientID: "superclient",
|
||||||
ClientSecret: "supersecret",
|
ClientSecret: "supersecret",
|
||||||
StripEmaildomain: true,
|
|
||||||
OnlyStartIfOIDCIsAvailable: true,
|
OnlyStartIfOIDCIsAvailable: true,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue