Add and fix nlreturn (new line return)

This commit is contained in:
Kristoffer Dalby 2021-11-14 16:46:09 +01:00
parent d0ef850035
commit 89eb13c6cb
No known key found for this signature in database
GPG Key ID: 09F62DC067465735
25 changed files with 198 additions and 1 deletions

View File

@ -30,7 +30,6 @@ linters:
- stylecheck - stylecheck
- wrapcheck - wrapcheck
- paralleltest - paralleltest
- nlreturn
- ifshort - ifshort
- gomnd - gomnd
- goerr113 - goerr113

View File

@ -58,6 +58,7 @@ func (h *Headscale) LoadACLPolicy(path string) error {
return err return err
} }
h.aclRules = rules h.aclRules = rules
return nil return nil
} }
@ -77,6 +78,7 @@ func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) {
if err != nil { if err != nil {
log.Error(). log.Error().
Msgf("Error parsing ACL %d, User %d", i, j) Msgf("Error parsing ACL %d, User %d", i, j)
return nil, err return nil, err
} }
srcIPs = append(srcIPs, srcs...) srcIPs = append(srcIPs, srcs...)
@ -89,6 +91,7 @@ func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) {
if err != nil { if err != nil {
log.Error(). log.Error().
Msgf("Error parsing ACL %d, Port %d", i, j) Msgf("Error parsing ACL %d, Port %d", i, j)
return nil, err return nil, err
} }
destPorts = append(destPorts, dests...) destPorts = append(destPorts, dests...)
@ -147,6 +150,7 @@ func (h *Headscale) generateACLPolicyDestPorts(
dests = append(dests, pr) dests = append(dests, pr)
} }
} }
return dests, nil return dests, nil
} }
@ -169,6 +173,7 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
ips = append(ips, node.IPAddress) ips = append(ips, node.IPAddress)
} }
} }
return ips, nil return ips, nil
} }
@ -200,11 +205,13 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
for _, t := range hostinfo.RequestTags { for _, t := range hostinfo.RequestTags {
if s[4:] == t { if s[4:] == t {
ips = append(ips, m.IPAddress) ips = append(ips, m.IPAddress)
break break
} }
} }
} }
} }
return ips, nil return ips, nil
} }
@ -218,6 +225,7 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
for _, n := range nodes { for _, n := range nodes {
ips = append(ips, n.IPAddress) ips = append(ips, n.IPAddress)
} }
return ips, nil return ips, nil
} }
@ -272,5 +280,6 @@ func (h *Headscale) expandPorts(s string) (*[]tailcfg.PortRange, error) {
return nil, errorInvalidPortFormat return nil, errorInvalidPortFormat
} }
} }
return &ports, nil return &ports, nil
} }

View File

@ -65,6 +65,7 @@ func (h *Hosts) UnmarshalJSON(data []byte) error {
hosts[k] = prefix hosts[k] = prefix
} }
*h = hosts *h = hosts
return nil return nil
} }
@ -73,5 +74,6 @@ func (p ACLPolicy) IsZero() bool {
if len(p.Groups) == 0 && len(p.Hosts) == 0 && len(p.ACLs) == 0 { if len(p.Groups) == 0 && len(p.Hosts) == 0 && len(p.ACLs) == 0 {
return true return true
} }
return false return false
} }

24
api.go
View File

