diff --git a/.gitignore b/.gitignore index f6e506bc..1662d7f2 100644 --- a/.gitignore +++ b/.gitignore @@ -22,6 +22,7 @@ dist/ /headscale config.json config.yaml +config*.yaml derp.yaml *.hujson *.key diff --git a/CHANGELOG.md b/CHANGELOG.md index 76982608..91aed9ef 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -72,6 +72,8 @@ after improving the test harness as part of adopting [#1460](https://github.com/ - Add APIs for managing headscale policy. [#1792](https://github.com/juanfont/headscale/pull/1792) - Fix for registering nodes using preauthkeys when running on a postgres database in a non-UTC timezone. [#764](https://github.com/juanfont/headscale/issues/764) - Make sure integration tests cover postgres for all scenarios +- CLI commands (all except `serve`) only requires minimal configuration, no more errors or warnings from unset settings [#2109](https://github.com/juanfont/headscale/pull/2109) +- CLI results are now concistently sent to stdout and errors to stderr [#2109](https://github.com/juanfont/headscale/pull/2109) ## 0.22.3 (2023-05-12) diff --git a/cmd/headscale/cli/api_key.go b/cmd/headscale/cli/api_key.go index 372ec390..bd839b7b 100644 --- a/cmd/headscale/cli/api_key.go +++ b/cmd/headscale/cli/api_key.go @@ -54,7 +54,7 @@ var listAPIKeys = &cobra.Command{ Run: func(cmd *cobra.Command, args []string) { output, _ := cmd.Flags().GetString("output") - ctx, client, conn, cancel := getHeadscaleCLIClient() + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() @@ -67,14 +67,10 @@ var listAPIKeys = &cobra.Command{ fmt.Sprintf("Error getting the list of keys: %s", err), output, ) - - return } if output != "" { SuccessOutput(response.GetApiKeys(), "", output) - - return } tableData := pterm.TableData{ @@ -102,8 +98,6 @@ var listAPIKeys = &cobra.Command{ fmt.Sprintf("Failed to render pterm table: %s", err), output, ) - - return } }, } @@ -119,9 +113,6 @@ If you loose a key, create a new one and revoke (expire) the old one.`, Run: func(cmd *cobra.Command, args []string) { output, _ := cmd.Flags().GetString("output") - log.Trace(). - Msg("Preparing to create ApiKey") - request := &v1.CreateApiKeyRequest{} durationStr, _ := cmd.Flags().GetString("expiration") @@ -133,19 +124,13 @@ If you loose a key, create a new one and revoke (expire) the old one.`, fmt.Sprintf("Could not parse duration: %s\n", err), output, ) - - return } expiration := time.Now().UTC().Add(time.Duration(duration)) - log.Trace(). - Dur("expiration", time.Duration(duration)). - Msg("expiration has been set") - request.Expiration = timestamppb.New(expiration) - ctx, client, conn, cancel := getHeadscaleCLIClient() + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() @@ -156,8 +141,6 @@ If you loose a key, create a new one and revoke (expire) the old one.`, fmt.Sprintf("Cannot create Api Key: %s\n", err), output, ) - - return } SuccessOutput(response.GetApiKey(), response.GetApiKey(), output) @@ -178,11 +161,9 @@ var expireAPIKeyCmd = &cobra.Command{ fmt.Sprintf("Error getting prefix from CLI flag: %s", err), output, ) - - return } - ctx, client, conn, cancel := getHeadscaleCLIClient() + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() @@ -197,8 +178,6 @@ var expireAPIKeyCmd = &cobra.Command{ fmt.Sprintf("Cannot expire Api Key: %s\n", err), output, ) - - return } SuccessOutput(response, "Key expired", output) @@ -219,11 +198,9 @@ var deleteAPIKeyCmd = &cobra.Command{ fmt.Sprintf("Error getting prefix from CLI flag: %s", err), output, ) - - return } - ctx, client, conn, cancel := getHeadscaleCLIClient() + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() @@ -238,8 +215,6 @@ var deleteAPIKeyCmd = &cobra.Command{ fmt.Sprintf("Cannot delete Api Key: %s\n", err), output, ) - - return } SuccessOutput(response, "Key deleted", output) diff --git a/cmd/headscale/cli/configtest.go b/cmd/headscale/cli/configtest.go index 72744a7b..d469885b 100644 --- a/cmd/headscale/cli/configtest.go +++ b/cmd/headscale/cli/configtest.go @@ -14,7 +14,7 @@ var configTestCmd = &cobra.Command{ Short: "Test the configuration.", Long: "Run a test of the configuration and exit.", Run: func(cmd *cobra.Command, args []string) { - _, err := getHeadscaleApp() + _, err := newHeadscaleServerWithConfig() if err != nil { log.Fatal().Caller().Err(err).Msg("Error initializing") } diff --git a/cmd/headscale/cli/debug.go b/cmd/headscale/cli/debug.go index 054fc07f..72cde32d 100644 --- a/cmd/headscale/cli/debug.go +++ b/cmd/headscale/cli/debug.go @@ -64,11 +64,9 @@ var createNodeCmd = &cobra.Command{ user, err := cmd.Flags().GetString("user") if err != nil { ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output) - - return } - ctx, client, conn, cancel := getHeadscaleCLIClient() + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() @@ -79,8 +77,6 @@ var createNodeCmd = &cobra.Command{ fmt.Sprintf("Error getting node from flag: %s", err), output, ) - - return } machineKey, err := cmd.Flags().GetString("key") @@ -90,8 +86,6 @@ var createNodeCmd = &cobra.Command{ fmt.Sprintf("Error getting key from flag: %s", err), output, ) - - return } var mkey key.MachinePublic @@ -102,8 +96,6 @@ var createNodeCmd = &cobra.Command{ fmt.Sprintf("Failed to parse machine key from flag: %s", err), output, ) - - return } routes, err := cmd.Flags().GetStringSlice("route") @@ -113,8 +105,6 @@ var createNodeCmd = &cobra.Command{ fmt.Sprintf("Error getting routes from flag: %s", err), output, ) - - return } request := &v1.DebugCreateNodeRequest{ @@ -131,8 +121,6 @@ var createNodeCmd = &cobra.Command{ fmt.Sprintf("Cannot create node: %s", status.Convert(err).Message()), output, ) - - return } SuccessOutput(response.GetNode(), "Node created", output) diff --git a/cmd/headscale/cli/nodes.go b/cmd/headscale/cli/nodes.go index 4de7b969..b9e97a33 100644 --- a/cmd/headscale/cli/nodes.go +++ b/cmd/headscale/cli/nodes.go @@ -116,11 +116,9 @@ var registerNodeCmd = &cobra.Command{ user, err := cmd.Flags().GetString("user") if err != nil { ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output) - - return } - ctx, client, conn, cancel := getHeadscaleCLIClient() + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() @@ -131,8 +129,6 @@ var registerNodeCmd = &cobra.Command{ fmt.Sprintf("Error getting node key from flag: %s", err), output, ) - - return } request := &v1.RegisterNodeRequest{ @@ -150,8 +146,6 @@ var registerNodeCmd = &cobra.Command{ ), output, ) - - return } SuccessOutput( @@ -169,17 +163,13 @@ var listNodesCmd = &cobra.Command{ user, err := cmd.Flags().GetString("user") if err != nil { ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output) - - return } showTags, err := cmd.Flags().GetBool("tags") if err != nil { ErrorOutput(err, fmt.Sprintf("Error getting tags flag: %s", err), output) - - return } - ctx, client, conn, cancel := getHeadscaleCLIClient() + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() @@ -194,21 +184,15 @@ var listNodesCmd = &cobra.Command{ fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()), output, ) - - return } if output != "" { SuccessOutput(response.GetNodes(), "", output) - - return } tableData, err := nodesToPtables(user, showTags, response.GetNodes()) if err != nil { ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output) - - return } err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render() @@ -218,8 +202,6 @@ var listNodesCmd = &cobra.Command{ fmt.Sprintf("Failed to render pterm table: %s", err), output, ) - - return } }, } @@ -243,7 +225,7 @@ var expireNodeCmd = &cobra.Command{ return } - ctx, client, conn, cancel := getHeadscaleCLIClient() + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() @@ -286,7 +268,7 @@ var renameNodeCmd = &cobra.Command{ return } - ctx, client, conn, cancel := getHeadscaleCLIClient() + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() @@ -335,7 +317,7 @@ var deleteNodeCmd = &cobra.Command{ return } - ctx, client, conn, cancel := getHeadscaleCLIClient() + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() @@ -435,7 +417,7 @@ var moveNodeCmd = &cobra.Command{ return } - ctx, client, conn, cancel := getHeadscaleCLIClient() + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() @@ -508,7 +490,7 @@ be assigned to nodes.`, return } if confirm { - ctx, client, conn, cancel := getHeadscaleCLIClient() + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() @@ -681,7 +663,7 @@ var tagCmd = &cobra.Command{ Aliases: []string{"tags", "t"}, Run: func(cmd *cobra.Command, args []string) { output, _ := cmd.Flags().GetString("output") - ctx, client, conn, cancel := getHeadscaleCLIClient() + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() diff --git a/cmd/headscale/cli/policy.go b/cmd/headscale/cli/policy.go index 00c4566d..d1349b5a 100644 --- a/cmd/headscale/cli/policy.go +++ b/cmd/headscale/cli/policy.go @@ -1,6 +1,7 @@ package cli import ( + "fmt" "io" "os" @@ -30,7 +31,8 @@ var getPolicy = &cobra.Command{ Short: "Print the current ACL Policy", Aliases: []string{"show", "view", "fetch"}, Run: func(cmd *cobra.Command, args []string) { - ctx, client, conn, cancel := getHeadscaleCLIClient() + output, _ := cmd.Flags().GetString("output") + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() @@ -38,13 +40,13 @@ var getPolicy = &cobra.Command{ response, err := client.GetPolicy(ctx, request) if err != nil { - log.Fatal().Err(err).Msg("Failed to get the policy") - - return + ErrorOutput(err, fmt.Sprintf("Failed loading ACL Policy: %s", err), output) } // TODO(pallabpain): Maybe print this better? - SuccessOutput("", response.GetPolicy(), "hujson") + // This does not pass output as we dont support yaml, json or json-line + // output for this command. It is HuJSON already. + SuccessOutput("", response.GetPolicy(), "") }, } @@ -56,33 +58,28 @@ var setPolicy = &cobra.Command{ This command only works when the acl.policy_mode is set to "db", and the policy will be stored in the database.`, Aliases: []string{"put", "update"}, Run: func(cmd *cobra.Command, args []string) { + output, _ := cmd.Flags().GetString("output") policyPath, _ := cmd.Flags().GetString("file") f, err := os.Open(policyPath) if err != nil { - log.Fatal().Err(err).Msg("Error opening the policy file") - - return + ErrorOutput(err, fmt.Sprintf("Error opening the policy file: %s", err), output) } defer f.Close() policyBytes, err := io.ReadAll(f) if err != nil { - log.Fatal().Err(err).Msg("Error reading the policy file") - - return + ErrorOutput(err, fmt.Sprintf("Error reading the policy file: %s", err), output) } request := &v1.SetPolicyRequest{Policy: string(policyBytes)} - ctx, client, conn, cancel := getHeadscaleCLIClient() + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() if _, err := client.SetPolicy(ctx, request); err != nil { - log.Fatal().Err(err).Msg("Failed to set ACL Policy") - - return + ErrorOutput(err, fmt.Sprintf("Failed to set ACL Policy: %s", err), output) } SuccessOutput(nil, "Policy updated.", "") diff --git a/cmd/headscale/cli/preauthkeys.go b/cmd/headscale/cli/preauthkeys.go index cc3b1b76..0074e029 100644 --- a/cmd/headscale/cli/preauthkeys.go +++ b/cmd/headscale/cli/preauthkeys.go @@ -60,11 +60,9 @@ var listPreAuthKeys = &cobra.Command{ user, err := cmd.Flags().GetString("user") if err != nil { ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output) - - return } - ctx, client, conn, cancel := getHeadscaleCLIClient() + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() @@ -85,8 +83,6 @@ var listPreAuthKeys = &cobra.Command{ if output != "" { SuccessOutput(response.GetPreAuthKeys(), "", output) - - return } tableData := pterm.TableData{ @@ -134,8 +130,6 @@ var listPreAuthKeys = &cobra.Command{ fmt.Sprintf("Failed to render pterm table: %s", err), output, ) - - return } }, } @@ -150,20 +144,12 @@ var createPreAuthKeyCmd = &cobra.Command{ user, err := cmd.Flags().GetString("user") if err != nil { ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output) - - return } reusable, _ := cmd.Flags().GetBool("reusable") ephemeral, _ := cmd.Flags().GetBool("ephemeral") tags, _ := cmd.Flags().GetStringSlice("tags") - log.Trace(). - Bool("reusable", reusable). - Bool("ephemeral", ephemeral). - Str("user", user). - Msg("Preparing to create preauthkey") - request := &v1.CreatePreAuthKeyRequest{ User: user, Reusable: reusable, @@ -180,8 +166,6 @@ var createPreAuthKeyCmd = &cobra.Command{ fmt.Sprintf("Could not parse duration: %s\n", err), output, ) - - return } expiration := time.Now().UTC().Add(time.Duration(duration)) @@ -192,7 +176,7 @@ var createPreAuthKeyCmd = &cobra.Command{ request.Expiration = timestamppb.New(expiration) - ctx, client, conn, cancel := getHeadscaleCLIClient() + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() @@ -203,8 +187,6 @@ var createPreAuthKeyCmd = &cobra.Command{ fmt.Sprintf("Cannot create Pre Auth Key: %s\n", err), output, ) - - return } SuccessOutput(response.GetPreAuthKey(), response.GetPreAuthKey().GetKey(), output) @@ -227,11 +209,9 @@ var expirePreAuthKeyCmd = &cobra.Command{ user, err := cmd.Flags().GetString("user") if err != nil { ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output) - - return } - ctx, client, conn, cancel := getHeadscaleCLIClient() + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() @@ -247,8 +227,6 @@ var expirePreAuthKeyCmd = &cobra.Command{ fmt.Sprintf("Cannot expire Pre Auth Key: %s\n", err), output, ) - - return } SuccessOutput(response, "Key expired", output) diff --git a/cmd/headscale/cli/root.go b/cmd/headscale/cli/root.go index b0d9500e..7bac79ce 100644 --- a/cmd/headscale/cli/root.go +++ b/cmd/headscale/cli/root.go @@ -9,6 +9,7 @@ import ( "github.com/rs/zerolog" "github.com/rs/zerolog/log" "github.com/spf13/cobra" + "github.com/spf13/viper" "github.com/tcnksm/go-latest" ) @@ -49,11 +50,6 @@ func initConfig() { } } - cfg, err := types.GetHeadscaleConfig() - if err != nil { - log.Fatal().Err(err).Msg("Failed to read headscale configuration") - } - machineOutput := HasMachineOutputFlag() // If the user has requested a "node" readable format, @@ -62,11 +58,13 @@ func initConfig() { zerolog.SetGlobalLevel(zerolog.Disabled) } - if cfg.Log.Format == types.JSONLogFormat { - log.Logger = log.Output(os.Stdout) - } + // logFormat := viper.GetString("log.format") + // if logFormat == types.JSONLogFormat { + // log.Logger = log.Output(os.Stdout) + // } - if !cfg.DisableUpdateCheck && !machineOutput { + disableUpdateCheck := viper.GetBool("disable_check_updates") + if !disableUpdateCheck && !machineOutput { if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") && Version != "dev" { githubTag := &latest.GithubTag{ diff --git a/cmd/headscale/cli/routes.go b/cmd/headscale/cli/routes.go index 86ef295c..96227b31 100644 --- a/cmd/headscale/cli/routes.go +++ b/cmd/headscale/cli/routes.go @@ -64,11 +64,9 @@ var listRoutesCmd = &cobra.Command{ fmt.Sprintf("Error getting machine id from flag: %s", err), output, ) - - return } - ctx, client, conn, cancel := getHeadscaleCLIClient() + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() @@ -82,14 +80,10 @@ var listRoutesCmd = &cobra.Command{ fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()), output, ) - - return } if output != "" { SuccessOutput(response.GetRoutes(), "", output) - - return } routes = response.GetRoutes() @@ -103,14 +97,10 @@ var listRoutesCmd = &cobra.Command{ fmt.Sprintf("Cannot get routes for node %d: %s", machineID, status.Convert(err).Message()), output, ) - - return } if output != "" { SuccessOutput(response.GetRoutes(), "", output) - - return } routes = response.GetRoutes() @@ -119,8 +109,6 @@ var listRoutesCmd = &cobra.Command{ tableData := routesToPtables(routes) if err != nil { ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output) - - return } err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render() @@ -130,8 +118,6 @@ var listRoutesCmd = &cobra.Command{ fmt.Sprintf("Failed to render pterm table: %s", err), output, ) - - return } }, } @@ -150,11 +136,9 @@ var enableRouteCmd = &cobra.Command{ fmt.Sprintf("Error getting machine id from flag: %s", err), output, ) - - return } - ctx, client, conn, cancel := getHeadscaleCLIClient() + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() @@ -167,14 +151,10 @@ var enableRouteCmd = &cobra.Command{ fmt.Sprintf("Cannot enable route %d: %s", routeID, status.Convert(err).Message()), output, ) - - return } if output != "" { SuccessOutput(response, "", output) - - return } }, } @@ -193,11 +173,9 @@ var disableRouteCmd = &cobra.Command{ fmt.Sprintf("Error getting machine id from flag: %s", err), output, ) - - return } - ctx, client, conn, cancel := getHeadscaleCLIClient() + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() @@ -210,14 +188,10 @@ var disableRouteCmd = &cobra.Command{ fmt.Sprintf("Cannot disable route %d: %s", routeID, status.Convert(err).Message()), output, ) - - return } if output != "" { SuccessOutput(response, "", output) - - return } }, } @@ -236,11 +210,9 @@ var deleteRouteCmd = &cobra.Command{ fmt.Sprintf("Error getting machine id from flag: %s", err), output, ) - - return } - ctx, client, conn, cancel := getHeadscaleCLIClient() + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() @@ -253,14 +225,10 @@ var deleteRouteCmd = &cobra.Command{ fmt.Sprintf("Cannot delete route %d: %s", routeID, status.Convert(err).Message()), output, ) - - return } if output != "" { SuccessOutput(response, "", output) - - return } }, } diff --git a/cmd/headscale/cli/server.go b/cmd/headscale/cli/serve.go similarity index 92% rename from cmd/headscale/cli/server.go rename to cmd/headscale/cli/serve.go index a1d19600..9f0fa35e 100644 --- a/cmd/headscale/cli/server.go +++ b/cmd/headscale/cli/serve.go @@ -16,7 +16,7 @@ var serveCmd = &cobra.Command{ return nil }, Run: func(cmd *cobra.Command, args []string) { - app, err := getHeadscaleApp() + app, err := newHeadscaleServerWithConfig() if err != nil { log.Fatal().Caller().Err(err).Msg("Error initializing") } diff --git a/cmd/headscale/cli/users.go b/cmd/headscale/cli/users.go index e6463d6f..d04d7568 100644 --- a/cmd/headscale/cli/users.go +++ b/cmd/headscale/cli/users.go @@ -44,7 +44,7 @@ var createUserCmd = &cobra.Command{ userName := args[0] - ctx, client, conn, cancel := getHeadscaleCLIClient() + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() @@ -63,8 +63,6 @@ var createUserCmd = &cobra.Command{ ), output, ) - - return } SuccessOutput(response.GetUser(), "User created", output) @@ -91,7 +89,7 @@ var destroyUserCmd = &cobra.Command{ Name: userName, } - ctx, client, conn, cancel := getHeadscaleCLIClient() + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() @@ -102,8 +100,6 @@ var destroyUserCmd = &cobra.Command{ fmt.Sprintf("Error: %s", status.Convert(err).Message()), output, ) - - return } confirm := false @@ -134,8 +130,6 @@ var destroyUserCmd = &cobra.Command{ ), output, ) - - return } SuccessOutput(response, "User destroyed", output) } else { @@ -151,7 +145,7 @@ var listUsersCmd = &cobra.Command{ Run: func(cmd *cobra.Command, args []string) { output, _ := cmd.Flags().GetString("output") - ctx, client, conn, cancel := getHeadscaleCLIClient() + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() @@ -164,14 +158,10 @@ var listUsersCmd = &cobra.Command{ fmt.Sprintf("Cannot get users: %s", status.Convert(err).Message()), output, ) - - return } if output != "" { SuccessOutput(response.GetUsers(), "", output) - - return } tableData := pterm.TableData{{"ID", "Name", "Created"}} @@ -192,8 +182,6 @@ var listUsersCmd = &cobra.Command{ fmt.Sprintf("Failed to render pterm table: %s", err), output, ) - - return } }, } @@ -213,7 +201,7 @@ var renameUserCmd = &cobra.Command{ Run: func(cmd *cobra.Command, args []string) { output, _ := cmd.Flags().GetString("output") - ctx, client, conn, cancel := getHeadscaleCLIClient() + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() @@ -232,8 +220,6 @@ var renameUserCmd = &cobra.Command{ ), output, ) - - return } SuccessOutput(response.GetUser(), "User renamed", output) diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go index 409e3dc4..ff1137be 100644 --- a/cmd/headscale/cli/utils.go +++ b/cmd/headscale/cli/utils.go @@ -23,8 +23,8 @@ const ( SocketWritePermissions = 0o666 ) -func getHeadscaleApp() (*hscontrol.Headscale, error) { - cfg, err := types.GetHeadscaleConfig() +func newHeadscaleServerWithConfig() (*hscontrol.Headscale, error) { + cfg, err := types.LoadServerConfig() if err != nil { return nil, fmt.Errorf( "failed to load configuration while creating headscale instance: %w", @@ -40,8 +40,8 @@ func getHeadscaleApp() (*hscontrol.Headscale, error) { return app, nil } -func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.ClientConn, context.CancelFunc) { - cfg, err := types.GetHeadscaleConfig() +func newHeadscaleCLIWithConfig() (context.Context, v1.HeadscaleServiceClient, *grpc.ClientConn, context.CancelFunc) { + cfg, err := types.LoadCLIConfig() if err != nil { log.Fatal(). Err(err). @@ -130,7 +130,7 @@ func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc. return ctx, client, conn, cancel } -func SuccessOutput(result interface{}, override string, outputFormat string) { +func output(result interface{}, override string, outputFormat string) string { var jsonBytes []byte var err error switch outputFormat { @@ -151,21 +151,26 @@ func SuccessOutput(result interface{}, override string, outputFormat string) { } default: // nolint - fmt.Println(override) - - return + return override } - // nolint - fmt.Println(string(jsonBytes)) + return string(jsonBytes) } +// SuccessOutput prints the result to stdout and exits with status code 0. +func SuccessOutput(result interface{}, override string, outputFormat string) { + fmt.Println(output(result, override, outputFormat)) + os.Exit(0) +} + +// ErrorOutput prints an error message to stderr and exits with status code 1. func ErrorOutput(errResult error, override string, outputFormat string) { type errOutput struct { Error string `json:"error"` } - SuccessOutput(errOutput{errResult.Error()}, override, outputFormat) + fmt.Fprintf(os.Stderr, "%s\n", output(errOutput{errResult.Error()}, override, outputFormat)) + os.Exit(1) } func HasMachineOutputFlag() bool { diff --git a/cmd/headscale/headscale_test.go b/cmd/headscale/headscale_test.go index 580caf17..00c4a276 100644 --- a/cmd/headscale/headscale_test.go +++ b/cmd/headscale/headscale_test.go @@ -4,7 +4,6 @@ import ( "io/fs" "os" "path/filepath" - "strings" "testing" "github.com/juanfont/headscale/hscontrol/types" @@ -113,60 +112,3 @@ func (*Suite) TestConfigLoading(c *check.C) { c.Assert(viper.GetBool("logtail.enabled"), check.Equals, false) c.Assert(viper.GetBool("randomize_client_port"), check.Equals, false) } - -func writeConfig(c *check.C, tmpDir string, configYaml []byte) { - // Populate a custom config file - configFile := filepath.Join(tmpDir, "config.yaml") - err := os.WriteFile(configFile, configYaml, 0o600) - if err != nil { - c.Fatalf("Couldn't write file %s", configFile) - } -} - -func (*Suite) TestTLSConfigValidation(c *check.C) { - tmpDir, err := os.MkdirTemp("", "headscale") - if err != nil { - c.Fatal(err) - } - // defer os.RemoveAll(tmpDir) - configYaml := []byte(`--- -tls_letsencrypt_hostname: example.com -tls_letsencrypt_challenge_type: "" -tls_cert_path: abc.pem -noise: - private_key_path: noise_private.key`) - writeConfig(c, tmpDir, configYaml) - - // Check configuration validation errors (1) - err = types.LoadConfig(tmpDir, false) - c.Assert(err, check.NotNil) - // check.Matches can not handle multiline strings - tmp := strings.ReplaceAll(err.Error(), "\n", "***") - c.Assert( - tmp, - check.Matches, - ".*Fatal config error: set either tls_letsencrypt_hostname or tls_cert_path/tls_key_path, not both.*", - ) - c.Assert( - tmp, - check.Matches, - ".*Fatal config error: the only supported values for tls_letsencrypt_challenge_type are.*", - ) - c.Assert( - tmp, - check.Matches, - ".*Fatal config error: server_url must start with https:// or http://.*", - ) - - // Check configuration validation errors (2) - configYaml = []byte(`--- -noise: - private_key_path: noise_private.key -server_url: http://127.0.0.1:8080 -tls_letsencrypt_hostname: example.com -tls_letsencrypt_challenge_type: TLS-ALPN-01 -`) - writeConfig(c, tmpDir, configYaml) - err = types.LoadConfig(tmpDir, false) - c.Assert(err, check.IsNil) -} diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 83048bec..3f985d98 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -684,7 +684,7 @@ func (api headscaleV1APIServer) GetPolicy( case types.PolicyModeDB: p, err := api.h.db.GetPolicy() if err != nil { - return nil, err + return nil, fmt.Errorf("loading ACL from database: %w", err) } return &v1.GetPolicyResponse{ @@ -696,20 +696,20 @@ func (api headscaleV1APIServer) GetPolicy( absPath := util.AbsolutePathFromConfigPath(api.h.cfg.Policy.Path) f, err := os.Open(absPath) if err != nil { - return nil, err + return nil, fmt.Errorf("reading policy from path %q: %w", absPath, err) } defer f.Close() b, err := io.ReadAll(f) if err != nil { - return nil, err + return nil, fmt.Errorf("reading policy from file: %w", err) } return &v1.GetPolicyResponse{Policy: string(b)}, nil } - return nil, nil + return nil, fmt.Errorf("no supported policy mode found in configuration, policy.mode: %q", api.h.cfg.Policy.Mode) } func (api headscaleV1APIServer) SetPolicy( diff --git a/hscontrol/types/config.go b/hscontrol/types/config.go index 0b7d63b7..8767077e 100644 --- a/hscontrol/types/config.go +++ b/hscontrol/types/config.go @@ -212,6 +212,12 @@ type Tuning struct { NodeMapSessionBufferedChanSize int } +// LoadConfig prepares and loads the Headscale configuration into Viper. +// This means it sets the default values, reads the configuration file and +// environment variables, and handles deprecated configuration options. +// It has to be called before LoadServerConfig and LoadCLIConfig. +// The configuration is not validated and the caller should check for errors +// using a validation function. func LoadConfig(path string, isFile bool) error { if isFile { viper.SetConfigFile(path) @@ -284,14 +290,14 @@ func LoadConfig(path string, isFile bool) error { viper.SetDefault("prefixes.allocation", string(IPAllocationStrategySequential)) - if IsCLIConfigured() { - return nil - } - if err := viper.ReadInConfig(); err != nil { return fmt.Errorf("fatal error reading config file: %w", err) } + return nil +} + +func validateServerConfig() error { depr := deprecator{ warns: make(set.Set[string]), fatals: make(set.Set[string]), @@ -360,12 +366,12 @@ func LoadConfig(path string, isFile bool) error { if errorText != "" { // nolint return errors.New(strings.TrimSuffix(errorText, "\n")) - } else { - return nil } + + return nil } -func GetTLSConfig() TLSConfig { +func tlsConfig() TLSConfig { return TLSConfig{ LetsEncrypt: LetsEncryptConfig{ Hostname: viper.GetString("tls_letsencrypt_hostname"), @@ -384,7 +390,7 @@ func GetTLSConfig() TLSConfig { } } -func GetDERPConfig() DERPConfig { +func derpConfig() DERPConfig { serverEnabled := viper.GetBool("derp.server.enabled") serverRegionID := viper.GetInt("derp.server.region_id") serverRegionCode := viper.GetString("derp.server.region_code") @@ -445,7 +451,7 @@ func GetDERPConfig() DERPConfig { } } -func GetLogTailConfig() LogTailConfig { +func logtailConfig() LogTailConfig { enabled := viper.GetBool("logtail.enabled") return LogTailConfig{ @@ -453,7 +459,7 @@ func GetLogTailConfig() LogTailConfig { } } -func GetPolicyConfig() PolicyConfig { +func policyConfig() PolicyConfig { policyPath := viper.GetString("policy.path") policyMode := viper.GetString("policy.mode") @@ -463,7 +469,7 @@ func GetPolicyConfig() PolicyConfig { } } -func GetLogConfig() LogConfig { +func logConfig() LogConfig { logLevelStr := viper.GetString("log.level") logLevel, err := zerolog.ParseLevel(logLevelStr) if err != nil { @@ -473,9 +479,9 @@ func GetLogConfig() LogConfig { logFormatOpt := viper.GetString("log.format") var logFormat string switch logFormatOpt { - case "json": + case JSONLogFormat: logFormat = JSONLogFormat - case "text": + case TextLogFormat: logFormat = TextLogFormat case "": logFormat = TextLogFormat @@ -491,7 +497,7 @@ func GetLogConfig() LogConfig { } } -func GetDatabaseConfig() DatabaseConfig { +func databaseConfig() DatabaseConfig { debug := viper.GetBool("database.debug") type_ := viper.GetString("database.type") @@ -543,7 +549,7 @@ func GetDatabaseConfig() DatabaseConfig { } } -func DNS() (DNSConfig, error) { +func dns() (DNSConfig, error) { var dns DNSConfig // TODO: Use this instead of manually getting settings when @@ -575,12 +581,12 @@ func DNS() (DNSConfig, error) { return dns, nil } -// GlobalResolvers returns the global DNS resolvers +// globalResolvers returns the global DNS resolvers // defined in the config file. // If a nameserver is a valid IP, it will be used as a regular resolver. // If a nameserver is a valid URL, it will be used as a DoH resolver. // If a nameserver is neither a valid URL nor a valid IP, it will be ignored. -func (d *DNSConfig) GlobalResolvers() []*dnstype.Resolver { +func (d *DNSConfig) globalResolvers() []*dnstype.Resolver { var resolvers []*dnstype.Resolver for _, nsStr := range d.Nameservers.Global { @@ -613,11 +619,11 @@ func (d *DNSConfig) GlobalResolvers() []*dnstype.Resolver { return resolvers } -// SplitResolvers returns a map of domain to DNS resolvers. +// splitResolvers returns a map of domain to DNS resolvers. // If a nameserver is a valid IP, it will be used as a regular resolver. // If a nameserver is a valid URL, it will be used as a DoH resolver. // If a nameserver is neither a valid URL nor a valid IP, it will be ignored. -func (d *DNSConfig) SplitResolvers() map[string][]*dnstype.Resolver { +func (d *DNSConfig) splitResolvers() map[string][]*dnstype.Resolver { routes := make(map[string][]*dnstype.Resolver) for domain, nameservers := range d.Nameservers.Split { var resolvers []*dnstype.Resolver @@ -653,7 +659,7 @@ func (d *DNSConfig) SplitResolvers() map[string][]*dnstype.Resolver { return routes } -func DNSToTailcfgDNS(dns DNSConfig) *tailcfg.DNSConfig { +func dnsToTailcfgDNS(dns DNSConfig) *tailcfg.DNSConfig { cfg := tailcfg.DNSConfig{} if dns.BaseDomain == "" && dns.MagicDNS { @@ -662,9 +668,9 @@ func DNSToTailcfgDNS(dns DNSConfig) *tailcfg.DNSConfig { cfg.Proxied = dns.MagicDNS cfg.ExtraRecords = dns.ExtraRecords - cfg.Resolvers = dns.GlobalResolvers() + cfg.Resolvers = dns.globalResolvers() - routes := dns.SplitResolvers() + routes := dns.splitResolvers() cfg.Routes = routes if dns.BaseDomain != "" { cfg.Domains = []string{dns.BaseDomain} @@ -674,7 +680,7 @@ func DNSToTailcfgDNS(dns DNSConfig) *tailcfg.DNSConfig { return &cfg } -func PrefixV4() (*netip.Prefix, error) { +func prefixV4() (*netip.Prefix, error) { prefixV4Str := viper.GetString("prefixes.v4") if prefixV4Str == "" { @@ -698,7 +704,7 @@ func PrefixV4() (*netip.Prefix, error) { return &prefixV4, nil } -func PrefixV6() (*netip.Prefix, error) { +func prefixV6() (*netip.Prefix, error) { prefixV6Str := viper.GetString("prefixes.v6") if prefixV6Str == "" { @@ -723,27 +729,37 @@ func PrefixV6() (*netip.Prefix, error) { return &prefixV6, nil } -func GetHeadscaleConfig() (*Config, error) { - if IsCLIConfigured() { - return &Config{ - CLI: CLIConfig{ - Address: viper.GetString("cli.address"), - APIKey: viper.GetString("cli.api_key"), - Timeout: viper.GetDuration("cli.timeout"), - Insecure: viper.GetBool("cli.insecure"), - }, - }, nil +// LoadCLIConfig returns the needed configuration for the CLI client +// of Headscale to connect to a Headscale server. +func LoadCLIConfig() (*Config, error) { + return &Config{ + DisableUpdateCheck: viper.GetBool("disable_check_updates"), + UnixSocket: viper.GetString("unix_socket"), + CLI: CLIConfig{ + Address: viper.GetString("cli.address"), + APIKey: viper.GetString("cli.api_key"), + Timeout: viper.GetDuration("cli.timeout"), + Insecure: viper.GetBool("cli.insecure"), + }, + }, nil +} + +// LoadServerConfig returns the full Headscale configuration to +// host a Headscale server. This is called as part of `headscale serve`. +func LoadServerConfig() (*Config, error) { + if err := validateServerConfig(); err != nil { + return nil, err } - logConfig := GetLogConfig() + logConfig := logConfig() zerolog.SetGlobalLevel(logConfig.Level) - prefix4, err := PrefixV4() + prefix4, err := prefixV4() if err != nil { return nil, err } - prefix6, err := PrefixV6() + prefix6, err := prefixV6() if err != nil { return nil, err } @@ -763,13 +779,13 @@ func GetHeadscaleConfig() (*Config, error) { return nil, fmt.Errorf("config error, prefixes.allocation is set to %s, which is not a valid strategy, allowed options: %s, %s", allocStr, IPAllocationStrategySequential, IPAllocationStrategyRandom) } - dnsConfig, err := DNS() + dnsConfig, err := dns() if err != nil { return nil, err } - derpConfig := GetDERPConfig() - logTailConfig := GetLogTailConfig() + derpConfig := derpConfig() + logTailConfig := logtailConfig() randomizeClientPort := viper.GetBool("randomize_client_port") oidcClientSecret := viper.GetString("oidc.client_secret") @@ -806,7 +822,7 @@ func GetHeadscaleConfig() (*Config, error) { MetricsAddr: viper.GetString("metrics_listen_addr"), GRPCAddr: viper.GetString("grpc_listen_addr"), GRPCAllowInsecure: viper.GetBool("grpc_allow_insecure"), - DisableUpdateCheck: viper.GetBool("disable_check_updates"), + DisableUpdateCheck: false, PrefixV4: prefix4, PrefixV6: prefix6, @@ -823,11 +839,11 @@ func GetHeadscaleConfig() (*Config, error) { "ephemeral_node_inactivity_timeout", ), - Database: GetDatabaseConfig(), + Database: databaseConfig(), - TLS: GetTLSConfig(), + TLS: tlsConfig(), - DNSConfig: DNSToTailcfgDNS(dnsConfig), + DNSConfig: dnsToTailcfgDNS(dnsConfig), DNSUserNameInMagicDNS: dnsConfig.UserNameInMagicDNS, ACMEEmail: viper.GetString("acme_email"), @@ -870,7 +886,7 @@ func GetHeadscaleConfig() (*Config, error) { LogTail: logTailConfig, RandomizeClientPort: randomizeClientPort, - Policy: GetPolicyConfig(), + Policy: policyConfig(), CLI: CLIConfig{ Address: viper.GetString("cli.address"), @@ -890,10 +906,6 @@ func GetHeadscaleConfig() (*Config, error) { }, nil } -func IsCLIConfigured() bool { - return viper.GetString("cli.address") != "" && viper.GetString("cli.api_key") != "" -} - type deprecator struct { warns set.Set[string] fatals set.Set[string] diff --git a/hscontrol/types/config_test.go b/hscontrol/types/config_test.go index 2b36e45c..e6e8d6c2 100644 --- a/hscontrol/types/config_test.go +++ b/hscontrol/types/config_test.go @@ -1,6 +1,8 @@ package types import ( + "os" + "path/filepath" "testing" "github.com/google/go-cmp/cmp" @@ -22,7 +24,7 @@ func TestReadConfig(t *testing.T) { name: "unmarshal-dns-full-config", configPath: "testdata/dns_full.yaml", setup: func(t *testing.T) (any, error) { - dns, err := DNS() + dns, err := dns() if err != nil { return nil, err } @@ -48,12 +50,12 @@ func TestReadConfig(t *testing.T) { name: "dns-to-tailcfg.DNSConfig", configPath: "testdata/dns_full.yaml", setup: func(t *testing.T) (any, error) { - dns, err := DNS() + dns, err := dns() if err != nil { return nil, err } - return DNSToTailcfgDNS(dns), nil + return dnsToTailcfgDNS(dns), nil }, want: &tailcfg.DNSConfig{ Proxied: true, @@ -79,7 +81,7 @@ func TestReadConfig(t *testing.T) { name: "unmarshal-dns-full-no-magic", configPath: "testdata/dns_full_no_magic.yaml", setup: func(t *testing.T) (any, error) { - dns, err := DNS() + dns, err := dns() if err != nil { return nil, err } @@ -105,12 +107,12 @@ func TestReadConfig(t *testing.T) { name: "dns-to-tailcfg.DNSConfig", configPath: "testdata/dns_full_no_magic.yaml", setup: func(t *testing.T) (any, error) { - dns, err := DNS() + dns, err := dns() if err != nil { return nil, err } - return DNSToTailcfgDNS(dns), nil + return dnsToTailcfgDNS(dns), nil }, want: &tailcfg.DNSConfig{ Proxied: false, @@ -136,7 +138,7 @@ func TestReadConfig(t *testing.T) { name: "base-domain-in-server-url-err", configPath: "testdata/base-domain-in-server-url.yaml", setup: func(t *testing.T) (any, error) { - return GetHeadscaleConfig() + return LoadServerConfig() }, want: nil, wantErr: "server_url cannot contain the base_domain, this will cause the headscale server and embedded DERP to become unreachable from the Tailscale node.", @@ -145,7 +147,7 @@ func TestReadConfig(t *testing.T) { name: "base-domain-not-in-server-url", configPath: "testdata/base-domain-not-in-server-url.yaml", setup: func(t *testing.T) (any, error) { - cfg, err := GetHeadscaleConfig() + cfg, err := LoadServerConfig() if err != nil { return nil, err } @@ -165,7 +167,7 @@ func TestReadConfig(t *testing.T) { name: "policy-path-is-loaded", configPath: "testdata/policy-path-is-loaded.yaml", setup: func(t *testing.T) (any, error) { - cfg, err := GetHeadscaleConfig() + cfg, err := LoadServerConfig() if err != nil { return nil, err } @@ -245,7 +247,7 @@ func TestReadConfigFromEnv(t *testing.T) { setup: func(t *testing.T) (any, error) { t.Logf("all settings: %#v", viper.AllSettings()) - dns, err := DNS() + dns, err := dns() if err != nil { return nil, err } @@ -289,3 +291,49 @@ func TestReadConfigFromEnv(t *testing.T) { }) } } + +func TestTLSConfigValidation(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "headscale") + if err != nil { + t.Fatal(err) + } + // defer os.RemoveAll(tmpDir) + configYaml := []byte(`--- +tls_letsencrypt_hostname: example.com +tls_letsencrypt_challenge_type: "" +tls_cert_path: abc.pem +noise: + private_key_path: noise_private.key`) + + // Populate a custom config file + configFilePath := filepath.Join(tmpDir, "config.yaml") + err = os.WriteFile(configFilePath, configYaml, 0o600) + if err != nil { + t.Fatalf("Couldn't write file %s", configFilePath) + } + + // Check configuration validation errors (1) + err = LoadConfig(tmpDir, false) + assert.NoError(t, err) + + err = validateServerConfig() + assert.Error(t, err) + assert.Contains(t, err.Error(), "Fatal config error: set either tls_letsencrypt_hostname or tls_cert_path/tls_key_path, not both") + assert.Contains(t, err.Error(), "Fatal config error: the only supported values for tls_letsencrypt_challenge_type are") + assert.Contains(t, err.Error(), "Fatal config error: server_url must start with https:// or http://") + + // Check configuration validation errors (2) + configYaml = []byte(`--- +noise: + private_key_path: noise_private.key +server_url: http://127.0.0.1:8080 +tls_letsencrypt_hostname: example.com +tls_letsencrypt_challenge_type: TLS-ALPN-01 +`) + err = os.WriteFile(configFilePath, configYaml, 0o600) + if err != nil { + t.Fatalf("Couldn't write file %s", configFilePath) + } + err = LoadConfig(tmpDir, false) + assert.NoError(t, err) +} diff --git a/integration/cli_test.go b/integration/cli_test.go index 9e7d179f..fd7a8c1b 100644 --- a/integration/cli_test.go +++ b/integration/cli_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "sort" + "strings" "testing" "time" @@ -735,13 +736,7 @@ func TestNodeTagCommand(t *testing.T) { assert.Equal(t, []string{"tag:test"}, node.GetForcedTags()) - // try to set a wrong tag and retrieve the error - type errOutput struct { - Error string `json:"error"` - } - var errorOutput errOutput - err = executeAndUnmarshal( - headscale, + _, err = headscale.Execute( []string{ "headscale", "nodes", @@ -750,10 +745,8 @@ func TestNodeTagCommand(t *testing.T) { "-t", "wrong-tag", "--output", "json", }, - &errorOutput, ) - assert.Nil(t, err) - assert.Contains(t, errorOutput.Error, "tag must start with the string 'tag:'") + assert.ErrorContains(t, err, "tag must start with the string 'tag:'") // Test list all nodes after added seconds resultMachines := make([]*v1.Node, len(machineKeys)) @@ -1398,18 +1391,17 @@ func TestNodeRenameCommand(t *testing.T) { assert.Contains(t, listAllAfterRename[4].GetGivenName(), "node-5") // Test failure for too long names - result, err := headscale.Execute( + _, err = headscale.Execute( []string{ "headscale", "nodes", "rename", "--identifier", fmt.Sprintf("%d", listAll[4].GetId()), - "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine12345678901234567890", + strings.Repeat("t", 64), }, ) - assert.Nil(t, err) - assert.Contains(t, result, "not be over 63 chars") + assert.ErrorContains(t, err, "not be over 63 chars") var listAllAfterRenameAttempt []v1.Node err = executeAndUnmarshal( @@ -1536,7 +1528,7 @@ func TestNodeMoveCommand(t *testing.T) { assert.Equal(t, allNodes[0].GetUser(), node.GetUser()) assert.Equal(t, allNodes[0].GetUser().GetName(), "new-user") - moveToNonExistingNSResult, err := headscale.Execute( + _, err = headscale.Execute( []string{ "headscale", "nodes", @@ -1549,11 +1541,9 @@ func TestNodeMoveCommand(t *testing.T) { "json", }, ) - assert.Nil(t, err) - - assert.Contains( + assert.ErrorContains( t, - moveToNonExistingNSResult, + err, "user not found", ) assert.Equal(t, node.GetUser().GetName(), "new-user")