From 94b30abf56ae09d82a1541bbc3d19557914f9b27 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Fri, 9 Feb 2024 07:27:00 +0100 Subject: [PATCH] Restructure database config (#1700) --- CHANGELOG.md | 22 +++++---- cmd/headscale/headscale_test.go | 6 ++- config-example.yaml | 31 ++++++------ hscontrol/app.go | 47 +++--------------- hscontrol/db/db.go | 56 ++++++++++++++------- hscontrol/db/routes_test.go | 11 +++-- hscontrol/db/suite_test.go | 12 +++-- hscontrol/suite_test.go | 8 ++- hscontrol/types/common.go | 10 +++- hscontrol/types/config.go | 87 ++++++++++++++++++++++++++------- integration/hsic/config.go | 4 +- 11 files changed, 180 insertions(+), 114 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a7908eaa..3adb23be 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,16 +34,18 @@ after improving the test harness as part of adopting [#1460](https://github.com/ ### Changes -Use versioned migrations [#1644](https://github.com/juanfont/headscale/pull/1644) -Make the OIDC callback page better [#1484](https://github.com/juanfont/headscale/pull/1484) -SSH support [#1487](https://github.com/juanfont/headscale/pull/1487) -State management has been improved [#1492](https://github.com/juanfont/headscale/pull/1492) -Use error group handling to ensure tests actually pass [#1535](https://github.com/juanfont/headscale/pull/1535) based on [#1460](https://github.com/juanfont/headscale/pull/1460) -Fix hang on SIGTERM [#1492](https://github.com/juanfont/headscale/pull/1492) taken from [#1480](https://github.com/juanfont/headscale/pull/1480) -Send logs to stderr by default [#1524](https://github.com/juanfont/headscale/pull/1524) -Fix [TS-2023-006](https://tailscale.com/security-bulletins/#ts-2023-006) security UPnP issue [#1563](https://github.com/juanfont/headscale/pull/1563) -Turn off gRPC logging [#1640](https://github.com/juanfont/headscale/pull/1640) fixes [#1259](https://github.com/juanfont/headscale/issues/1259) -Added the possibility to manually create a DERP-map entry which can be customized, instead of automatically creating it. [#1565](https://github.com/juanfont/headscale/pull/1565) +- Use versioned migrations [#1644](https://github.com/juanfont/headscale/pull/1644) +- Make the OIDC callback page better [#1484](https://github.com/juanfont/headscale/pull/1484) +- SSH support [#1487](https://github.com/juanfont/headscale/pull/1487) +- State management has been improved [#1492](https://github.com/juanfont/headscale/pull/1492) +- Use error group handling to ensure tests actually pass [#1535](https://github.com/juanfont/headscale/pull/1535) based on [#1460](https://github.com/juanfont/headscale/pull/1460) +- Fix hang on SIGTERM [#1492](https://github.com/juanfont/headscale/pull/1492) taken from [#1480](https://github.com/juanfont/headscale/pull/1480) +- Send logs to stderr by default [#1524](https://github.com/juanfont/headscale/pull/1524) +- Fix [TS-2023-006](https://tailscale.com/security-bulletins/#ts-2023-006) security UPnP issue [#1563](https://github.com/juanfont/headscale/pull/1563) +- Turn off gRPC logging [#1640](https://github.com/juanfont/headscale/pull/1640) fixes [#1259](https://github.com/juanfont/headscale/issues/1259) +- Added the possibility to manually create a DERP-map entry which can be customized, instead of automatically creating it. [#1565](https://github.com/juanfont/headscale/pull/1565) +- Change the structure of database configuration, see [config-example.yaml](./config-example.yaml) for the new structure. [#1700](https://github.com/juanfont/headscale/pull/1700) + - Old structure is now considered deprecated and will be removed in the future. ## 0.22.3 (2023-05-12) diff --git a/cmd/headscale/headscale_test.go b/cmd/headscale/headscale_test.go index 897e2537..d73d30b5 100644 --- a/cmd/headscale/headscale_test.go +++ b/cmd/headscale/headscale_test.go @@ -58,8 +58,10 @@ func (*Suite) TestConfigFileLoading(c *check.C) { c.Assert(viper.GetString("server_url"), check.Equals, "http://127.0.0.1:8080") c.Assert(viper.GetString("listen_addr"), check.Equals, "127.0.0.1:8080") c.Assert(viper.GetString("metrics_listen_addr"), check.Equals, "127.0.0.1:9090") - c.Assert(viper.GetString("db_type"), check.Equals, "sqlite3") + c.Assert(viper.GetString("db_type"), check.Equals, "sqlite") c.Assert(viper.GetString("db_path"), check.Equals, "/var/lib/headscale/db.sqlite") + c.Assert(viper.GetString("database.type"), check.Equals, "sqlite") + c.Assert(viper.GetString("database.sqlite.path"), check.Equals, "/var/lib/headscale/db.sqlite") c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "") c.Assert(viper.GetString("tls_letsencrypt_listen"), check.Equals, ":http") c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01") @@ -101,7 +103,7 @@ func (*Suite) TestConfigLoading(c *check.C) { c.Assert(viper.GetString("server_url"), check.Equals, "http://127.0.0.1:8080") c.Assert(viper.GetString("listen_addr"), check.Equals, "127.0.0.1:8080") c.Assert(viper.GetString("metrics_listen_addr"), check.Equals, "127.0.0.1:9090") - c.Assert(viper.GetString("db_type"), check.Equals, "sqlite3") + c.Assert(viper.GetString("db_type"), check.Equals, "sqlite") c.Assert(viper.GetString("db_path"), check.Equals, "/var/lib/headscale/db.sqlite") c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "") c.Assert(viper.GetString("tls_letsencrypt_listen"), check.Equals, ":http") diff --git a/config-example.yaml b/config-example.yaml index 96a654a4..8e4373fc 100644 --- a/config-example.yaml +++ b/config-example.yaml @@ -138,24 +138,25 @@ ephemeral_node_inactivity_timeout: 30m # In case of doubts, do not touch the default 10s. node_update_check_interval: 10s -# SQLite config -db_type: sqlite3 +database: + type: sqlite -# For production: -db_path: /var/lib/headscale/db.sqlite + # SQLite config + sqlite: + path: /var/lib/headscale/db.sqlite -# # Postgres config -# If using a Unix socket to connect to Postgres, set the socket path in the 'host' field and leave 'port' blank. -# db_type: postgres -# db_host: localhost -# db_port: 5432 -# db_name: headscale -# db_user: foo -# db_pass: bar + # # Postgres config + # postgres: + # # If using a Unix socket to connect to Postgres, set the socket path in the 'host' field and leave 'port' blank. + # host: localhost + # port: 5432 + # name: headscale + # user: foo + # pass: bar -# If other 'sslmode' is required instead of 'require(true)' and 'disabled(false)', set the 'sslmode' you need -# in the 'db_ssl' field. Refers to https://www.postgresql.org/docs/current/libpq-ssl.html Table 34.1. -# db_ssl: false + # # If other 'sslmode' is required instead of 'require(true)' and 'disabled(false)', set the 'sslmode' you need + # # in the 'db_ssl' field. Refers to https://www.postgresql.org/docs/current/libpq-ssl.html Table 34.1. + # ssl: false ### TLS configuration # diff --git a/hscontrol/app.go b/hscontrol/app.go index 91d53263..78b72bf5 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -12,7 +12,6 @@ import ( "os" "os/signal" "runtime" - "strconv" "strings" "sync" "syscall" @@ -118,37 +117,6 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { return nil, fmt.Errorf("failed to read or create Noise protocol private key: %w", err) } - var dbString string - switch cfg.DBtype { - case db.Postgres: - dbString = fmt.Sprintf( - "host=%s dbname=%s user=%s", - cfg.DBhost, - cfg.DBname, - cfg.DBuser, - ) - - if sslEnabled, err := strconv.ParseBool(cfg.DBssl); err == nil { - if !sslEnabled { - dbString += " sslmode=disable" - } - } else { - dbString += fmt.Sprintf(" sslmode=%s", cfg.DBssl) - } - - if cfg.DBport != 0 { - dbString += fmt.Sprintf(" port=%d", cfg.DBport) - } - - if cfg.DBpass != "" { - dbString += fmt.Sprintf(" password=%s", cfg.DBpass) - } - case db.Sqlite: - dbString = cfg.DBpath - default: - return nil, errUnsupportedDatabase - } - registrationCache := cache.New( registerCacheExpiration, registerCacheCleanup, @@ -156,8 +124,6 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { app := Headscale{ cfg: cfg, - dbType: cfg.DBtype, - dbString: dbString, noisePrivateKey: noisePrivateKey, registrationCache: registrationCache, pollNetMapStreamWG: sync.WaitGroup{}, @@ -165,9 +131,8 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { } database, err := db.NewHeadscaleDatabase( - cfg.DBtype, - dbString, - app.dbDebug, + cfg.Database, + app.nodeNotifier, cfg.IPPrefixes, cfg.BaseDomain) if err != nil { @@ -755,14 +720,16 @@ func (h *Headscale) Serve() error { var tailsqlContext context.Context if tailsqlEnabled { - if h.cfg.DBtype != db.Sqlite { - log.Fatal().Str("type", h.cfg.DBtype).Msgf("tailsql only support %q", db.Sqlite) + if h.cfg.Database.Type != types.DatabaseSqlite { + log.Fatal(). + Str("type", h.cfg.Database.Type). + Msgf("tailsql only support %q", types.DatabaseSqlite) } if tailsqlTSKey == "" { log.Fatal().Msg("tailsql requires TS_AUTHKEY to be set") } tailsqlContext = context.Background() - go runTailSQLService(ctx, util.TSLogfWrapper(), tailsqlStateDir, h.cfg.DBpath) + go runTailSQLService(ctx, util.TSLogfWrapper(), tailsqlStateDir, h.cfg.Database.Sqlite.Path) } // Handle common process-killing signals so we can gracefully shut down: diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index df7b0a4c..fe77dda8 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -6,11 +6,13 @@ import ( "errors" "fmt" "net/netip" + "strconv" "strings" "time" "github.com/glebarez/sqlite" "github.com/go-gormigrate/gormigrate/v2" + "github.com/juanfont/headscale/hscontrol/notifier" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" @@ -19,11 +21,6 @@ import ( "gorm.io/gorm/logger" ) -const ( - Postgres = "postgres" - Sqlite = "sqlite3" -) - var errDatabaseNotSupported = errors.New("database type not supported") // KV is a key-value store in a psql table. For future use... @@ -43,12 +40,12 @@ type HSDatabase struct { // TODO(kradalby): assemble this struct from toptions or something typed // rather than arguments. func NewHeadscaleDatabase( - dbType, connectionAddr string, - debug bool, + cfg types.DatabaseConfig, + notifier *notifier.Notifier, ipPrefixes []netip.Prefix, baseDomain string, ) (*HSDatabase, error) { - dbConn, err := openDB(dbType, connectionAddr, debug) + dbConn, err := openDB(cfg) if err != nil { return nil, err } @@ -62,7 +59,7 @@ func NewHeadscaleDatabase( { ID: "202312101416", Migrate: func(tx *gorm.DB) error { - if dbType == Postgres { + if cfg.Type == types.DatabasePostgres { tx.Exec(`create extension if not exists "uuid-ossp";`) } @@ -321,20 +318,20 @@ func NewHeadscaleDatabase( return &db, err } -func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) { - log.Debug().Str("type", dbType).Str("connection", connectionAddr).Msg("opening database") +func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) { + // TODO(kradalby): Integrate this with zerolog var dbLogger logger.Interface - if debug { + if cfg.Debug { dbLogger = logger.Default } else { dbLogger = logger.Default.LogMode(logger.Silent) } - switch dbType { - case Sqlite: + switch cfg.Type { + case types.DatabaseSqlite: db, err := gorm.Open( - sqlite.Open(connectionAddr+"?_synchronous=1&_journal_mode=WAL"), + sqlite.Open(cfg.Sqlite.Path+"?_synchronous=1&_journal_mode=WAL"), &gorm.Config{ DisableForeignKeyConstraintWhenMigrating: true, Logger: dbLogger, @@ -353,8 +350,31 @@ func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) { return db, err - case Postgres: - return gorm.Open(postgres.Open(connectionAddr), &gorm.Config{ + case types.DatabasePostgres: + dbString := fmt.Sprintf( + "host=%s dbname=%s user=%s", + cfg.Postgres.Host, + cfg.Postgres.Name, + cfg.Postgres.User, + ) + + if sslEnabled, err := strconv.ParseBool(cfg.Postgres.Ssl); err == nil { + if !sslEnabled { + dbString += " sslmode=disable" + } + } else { + dbString += fmt.Sprintf(" sslmode=%s", cfg.Postgres.Ssl) + } + + if cfg.Postgres.Port != 0 { + dbString += fmt.Sprintf(" port=%d", cfg.Postgres.Port) + } + + if cfg.Postgres.Pass != "" { + dbString += fmt.Sprintf(" password=%s", cfg.Postgres.Pass) + } + + return gorm.Open(postgres.Open(dbString), &gorm.Config{ DisableForeignKeyConstraintWhenMigrating: true, Logger: dbLogger, }) @@ -362,7 +382,7 @@ func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) { return nil, fmt.Errorf( "database of type %s is not supported: %w", - dbType, + cfg.Type, errDatabaseNotSupported, ) } diff --git a/hscontrol/db/routes_test.go b/hscontrol/db/routes_test.go index 3b544aa7..5d6281e8 100644 --- a/hscontrol/db/routes_test.go +++ b/hscontrol/db/routes_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "github.com/juanfont/headscale/hscontrol/notifier" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/stretchr/testify/assert" @@ -654,9 +655,13 @@ func TestFailoverRoute(t *testing.T) { assert.NoError(t, err) db, err = NewHeadscaleDatabase( - "sqlite3", - tmpDir+"/headscale_test.db", - false, + types.DatabaseConfig{ + Type: "sqlite3", + Sqlite: types.SqliteConfig{ + Path: tmpDir + "/headscale_test.db", + }, + }, + notifier.NewNotifier(), []netip.Prefix{ netip.MustParsePrefix("10.27.0.0/23"), }, diff --git a/hscontrol/db/suite_test.go b/hscontrol/db/suite_test.go index d4b11b14..e176e4b2 100644 --- a/hscontrol/db/suite_test.go +++ b/hscontrol/db/suite_test.go @@ -6,6 +6,8 @@ import ( "os" "testing" + "github.com/juanfont/headscale/hscontrol/notifier" + "github.com/juanfont/headscale/hscontrol/types" "gopkg.in/check.v1" ) @@ -44,9 +46,13 @@ func (s *Suite) ResetDB(c *check.C) { log.Printf("database path: %s", tmpDir+"/headscale_test.db") db, err = NewHeadscaleDatabase( - "sqlite3", - tmpDir+"/headscale_test.db", - false, + types.DatabaseConfig{ + Type: "sqlite3", + Sqlite: types.SqliteConfig{ + Path: tmpDir + "/headscale_test.db", + }, + }, + notifier.NewNotifier(), []netip.Prefix{ netip.MustParsePrefix("10.27.0.0/23"), }, diff --git a/hscontrol/suite_test.go b/hscontrol/suite_test.go index 82bdc797..3f0cc428 100644 --- a/hscontrol/suite_test.go +++ b/hscontrol/suite_test.go @@ -41,8 +41,12 @@ func (s *Suite) ResetDB(c *check.C) { } cfg := types.Config{ NoisePrivateKeyPath: tmpDir + "/noise_private.key", - DBtype: "sqlite3", - DBpath: tmpDir + "/headscale_test.db", + Database: types.DatabaseConfig{ + Type: "sqlite3", + Sqlite: types.SqliteConfig{ + Path: tmpDir + "/headscale_test.db", + }, + }, IPPrefixes: []netip.Prefix{ netip.MustParsePrefix("10.27.0.0/23"), }, diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index d45f9d4c..ceeceea0 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -12,7 +12,11 @@ import ( "tailscale.com/tailcfg" ) -const SelfUpdateIdentifier = "self-update" +const ( + SelfUpdateIdentifier = "self-update" + DatabasePostgres = "postgres" + DatabaseSqlite = "sqlite3" +) var ErrCannotParsePrefix = errors.New("cannot parse prefix") @@ -154,7 +158,9 @@ func (su *StateUpdate) Valid() bool { } case StateSelfUpdate: if su.ChangeNodes == nil || len(su.ChangeNodes) != 1 { - panic("Mandatory field ChangeNodes is not set for StateSelfUpdate or has more than one node") + panic( + "Mandatory field ChangeNodes is not set for StateSelfUpdate or has more than one node", + ) } case StateDERPUpdated: if su.DERPMap == nil { diff --git a/hscontrol/types/config.go b/hscontrol/types/config.go index d9d58301..d83b21f7 100644 --- a/hscontrol/types/config.go +++ b/hscontrol/types/config.go @@ -46,16 +46,9 @@ type Config struct { Log LogConfig DisableUpdateCheck bool - DERP DERPConfig + Database DatabaseConfig - DBtype string - DBpath string - DBhost string - DBport int - DBname string - DBuser string - DBpass string - DBssl string + DERP DERPConfig TLS TLSConfig @@ -77,6 +70,28 @@ type Config struct { ACL ACLConfig } +type SqliteConfig struct { + Path string +} + +type PostgresConfig struct { + Host string + Port int + Name string + User string + Pass string + Ssl string +} + +type DatabaseConfig struct { + // Type sets the database type, either "sqlite3" or "postgres" + Type string + Debug bool + + Sqlite SqliteConfig + Postgres PostgresConfig +} + type TLSConfig struct { CertPath string KeyPath string @@ -161,6 +176,19 @@ func LoadConfig(path string, isFile bool) error { viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) viper.AutomaticEnv() + viper.RegisterAlias("db_type", "database.type") + + // SQLite aliases + viper.RegisterAlias("db_path", "database.sqlite.path") + + // Postgres aliases + viper.RegisterAlias("db_host", "database.postgres.host") + viper.RegisterAlias("db_port", "database.postgres.port") + viper.RegisterAlias("db_name", "database.postgres.name") + viper.RegisterAlias("db_user", "database.postgres.user") + viper.RegisterAlias("db_pass", "database.postgres.pass") + viper.RegisterAlias("db_ssl", "database.postgres.ssl") + viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache") viper.SetDefault("tls_letsencrypt_challenge_type", HTTP01ChallengeType) @@ -184,6 +212,7 @@ func LoadConfig(path string, isFile bool) error { viper.SetDefault("cli.insecure", false) viper.SetDefault("db_ssl", false) + viper.SetDefault("database.postgres.ssl", false) viper.SetDefault("oidc.scope", []string{oidc.ScopeOpenID, "profile", "email"}) viper.SetDefault("oidc.strip_email_domain", true) @@ -389,6 +418,37 @@ func GetLogConfig() LogConfig { } } +func GetDatabaseConfig() DatabaseConfig { + debug := viper.GetBool("database.debug") + + type_ := viper.GetString("database.type") + + switch type_ { + case DatabaseSqlite, DatabasePostgres: + break + case "sqlite": + type_ = "sqlite3" + default: + log.Fatal().Msgf("invalid database type %q, must be sqlite, sqlite3 or postgres", type_) + } + + return DatabaseConfig{ + Type: type_, + Debug: debug, + Sqlite: SqliteConfig{ + Path: util.AbsolutePathFromConfigPath(viper.GetString("database.sqlite.path")), + }, + Postgres: PostgresConfig{ + Host: viper.GetString("database.postgres.host"), + Port: viper.GetInt("database.postgres.port"), + Name: viper.GetString("database.postgres.name"), + User: viper.GetString("database.postgres.user"), + Pass: viper.GetString("database.postgres.pass"), + Ssl: viper.GetString("database.postgres.ssl"), + }, + } +} + func GetDNSConfig() (*tailcfg.DNSConfig, string) { if viper.IsSet("dns_config") { dnsConfig := &tailcfg.DNSConfig{} @@ -617,14 +677,7 @@ func GetHeadscaleConfig() (*Config, error) { "node_update_check_interval", ), - DBtype: viper.GetString("db_type"), - DBpath: util.AbsolutePathFromConfigPath(viper.GetString("db_path")), - DBhost: viper.GetString("db_host"), - DBport: viper.GetInt("db_port"), - DBname: viper.GetString("db_name"), - DBuser: viper.GetString("db_user"), - DBpass: viper.GetString("db_pass"), - DBssl: viper.GetString("db_ssl"), + Database: GetDatabaseConfig(), TLS: GetTLSConfig(), diff --git a/integration/hsic/config.go b/integration/hsic/config.go index 00c1770c..819b108f 100644 --- a/integration/hsic/config.go +++ b/integration/hsic/config.go @@ -110,8 +110,8 @@ func DefaultConfigEnv() map[string]string { return map[string]string{ "HEADSCALE_LOG_LEVEL": "trace", "HEADSCALE_ACL_POLICY_PATH": "", - "HEADSCALE_DB_TYPE": "sqlite3", - "HEADSCALE_DB_PATH": "/tmp/integration_test_db.sqlite3", + "HEADSCALE_DATABASE_TYPE": "sqlite", + "HEADSCALE_DATABASE_SQLITE_PATH": "/tmp/integration_test_db.sqlite3", "HEADSCALE_EPHEMERAL_NODE_INACTIVITY_TIMEOUT": "30m", "HEADSCALE_NODE_UPDATE_CHECK_INTERVAL": "10s", "HEADSCALE_IP_PREFIXES": "fd7a:115c:a1e0::/48 100.64.0.0/10",