@ -30,6 +30,7 @@ func (h *Headscale) RegisterWebAPI(c *gin.Context) {
mKeyStr := c.Query("key") mKeyStr := c.Query("key")
if mKeyStr == "" { if mKeyStr == "" {
c.String(http.StatusBadRequest, "Wrong params") c.String(http.StatusBadRequest, "Wrong params")
return return
} }
@ -66,6 +67,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
Msg("Cannot parse machine key") Msg("Cannot parse machine key")
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc() machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
c.String(http.StatusInternalServerError, "Sad!") c.String(http.StatusInternalServerError, "Sad!")
return return
} }
req := tailcfg.RegisterRequest{} req := tailcfg.RegisterRequest{}
@ -77,6 +79,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
Msg("Cannot decode message") Msg("Cannot decode message")
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc() machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
c.String(http.StatusInternalServerError, "Very sad!") c.String(http.StatusInternalServerError, "Very sad!")
return return
} }
@ -96,6 +99,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
Msg("Could not create row") Msg("Could not create row")
machineRegistrations.WithLabelValues("unknown", "web", "error", m.Namespace.Name). machineRegistrations.WithLabelValues("unknown", "web", "error", m.Namespace.Name).
Inc() Inc()
return return
} }
m = &newMachine m = &newMachine
@ -103,6 +107,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
if !m.Registered && req.Auth.AuthKey != "" { if !m.Registered && req.Auth.AuthKey != "" {
h.handleAuthKey(c, h.db, mKey, req, *m) h.handleAuthKey(c, h.db, mKey, req, *m)
return return
} }
@ -131,9 +136,11 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
c.String(http.StatusInternalServerError, "") c.String(http.StatusInternalServerError, "")
return return
} }
c.Data(200, "application/json; charset=utf-8", respBody) c.Data(200, "application/json; charset=utf-8", respBody)
return return
} }
@ -158,11 +165,13 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
machineRegistrations.WithLabelValues("update", "web", "error", m.Namespace.Name). machineRegistrations.WithLabelValues("update", "web", "error", m.Namespace.Name).
Inc() Inc()
c.String(http.StatusInternalServerError, "") c.String(http.StatusInternalServerError, "")
return return
} }
machineRegistrations.WithLabelValues("update", "web", "success", m.Namespace.Name). machineRegistrations.WithLabelValues("update", "web", "success", m.Namespace.Name).
Inc() Inc()
c.Data(200, "application/json; charset=utf-8", respBody) c.Data(200, "application/json; charset=utf-8", respBody)
return return
} }
@ -199,11 +208,13 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
machineRegistrations.WithLabelValues("new", "web", "error", m.Namespace.Name). machineRegistrations.WithLabelValues("new", "web", "error", m.Namespace.Name).
Inc() Inc()
c.String(http.StatusInternalServerError, "") c.String(http.StatusInternalServerError, "")
return return
} }
machineRegistrations.WithLabelValues("new", "web", "success", m.Namespace.Name). machineRegistrations.WithLabelValues("new", "web", "success", m.Namespace.Name).
Inc() Inc()
c.Data(200, "application/json; charset=utf-8", respBody) c.Data(200, "application/json; charset=utf-8", respBody)
return return
} }
@ -225,9 +236,11 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
c.String(http.StatusInternalServerError, "Extremely sad!") c.String(http.StatusInternalServerError, "Extremely sad!")
return return
} }
c.Data(200, "application/json; charset=utf-8", respBody) c.Data(200, "application/json; charset=utf-8", respBody)
return return
} }
@ -259,6 +272,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
c.String(http.StatusInternalServerError, "") c.String(http.StatusInternalServerError, "")
return return
} }
c.Data(200, "application/json; charset=utf-8", respBody) c.Data(200, "application/json; charset=utf-8", respBody)
@ -279,6 +293,7 @@ func (h *Headscale) getMapResponse(
Str("func", "getMapResponse"). Str("func", "getMapResponse").
Err(err). Err(err).
Msg("Cannot convert to node") Msg("Cannot convert to node")
return nil, err return nil, err
} }
@ -288,6 +303,7 @@ func (h *Headscale) getMapResponse(
Str("func", "getMapResponse"). Str("func", "getMapResponse").
Err(err). Err(err).
Msg("Cannot fetch peers") Msg("Cannot fetch peers")
return nil, err return nil, err
} }
@ -299,6 +315,7 @@ func (h *Headscale) getMapResponse(
Str("func", "getMapResponse"). Str("func", "getMapResponse").
Err(err). Err(err).
Msg("Failed to convert peers to Tailscale nodes") Msg("Failed to convert peers to Tailscale nodes")
return nil, err return nil, err
} }
@ -313,6 +330,7 @@ func (h *Headscale) getMapResponse(
Str("func", "getMapResponse"). Str("func", "getMapResponse").
Err(err). Err(err).
Msg("Failed generate the DNSConfig") Msg("Failed generate the DNSConfig")
return nil, err return nil, err
} }
@ -353,6 +371,7 @@ func (h *Headscale) getMapResponse(
data := make([]byte, 4) data := make([]byte, 4)
binary.LittleEndian.PutUint32(data, uint32(len(respBody))) binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
data = append(data, respBody...) data = append(data, respBody...)
return data, nil return data, nil
} }
@ -383,6 +402,7 @@ func (h *Headscale) getMapKeepAliveResponse(
data := make([]byte, 4) data := make([]byte, 4)
binary.LittleEndian.PutUint32(data, uint32(len(respBody))) binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
data = append(data, respBody...) data = append(data, respBody...)
return data, nil return data, nil
} }
@ -416,6 +436,7 @@ func (h *Headscale) handleAuthKey(
c.String(http.StatusInternalServerError, "") c.String(http.StatusInternalServerError, "")
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name). machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).
Inc() Inc()
return return
} }
c.Data(401, "application/json; charset=utf-8", respBody) c.Data(401, "application/json; charset=utf-8", respBody)
@ -425,6 +446,7 @@ func (h *Headscale) handleAuthKey(
Msg("Failed authentication via AuthKey") Msg("Failed authentication via AuthKey")
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name). machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).
Inc() Inc()
return return
} }
@ -440,6 +462,7 @@ func (h *Headscale) handleAuthKey(
Msg("Failed to find an available IP") Msg("Failed to find an available IP")
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name). machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).
Inc() Inc()
return return
} }
log.Info(). log.Info().
@ -471,6 +494,7 @@ func (h *Headscale) handleAuthKey(
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name). machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).
Inc() Inc()
c.String(http.StatusInternalServerError, "Extremely sad!") c.String(http.StatusInternalServerError, "Extremely sad!")
return return
} }
machineRegistrations.WithLabelValues("new", "authkey", "success", m.Namespace.Name). machineRegistrations.WithLabelValues("new", "authkey", "success", m.Namespace.Name).

5
app.go
View File

@ -297,6 +297,7 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
Caller(). Caller().
Str("client_address", p.Addr.String()). Str("client_address", p.Addr.String()).
Msg("Retrieving metadata is failed") Msg("Retrieving metadata is failed")
return ctx, status.Errorf( return ctx, status.Errorf(
codes.InvalidArgument, codes.InvalidArgument,
"Retrieving metadata is failed", "Retrieving metadata is failed",
@ -309,6 +310,7 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
Caller(). Caller().
Str("client_address", p.Addr.String()). Str("client_address", p.Addr.String()).
Msg("Authorization token is not supplied") Msg("Authorization token is not supplied")
return ctx, status.Errorf( return ctx, status.Errorf(
codes.Unauthenticated, codes.Unauthenticated,
"Authorization token is not supplied", "Authorization token is not supplied",
@ -322,6 +324,7 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
Caller(). Caller().
Str("client_address", p.Addr.String()). Str("client_address", p.Addr.String()).
Msg(`missing "Bearer " prefix in "Authorization" header`) Msg(`missing "Bearer " prefix in "Authorization" header`)
return ctx, status.Error( return ctx, status.Error(
codes.Unauthenticated, codes.Unauthenticated,
`missing "Bearer " prefix in "Authorization" header`, `missing "Bearer " prefix in "Authorization" header`,
@ -392,6 +395,7 @@ func (h *Headscale) ensureUnixSocketIsAbsent() error {
if _, err := os.Stat(h.cfg.UnixSocket); errors.Is(err, os.ErrNotExist) { if _, err := os.Stat(h.cfg.UnixSocket); errors.Is(err, os.ErrNotExist) {
return nil return nil
} }
return os.Remove(h.cfg.UnixSocket) return os.Remove(h.cfg.UnixSocket)
} }
@ -568,6 +572,7 @@ func (h *Headscale) Serve() error {
if tlsConfig != nil { if tlsConfig != nil {
g.Go(func() error { g.Go(func() error {
tlsl := tls.NewListener(httpListener, tlsConfig) tlsl := tls.NewListener(httpListener, tlsConfig)
return httpServer.Serve(tlsl) return httpServer.Serve(tlsl)
}) })
} else { } else {

View File

@ -77,6 +77,7 @@ func (h *Headscale) AppleMobileConfig(c *gin.Context) {
"text/html; charset=utf-8", "text/html; charset=utf-8",
[]byte("Could not render Apple index template"), []byte("Could not render Apple index template"),
) )
return return
} }
@ -97,6 +98,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
"text/html; charset=utf-8", "text/html; charset=utf-8",
[]byte("Failed to create UUID"), []byte("Failed to create UUID"),
) )
return return
} }
@ -111,6 +113,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
"text/html; charset=utf-8", "text/html; charset=utf-8",
[]byte("Failed to create UUID"), []byte("Failed to create UUID"),
) )
return return
} }
@ -133,6 +136,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
"text/html; charset=utf-8", "text/html; charset=utf-8",
[]byte("Could not render Apple macOS template"), []byte("Could not render Apple macOS template"),
) )
return return
} }
case "ios": case "ios":
@ -146,6 +150,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
"text/html; charset=utf-8", "text/html; charset=utf-8",
[]byte("Could not render Apple iOS template"), []byte("Could not render Apple iOS template"),
) )
return return
} }
default: default:
@ -154,6 +159,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
"text/html; charset=utf-8", "text/html; charset=utf-8",
[]byte("Invalid platform, only ios and macos is supported"), []byte("Invalid platform, only ios and macos is supported"),
) )
return return
} }
@ -174,6 +180,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
"text/html; charset=utf-8", "text/html; charset=utf-8",
[]byte("Could not render Apple platform template"), []byte("Could not render Apple platform template"),
) )
return return
} }

