diff --git a/Dockerfile b/Dockerfile
index 8d53f6d9..ac807794 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -8,7 +8,7 @@ RUN go mod download
COPY . .
-RUN GGO_ENABLED=0 GOOS=linux go install -a ./cmd/headscale
+RUN CGO_ENABLED=0 GOOS=linux go install -a ./cmd/headscale
RUN strip /go/bin/headscale
RUN test -e /go/bin/headscale
diff --git a/Dockerfile.alpine b/Dockerfile.alpine
index 45fa171d..24d4e6f8 100644
--- a/Dockerfile.alpine
+++ b/Dockerfile.alpine
@@ -9,7 +9,7 @@ RUN go mod download
COPY . .
-RUN GGO_ENABLED=0 GOOS=linux go install -a ./cmd/headscale
+RUN CGO_ENABLED=0 GOOS=linux go install -a ./cmd/headscale
RUN strip /go/bin/headscale
RUN test -e /go/bin/headscale
diff --git a/Dockerfile.debug b/Dockerfile.debug
index 91fe2893..f053d720 100644
--- a/Dockerfile.debug
+++ b/Dockerfile.debug
@@ -8,7 +8,7 @@ RUN go mod download
COPY . .
-RUN GGO_ENABLED=0 GOOS=linux go install -a ./cmd/headscale
+RUN CGO_ENABLED=0 GOOS=linux go install -a ./cmd/headscale
RUN test -e /go/bin/headscale
# Debug image
diff --git a/README.md b/README.md
index 98738bc2..a0071208 100644
--- a/README.md
+++ b/README.md
@@ -415,6 +415,15 @@ make build
Carson Yang
+
+
+
+
+ kundel
+
+ |
+
+
@@ -422,8 +431,6 @@ make build
Felix Kronlage-Dammers
|
-
-
@@ -445,6 +452,13 @@ make build
Jamie Greeff
|
+
+
+
+
+ Jiang Zhu
+
+ |
@@ -452,6 +466,8 @@ make build
Jim Tittsler
|
+
+
@@ -466,8 +482,6 @@ make build
rcursaru
|
-
-
@@ -496,6 +510,8 @@ make build
Tanner
|
+
+
@@ -510,8 +526,6 @@ make build
The Gitter Badger
|
-
-
@@ -540,6 +554,8 @@ make build
Zakhar Bessarab
|
+
+
@@ -554,8 +570,6 @@ make build
derelm
|
-
-
@@ -584,6 +598,8 @@ make build
pernila
|
+
+
diff --git a/app.go b/app.go
index 054fd178..01528fb9 100644
--- a/app.go
+++ b/app.go
@@ -6,10 +6,8 @@ import (
"errors"
"fmt"
"io"
- "io/fs"
"net"
"net/http"
- "net/url"
"os"
"os/signal"
"sort"
@@ -42,7 +40,6 @@ import (
"google.golang.org/grpc/reflection"
"google.golang.org/grpc/status"
"gorm.io/gorm"
- "inet.af/netaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/dnstype"
"tailscale.com/types/key"
@@ -72,95 +69,9 @@ const (
EnforcedClientAuth = "enforced"
)
-// Config contains the initial Headscale configuration.
-type Config struct {
- ServerURL string
- Addr string
- MetricsAddr string
- GRPCAddr string
- GRPCAllowInsecure bool
- EphemeralNodeInactivityTimeout time.Duration
- IPPrefixes []netaddr.IPPrefix
- PrivateKeyPath string
- BaseDomain string
-
- DERP DERPConfig
-
- DBtype string
- DBpath string
- DBhost string
- DBport int
- DBname string
- DBuser string
- DBpass string
-
- TLSLetsEncryptListen string
- TLSLetsEncryptHostname string
- TLSLetsEncryptCacheDir string
- TLSLetsEncryptChallengeType string
-
- TLSCertPath string
- TLSKeyPath string
- TLSClientAuthMode tls.ClientAuthType
-
- ACMEURL string
- ACMEEmail string
-
- DNSConfig *tailcfg.DNSConfig
-
- UnixSocket string
- UnixSocketPermission fs.FileMode
-
- OIDC OIDCConfig
-
- LogTail LogTailConfig
-
- CLI CLIConfig
-
- ACL ACLConfig
-}
-
-type OIDCConfig struct {
- Issuer string
- ClientID string
- ClientSecret string
- Scope []string
- ExtraParams map[string]string
- AllowedDomains []string
- AllowedUsers []string
- StripEmaildomain bool
-}
-
-type DERPConfig struct {
- ServerEnabled bool
- ServerRegionID int
- ServerRegionCode string
- ServerRegionName string
- STUNAddr string
- URLs []url.URL
- Paths []string
- AutoUpdate bool
- UpdateFrequency time.Duration
-}
-
-type LogTailConfig struct {
- Enabled bool
-}
-
-type CLIConfig struct {
- Address string
- APIKey string
- Timeout time.Duration
- Insecure bool
-}
-
-type ACLConfig struct {
- PolicyPath string
-}
-
// Headscale represents the base app of the service.
type Headscale struct {
- cfg Config
+ cfg *Config
db *gorm.DB
dbString string
dbType string
@@ -204,7 +115,7 @@ func LookupTLSClientAuthMode(mode string) (tls.ClientAuthType, bool) {
}
}
-func NewHeadscale(cfg Config) (*Headscale, error) {
+func NewHeadscale(cfg *Config) (*Headscale, error) {
privKey, err := readOrCreatePrivateKey(cfg.PrivateKeyPath)
if err != nil {
return nil, fmt.Errorf("failed to read or create private key: %w", err)
@@ -778,7 +689,7 @@ func (h *Headscale) Serve() error {
func (h *Headscale) getTLSSettings() (*tls.Config, error) {
var err error
- if h.cfg.TLSLetsEncryptHostname != "" {
+ if h.cfg.TLS.LetsEncrypt.Hostname != "" {
if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
log.Warn().
Msg("Listening with TLS but ServerURL does not start with https://")
@@ -786,15 +697,15 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
certManager := autocert.Manager{
Prompt: autocert.AcceptTOS,
- HostPolicy: autocert.HostWhitelist(h.cfg.TLSLetsEncryptHostname),
- Cache: autocert.DirCache(h.cfg.TLSLetsEncryptCacheDir),
+ HostPolicy: autocert.HostWhitelist(h.cfg.TLS.LetsEncrypt.Hostname),
+ Cache: autocert.DirCache(h.cfg.TLS.LetsEncrypt.CacheDir),
Client: &acme.Client{
DirectoryURL: h.cfg.ACMEURL,
},
Email: h.cfg.ACMEEmail,
}
- switch h.cfg.TLSLetsEncryptChallengeType {
+ switch h.cfg.TLS.LetsEncrypt.ChallengeType {
case "TLS-ALPN-01":
// Configuration via autocert with TLS-ALPN-01 (https://tools.ietf.org/html/rfc8737)
// The RFC requires that the validation is done on port 443; in other words, headscale
@@ -808,7 +719,7 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
go func() {
log.Fatal().
Caller().
- Err(http.ListenAndServe(h.cfg.TLSLetsEncryptListen, certManager.HTTPHandler(http.HandlerFunc(h.redirect)))).
+ Err(http.ListenAndServe(h.cfg.TLS.LetsEncrypt.Listen, certManager.HTTPHandler(http.HandlerFunc(h.redirect)))).
Msg("failed to set up a HTTP server")
}()
@@ -817,7 +728,7 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
default:
return nil, errUnsupportedLetsEncryptChallengeType
}
- } else if h.cfg.TLSCertPath == "" {
+ } else if h.cfg.TLS.CertPath == "" {
if !strings.HasPrefix(h.cfg.ServerURL, "http://") {
log.Warn().Msg("Listening without TLS but ServerURL does not start with http://")
}
@@ -830,16 +741,16 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
log.Info().Msg(fmt.Sprintf(
"Client authentication (mTLS) is \"%s\". See the docs to learn about configuring this setting.",
- h.cfg.TLSClientAuthMode))
+ h.cfg.TLS.ClientAuthMode))
tlsConfig := &tls.Config{
- ClientAuth: h.cfg.TLSClientAuthMode,
+ ClientAuth: h.cfg.TLS.ClientAuthMode,
NextProtos: []string{"http/1.1"},
Certificates: make([]tls.Certificate, 1),
MinVersion: tls.VersionTLS12,
}
- tlsConfig.Certificates[0], err = tls.LoadX509KeyPair(h.cfg.TLSCertPath, h.cfg.TLSKeyPath)
+ tlsConfig.Certificates[0], err = tls.LoadX509KeyPair(h.cfg.TLS.CertPath, h.cfg.TLS.KeyPath)
return tlsConfig, err
}
diff --git a/app_test.go b/app_test.go
index 96036a1d..170db482 100644
--- a/app_test.go
+++ b/app_test.go
@@ -46,7 +46,7 @@ func (s *Suite) ResetDB(c *check.C) {
}
app = Headscale{
- cfg: cfg,
+ cfg: &cfg,
dbType: "sqlite3",
dbString: tmpDir + "/headscale_test.db",
}
diff --git a/cmd/headscale/cli/server.go b/cmd/headscale/cli/server.go
index c19580b9..a1d19600 100644
--- a/cmd/headscale/cli/server.go
+++ b/cmd/headscale/cli/server.go
@@ -16,12 +16,12 @@ var serveCmd = &cobra.Command{
return nil
},
Run: func(cmd *cobra.Command, args []string) {
- h, err := getHeadscaleApp()
+ app, err := getHeadscaleApp()
if err != nil {
log.Fatal().Caller().Err(err).Msg("Error initializing")
}
- err = h.Serve()
+ err = app.Serve()
if err != nil {
log.Fatal().Caller().Err(err).Msg("Error starting server")
}
diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go
index af4391a3..f5c679c2 100644
--- a/cmd/headscale/cli/utils.go
+++ b/cmd/headscale/cli/utils.go
@@ -4,17 +4,11 @@ import (
"context"
"crypto/tls"
"encoding/json"
- "errors"
"fmt"
- "io/fs"
- "net/url"
"os"
"reflect"
- "strconv"
- "strings"
"time"
- "github.com/coreos/go-oidc/v3/oidc"
"github.com/juanfont/headscale"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/rs/zerolog/log"
@@ -23,380 +17,18 @@ import (
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"gopkg.in/yaml.v2"
- "inet.af/netaddr"
- "tailscale.com/tailcfg"
- "tailscale.com/types/dnstype"
)
const (
- PermissionFallback = 0o700
HeadscaleDateTimeFormat = "2006-01-02 15:04:05"
)
-func LoadConfig(path string) error {
- viper.SetConfigName("config")
- if path == "" {
- viper.AddConfigPath("/etc/headscale/")
- viper.AddConfigPath("$HOME/.headscale")
- viper.AddConfigPath(".")
- } else {
- // For testing
- viper.AddConfigPath(path)
- }
-
- viper.SetEnvPrefix("headscale")
- viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
- viper.AutomaticEnv()
-
- viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache")
- viper.SetDefault("tls_letsencrypt_challenge_type", "HTTP-01")
- viper.SetDefault("tls_client_auth_mode", "relaxed")
-
- viper.SetDefault("log_level", "info")
-
- viper.SetDefault("dns_config", nil)
-
- viper.SetDefault("derp.server.enabled", false)
- viper.SetDefault("derp.server.stun.enabled", true)
-
- viper.SetDefault("unix_socket", "/var/run/headscale.sock")
- viper.SetDefault("unix_socket_permission", "0o770")
-
- viper.SetDefault("grpc_listen_addr", ":50443")
- viper.SetDefault("grpc_allow_insecure", false)
-
- viper.SetDefault("cli.timeout", "5s")
- viper.SetDefault("cli.insecure", false)
-
- viper.SetDefault("oidc.scope", []string{oidc.ScopeOpenID, "profile", "email"})
- viper.SetDefault("oidc.strip_email_domain", true)
-
- viper.SetDefault("logtail.enabled", false)
-
- if err := viper.ReadInConfig(); err != nil {
- return fmt.Errorf("fatal error reading config file: %w", err)
- }
-
- // Collect any validation errors and return them all at once
- var errorText string
- if (viper.GetString("tls_letsencrypt_hostname") != "") &&
- ((viper.GetString("tls_cert_path") != "") || (viper.GetString("tls_key_path") != "")) {
- errorText += "Fatal config error: set either tls_letsencrypt_hostname or tls_cert_path/tls_key_path, not both\n"
- }
-
- if (viper.GetString("tls_letsencrypt_hostname") != "") &&
- (viper.GetString("tls_letsencrypt_challenge_type") == "TLS-ALPN-01") &&
- (!strings.HasSuffix(viper.GetString("listen_addr"), ":443")) {
- // this is only a warning because there could be something sitting in front of headscale that redirects the traffic (e.g. an iptables rule)
- log.Warn().
- Msg("Warning: when using tls_letsencrypt_hostname with TLS-ALPN-01 as challenge type, headscale must be reachable on port 443, i.e. listen_addr should probably end in :443")
- }
-
- if (viper.GetString("tls_letsencrypt_challenge_type") != "HTTP-01") &&
- (viper.GetString("tls_letsencrypt_challenge_type") != "TLS-ALPN-01") {
- errorText += "Fatal config error: the only supported values for tls_letsencrypt_challenge_type are HTTP-01 and TLS-ALPN-01\n"
- }
-
- if !strings.HasPrefix(viper.GetString("server_url"), "http://") &&
- !strings.HasPrefix(viper.GetString("server_url"), "https://") {
- errorText += "Fatal config error: server_url must start with https:// or http://\n"
- }
-
- _, authModeValid := headscale.LookupTLSClientAuthMode(
- viper.GetString("tls_client_auth_mode"),
- )
-
- if !authModeValid {
- errorText += fmt.Sprintf(
- "Invalid tls_client_auth_mode supplied: %s. Accepted values: %s, %s, %s.",
- viper.GetString("tls_client_auth_mode"),
- headscale.DisabledClientAuth,
- headscale.RelaxedClientAuth,
- headscale.EnforcedClientAuth)
- }
-
- if errorText != "" {
- //nolint
- return errors.New(strings.TrimSuffix(errorText, "\n"))
- } else {
- return nil
- }
-}
-
-func GetDERPConfig() headscale.DERPConfig {
- serverEnabled := viper.GetBool("derp.server.enabled")
- serverRegionID := viper.GetInt("derp.server.region_id")
- serverRegionCode := viper.GetString("derp.server.region_code")
- serverRegionName := viper.GetString("derp.server.region_name")
- stunAddr := viper.GetString("derp.server.stun_listen_addr")
-
- if serverEnabled && stunAddr == "" {
- log.Fatal().
- Msg("derp.server.stun_listen_addr must be set if derp.server.enabled is true")
- }
-
- urlStrs := viper.GetStringSlice("derp.urls")
-
- urls := make([]url.URL, len(urlStrs))
- for index, urlStr := range urlStrs {
- urlAddr, err := url.Parse(urlStr)
- if err != nil {
- log.Error().
- Str("url", urlStr).
- Err(err).
- Msg("Failed to parse url, ignoring...")
- }
-
- urls[index] = *urlAddr
- }
-
- paths := viper.GetStringSlice("derp.paths")
-
- autoUpdate := viper.GetBool("derp.auto_update_enabled")
- updateFrequency := viper.GetDuration("derp.update_frequency")
-
- return headscale.DERPConfig{
- ServerEnabled: serverEnabled,
- ServerRegionID: serverRegionID,
- ServerRegionCode: serverRegionCode,
- ServerRegionName: serverRegionName,
- STUNAddr: stunAddr,
- URLs: urls,
- Paths: paths,
- AutoUpdate: autoUpdate,
- UpdateFrequency: updateFrequency,
- }
-}
-
-func GetLogTailConfig() headscale.LogTailConfig {
- enabled := viper.GetBool("logtail.enabled")
-
- return headscale.LogTailConfig{
- Enabled: enabled,
- }
-}
-
-func GetACLConfig() headscale.ACLConfig {
- policyPath := viper.GetString("acl_policy_path")
-
- return headscale.ACLConfig{
- PolicyPath: policyPath,
- }
-}
-
-func GetDNSConfig() (*tailcfg.DNSConfig, string) {
- if viper.IsSet("dns_config") {
- dnsConfig := &tailcfg.DNSConfig{}
-
- if viper.IsSet("dns_config.nameservers") {
- nameserversStr := viper.GetStringSlice("dns_config.nameservers")
-
- nameservers := make([]netaddr.IP, len(nameserversStr))
- resolvers := make([]dnstype.Resolver, len(nameserversStr))
-
- for index, nameserverStr := range nameserversStr {
- nameserver, err := netaddr.ParseIP(nameserverStr)
- if err != nil {
- log.Error().
- Str("func", "getDNSConfig").
- Err(err).
- Msgf("Could not parse nameserver IP: %s", nameserverStr)
- }
-
- nameservers[index] = nameserver
- resolvers[index] = dnstype.Resolver{
- Addr: nameserver.String(),
- }
- }
-
- dnsConfig.Nameservers = nameservers
- dnsConfig.Resolvers = resolvers
- }
-
- if viper.IsSet("dns_config.restricted_nameservers") {
- if len(dnsConfig.Nameservers) > 0 {
- dnsConfig.Routes = make(map[string][]dnstype.Resolver)
- restrictedDNS := viper.GetStringMapStringSlice(
- "dns_config.restricted_nameservers",
- )
- for domain, restrictedNameservers := range restrictedDNS {
- restrictedResolvers := make(
- []dnstype.Resolver,
- len(restrictedNameservers),
- )
- for index, nameserverStr := range restrictedNameservers {
- nameserver, err := netaddr.ParseIP(nameserverStr)
- if err != nil {
- log.Error().
- Str("func", "getDNSConfig").
- Err(err).
- Msgf("Could not parse restricted nameserver IP: %s", nameserverStr)
- }
- restrictedResolvers[index] = dnstype.Resolver{
- Addr: nameserver.String(),
- }
- }
- dnsConfig.Routes[domain] = restrictedResolvers
- }
- } else {
- log.Warn().
- Msg("Warning: dns_config.restricted_nameservers is set, but no nameservers are configured. Ignoring restricted_nameservers.")
- }
- }
-
- if viper.IsSet("dns_config.domains") {
- dnsConfig.Domains = viper.GetStringSlice("dns_config.domains")
- }
-
- if viper.IsSet("dns_config.magic_dns") {
- magicDNS := viper.GetBool("dns_config.magic_dns")
- if len(dnsConfig.Nameservers) > 0 {
- dnsConfig.Proxied = magicDNS
- } else if magicDNS {
- log.Warn().
- Msg("Warning: dns_config.magic_dns is set, but no nameservers are configured. Ignoring magic_dns.")
- }
- }
-
- var baseDomain string
- if viper.IsSet("dns_config.base_domain") {
- baseDomain = viper.GetString("dns_config.base_domain")
- } else {
- baseDomain = "headscale.net" // does not really matter when MagicDNS is not enabled
- }
-
- return dnsConfig, baseDomain
- }
-
- return nil, ""
-}
-
-func GetHeadscaleConfig() headscale.Config {
- dnsConfig, baseDomain := GetDNSConfig()
- derpConfig := GetDERPConfig()
- logConfig := GetLogTailConfig()
-
- configuredPrefixes := viper.GetStringSlice("ip_prefixes")
- parsedPrefixes := make([]netaddr.IPPrefix, 0, len(configuredPrefixes)+1)
-
- legacyPrefixField := viper.GetString("ip_prefix")
- if len(legacyPrefixField) > 0 {
- log.
- Warn().
- Msgf(
- "%s, %s",
- "use of 'ip_prefix' for configuration is deprecated",
- "please see 'ip_prefixes' in the shipped example.",
- )
- legacyPrefix, err := netaddr.ParseIPPrefix(legacyPrefixField)
- if err != nil {
- panic(fmt.Errorf("failed to parse ip_prefix: %w", err))
- }
- parsedPrefixes = append(parsedPrefixes, legacyPrefix)
- }
-
- for i, prefixInConfig := range configuredPrefixes {
- prefix, err := netaddr.ParseIPPrefix(prefixInConfig)
- if err != nil {
- panic(fmt.Errorf("failed to parse ip_prefixes[%d]: %w", i, err))
- }
- parsedPrefixes = append(parsedPrefixes, prefix)
- }
-
- prefixes := make([]netaddr.IPPrefix, 0, len(parsedPrefixes))
- {
- // dedup
- normalizedPrefixes := make(map[string]int, len(parsedPrefixes))
- for i, p := range parsedPrefixes {
- normalized, _ := p.Range().Prefix()
- normalizedPrefixes[normalized.String()] = i
- }
-
- // convert back to list
- for _, i := range normalizedPrefixes {
- prefixes = append(prefixes, parsedPrefixes[i])
- }
- }
-
- if len(prefixes) < 1 {
- prefixes = append(prefixes, netaddr.MustParseIPPrefix("100.64.0.0/10"))
- log.Warn().
- Msgf("'ip_prefixes' not configured, falling back to default: %v", prefixes)
- }
-
- tlsClientAuthMode, _ := headscale.LookupTLSClientAuthMode(
- viper.GetString("tls_client_auth_mode"),
- )
-
- return headscale.Config{
- ServerURL: viper.GetString("server_url"),
- Addr: viper.GetString("listen_addr"),
- MetricsAddr: viper.GetString("metrics_listen_addr"),
- GRPCAddr: viper.GetString("grpc_listen_addr"),
- GRPCAllowInsecure: viper.GetBool("grpc_allow_insecure"),
-
- IPPrefixes: prefixes,
- PrivateKeyPath: headscale.AbsolutePathFromConfigPath(viper.GetString("private_key_path")),
- BaseDomain: baseDomain,
-
- DERP: derpConfig,
-
- EphemeralNodeInactivityTimeout: viper.GetDuration(
- "ephemeral_node_inactivity_timeout",
- ),
-
- DBtype: viper.GetString("db_type"),
- DBpath: headscale.AbsolutePathFromConfigPath(viper.GetString("db_path")),
- DBhost: viper.GetString("db_host"),
- DBport: viper.GetInt("db_port"),
- DBname: viper.GetString("db_name"),
- DBuser: viper.GetString("db_user"),
- DBpass: viper.GetString("db_pass"),
-
- TLSLetsEncryptHostname: viper.GetString("tls_letsencrypt_hostname"),
- TLSLetsEncryptListen: viper.GetString("tls_letsencrypt_listen"),
- TLSLetsEncryptCacheDir: headscale.AbsolutePathFromConfigPath(
- viper.GetString("tls_letsencrypt_cache_dir"),
- ),
- TLSLetsEncryptChallengeType: viper.GetString("tls_letsencrypt_challenge_type"),
-
- TLSCertPath: headscale.AbsolutePathFromConfigPath(viper.GetString("tls_cert_path")),
- TLSKeyPath: headscale.AbsolutePathFromConfigPath(viper.GetString("tls_key_path")),
- TLSClientAuthMode: tlsClientAuthMode,
-
- DNSConfig: dnsConfig,
-
- ACMEEmail: viper.GetString("acme_email"),
- ACMEURL: viper.GetString("acme_url"),
-
- UnixSocket: viper.GetString("unix_socket"),
- UnixSocketPermission: GetFileMode("unix_socket_permission"),
-
- OIDC: headscale.OIDCConfig{
- Issuer: viper.GetString("oidc.issuer"),
- ClientID: viper.GetString("oidc.client_id"),
- ClientSecret: viper.GetString("oidc.client_secret"),
- Scope: viper.GetStringSlice("oidc.scope"),
- ExtraParams: viper.GetStringMapString("oidc.extra_params"),
- AllowedDomains: viper.GetStringSlice("oidc.allowed_domains"),
- AllowedUsers: viper.GetStringSlice("oidc.allowed_users"),
- StripEmaildomain: viper.GetBool("oidc.strip_email_domain"),
- },
-
- LogTail: logConfig,
-
- CLI: headscale.CLIConfig{
- Address: viper.GetString("cli.address"),
- APIKey: viper.GetString("cli.api_key"),
- Timeout: viper.GetDuration("cli.timeout"),
- Insecure: viper.GetBool("cli.insecure"),
- },
-
- ACL: GetACLConfig(),
- }
-}
-
func getHeadscaleApp() (*headscale.Headscale, error) {
+ cfg, err := headscale.GetHeadscaleConfig()
+ if err != nil {
+ return nil, fmt.Errorf("failed to load configuration while creating headscale instance: %w", err)
+ }
+
// Minimum inactivity time out is keepalive timeout (60s) plus a few seconds
// to avoid races
minInactivityTimeout, _ := time.ParseDuration("65s")
@@ -412,8 +44,6 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
return nil, err
}
- cfg := GetHeadscaleConfig()
-
app, err := headscale.NewHeadscale(cfg)
if err != nil {
return nil, err
@@ -436,7 +66,13 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
}
func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.ClientConn, context.CancelFunc) {
- cfg := GetHeadscaleConfig()
+ cfg, err := headscale.GetHeadscaleConfig()
+ if err != nil {
+ log.Fatal().
+ Err(err).
+ Caller().
+ Msgf("Failed to load configuration")
+ }
log.Debug().
Dur("timeout", cfg.CLI.Timeout).
@@ -570,17 +206,6 @@ func (tokenAuth) RequireTransportSecurity() bool {
return true
}
-func GetFileMode(key string) fs.FileMode {
- modeStr := viper.GetString(key)
-
- mode, err := strconv.ParseUint(modeStr, headscale.Base8, headscale.BitSize64)
- if err != nil {
- return PermissionFallback
- }
-
- return fs.FileMode(mode)
-}
-
func contains[T string](ts []T, t T) bool {
for _, v := range ts {
if reflect.DeepEqual(v, t) {
diff --git a/cmd/headscale/headscale.go b/cmd/headscale/headscale.go
index 600b186e..f5e28661 100644
--- a/cmd/headscale/headscale.go
+++ b/cmd/headscale/headscale.go
@@ -7,10 +7,10 @@ import (
"time"
"github.com/efekarakus/termcolor"
+ "github.com/juanfont/headscale"
"github.com/juanfont/headscale/cmd/headscale/cli"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
- "github.com/spf13/viper"
"github.com/tcnksm/go-latest"
)
@@ -43,19 +43,14 @@ func main() {
NoColor: !colors,
})
- if err := cli.LoadConfig(""); err != nil {
+ cfg, err := headscale.GetHeadscaleConfig()
+ if err != nil {
log.Fatal().Caller().Err(err)
}
machineOutput := cli.HasMachineOutputFlag()
- logLevel := viper.GetString("log_level")
- level, err := zerolog.ParseLevel(logLevel)
- if err != nil {
- zerolog.SetGlobalLevel(zerolog.DebugLevel)
- } else {
- zerolog.SetGlobalLevel(level)
- }
+ zerolog.SetGlobalLevel(cfg.LogLevel)
// If the user has requested a "machine" readable format,
// then disable login so the output remains valid.
@@ -63,7 +58,7 @@ func main() {
zerolog.SetGlobalLevel(zerolog.Disabled)
}
- if !viper.GetBool("disable_check_updates") && !machineOutput {
+ if !cfg.DisableUpdateCheck && !machineOutput {
if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") &&
cli.Version != "dev" {
githubTag := &latest.GithubTag{
diff --git a/cmd/headscale/headscale_test.go b/cmd/headscale/headscale_test.go
index faf55f4c..9ca4a2c3 100644
--- a/cmd/headscale/headscale_test.go
+++ b/cmd/headscale/headscale_test.go
@@ -8,7 +8,7 @@ import (
"strings"
"testing"
- "github.com/juanfont/headscale/cmd/headscale/cli"
+ "github.com/juanfont/headscale"
"github.com/spf13/viper"
"gopkg.in/check.v1"
)
@@ -49,7 +49,7 @@ func (*Suite) TestConfigLoading(c *check.C) {
}
// Load example config, it should load without validation errors
- err = cli.LoadConfig(tmpDir)
+ err = headscale.LoadConfig(tmpDir)
c.Assert(err, check.IsNil)
// Test that config file was interpreted correctly
@@ -63,7 +63,7 @@ func (*Suite) TestConfigLoading(c *check.C) {
c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01")
c.Assert(viper.GetStringSlice("dns_config.nameservers")[0], check.Equals, "1.1.1.1")
c.Assert(
- cli.GetFileMode("unix_socket_permission"),
+ headscale.GetFileMode("unix_socket_permission"),
check.Equals,
fs.FileMode(0o770),
)
@@ -92,10 +92,10 @@ func (*Suite) TestDNSConfigLoading(c *check.C) {
}
// Load example config, it should load without validation errors
- err = cli.LoadConfig(tmpDir)
+ err = headscale.LoadConfig(tmpDir)
c.Assert(err, check.IsNil)
- dnsConfig, baseDomain := cli.GetDNSConfig()
+ dnsConfig, baseDomain := headscale.GetDNSConfig()
c.Assert(dnsConfig.Nameservers[0].String(), check.Equals, "1.1.1.1")
c.Assert(dnsConfig.Resolvers[0].Addr, check.Equals, "1.1.1.1")
@@ -125,7 +125,7 @@ func (*Suite) TestTLSConfigValidation(c *check.C) {
writeConfig(c, tmpDir, configYaml)
// Check configuration validation errors (1)
- err = cli.LoadConfig(tmpDir)
+ err = headscale.LoadConfig(tmpDir)
c.Assert(err, check.NotNil)
// check.Matches can not handle multiline strings
tmp := strings.ReplaceAll(err.Error(), "\n", "***")
@@ -150,6 +150,6 @@ func (*Suite) TestTLSConfigValidation(c *check.C) {
"---\nserver_url: \"http://127.0.0.1:8080\"\ntls_letsencrypt_hostname: \"example.com\"\ntls_letsencrypt_challenge_type: \"TLS-ALPN-01\"",
)
writeConfig(c, tmpDir, configYaml)
- err = cli.LoadConfig(tmpDir)
+ err = headscale.LoadConfig(tmpDir)
c.Assert(err, check.IsNil)
}
diff --git a/config.go b/config.go
new file mode 100644
index 00000000..909a48c4
--- /dev/null
+++ b/config.go
@@ -0,0 +1,504 @@
+package headscale
+
+import (
+ "crypto/tls"
+ "errors"
+ "fmt"
+ "io/fs"
+ "net/url"
+ "strings"
+ "time"
+
+ "github.com/coreos/go-oidc/v3/oidc"
+ "github.com/rs/zerolog"
+ "github.com/rs/zerolog/log"
+ "github.com/spf13/viper"
+ "inet.af/netaddr"
+ "tailscale.com/tailcfg"
+ "tailscale.com/types/dnstype"
+)
+
+// Config contains the initial Headscale configuration.
+type Config struct {
+ ServerURL string
+ Addr string
+ MetricsAddr string
+ GRPCAddr string
+ GRPCAllowInsecure bool
+ EphemeralNodeInactivityTimeout time.Duration
+ IPPrefixes []netaddr.IPPrefix
+ PrivateKeyPath string
+ BaseDomain string
+ LogLevel zerolog.Level
+ DisableUpdateCheck bool
+
+ DERP DERPConfig
+
+ DBtype string
+ DBpath string
+ DBhost string
+ DBport int
+ DBname string
+ DBuser string
+ DBpass string
+
+ TLS TLSConfig
+
+ ACMEURL string
+ ACMEEmail string
+
+ DNSConfig *tailcfg.DNSConfig
+
+ UnixSocket string
+ UnixSocketPermission fs.FileMode
+
+ OIDC OIDCConfig
+
+ LogTail LogTailConfig
+
+ CLI CLIConfig
+
+ ACL ACLConfig
+}
+
+type TLSConfig struct {
+ CertPath string
+ KeyPath string
+ ClientAuthMode tls.ClientAuthType
+
+ LetsEncrypt LetsEncryptConfig
+}
+
+type LetsEncryptConfig struct {
+ Listen string
+ Hostname string
+ CacheDir string
+ ChallengeType string
+}
+
+type OIDCConfig struct {
+ Issuer string
+ ClientID string
+ ClientSecret string
+ Scope []string
+ ExtraParams map[string]string
+ AllowedDomains []string
+ AllowedUsers []string
+ StripEmaildomain bool
+}
+
+type DERPConfig struct {
+ ServerEnabled bool
+ ServerRegionID int
+ ServerRegionCode string
+ ServerRegionName string
+ STUNAddr string
+ URLs []url.URL
+ Paths []string
+ AutoUpdate bool
+ UpdateFrequency time.Duration
+}
+
+type LogTailConfig struct {
+ Enabled bool
+}
+
+type CLIConfig struct {
+ Address string
+ APIKey string
+ Timeout time.Duration
+ Insecure bool
+}
+
+type ACLConfig struct {
+ PolicyPath string
+}
+
+func LoadConfig(path string) error {
+ viper.SetConfigName("config")
+ if path == "" {
+ viper.AddConfigPath("/etc/headscale/")
+ viper.AddConfigPath("$HOME/.headscale")
+ viper.AddConfigPath(".")
+ } else {
+ // For testing
+ viper.AddConfigPath(path)
+ }
+
+ viper.SetEnvPrefix("headscale")
+ viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
+ viper.AutomaticEnv()
+
+ viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache")
+ viper.SetDefault("tls_letsencrypt_challenge_type", "HTTP-01")
+ viper.SetDefault("tls_client_auth_mode", "relaxed")
+
+ viper.SetDefault("log_level", "info")
+
+ viper.SetDefault("dns_config", nil)
+
+ viper.SetDefault("derp.server.enabled", false)
+ viper.SetDefault("derp.server.stun.enabled", true)
+
+ viper.SetDefault("unix_socket", "/var/run/headscale.sock")
+ viper.SetDefault("unix_socket_permission", "0o770")
+
+ viper.SetDefault("grpc_listen_addr", ":50443")
+ viper.SetDefault("grpc_allow_insecure", false)
+
+ viper.SetDefault("cli.timeout", "5s")
+ viper.SetDefault("cli.insecure", false)
+
+ viper.SetDefault("oidc.scope", []string{oidc.ScopeOpenID, "profile", "email"})
+ viper.SetDefault("oidc.strip_email_domain", true)
+
+ viper.SetDefault("logtail.enabled", false)
+
+ if err := viper.ReadInConfig(); err != nil {
+ return fmt.Errorf("fatal error reading config file: %w", err)
+ }
+
+ // Collect any validation errors and return them all at once
+ var errorText string
+ if (viper.GetString("tls_letsencrypt_hostname") != "") &&
+ ((viper.GetString("tls_cert_path") != "") || (viper.GetString("tls_key_path") != "")) {
+ errorText += "Fatal config error: set either tls_letsencrypt_hostname or tls_cert_path/tls_key_path, not both\n"
+ }
+
+ if (viper.GetString("tls_letsencrypt_hostname") != "") &&
+ (viper.GetString("tls_letsencrypt_challenge_type") == "TLS-ALPN-01") &&
+ (!strings.HasSuffix(viper.GetString("listen_addr"), ":443")) {
+ // this is only a warning because there could be something sitting in front of headscale that redirects the traffic (e.g. an iptables rule)
+ log.Warn().
+ Msg("Warning: when using tls_letsencrypt_hostname with TLS-ALPN-01 as challenge type, headscale must be reachable on port 443, i.e. listen_addr should probably end in :443")
+ }
+
+ if (viper.GetString("tls_letsencrypt_challenge_type") != "HTTP-01") &&
+ (viper.GetString("tls_letsencrypt_challenge_type") != "TLS-ALPN-01") {
+ errorText += "Fatal config error: the only supported values for tls_letsencrypt_challenge_type are HTTP-01 and TLS-ALPN-01\n"
+ }
+
+ if !strings.HasPrefix(viper.GetString("server_url"), "http://") &&
+ !strings.HasPrefix(viper.GetString("server_url"), "https://") {
+ errorText += "Fatal config error: server_url must start with https:// or http://\n"
+ }
+
+ _, authModeValid := LookupTLSClientAuthMode(
+ viper.GetString("tls_client_auth_mode"),
+ )
+
+ if !authModeValid {
+ errorText += fmt.Sprintf(
+ "Invalid tls_client_auth_mode supplied: %s. Accepted values: %s, %s, %s.",
+ viper.GetString("tls_client_auth_mode"),
+ DisabledClientAuth,
+ RelaxedClientAuth,
+ EnforcedClientAuth)
+ }
+
+ if errorText != "" {
+ //nolint
+ return errors.New(strings.TrimSuffix(errorText, "\n"))
+ } else {
+ return nil
+ }
+}
+
+func GetTLSConfig() TLSConfig {
+ tlsClientAuthMode, _ := LookupTLSClientAuthMode(
+ viper.GetString("tls_client_auth_mode"),
+ )
+
+ return TLSConfig{
+ LetsEncrypt: LetsEncryptConfig{
+ Hostname: viper.GetString("tls_letsencrypt_hostname"),
+ Listen: viper.GetString("tls_letsencrypt_listen"),
+ CacheDir: AbsolutePathFromConfigPath(
+ viper.GetString("tls_letsencrypt_cache_dir"),
+ ),
+ ChallengeType: viper.GetString("tls_letsencrypt_challenge_type"),
+ },
+ CertPath: AbsolutePathFromConfigPath(
+ viper.GetString("tls_cert_path"),
+ ),
+ KeyPath: AbsolutePathFromConfigPath(
+ viper.GetString("tls_key_path"),
+ ),
+ ClientAuthMode: tlsClientAuthMode,
+ }
+}
+
+func GetDERPConfig() DERPConfig {
+ serverEnabled := viper.GetBool("derp.server.enabled")
+ serverRegionID := viper.GetInt("derp.server.region_id")
+ serverRegionCode := viper.GetString("derp.server.region_code")
+ serverRegionName := viper.GetString("derp.server.region_name")
+ stunAddr := viper.GetString("derp.server.stun_listen_addr")
+
+ if serverEnabled && stunAddr == "" {
+ log.Fatal().
+ Msg("derp.server.stun_listen_addr must be set if derp.server.enabled is true")
+ }
+
+ urlStrs := viper.GetStringSlice("derp.urls")
+
+ urls := make([]url.URL, len(urlStrs))
+ for index, urlStr := range urlStrs {
+ urlAddr, err := url.Parse(urlStr)
+ if err != nil {
+ log.Error().
+ Str("url", urlStr).
+ Err(err).
+ Msg("Failed to parse url, ignoring...")
+ }
+
+ urls[index] = *urlAddr
+ }
+
+ paths := viper.GetStringSlice("derp.paths")
+
+ autoUpdate := viper.GetBool("derp.auto_update_enabled")
+ updateFrequency := viper.GetDuration("derp.update_frequency")
+
+ return DERPConfig{
+ ServerEnabled: serverEnabled,
+ ServerRegionID: serverRegionID,
+ ServerRegionCode: serverRegionCode,
+ ServerRegionName: serverRegionName,
+ STUNAddr: stunAddr,
+ URLs: urls,
+ Paths: paths,
+ AutoUpdate: autoUpdate,
+ UpdateFrequency: updateFrequency,
+ }
+}
+
+func GetLogTailConfig() LogTailConfig {
+ enabled := viper.GetBool("logtail.enabled")
+
+ return LogTailConfig{
+ Enabled: enabled,
+ }
+}
+
+func GetACLConfig() ACLConfig {
+ policyPath := viper.GetString("acl_policy_path")
+
+ return ACLConfig{
+ PolicyPath: policyPath,
+ }
+}
+
+func GetDNSConfig() (*tailcfg.DNSConfig, string) {
+ if viper.IsSet("dns_config") {
+ dnsConfig := &tailcfg.DNSConfig{}
+
+ if viper.IsSet("dns_config.nameservers") {
+ nameserversStr := viper.GetStringSlice("dns_config.nameservers")
+
+ nameservers := make([]netaddr.IP, len(nameserversStr))
+ resolvers := make([]dnstype.Resolver, len(nameserversStr))
+
+ for index, nameserverStr := range nameserversStr {
+ nameserver, err := netaddr.ParseIP(nameserverStr)
+ if err != nil {
+ log.Error().
+ Str("func", "getDNSConfig").
+ Err(err).
+ Msgf("Could not parse nameserver IP: %s", nameserverStr)
+ }
+
+ nameservers[index] = nameserver
+ resolvers[index] = dnstype.Resolver{
+ Addr: nameserver.String(),
+ }
+ }
+
+ dnsConfig.Nameservers = nameservers
+ dnsConfig.Resolvers = resolvers
+ }
+
+ if viper.IsSet("dns_config.restricted_nameservers") {
+ if len(dnsConfig.Nameservers) > 0 {
+ dnsConfig.Routes = make(map[string][]dnstype.Resolver)
+ restrictedDNS := viper.GetStringMapStringSlice(
+ "dns_config.restricted_nameservers",
+ )
+ for domain, restrictedNameservers := range restrictedDNS {
+ restrictedResolvers := make(
+ []dnstype.Resolver,
+ len(restrictedNameservers),
+ )
+ for index, nameserverStr := range restrictedNameservers {
+ nameserver, err := netaddr.ParseIP(nameserverStr)
+ if err != nil {
+ log.Error().
+ Str("func", "getDNSConfig").
+ Err(err).
+ Msgf("Could not parse restricted nameserver IP: %s", nameserverStr)
+ }
+ restrictedResolvers[index] = dnstype.Resolver{
+ Addr: nameserver.String(),
+ }
+ }
+ dnsConfig.Routes[domain] = restrictedResolvers
+ }
+ } else {
+ log.Warn().
+ Msg("Warning: dns_config.restricted_nameservers is set, but no nameservers are configured. Ignoring restricted_nameservers.")
+ }
+ }
+
+ if viper.IsSet("dns_config.domains") {
+ dnsConfig.Domains = viper.GetStringSlice("dns_config.domains")
+ }
+
+ if viper.IsSet("dns_config.magic_dns") {
+ magicDNS := viper.GetBool("dns_config.magic_dns")
+ if len(dnsConfig.Nameservers) > 0 {
+ dnsConfig.Proxied = magicDNS
+ } else if magicDNS {
+ log.Warn().
+ Msg("Warning: dns_config.magic_dns is set, but no nameservers are configured. Ignoring magic_dns.")
+ }
+ }
+
+ var baseDomain string
+ if viper.IsSet("dns_config.base_domain") {
+ baseDomain = viper.GetString("dns_config.base_domain")
+ } else {
+ baseDomain = "headscale.net" // does not really matter when MagicDNS is not enabled
+ }
+
+ return dnsConfig, baseDomain
+ }
+
+ return nil, ""
+}
+
+func GetHeadscaleConfig() (*Config, error) {
+ err := LoadConfig("")
+ if err != nil {
+ return nil, err
+ }
+
+ dnsConfig, baseDomain := GetDNSConfig()
+ derpConfig := GetDERPConfig()
+ logConfig := GetLogTailConfig()
+
+ configuredPrefixes := viper.GetStringSlice("ip_prefixes")
+ parsedPrefixes := make([]netaddr.IPPrefix, 0, len(configuredPrefixes)+1)
+
+ logLevelStr := viper.GetString("log_level")
+ logLevel, err := zerolog.ParseLevel(logLevelStr)
+ if err != nil {
+ logLevel = zerolog.DebugLevel
+ }
+
+ legacyPrefixField := viper.GetString("ip_prefix")
+ if len(legacyPrefixField) > 0 {
+ log.
+ Warn().
+ Msgf(
+ "%s, %s",
+ "use of 'ip_prefix' for configuration is deprecated",
+ "please see 'ip_prefixes' in the shipped example.",
+ )
+ legacyPrefix, err := netaddr.ParseIPPrefix(legacyPrefixField)
+ if err != nil {
+ panic(fmt.Errorf("failed to parse ip_prefix: %w", err))
+ }
+ parsedPrefixes = append(parsedPrefixes, legacyPrefix)
+ }
+
+ for i, prefixInConfig := range configuredPrefixes {
+ prefix, err := netaddr.ParseIPPrefix(prefixInConfig)
+ if err != nil {
+ panic(fmt.Errorf("failed to parse ip_prefixes[%d]: %w", i, err))
+ }
+ parsedPrefixes = append(parsedPrefixes, prefix)
+ }
+
+ prefixes := make([]netaddr.IPPrefix, 0, len(parsedPrefixes))
+ {
+ // dedup
+ normalizedPrefixes := make(map[string]int, len(parsedPrefixes))
+ for i, p := range parsedPrefixes {
+ normalized, _ := p.Range().Prefix()
+ normalizedPrefixes[normalized.String()] = i
+ }
+
+ // convert back to list
+ for _, i := range normalizedPrefixes {
+ prefixes = append(prefixes, parsedPrefixes[i])
+ }
+ }
+
+ if len(prefixes) < 1 {
+ prefixes = append(prefixes, netaddr.MustParseIPPrefix("100.64.0.0/10"))
+ log.Warn().
+ Msgf("'ip_prefixes' not configured, falling back to default: %v", prefixes)
+ }
+
+ return &Config{
+ ServerURL: viper.GetString("server_url"),
+ Addr: viper.GetString("listen_addr"),
+ MetricsAddr: viper.GetString("metrics_listen_addr"),
+ GRPCAddr: viper.GetString("grpc_listen_addr"),
+ GRPCAllowInsecure: viper.GetBool("grpc_allow_insecure"),
+ DisableUpdateCheck: viper.GetBool("disable_check_updates"),
+ LogLevel: logLevel,
+
+ IPPrefixes: prefixes,
+ PrivateKeyPath: AbsolutePathFromConfigPath(
+ viper.GetString("private_key_path"),
+ ),
+ BaseDomain: baseDomain,
+
+ DERP: derpConfig,
+
+ EphemeralNodeInactivityTimeout: viper.GetDuration(
+ "ephemeral_node_inactivity_timeout",
+ ),
+
+ DBtype: viper.GetString("db_type"),
+ DBpath: AbsolutePathFromConfigPath(viper.GetString("db_path")),
+ DBhost: viper.GetString("db_host"),
+ DBport: viper.GetInt("db_port"),
+ DBname: viper.GetString("db_name"),
+ DBuser: viper.GetString("db_user"),
+ DBpass: viper.GetString("db_pass"),
+
+ TLS: GetTLSConfig(),
+
+ DNSConfig: dnsConfig,
+
+ ACMEEmail: viper.GetString("acme_email"),
+ ACMEURL: viper.GetString("acme_url"),
+
+ UnixSocket: viper.GetString("unix_socket"),
+ UnixSocketPermission: GetFileMode("unix_socket_permission"),
+
+ OIDC: OIDCConfig{
+ Issuer: viper.GetString("oidc.issuer"),
+ ClientID: viper.GetString("oidc.client_id"),
+ ClientSecret: viper.GetString("oidc.client_secret"),
+ Scope: viper.GetStringSlice("oidc.scope"),
+ ExtraParams: viper.GetStringMapString("oidc.extra_params"),
+ AllowedDomains: viper.GetStringSlice("oidc.allowed_domains"),
+ AllowedUsers: viper.GetStringSlice("oidc.allowed_users"),
+ StripEmaildomain: viper.GetBool("oidc.strip_email_domain"),
+ },
+
+ LogTail: logConfig,
+
+ CLI: CLIConfig{
+ Address: viper.GetString("cli.address"),
+ APIKey: viper.GetString("cli.api_key"),
+ Timeout: viper.GetDuration("cli.timeout"),
+ Insecure: viper.GetBool("cli.insecure"),
+ },
+
+ ACL: GetACLConfig(),
+ }, nil
+}
diff --git a/machine_test.go b/machine_test.go
index 0fb3ed78..bde96057 100644
--- a/machine_test.go
+++ b/machine_test.go
@@ -821,7 +821,7 @@ func TestHeadscale_GenerateGivenName(t *testing.T) {
{
name: "simple machine name generation",
h: &Headscale{
- cfg: Config{
+ cfg: &Config{
OIDC: OIDCConfig{
StripEmaildomain: true,
},
@@ -836,7 +836,7 @@ func TestHeadscale_GenerateGivenName(t *testing.T) {
{
name: "machine name with 53 chars",
h: &Headscale{
- cfg: Config{
+ cfg: &Config{
OIDC: OIDCConfig{
StripEmaildomain: true,
},
@@ -851,7 +851,7 @@ func TestHeadscale_GenerateGivenName(t *testing.T) {
{
name: "machine name with 60 chars",
h: &Headscale{
- cfg: Config{
+ cfg: &Config{
OIDC: OIDCConfig{
StripEmaildomain: true,
},
@@ -866,7 +866,7 @@ func TestHeadscale_GenerateGivenName(t *testing.T) {
{
name: "machine name with 63 chars",
h: &Headscale{
- cfg: Config{
+ cfg: &Config{
OIDC: OIDCConfig{
StripEmaildomain: true,
},
@@ -881,7 +881,7 @@ func TestHeadscale_GenerateGivenName(t *testing.T) {
{
name: "machine name with 64 chars",
h: &Headscale{
- cfg: Config{
+ cfg: &Config{
OIDC: OIDCConfig{
StripEmaildomain: true,
},
@@ -896,7 +896,7 @@ func TestHeadscale_GenerateGivenName(t *testing.T) {
{
name: "machine name with 73 chars",
h: &Headscale{
- cfg: Config{
+ cfg: &Config{
OIDC: OIDCConfig{
StripEmaildomain: true,
},
diff --git a/utils.go b/utils.go
index 6dddf4c5..8d9dec5b 100644
--- a/utils.go
+++ b/utils.go
@@ -11,10 +11,12 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
+ "io/fs"
"net"
"os"
"path/filepath"
"reflect"
+ "strconv"
"strings"
"github.com/rs/zerolog/log"
@@ -55,6 +57,8 @@ const (
// privateKey prefix.
privateHexPrefix = "privkey:"
+
+ PermissionFallback = 0o700
)
func MachinePublicKeyStripPrefix(machineKey key.MachinePublic) string {
@@ -350,3 +354,14 @@ func AbsolutePathFromConfigPath(path string) string {
return path
}
+
+func GetFileMode(key string) fs.FileMode {
+ modeStr := viper.GetString(key)
+
+ mode, err := strconv.ParseUint(modeStr, Base8, BitSize64)
+ if err != nil {
+ return PermissionFallback
+ }
+
+ return fs.FileMode(mode)
+}
|