View File

@ -48,6 +48,7 @@ var createNodeCmd = &cobra.Command{
namespace, err := cmd.Flags().GetString("namespace") namespace, err := cmd.Flags().GetString("namespace")
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output) ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
return return
} }
@ -62,6 +63,7 @@ var createNodeCmd = &cobra.Command{
fmt.Sprintf("Error getting node from flag: %s", err), fmt.Sprintf("Error getting node from flag: %s", err),
output, output,
) )
return return
} }
@ -72,6 +74,7 @@ var createNodeCmd = &cobra.Command{
fmt.Sprintf("Error getting key from flag: %s", err), fmt.Sprintf("Error getting key from flag: %s", err),
output, output,
) )
return return
} }
@ -82,6 +85,7 @@ var createNodeCmd = &cobra.Command{
fmt.Sprintf("Error getting routes from flag: %s", err), fmt.Sprintf("Error getting routes from flag: %s", err),
output, output,
) )
return return
} }
@ -99,6 +103,7 @@ var createNodeCmd = &cobra.Command{
fmt.Sprintf("Cannot create machine: %s", status.Convert(err).Message()), fmt.Sprintf("Cannot create machine: %s", status.Convert(err).Message()),
output, output,
) )
return return
} }

View File

@ -31,6 +31,7 @@ var createNamespaceCmd = &cobra.Command{
if len(args) < 1 { if len(args) < 1 {
return fmt.Errorf("Missing parameters") return fmt.Errorf("Missing parameters")
} }
return nil return nil
}, },
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
@ -57,6 +58,7 @@ var createNamespaceCmd = &cobra.Command{
), ),
output, output,
) )
return return
} }
@ -71,6 +73,7 @@ var destroyNamespaceCmd = &cobra.Command{
if len(args) < 1 { if len(args) < 1 {
return fmt.Errorf("Missing parameters") return fmt.Errorf("Missing parameters")
} }
return nil return nil
}, },
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
@ -93,6 +96,7 @@ var destroyNamespaceCmd = &cobra.Command{
fmt.Sprintf("Error: %s", status.Convert(err).Message()), fmt.Sprintf("Error: %s", status.Convert(err).Message()),
output, output,
) )
return return
} }
@ -124,6 +128,7 @@ var destroyNamespaceCmd = &cobra.Command{
), ),
output, output,
) )
return return
} }
SuccessOutput(response, "Namespace destroyed", output) SuccessOutput(response, "Namespace destroyed", output)
@ -152,11 +157,13 @@ var listNamespacesCmd = &cobra.Command{
fmt.Sprintf("Cannot get namespaces: %s", status.Convert(err).Message()), fmt.Sprintf("Cannot get namespaces: %s", status.Convert(err).Message()),
output, output,
) )
return return
} }
if output != "" { if output != "" {
SuccessOutput(response.Namespaces, "", output) SuccessOutput(response.Namespaces, "", output)
return return
} }
@ -178,6 +185,7 @@ var listNamespacesCmd = &cobra.Command{
fmt.Sprintf("Failed to render pterm table: %s", err), fmt.Sprintf("Failed to render pterm table: %s", err),
output, output,
) )
return return
} }
}, },
@ -190,6 +198,7 @@ var renameNamespaceCmd = &cobra.Command{
if len(args) < 2 { if len(args) < 2 {
return fmt.Errorf("Missing parameters") return fmt.Errorf("Missing parameters")
} }
return nil return nil
}, },
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
@ -214,6 +223,7 @@ var renameNamespaceCmd = &cobra.Command{
), ),
output, output,
) )
return return
} }

View File

@ -77,6 +77,7 @@ var registerNodeCmd = &cobra.Command{
namespace, err := cmd.Flags().GetString("namespace") namespace, err := cmd.Flags().GetString("namespace")
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output) ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
return return
} }
@ -91,6 +92,7 @@ var registerNodeCmd = &cobra.Command{
fmt.Sprintf("Error getting machine key from flag: %s", err), fmt.Sprintf("Error getting machine key from flag: %s", err),
output, output,
) )
return return
} }
@ -109,6 +111,7 @@ var registerNodeCmd = &cobra.Command{
), ),
output, output,
) )
return return
} }
@ -124,6 +127,7 @@ var listNodesCmd = &cobra.Command{
namespace, err := cmd.Flags().GetString("namespace") namespace, err := cmd.Flags().GetString("namespace")
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output) ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
return return
} }
@ -142,17 +146,20 @@ var listNodesCmd = &cobra.Command{
fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()), fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()),
output, output,
) )
return return
} }
if output != "" { if output != "" {
SuccessOutput(response.Machines, "", output) SuccessOutput(response.Machines, "", output)
return return
} }
d, err := nodesToPtables(namespace, response.Machines) d, err := nodesToPtables(namespace, response.Machines)
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output) ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
return return
} }
@ -163,6 +170,7 @@ var listNodesCmd = &cobra.Command{
fmt.Sprintf("Failed to render pterm table: %s", err), fmt.Sprintf("Failed to render pterm table: %s", err),
output, output,
) )
return return
} }
}, },
@ -181,6 +189,7 @@ var deleteNodeCmd = &cobra.Command{
fmt.Sprintf("Error converting ID to integer: %s", err), fmt.Sprintf("Error converting ID to integer: %s", err),
output, output,
) )
return return
} }
@ -202,6 +211,7 @@ var deleteNodeCmd = &cobra.Command{
), ),
output, output,
) )
return return
} }
@ -228,6 +238,7 @@ var deleteNodeCmd = &cobra.Command{
response, err := client.DeleteMachine(ctx, deleteRequest) response, err := client.DeleteMachine(ctx, deleteRequest)
if output != "" { if output != "" {
SuccessOutput(response, "", output) SuccessOutput(response, "", output)
return return
} }
if err != nil { if err != nil {
@ -239,6 +250,7 @@ var deleteNodeCmd = &cobra.Command{
), ),
output, output,
) )
return return
} }
SuccessOutput( SuccessOutput(
@ -260,6 +272,7 @@ func sharingWorker(
namespaceStr, err := cmd.Flags().GetString("namespace") namespaceStr, err := cmd.Flags().GetString("namespace")
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output) ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
return "", nil, nil, err return "", nil, nil, err
} }
@ -270,6 +283,7 @@ func sharingWorker(
id, err := cmd.Flags().GetInt("identifier") id, err := cmd.Flags().GetInt("identifier")
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting ID to integer: %s", err), output) ErrorOutput(err, fmt.Sprintf("Error converting ID to integer: %s", err), output)
return "", nil, nil, err return "", nil, nil, err
} }
@ -284,6 +298,7 @@ func sharingWorker(
fmt.Sprintf("Error getting node node: %s", status.Convert(err).Message()), fmt.Sprintf("Error getting node node: %s", status.Convert(err).Message()),
output, output,
) )
return "", nil, nil, err return "", nil, nil, err
} }
@ -298,6 +313,7 @@ func sharingWorker(
fmt.Sprintf("Error getting node node: %s", status.Convert(err).Message()), fmt.Sprintf("Error getting node node: %s", status.Convert(err).Message()),
output, output,
) )
return "", nil, nil, err return "", nil, nil, err
} }
@ -315,6 +331,7 @@ var shareMachineCmd = &cobra.Command{
fmt.Sprintf("Failed to fetch namespace or machine: %s", err), fmt.Sprintf("Failed to fetch namespace or machine: %s", err),
output, output,
) )
return return
} }
@ -334,6 +351,7 @@ var shareMachineCmd = &cobra.Command{
fmt.Sprintf("Error sharing node: %s", status.Convert(err).Message()), fmt.Sprintf("Error sharing node: %s", status.Convert(err).Message()),
output, output,
) )
return return
} }
@ -352,6 +370,7 @@ var unshareMachineCmd = &cobra.Command{
fmt.Sprintf("Failed to fetch namespace or machine: %s", err), fmt.Sprintf("Failed to fetch namespace or machine: %s", err),
output, output,
) )
return return
} }
@ -371,6 +390,7 @@ var unshareMachineCmd = &cobra.Command{
fmt.Sprintf("Error unsharing node: %s", status.Convert(err).Message()), fmt.Sprintf("Error unsharing node: %s", status.Convert(err).Message()),
output, output,
) )
return return
} }
@ -442,5 +462,6 @@ func nodesToPtables(
}, },
) )
} }
return d, nil return d, nil
} }

View File

@ -44,6 +44,7 @@ var listPreAuthKeys = &cobra.Command{
n, err := cmd.Flags().GetString("namespace") n, err := cmd.Flags().GetString("namespace")
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output) ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
return return
} }
@ -62,11 +63,13 @@ var listPreAuthKeys = &cobra.Command{
fmt.Sprintf("Error getting the list of keys: %s", err), fmt.Sprintf("Error getting the list of keys: %s", err),
output, output,
) )
return return
} }
if output != "" { if output != "" {
SuccessOutput(response.PreAuthKeys, "", output) SuccessOutput(response.PreAuthKeys, "", output)
return return
} }
@ -104,6 +107,7 @@ var listPreAuthKeys = &cobra.Command{
fmt.Sprintf("Failed to render pterm table: %s", err), fmt.Sprintf("Failed to render pterm table: %s", err),
output, output,
) )
return return
} }
}, },
@ -118,6 +122,7 @@ var createPreAuthKeyCmd = &cobra.Command{
namespace, err := cmd.Flags().GetString("namespace") namespace, err := cmd.Flags().GetString("namespace")
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output) ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
return return
} }
@ -156,6 +161,7 @@ var createPreAuthKeyCmd = &cobra.Command{
fmt.Sprintf("Cannot create Pre Auth Key: %s\n", err), fmt.Sprintf("Cannot create Pre Auth Key: %s\n", err),
output, output,
) )
return return
} }
@ -170,6 +176,7 @@ var expirePreAuthKeyCmd = &cobra.Command{
if len(args) < 1 { if len(args) < 1 {
return fmt.Errorf("missing parameters") return fmt.Errorf("missing parameters")
} }
return nil return nil
}, },
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
@ -177,6 +184,7 @@ var expirePreAuthKeyCmd = &cobra.Command{
namespace, err := cmd.Flags().GetString("namespace") namespace, err := cmd.Flags().GetString("namespace")
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output) ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
return return
} }
@ -196,6 +204,7 @@ var expirePreAuthKeyCmd = &cobra.Command{
fmt.Sprintf("Cannot expire Pre Auth Key: %s\n", err), fmt.Sprintf("Cannot expire Pre Auth Key: %s\n", err),
output, output,
) )
return return
} }

View File

@ -52,6 +52,7 @@ var listRoutesCmd = &cobra.Command{
fmt.Sprintf("Error getting machine id from flag: %s", err), fmt.Sprintf("Error getting machine id from flag: %s", err),
output, output,
) )
return return
} }
@ -70,17 +71,20 @@ var listRoutesCmd = &cobra.Command{
fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()), fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()),
output, output,
) )
return return
} }
if output != "" { if output != "" {
SuccessOutput(response.Routes, "", output) SuccessOutput(response.Routes, "", output)
return return
} }
d := routesToPtables(response.Routes) d := routesToPtables(response.Routes)
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output) ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
return return
} }
@ -91,6 +95,7 @@ var listRoutesCmd = &cobra.Command{
fmt.Sprintf("Failed to render pterm table: %s", err), fmt.Sprintf("Failed to render pterm table: %s", err),
output, output,
) )
return return
} }
}, },
@ -113,6 +118,7 @@ omit the route you do not want to enable.
fmt.Sprintf("Error getting machine id from flag: %s", err), fmt.Sprintf("Error getting machine id from flag: %s", err),
output, output,
) )
return return
} }
@ -123,6 +129,7 @@ omit the route you do not want to enable.
fmt.Sprintf("Error getting routes from flag: %s", err), fmt.Sprintf("Error getting routes from flag: %s", err),
output, output,
) )
return return
} }
@ -145,17 +152,20 @@ omit the route you do not want to enable.
), ),
output, output,
) )
return return
} }
if output != "" { if output != "" {
SuccessOutput(response.Routes, "", output) SuccessOutput(response.Routes, "", output)
return return
} }
d := routesToPtables(response.Routes) d := routesToPtables(response.Routes)
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output) ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
return return
} }
@ -166,6 +176,7 @@ omit the route you do not want to enable.
fmt.Sprintf("Failed to render pterm table: %s", err), fmt.Sprintf("Failed to render pterm table: %s", err),
output, output,
) )
return return
} }
}, },
@ -180,6 +191,7 @@ func routesToPtables(routes *v1.Routes) pterm.TableData {
d = append(d, []string{route, strconv.FormatBool(enabled)}) d = append(d, []string{route, strconv.FormatBool(enabled)})
} }
return d return d
} }

View File

@ -213,6 +213,7 @@ func absPath(path string) string {
path = filepath.Join(dir, path) path = filepath.Join(dir, path)
} }
} }
return path return path
} }
@ -310,6 +311,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
viper.GetString("ephemeral_node_inactivity_timeout"), viper.GetString("ephemeral_node_inactivity_timeout"),
minInactivityTimeout, minInactivityTimeout,
) )
return nil, err return nil, err
} }
@ -415,6 +417,7 @@ func SuccessOutput(result interface{}, override string, outputFormat string) {
} }
default: default:
fmt.Println(override) fmt.Println(override)
return return
} }
@ -435,6 +438,7 @@ func HasMachineOutputFlag() bool {
return true return true
} }
} }
return false return false
} }

4
db.go
View File

@ -50,6 +50,7 @@ func (h *Headscale) initDB() error {
} }
err = h.setValue("db_version", dbVersion) err = h.setValue("db_version", dbVersion)
return err return err
} }
@ -93,6 +94,7 @@ func (h *Headscale) getValue(key string) (string, error) {
) { ) {
return "", errors.New("not found") return "", errors.New("not found")
} }
return row.Value, nil return row.Value, nil
} }
@ -106,9 +108,11 @@ func (h *Headscale) setValue(key string, value string) error {
_, err := h.getValue(key) _, err := h.getValue(key)
if err == nil { if err == nil {
h.db.Model(&kv).Where("key = ?", key).Update("value", value) h.db.Model(&kv).Where("key = ?", key).Update("value", value)
return nil return nil
} }
h.db.Create(kv) h.db.Create(kv)
return nil return nil
} }

View File

@ -27,6 +27,7 @@ func loadDERPMapFromPath(path string) (*tailcfg.DERPMap, error) {
return nil, err return nil, err
} }
err = yaml.Unmarshal(b, &derpMap) err = yaml.Unmarshal(b, &derpMap)
return &derpMap, err return &derpMap, err
} }
@ -56,6 +57,7 @@ func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) {
var derpMap tailcfg.DERPMap var derpMap tailcfg.DERPMap
err = json.Unmarshal(body, &derpMap) err = json.Unmarshal(body, &derpMap)
return &derpMap, err return &derpMap, err
} }
@ -94,6 +96,7 @@ func GetDERPMap(cfg DERPConfig) *tailcfg.DERPMap {
Str("path", path). Str("path", path).
Err(err). Err(err).
Msg("Could not load DERP map from path") Msg("Could not load DERP map from path")
break break
} }
@ -112,6 +115,7 @@ func GetDERPMap(cfg DERPConfig) *tailcfg.DERPMap {
Str("url", addr.String()). Str("url", addr.String()).
Err(err). Err(err).
Msg("Could not load DERP map from path") Msg("Could not load DERP map from path")
break break
} }

2
dns.go
View File

@ -69,6 +69,7 @@ func generateMagicDNSRootDomains(
} }
fqdns = append(fqdns, fqdn) fqdns = append(fqdns, fqdn)
} }
return fqdns, nil return fqdns, nil
} }
@ -99,5 +100,6 @@ func getMapResponseDNSConfig(
} else { } else {
dnsConfig = dnsConfigOrig dnsConfig = dnsConfigOrig
} }
return dnsConfig, nil return dnsConfig, nil
} }

View File

@ -18,6 +18,7 @@ func (s *Suite) TestMagicDNSRootDomains100(c *check.C) {
for _, domain := range domains { for _, domain := range domains {
if domain == "64.100.in-addr.arpa." { if domain == "64.100.in-addr.arpa." {
found = true found = true
break break
} }
} }
@ -27,6 +28,7 @@ func (s *Suite) TestMagicDNSRootDomains100(c *check.C) {
for _, domain := range domains { for _, domain := range domains {
if domain == "100.100.in-addr.arpa." { if domain == "100.100.in-addr.arpa." {
found = true found = true
break break
} }
} }
@ -36,6 +38,7 @@ func (s *Suite) TestMagicDNSRootDomains100(c *check.C) {
for _, domain := range domains { for _, domain := range domains {
if domain == "127.100.in-addr.arpa." { if domain == "127.100.in-addr.arpa." {
found = true found = true
break break
} }
} }
@ -51,6 +54,7 @@ func (s *Suite) TestMagicDNSRootDomains172(c *check.C) {
for _, domain := range domains { for _, domain := range domains {
if domain == "0.16.172.in-addr.arpa." { if domain == "0.16.172.in-addr.arpa." {
found = true found = true
break break
} }
} }
@ -60,6 +64,7 @@ func (s *Suite) TestMagicDNSRootDomains172(c *check.C) {
for _, domain := range domains { for _, domain := range domains {
if domain == "255.16.172.in-addr.arpa." { if domain == "255.16.172.in-addr.arpa." {
found = true found = true
break break
} }
} }

View File

@ -106,6 +106,7 @@ func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) {
if err := h.db.Preload("Namespace").Where("namespace_id = ? AND machine_key <> ? AND registered", if err := h.db.Preload("Namespace").Where("namespace_id = ? AND machine_key <> ? AND registered",
m.NamespaceID, m.MachineKey).Find(&machines).Error; err != nil { m.NamespaceID, m.MachineKey).Find(&machines).Error; err != nil {
log.Error().Err(err).Msg("Error accessing db") log.Error().Err(err).Msg("Error accessing db")
return Machines{}, err return Machines{}, err
} }
@ -115,6 +116,7 @@ func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) {
Caller(). Caller().
Str("machine", m.Name). Str("machine", m.Name).
Msgf("Found direct machines: %s", machines.String()) Msgf("Found direct machines: %s", machines.String())
return machines, nil return machines, nil
} }
@ -142,6 +144,7 @@ func (h *Headscale) getShared(m *Machine) (Machines, error) {
Caller(). Caller().
Str("machine", m.Name). Str("machine", m.Name).
Msgf("Found shared peers: %s", peers.String()) Msgf("Found shared peers: %s", peers.String())
return peers, nil return peers, nil
} }
@ -175,6 +178,7 @@ func (h *Headscale) getSharedTo(m *Machine) (Machines, error) {
Caller(). Caller().
Str("machine", m.Name). Str("machine", m.Name).
Msgf("Found peers we are shared with: %s", peers.String()) Msgf("Found peers we are shared with: %s", peers.String())
return peers, nil return peers, nil
} }
@ -185,6 +189,7 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) {
Caller(). Caller().
Err(err). Err(err).
Msg("Cannot fetch peers") Msg("Cannot fetch peers")
return Machines{}, err return Machines{}, err
} }
@ -194,6 +199,7 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) {
Caller(). Caller().
Err(err). Err(err).
Msg("Cannot fetch peers") Msg("Cannot fetch peers")
return Machines{}, err return Machines{}, err
} }
@ -203,6 +209,7 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) {
Caller(). Caller().
Err(err). Err(err).
Msg("Cannot fetch peers") Msg("Cannot fetch peers")
return Machines{}, err return Machines{}, err
} }
@ -224,6 +231,7 @@ func (h *Headscale) ListMachines() ([]Machine, error) {
if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Find(&machines).Error; err != nil { if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Find(&machines).Error; err != nil {
return nil, err return nil, err
} }
return machines, nil return machines, nil
} }
@ -239,6 +247,7 @@ func (h *Headscale) GetMachine(namespace string, name string) (*Machine, error)
return &m, nil return &m, nil
} }
} }
return nil, fmt.Errorf("machine not found") return nil, fmt.Errorf("machine not found")
} }
@ -248,6 +257,7 @@ func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) {
if result := h.db.Preload("Namespace").Find(&Machine{ID: id}).First(&m); result.Error != nil { if result := h.db.Preload("Namespace").Find(&Machine{ID: id}).First(&m); result.Error != nil {
return nil, result.Error return nil, result.Error
} }
return &m, nil return &m, nil
} }
@ -257,6 +267,7 @@ func (h *Headscale) GetMachineByMachineKey(mKey string) (*Machine, error) {
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey); result.Error != nil { if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey); result.Error != nil {
return nil, result.Error return nil, result.Error
} }
return &m, nil return &m, nil
} }
@ -266,6 +277,7 @@ func (h *Headscale) UpdateMachine(m *Machine) error {
if result := h.db.Find(m).First(&m); result.Error != nil { if result := h.db.Find(m).First(&m); result.Error != nil {
return result.Error return result.Error
} }
return nil return nil
} }
@ -314,6 +326,7 @@ func (m *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) {
return nil, err return nil, err
} }
} }
return &hostinfo, nil return &hostinfo, nil
} }
@ -348,6 +361,7 @@ func (h *Headscale) isOutdated(m *Machine) bool {
Time("last_successful_update", *m.LastSuccessfulUpdate). Time("last_successful_update", *m.LastSuccessfulUpdate).
Time("last_state_change", lastChange). Time("last_state_change", lastChange).
Msgf("Checking if %s is missing updates", m.Name) Msgf("Checking if %s is missing updates", m.Name)
return m.LastSuccessfulUpdate.Before(lastChange) return m.LastSuccessfulUpdate.Before(lastChange)
} }
@ -429,6 +443,7 @@ func (m Machine) toNode(
Caller(). Caller().
Str("ip", m.IPAddress). Str("ip", m.IPAddress).
Msgf("Failed to parse IP Prefix from IP: %s", m.IPAddress) Msgf("Failed to parse IP Prefix from IP: %s", m.IPAddress)
return nil, err return nil, err
} }
addrs = append(addrs, ip) // missing the ipv6 ? addrs = append(addrs, ip) // missing the ipv6 ?
@ -530,6 +545,7 @@ func (m Machine) toNode(
MachineAuthorized: m.Registered, MachineAuthorized: m.Registered,
Capabilities: []string{tailcfg.CapabilityFileSharing}, Capabilities: []string{tailcfg.CapabilityFileSharing},
} }
return &n, nil return &n, nil
} }
@ -613,6 +629,7 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err
Err(err). Err(err).
Str("machine", m.Name). Str("machine", m.Name).
Msg("Could not find IP for the new machine") Msg("Could not find IP for the new machine")
return nil, err return nil, err
} }
@ -642,6 +659,7 @@ func (m *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return hostInfo.RoutableIPs, nil return hostInfo.RoutableIPs, nil
} }
@ -685,6 +703,7 @@ func (m *Machine) IsRoutesEnabled(routeStr string) bool {
return true return true
} }
} }
return false return false
} }

View File

@ -42,8 +42,10 @@ func (h *Headscale) CreateNamespace(name string) (*Namespace, error) {
Str("func", "CreateNamespace"). Str("func", "CreateNamespace").
Err(err). Err(err).
Msg("Could not create row") Msg("Could not create row")
return nil, err return nil, err
} }
return &n, nil return &n, nil
} }
@ -119,6 +121,7 @@ func (h *Headscale) GetNamespace(name string) (*Namespace, error) {
) { ) {
return nil, errorNamespaceNotFound return nil, errorNamespaceNotFound
} }
return &n, nil return &n, nil
} }
@ -128,6 +131,7 @@ func (h *Headscale) ListNamespaces() ([]Namespace, error) {
if err := h.db.Find(&namespaces).Error; err != nil { if err := h.db.Find(&namespaces).Error; err != nil {
return nil, err return nil, err
} }
return namespaces, nil return namespaces, nil
} }
@ -142,6 +146,7 @@ func (h *Headscale) ListMachinesInNamespace(name string) ([]Machine, error) {
if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Where(&Machine{NamespaceID: n.ID}).Find(&machines).Error; err != nil { if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Where(&Machine{NamespaceID: n.ID}).Find(&machines).Error; err != nil {
return nil, err return nil, err
} }
return machines, nil return machines, nil
} }
@ -166,6 +171,7 @@ func (h *Headscale) ListSharedMachinesInNamespace(name string) ([]Machine, error
} }
machines = append(machines, *machine) machines = append(machines, *machine)
} }
return machines, nil return machines, nil
} }
@ -177,6 +183,7 @@ func (h *Headscale) SetMachineNamespace(m *Machine, namespaceName string) error
} }
m.NamespaceID = n.ID m.NamespaceID = n.ID
h.db.Save(&m) h.db.Save(&m)
return nil return nil
} }
@ -196,6 +203,7 @@ func (h *Headscale) RequestMapUpdates(namespaceID uint) error {
if err != nil { if err != nil {
return err return err
} }
return nil return nil
} }
names := []string{} names := []string{}
@ -208,6 +216,7 @@ func (h *Headscale) RequestMapUpdates(namespaceID uint) error {
if err != nil { if err != nil {
return err return err
} }
return nil return nil
} }
@ -218,8 +227,10 @@ func (h *Headscale) RequestMapUpdates(namespaceID uint) error {
Str("func", "RequestMapUpdates"). Str("func", "RequestMapUpdates").
Err(err). Err(err).
Msg("Could not marshal namespaces_pending_updates") Msg("Could not marshal namespaces_pending_updates")
return err return err
} }
return h.setValue("namespaces_pending_updates", string(data)) return h.setValue("namespaces_pending_updates", string(data))
} }
@ -255,6 +266,7 @@ func (h *Headscale) checkForNamespacesPendingUpdates() {
Str("func", "checkForNamespacesPendingUpdates"). Str("func", "checkForNamespacesPendingUpdates").
Err(err). Err(err).
Msg("Could not save to KV") Msg("Could not save to KV")
return return
} }
} }
@ -270,6 +282,7 @@ func (n *Namespace) toUser() *tailcfg.User {
Logins: []tailcfg.LoginID{}, Logins: []tailcfg.LoginID{},
Created: time.Time{}, Created: time.Time{},
} }
return &u return &u
} }
@ -281,6 +294,7 @@ func (n *Namespace) toLogin() *tailcfg.Login {
ProfilePicURL: "", ProfilePicURL: "",
Domain: "headscale.net", Domain: "headscale.net",
} }
return &l return &l
} }
@ -300,6 +314,7 @@ func getMapResponseUserProfiles(m Machine, peers Machines) []tailcfg.UserProfile
DisplayName: namespace.Name, DisplayName: namespace.Name,
}) })
} }
return profiles return profiles
} }

View File

@ -199,6 +199,7 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
for _, up := range userProfiles { for _, up := range userProfiles {
if up.DisplayName == n1.Name { if up.DisplayName == n1.Name {
found = true found = true
break break
} }
} }
@ -208,6 +209,7 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
for _, up := range userProfiles { for _, up := range userProfiles {
if up.DisplayName == n2.Name { if up.DisplayName == n2.Name {
found = true found = true
break break
} }
} }

13
oidc.go
View File

@ -32,6 +32,7 @@ func (h *Headscale) initOIDC() error {
if err != nil { if err != nil {
log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error()) log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error())
return err return err
} }
@ -62,6 +63,7 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) {
mKeyStr := c.Param("mkey") mKeyStr := c.Param("mkey")
if mKeyStr == "" { if mKeyStr == "" {
c.String(http.StatusBadRequest, "Wrong params") c.String(http.StatusBadRequest, "Wrong params")
return return
} }
@ -70,6 +72,7 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) {
if err != nil { if err != nil {
log.Error().Msg("could not read 16 bytes from rand") log.Error().Msg("could not read 16 bytes from rand")
c.String(http.StatusInternalServerError, "could not read 16 bytes from rand") c.String(http.StatusInternalServerError, "could not read 16 bytes from rand")
return return
} }
@ -95,12 +98,14 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
if code == "" || state == "" { if code == "" || state == "" {
c.String(http.StatusBadRequest, "Wrong params") c.String(http.StatusBadRequest, "Wrong params")
return return
} }
oauth2Token, err := h.oauth2Config.Exchange(context.Background(), code) oauth2Token, err := h.oauth2Config.Exchange(context.Background(), code)
if err != nil { if err != nil {
c.String(http.StatusBadRequest, "Could not exchange code for token") c.String(http.StatusBadRequest, "Could not exchange code for token")
return return
} }
@ -109,6 +114,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
rawIDToken, rawIDTokenOK := oauth2Token.Extra("id_token").(string) rawIDToken, rawIDTokenOK := oauth2Token.Extra("id_token").(string)
if !rawIDTokenOK { if !rawIDTokenOK {
c.String(http.StatusBadRequest, "Could not extract ID Token") c.String(http.StatusBadRequest, "Could not extract ID Token")
return return
} }
@ -117,6 +123,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
idToken, err := verifier.Verify(context.Background(), rawIDToken) idToken, err := verifier.Verify(context.Background(), rawIDToken)
if err != nil { if err != nil {
c.String(http.StatusBadRequest, "Failed to verify id token: %s", err.Error()) c.String(http.StatusBadRequest, "Failed to verify id token: %s", err.Error())
return return
} }
@ -134,6 +141,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
http.StatusBadRequest, http.StatusBadRequest,
fmt.Sprintf("Failed to decode id token claims: %s", err), fmt.Sprintf("Failed to decode id token claims: %s", err),
) )
return return
} }
@ -144,6 +152,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
log.Error(). log.Error().
Msg("requested machine state key expired before authorisation completed") Msg("requested machine state key expired before authorisation completed")
c.String(http.StatusBadRequest, "state has expired") c.String(http.StatusBadRequest, "state has expired")
return return
} }
mKeyStr, mKeyOK := mKeyIf.(string) mKeyStr, mKeyOK := mKeyIf.(string)
@ -151,6 +160,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
if !mKeyOK { if !mKeyOK {
log.Error().Msg("could not get machine key from cache") log.Error().Msg("could not get machine key from cache")
c.String(http.StatusInternalServerError, "could not get machine key from cache") c.String(http.StatusInternalServerError, "could not get machine key from cache")
return return
} }
@ -162,6 +172,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
http.StatusInternalServerError, http.StatusInternalServerError,
"could not get machine info from database", "could not get machine info from database",
) )
return return
} }
@ -183,6 +194,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
http.StatusInternalServerError, http.StatusInternalServerError,
"could not create new namespace", "could not create new namespace",
) )
return return
} }
} }
@ -193,6 +205,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
http.StatusInternalServerError, http.StatusInternalServerError,
"could not get an IP from the pool", "could not get an IP from the pool",
) )
return return
} }

14
poll.go
View File

@ -38,6 +38,7 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
Err(err). Err(err).
Msg("Cannot parse client key") Msg("Cannot parse client key")
c.String(http.StatusBadRequest, "") c.String(http.StatusBadRequest, "")
return return
} }
req := tailcfg.MapRequest{} req := tailcfg.MapRequest{}
@ -48,6 +49,7 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
Err(err). Err(err).
Msg("Cannot decode message") Msg("Cannot decode message")
c.String(http.StatusBadRequest, "") c.String(http.StatusBadRequest, "")
return return
} }
@ -58,6 +60,7 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Msgf("Ignoring request, cannot find machine with key %s", mKey.HexString()) Msgf("Ignoring request, cannot find machine with key %s", mKey.HexString())
c.String(http.StatusUnauthorized, "") c.String(http.StatusUnauthorized, "")
return return
} }
log.Error(). log.Error().
@ -101,6 +104,7 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
Err(err). Err(err).
Msg("Failed to get Map response") Msg("Failed to get Map response")
c.String(http.StatusInternalServerError, ":(") c.String(http.StatusInternalServerError, ":(")
return return
} }
@ -124,6 +128,7 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
Str("machine", m.Name). Str("machine", m.Name).
Msg("Client is starting up. Probably interested in a DERP map") Msg("Client is starting up. Probably interested in a DERP map")
c.Data(200, "application/json; charset=utf-8", data) c.Data(200, "application/json; charset=utf-8", data)
return return
} }
@ -161,6 +166,7 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
updateRequestsFromNode.WithLabelValues(m.Name, m.Namespace.Name, "endpoint-update"). updateRequestsFromNode.WithLabelValues(m.Name, m.Namespace.Name, "endpoint-update").
Inc() Inc()
go func() { updateChan <- struct{}{} }() go func() { updateChan <- struct{}{} }()
return return
} else if req.OmitPeers && req.Stream { } else if req.OmitPeers && req.Stream {
log.Warn(). log.Warn().
@ -168,6 +174,7 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
Str("machine", m.Name). Str("machine", m.Name).
Msg("Ignoring request, don't know how to handle it") Msg("Ignoring request, don't know how to handle it")
c.String(http.StatusBadRequest, "") c.String(http.StatusBadRequest, "")
return return
} }
@ -248,6 +255,7 @@ func (h *Headscale) PollNetMapStream(
Str("channel", "pollData"). Str("channel", "pollData").
Err(err). Err(err).
Msg("Cannot write data") Msg("Cannot write data")
return false return false
} }
log.Trace(). log.Trace().
@ -282,6 +290,7 @@ func (h *Headscale) PollNetMapStream(
Str("channel", "pollData"). Str("channel", "pollData").
Int("bytes", len(data)). Int("bytes", len(data)).
Msg("Machine entry in database updated successfully after sending pollData") Msg("Machine entry in database updated successfully after sending pollData")
return true return true
case data := <-keepAliveChan: case data := <-keepAliveChan:
@ -299,6 +308,7 @@ func (h *Headscale) PollNetMapStream(
Str("channel", "keepAlive"). Str("channel", "keepAlive").
Err(err). Err(err).
Msg("Cannot write keep alive message") Msg("Cannot write keep alive message")
return false return false
} }
log.Trace(). log.Trace().
@ -328,6 +338,7 @@ func (h *Headscale) PollNetMapStream(
Str("channel", "keepAlive"). Str("channel", "keepAlive").
Int("bytes", len(data)). Int("bytes", len(data)).
Msg("Machine updated successfully after sending keep alive") Msg("Machine updated successfully after sending keep alive")
return true return true
case <-updateChan: case <-updateChan:
@ -364,6 +375,7 @@ func (h *Headscale) PollNetMapStream(
Msg("Could not write the map response") Msg("Could not write the map response")
updateRequestsSentToNode.WithLabelValues(m.Name, m.Namespace.Name, "failed"). updateRequestsSentToNode.WithLabelValues(m.Name, m.Namespace.Name, "failed").
Inc() Inc()
return false return false
} }
log.Trace(). log.Trace().
@ -405,6 +417,7 @@ func (h *Headscale) PollNetMapStream(
Time("last_state_change", h.getLastStateChange(m.Namespace.Name)). Time("last_state_change", h.getLastStateChange(m.Namespace.Name)).
Msgf("%s is up to date", m.Name) Msgf("%s is up to date", m.Name)
} }
return true return true
case <-c.Request.Context().Done(): case <-c.Request.Context().Done():
@ -485,6 +498,7 @@ func (h *Headscale) scheduledPollWorker(
Str("func", "keepAlive"). Str("func", "keepAlive").
Err(err). Err(err).
Msg("Error generating the keep alive msg") Msg("Error generating the keep alive msg")
return return
} }

View File

@ -75,6 +75,7 @@ func (h *Headscale) ListPreAuthKeys(namespaceName string) ([]PreAuthKey, error)
if err := h.db.Preload("Namespace").Where(&PreAuthKey{NamespaceID: n.ID}).Find(&keys).Error; err != nil { if err := h.db.Preload("Namespace").Where(&PreAuthKey{NamespaceID: n.ID}).Find(&keys).Error; err != nil {
return nil, err return nil, err
} }
return keys, nil return keys, nil
} }
@ -107,6 +108,7 @@ func (h *Headscale) ExpirePreAuthKey(k *PreAuthKey) error {
if err := h.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil { if err := h.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
return err return err
} }
return nil return nil
} }
@ -147,6 +149,7 @@ func (h *Headscale) generateKey() (string, error) {
if _, err := rand.Read(bytes); err != nil { if _, err := rand.Read(bytes); err != nil {
return "", err return "", err
} }
return hex.EncodeToString(bytes), nil return hex.EncodeToString(bytes), nil
} }

View File

@ -24,6 +24,7 @@ func (h *Headscale) GetAdvertisedNodeRoutes(
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &hostInfo.RoutableIPs, nil return &hostInfo.RoutableIPs, nil
} }
@ -84,6 +85,7 @@ func (h *Headscale) IsNodeRouteEnabled(
return true return true
} }
} }
return false return false
} }

View File

@ -57,6 +57,7 @@ func SwaggerUI(c *gin.Context) {
"text/html; charset=utf-8", "text/html; charset=utf-8",
[]byte("Could not render Swagger"), []byte("Could not render Swagger"),
) )
return return
} }

View File

@ -48,6 +48,7 @@ func decodeMsg(
if err := json.Unmarshal(decrypted, v); err != nil { if err := json.Unmarshal(decrypted, v); err != nil {
return fmt.Errorf("response: %v", err) return fmt.Errorf("response: %v", err)
} }
return nil return nil
} }
@ -64,6 +65,7 @@ func decryptMsg(msg []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte,
if !ok { if !ok {
return nil, fmt.Errorf("cannot decrypt response") return nil, fmt.Errorf("cannot decrypt response")
} }
return decrypted, nil return decrypted, nil
} }
@ -83,6 +85,7 @@ func encodeMsg(b []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, err
} }
pub, pri := (*[32]byte)(pubKey), (*[32]byte)(privKey) pub, pri := (*[32]byte)(pubKey), (*[32]byte)(privKey)
msg := box.Seal(nonce[:], b, &nonce, pub, pri) msg := box.Seal(nonce[:], b, &nonce, pub, pri)
return msg, nil return msg, nil
} }
@ -108,12 +111,14 @@ func (h *Headscale) getAvailableIP() (*netaddr.IP, error) {
ipRaw := ip.As4() ipRaw := ip.As4()
if ipRaw[3] == 0 || ipRaw[3] == 255 { if ipRaw[3] == 0 || ipRaw[3] == 255 {
ip = ip.Next() ip = ip.Next()
continue continue
} }
if ip.IsZero() && if ip.IsZero() &&
ip.IsLoopback() { ip.IsLoopback() {
ip = ip.Next() ip = ip.Next()
continue continue
} }
@ -174,6 +179,7 @@ func tailMapResponseToString(resp tailcfg.MapResponse) string {
func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) { func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) {
var d net.Dialer var d net.Dialer
return d.DialContext(ctx, "unix", addr) return d.DialContext(ctx, "unix", addr)
} }