integration: replace time.Sleep with assert.EventuallyWithT (#2680)

This commit is contained in:
Kristoffer Dalby
2025-07-10 23:38:55 +02:00
committed by GitHub
parent b904276f2b
commit c6d7b512bd
73 changed files with 584 additions and 573 deletions

View File

@@ -98,7 +98,6 @@ func (h *Headscale) handleExistingNode(
return nil, nil
}
}
n, policyChanged, err := h.state.SetNodeExpiry(node.ID, requestExpiry)
@@ -169,7 +168,6 @@ func (h *Headscale) handleRegisterWithAuthKey(
regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) {
node, changed, err := h.state.HandleNodeFromPreAuthKey(
regReq,
machineKey,
@@ -178,9 +176,11 @@ func (h *Headscale) handleRegisterWithAuthKey(
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, NewHTTPError(http.StatusUnauthorized, "invalid pre auth key", nil)
}
if perr, ok := err.(types.PAKError); ok {
var perr types.PAKError
if errors.As(err, &perr) {
return nil, NewHTTPError(http.StatusUnauthorized, perr.Error(), nil)
}
return nil, err
}

View File

@@ -1,11 +1,10 @@
package capver
import (
"slices"
"sort"
"strings"
"slices"
xmaps "golang.org/x/exp/maps"
"tailscale.com/tailcfg"
"tailscale.com/util/set"

View File

@@ -1,6 +1,6 @@
package capver
//Generated DO NOT EDIT
// Generated DO NOT EDIT
import "tailscale.com/tailcfg"
@@ -38,17 +38,16 @@ var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{
"v1.82.5": 115,
}
var capVerToTailscaleVer = map[tailcfg.CapabilityVersion]string{
87: "v1.60.0",
88: "v1.62.0",
90: "v1.64.0",
95: "v1.66.0",
97: "v1.68.0",
102: "v1.70.0",
104: "v1.72.0",
106: "v1.74.0",
109: "v1.78.0",
113: "v1.80.0",
115: "v1.82.0",
87: "v1.60.0",
88: "v1.62.0",
90: "v1.64.0",
95: "v1.66.0",
97: "v1.68.0",
102: "v1.70.0",
104: "v1.72.0",
106: "v1.74.0",
109: "v1.78.0",
113: "v1.80.0",
115: "v1.82.0",
}

View File

@@ -764,13 +764,13 @@ AND auth_key_id NOT IN (
// Drop all indexes first to avoid conflicts
indexesToDrop := []string{
"idx_users_deleted_at",
"idx_provider_identifier",
"idx_provider_identifier",
"idx_name_provider_identifier",
"idx_name_no_provider_identifier",
"idx_api_keys_prefix",
"idx_policies_deleted_at",
}
for _, index := range indexesToDrop {
_ = tx.Exec("DROP INDEX IF EXISTS " + index).Error
}
@@ -927,6 +927,7 @@ AND auth_key_id NOT IN (
}
log.Info().Msg("Schema recreation completed successfully")
return nil
},
Rollback: func(db *gorm.DB) error { return nil },

View File

@@ -93,7 +93,7 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) {
Avoid: false,
Nodes: []*tailcfg.DERPNode{
{
Name: fmt.Sprintf("%d", d.cfg.ServerRegionID),
Name: strconv.Itoa(d.cfg.ServerRegionID),
RegionID: d.cfg.ServerRegionID,
HostName: host,
DERPPort: port,

View File

@@ -103,7 +103,6 @@ func (e *ExtraRecordsMan) Run() {
return struct{}{}, nil
}, backoff.WithBackOff(backoff.NewExponentialBackOff()))
if err != nil {
log.Error().Caller().Err(err).Msgf("extra records filewatcher retrying to find file after delete")
continue

View File

@@ -475,7 +475,10 @@ func (api headscaleV1APIServer) RenameNode(
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
ctx = types.NotifyCtx(ctx, "cli-renamenode", node.Hostname)
ctx = types.NotifyCtx(ctx, "cli-renamenode-self", node.Hostname)
api.h.nodeNotifier.NotifyByNodeID(ctx, types.UpdateSelf(node.ID), node.ID)
ctx = types.NotifyCtx(ctx, "cli-renamenode-peers", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
log.Trace().

View File

@@ -32,7 +32,7 @@ const (
reservedResponseHeaderSize = 4
)
// httpError logs an error and sends an HTTP error response with the given
// httpError logs an error and sends an HTTP error response with the given.
func httpError(w http.ResponseWriter, err error) {
var herr HTTPError
if errors.As(err, &herr) {
@@ -102,6 +102,7 @@ func (h *Headscale) handleVerifyRequest(
resp := &tailcfg.DERPAdmitClientResponse{
Allow: nodes.ContainsNodeKey(derpAdmitClientRequest.NodePublic),
}
return json.NewEncoder(writer).Encode(resp)
}

View File

@@ -500,7 +500,7 @@ func (m *Mapper) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.
}
// ListNodes queries the database for either all nodes if no parameters are given
// or for the given nodes if at least one node ID is given as parameter
// or for the given nodes if at least one node ID is given as parameter.
func (m *Mapper) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
nodes, err := m.state.ListNodes(nodeIDs...)
if err != nil {

View File

@@ -80,7 +80,7 @@ func TestDNSConfigMapResponse(t *testing.T) {
}
}
// mockState is a mock implementation that provides the required methods
// mockState is a mock implementation that provides the required methods.
type mockState struct {
polMan policy.PolicyManager
derpMap *tailcfg.DERPMap
@@ -133,6 +133,7 @@ func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (typ
}
}
}
return filtered, nil
}
// Return all peers except the node itself
@@ -142,6 +143,7 @@ func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (typ
filtered = append(filtered, peer)
}
}
return filtered, nil
}
@@ -157,8 +159,10 @@ func (m *mockState) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
}
}
}
return filtered, nil
}
return m.nodes, nil
}

View File

@@ -11,7 +11,7 @@ import (
"tailscale.com/types/views"
)
// NodeCanHaveTagChecker is an interface for checking if a node can have a tag
// NodeCanHaveTagChecker is an interface for checking if a node can have a tag.
type NodeCanHaveTagChecker interface {
NodeCanHaveTag(node types.NodeView, tag string) bool
}

View File

@@ -111,5 +111,6 @@ func (r *respWriterProm) Write(b []byte) (int, error) {
}
n, err := r.ResponseWriter.Write(b)
r.written += int64(n)
return n, err
}

View File

@@ -50,6 +50,7 @@ func NewNotifier(cfg *types.Config) *Notifier {
n.b = b
go b.doWork()
return n
}
@@ -72,7 +73,7 @@ func (n *Notifier) Close() {
n.nodes = make(map[types.NodeID]chan<- types.StateUpdate)
}
// safeCloseChannel closes a channel and panic recovers if already closed
// safeCloseChannel closes a channel and panic recovers if already closed.
func (n *Notifier) safeCloseChannel(nodeID types.NodeID, c chan<- types.StateUpdate) {
defer func() {
if r := recover(); r != nil {
@@ -170,6 +171,7 @@ func (n *Notifier) IsConnected(nodeID types.NodeID) bool {
if val, ok := n.connected.Load(nodeID); ok {
return val
}
return false
}
@@ -182,7 +184,7 @@ func (n *Notifier) IsLikelyConnected(nodeID types.NodeID) bool {
return false
}
// LikelyConnectedMap returns a thread safe map of connected nodes
// LikelyConnectedMap returns a thread safe map of connected nodes.
func (n *Notifier) LikelyConnectedMap() *xsync.MapOf[types.NodeID, bool] {
return n.connected
}

View File

@@ -1,17 +1,15 @@
package notifier
import (
"context"
"fmt"
"math/rand"
"net/netip"
"slices"
"sort"
"sync"
"testing"
"time"
"slices"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
@@ -241,7 +239,7 @@ func TestBatcher(t *testing.T) {
defer n.RemoveNode(1, ch)
for _, u := range tt.updates {
n.NotifyAll(context.Background(), u)
n.NotifyAll(t.Context(), u)
}
n.b.flush()
@@ -270,7 +268,7 @@ func TestBatcher(t *testing.T) {
// TestIsLikelyConnectedRaceCondition tests for a race condition in IsLikelyConnected
// Multiple goroutines calling AddNode and RemoveNode cause panics when trying to
// close a channel that was already closed, which can happen when a node changes
// network transport quickly (eg mobile->wifi) and reconnects whilst also disconnecting
// network transport quickly (eg mobile->wifi) and reconnects whilst also disconnecting.
func TestIsLikelyConnectedRaceCondition(t *testing.T) {
// mock config for the notifier
cfg := &types.Config{
@@ -308,16 +306,17 @@ func TestIsLikelyConnectedRaceCondition(t *testing.T) {
for range iterations {
// Simulate race by having some goroutines check IsLikelyConnected
// while others add/remove the node
if routineID%3 == 0 {
switch routineID % 3 {
case 0:
// This goroutine checks connection status
isConnected := notifier.IsLikelyConnected(nodeID)
if isConnected != true && isConnected != false {
errChan <- fmt.Sprintf("Invalid connection status: %v", isConnected)
}
} else if routineID%3 == 1 {
case 1:
// This goroutine removes the node
notifier.RemoveNode(nodeID, updateChan)
} else {
default:
// This goroutine adds the node back
notifier.AddNode(nodeID, updateChan)
}

View File

@@ -84,11 +84,8 @@ func NewAuthProviderOIDC(
ClientID: cfg.ClientID,
ClientSecret: cfg.ClientSecret,
Endpoint: oidcProvider.Endpoint(),
RedirectURL: fmt.Sprintf(
"%s/oidc/callback",
strings.TrimSuffix(serverURL, "/"),
),
Scopes: cfg.Scope,
RedirectURL: strings.TrimSuffix(serverURL, "/") + "/oidc/callback",
Scopes: cfg.Scope,
}
registrationCache := zcache.New[string, RegistrationInfo](
@@ -131,7 +128,7 @@ func (a *AuthProviderOIDC) RegisterHandler(
req *http.Request,
) {
vars := mux.Vars(req)
registrationIdStr, _ := vars["registration_id"]
registrationIdStr := vars["registration_id"]
// We need to make sure we dont open for XSS style injections, if the parameter that
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render
@@ -232,7 +229,6 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
}
oauth2Token, err := a.getOauth2Token(req.Context(), code, state)
if err != nil {
httpError(writer, err)
return
@@ -364,6 +360,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
// Neither node nor machine key was found in the state cache meaning
// that we could not reauth nor register the node.
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil))
return
}
@@ -402,6 +399,7 @@ func (a *AuthProviderOIDC) getOauth2Token(
if err != nil {
return nil, NewHTTPError(http.StatusForbidden, "invalid code", fmt.Errorf("could not exchange code for token: %w", err))
}
return oauth2Token, err
}

View File

@@ -2,9 +2,8 @@ package matcher
import (
"net/netip"
"strings"
"slices"
"strings"
"github.com/juanfont/headscale/hscontrol/util"
"go4.org/netipx"
@@ -28,6 +27,7 @@ func (m Match) DebugString() string {
for _, prefix := range m.dests.Prefixes() {
sb.WriteString(" " + prefix.String() + "\n")
}
return sb.String()
}
@@ -36,6 +36,7 @@ func MatchesFromFilterRules(rules []tailcfg.FilterRule) []Match {
for _, rule := range rules {
matches = append(matches, MatchFromFilterRule(rule))
}
return matches
}

View File

@@ -4,7 +4,6 @@ import (
"net/netip"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
"github.com/juanfont/headscale/hscontrol/types"
"tailscale.com/tailcfg"

View File

@@ -5,7 +5,6 @@ import (
"slices"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/samber/lo"
@@ -131,7 +130,7 @@ func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcf
// AutoApproveRoutes approves any route that can be autoapproved from
// the nodes perspective according to the given policy.
// It reports true if any routes were approved.
// Note: This function now takes a pointer to the actual node to modify ApprovedRoutes
// Note: This function now takes a pointer to the actual node to modify ApprovedRoutes.
func AutoApproveRoutes(pm PolicyManager, node *types.Node) bool {
if pm == nil {
return false

View File

@@ -7,9 +7,8 @@ import (
"testing"
"time"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
@@ -1974,6 +1973,7 @@ func TestSSHPolicyRules(t *testing.T) {
}
}
}
func TestReduceRoutes(t *testing.T) {
type args struct {
node *types.Node

View File

@@ -13,9 +13,7 @@ import (
"tailscale.com/types/views"
)
var (
ErrInvalidAction = errors.New("invalid action")
)
var ErrInvalidAction = errors.New("invalid action")
// compileFilterRules takes a set of nodes and an ACLPolicy and generates a
// set of Tailscale compatible FilterRules used to allow traffic on clients.
@@ -52,7 +50,7 @@ func (pol *Policy) compileFilterRules(
var destPorts []tailcfg.NetPortRange
for _, dest := range acl.Destinations {
ips, err := dest.Alias.Resolve(pol, users, nodes)
ips, err := dest.Resolve(pol, users, nodes)
if err != nil {
log.Trace().Err(err).Msgf("resolving destination ips")
}
@@ -174,5 +172,6 @@ func ipSetToPrefixStringList(ips *netipx.IPSet) []string {
for _, pref := range ips.Prefixes() {
out = append(out, pref.String())
}
return out
}

View File

@@ -4,19 +4,17 @@ import (
"encoding/json"
"fmt"
"net/netip"
"slices"
"strings"
"sync"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"slices"
"github.com/juanfont/headscale/hscontrol/types"
"go4.org/netipx"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/util/deephash"
"tailscale.com/types/views"
"tailscale.com/util/deephash"
)
type PolicyManager struct {
@@ -166,6 +164,7 @@ func (pm *PolicyManager) Filter() ([]tailcfg.FilterRule, []matcher.Match) {
pm.mu.Lock()
defer pm.mu.Unlock()
return pm.filter, pm.matchers
}
@@ -178,6 +177,7 @@ func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) {
pm.mu.Lock()
defer pm.mu.Unlock()
pm.users = users
return pm.updateLocked()
}
@@ -190,6 +190,7 @@ func (pm *PolicyManager) SetNodes(nodes views.Slice[types.NodeView]) (bool, erro
pm.mu.Lock()
defer pm.mu.Unlock()
pm.nodes = nodes
return pm.updateLocked()
}
@@ -249,7 +250,6 @@ func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Pr
// cannot just lookup in the prefix map and have to check
// if there is a "parent" prefix available.
for prefix, approveAddrs := range pm.autoApproveMap {
// Check if prefix is larger (so containing) and then overlaps
// the route to see if the node can approve a subset of an autoapprover
if prefix.Bits() <= route.Bits() && prefix.Overlaps(route) {

View File

@@ -1,10 +1,10 @@
package v2
import (
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/require"
"gorm.io/gorm"

View File

@@ -6,9 +6,9 @@ import (
"errors"
"fmt"
"net/netip"
"strings"
"slices"
"strconv"
"strings"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
@@ -72,14 +72,14 @@ func (a AliasWithPorts) MarshalJSON() ([]byte, error) {
// Check if it's the wildcard port range
if len(a.Ports) == 1 && a.Ports[0].First == 0 && a.Ports[0].Last == 65535 {
return json.Marshal(fmt.Sprintf("%s:*", alias))
return json.Marshal(alias + ":*")
}
// Otherwise, format as "alias:ports"
var ports []string
for _, port := range a.Ports {
if port.First == port.Last {
ports = append(ports, fmt.Sprintf("%d", port.First))
ports = append(ports, strconv.FormatUint(uint64(port.First), 10))
} else {
ports = append(ports, fmt.Sprintf("%d-%d", port.First, port.Last))
}
@@ -133,6 +133,7 @@ func (u *Username) UnmarshalJSON(b []byte) error {
if err := u.Validate(); err != nil {
return err
}
return nil
}
@@ -203,7 +204,7 @@ func (u Username) Resolve(_ *Policy, users types.Users, nodes views.Slice[types.
return buildIPSetMultiErr(&ips, errs)
}
// Group is a special string which is always prefixed with `group:`
// Group is a special string which is always prefixed with `group:`.
type Group string
func (g Group) Validate() error {
@@ -218,6 +219,7 @@ func (g *Group) UnmarshalJSON(b []byte) error {
if err := g.Validate(); err != nil {
return err
}
return nil
}
@@ -264,7 +266,7 @@ func (g Group) Resolve(p *Policy, users types.Users, nodes views.Slice[types.Nod
return buildIPSetMultiErr(&ips, errs)
}
// Tag is a special string which is always prefixed with `tag:`
// Tag is a special string which is always prefixed with `tag:`.
type Tag string
func (t Tag) Validate() error {
@@ -279,6 +281,7 @@ func (t *Tag) UnmarshalJSON(b []byte) error {
if err := t.Validate(); err != nil {
return err
}
return nil
}
@@ -347,6 +350,7 @@ func (h *Host) UnmarshalJSON(b []byte) error {
if err := h.Validate(); err != nil {
return err
}
return nil
}
@@ -409,6 +413,7 @@ func (p *Prefix) parseString(addr string) error {
}
*p = Prefix(addrPref)
return nil
}
@@ -417,6 +422,7 @@ func (p *Prefix) parseString(addr string) error {
return err
}
*p = Prefix(pref)
return nil
}
@@ -428,6 +434,7 @@ func (p *Prefix) UnmarshalJSON(b []byte) error {
if err := p.Validate(); err != nil {
return err
}
return nil
}
@@ -462,7 +469,7 @@ func appendIfNodeHasIP(nodes views.Slice[types.NodeView], ips *netipx.IPSetBuild
}
}
// AutoGroup is a special string which is always prefixed with `autogroup:`
// AutoGroup is a special string which is always prefixed with `autogroup:`.
type AutoGroup string
const (
@@ -495,6 +502,7 @@ func (ag *AutoGroup) UnmarshalJSON(b []byte) error {
if err := ag.Validate(); err != nil {
return err
}
return nil
}
@@ -632,13 +640,14 @@ func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error {
if err != nil {
return err
}
if err := ve.Alias.Validate(); err != nil {
if err := ve.Validate(); err != nil {
return err
}
default:
return fmt.Errorf("type %T not supported", vs)
}
return nil
}
@@ -713,6 +722,7 @@ func (ve *AliasEnc) UnmarshalJSON(b []byte) error {
return err
}
ve.Alias = ptr
return nil
}
@@ -729,6 +739,7 @@ func (a *Aliases) UnmarshalJSON(b []byte) error {
for i, alias := range aliases {
(*a)[i] = alias.Alias
}
return nil
}
@@ -784,7 +795,7 @@ func buildIPSetMultiErr(ipBuilder *netipx.IPSetBuilder, errs []error) (*netipx.I
return ips, multierr.New(append(errs, err)...)
}
// Helper function to unmarshal a JSON string into either an AutoApprover or Owner pointer
// Helper function to unmarshal a JSON string into either an AutoApprover or Owner pointer.
func unmarshalPointer[T any](
b []byte,
parseFunc func(string) (T, error),
@@ -818,6 +829,7 @@ func (aa *AutoApprovers) UnmarshalJSON(b []byte) error {
for i, autoApprover := range autoApprovers {
(*aa)[i] = autoApprover.AutoApprover
}
return nil
}
@@ -874,6 +886,7 @@ func (ve *AutoApproverEnc) UnmarshalJSON(b []byte) error {
return err
}
ve.AutoApprover = ptr
return nil
}
@@ -894,6 +907,7 @@ func (ve *OwnerEnc) UnmarshalJSON(b []byte) error {
return err
}
ve.Owner = ptr
return nil
}
@@ -910,6 +924,7 @@ func (o *Owners) UnmarshalJSON(b []byte) error {
for i, owner := range owners {
(*o)[i] = owner.Owner
}
return nil
}
@@ -941,6 +956,7 @@ func parseOwner(s string) (Owner, error) {
case isGroup(s):
return ptr.To(Group(s)), nil
}
return nil, fmt.Errorf(`Invalid Owner %q. An alias must be one of the following types:
- user (containing an "@")
- group (starting with "group:")
@@ -1001,6 +1017,7 @@ func (g *Groups) UnmarshalJSON(b []byte) error {
(*g)[group] = usernames
}
return nil
}
@@ -1252,7 +1269,7 @@ type Policy struct {
// We use the default JSON marshalling behavior provided by the Go runtime.
var (
// TODO(kradalby): Add these checks for tagOwners and autoApprovers
// TODO(kradalby): Add these checks for tagOwners and autoApprovers.
autogroupForSrc = []AutoGroup{AutoGroupMember, AutoGroupTagged}
autogroupForDst = []AutoGroup{AutoGroupInternet, AutoGroupMember, AutoGroupTagged}
autogroupForSSHSrc = []AutoGroup{AutoGroupMember, AutoGroupTagged}
@@ -1279,7 +1296,7 @@ func validateAutogroupForSrc(src *AutoGroup) error {
}
if src.Is(AutoGroupInternet) {
return fmt.Errorf(`"autogroup:internet" used in source, it can only be used in ACL destinations`)
return errors.New(`"autogroup:internet" used in source, it can only be used in ACL destinations`)
}
if !slices.Contains(autogroupForSrc, *src) {
@@ -1307,7 +1324,7 @@ func validateAutogroupForSSHSrc(src *AutoGroup) error {
}
if src.Is(AutoGroupInternet) {
return fmt.Errorf(`"autogroup:internet" used in SSH source, it can only be used in ACL destinations`)
return errors.New(`"autogroup:internet" used in SSH source, it can only be used in ACL destinations`)
}
if !slices.Contains(autogroupForSSHSrc, *src) {
@@ -1323,7 +1340,7 @@ func validateAutogroupForSSHDst(dst *AutoGroup) error {
}
if dst.Is(AutoGroupInternet) {
return fmt.Errorf(`"autogroup:internet" used in SSH destination, it can only be used in ACL destinations`)
return errors.New(`"autogroup:internet" used in SSH destination, it can only be used in ACL destinations`)
}
if !slices.Contains(autogroupForSSHDst, *dst) {
@@ -1360,14 +1377,14 @@ func (p *Policy) validate() error {
for _, acl := range p.ACLs {
for _, src := range acl.Sources {
switch src.(type) {
switch src := src.(type) {
case *Host:
h := src.(*Host)
h := src
if !p.Hosts.exist(*h) {
errs = append(errs, fmt.Errorf(`Host %q is not defined in the Policy, please define or remove the reference to it`, *h))
}
case *AutoGroup:
ag := src.(*AutoGroup)
ag := src
if err := validateAutogroupSupported(ag); err != nil {
errs = append(errs, err)
@@ -1379,12 +1396,12 @@ func (p *Policy) validate() error {
continue
}
case *Group:
g := src.(*Group)
g := src
if err := p.Groups.Contains(g); err != nil {
errs = append(errs, err)
}
case *Tag:
tagOwner := src.(*Tag)
tagOwner := src
if err := p.TagOwners.Contains(tagOwner); err != nil {
errs = append(errs, err)
}
@@ -1440,9 +1457,9 @@ func (p *Policy) validate() error {
}
for _, src := range ssh.Sources {
switch src.(type) {
switch src := src.(type) {
case *AutoGroup:
ag := src.(*AutoGroup)
ag := src
if err := validateAutogroupSupported(ag); err != nil {
errs = append(errs, err)
@@ -1454,21 +1471,21 @@ func (p *Policy) validate() error {
continue
}
case *Group:
g := src.(*Group)
g := src
if err := p.Groups.Contains(g); err != nil {
errs = append(errs, err)
}
case *Tag:
tagOwner := src.(*Tag)
tagOwner := src
if err := p.TagOwners.Contains(tagOwner); err != nil {
errs = append(errs, err)
}
}
}
for _, dst := range ssh.Destinations {
switch dst.(type) {
switch dst := dst.(type) {
case *AutoGroup:
ag := dst.(*AutoGroup)
ag := dst
if err := validateAutogroupSupported(ag); err != nil {
errs = append(errs, err)
continue
@@ -1479,7 +1496,7 @@ func (p *Policy) validate() error {
continue
}
case *Tag:
tagOwner := dst.(*Tag)
tagOwner := dst
if err := p.TagOwners.Contains(tagOwner); err != nil {
errs = append(errs, err)
}
@@ -1489,9 +1506,9 @@ func (p *Policy) validate() error {
for _, tagOwners := range p.TagOwners {
for _, tagOwner := range tagOwners {
switch tagOwner.(type) {
switch tagOwner := tagOwner.(type) {
case *Group:
g := tagOwner.(*Group)
g := tagOwner
if err := p.Groups.Contains(g); err != nil {
errs = append(errs, err)
}
@@ -1501,14 +1518,14 @@ func (p *Policy) validate() error {
for _, approvers := range p.AutoApprovers.Routes {
for _, approver := range approvers {
switch approver.(type) {
switch approver := approver.(type) {
case *Group:
g := approver.(*Group)
g := approver
if err := p.Groups.Contains(g); err != nil {
errs = append(errs, err)
}
case *Tag:
tagOwner := approver.(*Tag)
tagOwner := approver
if err := p.TagOwners.Contains(tagOwner); err != nil {
errs = append(errs, err)
}
@@ -1517,14 +1534,14 @@ func (p *Policy) validate() error {
}
for _, approver := range p.AutoApprovers.ExitNode {
switch approver.(type) {
switch approver := approver.(type) {
case *Group:
g := approver.(*Group)
g := approver
if err := p.Groups.Contains(g); err != nil {
errs = append(errs, err)
}
case *Tag:
tagOwner := approver.(*Tag)
tagOwner := approver
if err := p.TagOwners.Contains(tagOwner); err != nil {
errs = append(errs, err)
}
@@ -1536,6 +1553,7 @@ func (p *Policy) validate() error {
}
p.validated = true
return nil
}
@@ -1589,6 +1607,7 @@ func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error {
)
}
}
return nil
}
@@ -1618,6 +1637,7 @@ func (a *SSHDstAliases) UnmarshalJSON(b []byte) error {
)
}
}
return nil
}

View File

@@ -5,13 +5,13 @@ import (
"net/netip"
"strings"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/prometheus/common/model"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go4.org/netipx"
@@ -68,7 +68,7 @@ func TestMarshalJSON(t *testing.T) {
// Marshal the policy to JSON
marshalled, err := json.MarshalIndent(policy, "", " ")
require.NoError(t, err)
// Make sure all expected fields are present in the JSON
jsonString := string(marshalled)
assert.Contains(t, jsonString, "group:example")
@@ -79,21 +79,21 @@ func TestMarshalJSON(t *testing.T) {
assert.Contains(t, jsonString, "accept")
assert.Contains(t, jsonString, "tcp")
assert.Contains(t, jsonString, "80")
// Unmarshal back to verify round trip
var roundTripped Policy
err = json.Unmarshal(marshalled, &roundTripped)
require.NoError(t, err)
// Compare the original and round-tripped policies
cmps := append(util.Comparers,
cmps := append(util.Comparers,
cmp.Comparer(func(x, y Prefix) bool {
return x == y
}),
cmpopts.IgnoreUnexported(Policy{}),
cmpopts.EquateEmpty(),
)
if diff := cmp.Diff(policy, &roundTripped, cmps...); diff != "" {
t.Fatalf("round trip policy (-original +roundtripped):\n%s", diff)
}
@@ -958,13 +958,13 @@ func TestUnmarshalPolicy(t *testing.T) {
},
}
cmps := append(util.Comparers,
cmps := append(util.Comparers,
cmp.Comparer(func(x, y Prefix) bool {
return x == y
}),
cmpopts.IgnoreUnexported(Policy{}),
)
// For round-trip testing, we'll normalize the policies before comparing
for _, tt := range tests {
@@ -981,6 +981,7 @@ func TestUnmarshalPolicy(t *testing.T) {
} else if !strings.Contains(err.Error(), tt.wantErr) {
t.Fatalf("unmarshalling: got err %v; want error %q", err, tt.wantErr)
}
return // Skip the rest of the test if we expected an error
}
@@ -1001,9 +1002,9 @@ func TestUnmarshalPolicy(t *testing.T) {
if err != nil {
t.Fatalf("round-trip unmarshalling: %v", err)
}
// Add EquateEmpty to handle nil vs empty maps/slices
roundTripCmps := append(cmps,
roundTripCmps := append(cmps,
cmpopts.EquateEmpty(),
cmpopts.IgnoreUnexported(Policy{}),
)
@@ -1584,6 +1585,7 @@ func mustIPSet(prefixes ...string) *netipx.IPSet {
builder.AddPrefix(mp(p))
}
ipSet, _ := builder.IPSet()
return ipSet
}

View File

@@ -73,10 +73,10 @@ func TestParsePortRange(t *testing.T) {
expected []tailcfg.PortRange
err string
}{
{"80", []tailcfg.PortRange{{80, 80}}, ""},
{"80-90", []tailcfg.PortRange{{80, 90}}, ""},
{"80,90", []tailcfg.PortRange{{80, 80}, {90, 90}}, ""},
{"80-91,92,93-95", []tailcfg.PortRange{{80, 91}, {92, 92}, {93, 95}}, ""},
{"80", []tailcfg.PortRange{{First: 80, Last: 80}}, ""},
{"80-90", []tailcfg.PortRange{{First: 80, Last: 90}}, ""},
{"80,90", []tailcfg.PortRange{{First: 80, Last: 80}, {First: 90, Last: 90}}, ""},
{"80-91,92,93-95", []tailcfg.PortRange{{First: 80, Last: 91}, {First: 92, Last: 92}, {First: 93, Last: 95}}, ""},
{"*", []tailcfg.PortRange{tailcfg.PortRangeAny}, ""},
{"80-", nil, "invalid port range format"},
{"-90", nil, "invalid port range format"},

View File

@@ -158,6 +158,7 @@ func (pr *PrimaryRoutes) PrimaryRoutes(id types.NodeID) []netip.Prefix {
}
tsaddr.SortPrefixes(routes)
return routes
}

View File

@@ -429,6 +429,7 @@ func (s *State) GetNodeViewByID(nodeID types.NodeID) (types.NodeView, error) {
if err != nil {
return types.NodeView{}, err
}
return node.View(), nil
}
@@ -443,6 +444,7 @@ func (s *State) GetNodeViewByNodeKey(nodeKey key.NodePublic) (types.NodeView, er
if err != nil {
return types.NodeView{}, err
}
return node.View(), nil
}
@@ -701,7 +703,7 @@ func (s *State) HandleNodeFromPreAuthKey(
if !regReq.Expiry.IsZero() && regReq.Expiry.After(time.Now()) {
nodeToRegister.Expiry = &regReq.Expiry
} else if !regReq.Expiry.IsZero() {
// If client is sending an expired time (e.g., after logout),
// If client is sending an expired time (e.g., after logout),
// don't set expiry so the node won't be considered expired
log.Debug().
Time("requested_expiry", regReq.Expiry).

View File

@@ -2,6 +2,7 @@ package hscontrol
import (
"context"
"errors"
"fmt"
"net/http"
"os"
@@ -70,7 +71,7 @@ func runTailSQLService(ctx context.Context, logf logger.Logf, stateDir, dbPath s
// When serving TLS, add a redirect from HTTP on port 80 to HTTPS on 443.
certDomains := tsNode.CertDomains()
if len(certDomains) == 0 {
return fmt.Errorf("no cert domains available for HTTPS")
return errors.New("no cert domains available for HTTPS")
}
base := "https://" + certDomains[0]
go http.Serve(lst, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -95,5 +96,6 @@ func runTailSQLService(ctx context.Context, logf logger.Logf, stateDir, dbPath s
logf("TailSQL started")
<-ctx.Done()
logf("TailSQL shutting down...")
return tsNode.Close()
}

View File

@@ -62,7 +62,7 @@ func Apple(url string) *elem.Element {
),
elem.Pre(nil,
elem.Code(nil,
elem.Text(fmt.Sprintf("tailscale login --login-server %s", url)),
elem.Text("tailscale login --login-server "+url),
),
),
headerTwo("GUI"),
@@ -143,10 +143,7 @@ func Apple(url string) *elem.Element {
elem.Code(
nil,
elem.Text(
fmt.Sprintf(
`defaults write io.tailscale.ipn.macos ControlURL %s`,
url,
),
"defaults write io.tailscale.ipn.macos ControlURL "+url,
),
),
),
@@ -155,10 +152,7 @@ func Apple(url string) *elem.Element {
elem.Code(
nil,
elem.Text(
fmt.Sprintf(
`defaults write io.tailscale.ipn.macsys ControlURL %s`,
url,
),
"defaults write io.tailscale.ipn.macsys ControlURL "+url,
),
),
),

View File

@@ -1,8 +1,6 @@
package templates
import (
"fmt"
"github.com/chasefleming/elem-go"
"github.com/chasefleming/elem-go/attrs"
)
@@ -31,7 +29,7 @@ func Windows(url string) *elem.Element {
),
elem.Pre(nil,
elem.Code(nil,
elem.Text(fmt.Sprintf(`tailscale login --login-server %s`, url)),
elem.Text("tailscale login --login-server "+url),
),
),
),

View File

@@ -180,6 +180,7 @@ func MustRegistrationID() RegistrationID {
if err != nil {
panic(err)
}
return rid
}

View File

@@ -339,6 +339,7 @@ func LoadConfig(path string, isFile bool) error {
log.Warn().Msg("No config file found, using defaults")
return nil
}
return fmt.Errorf("fatal error reading config file: %w", err)
}
@@ -843,7 +844,7 @@ func LoadServerConfig() (*Config, error) {
}
if prefix4 == nil && prefix6 == nil {
return nil, fmt.Errorf("no IPv4 or IPv6 prefix configured, minimum one prefix is required")
return nil, errors.New("no IPv4 or IPv6 prefix configured, minimum one prefix is required")
}
allocStr := viper.GetString("prefixes.allocation")
@@ -1020,7 +1021,7 @@ func isSafeServerURL(serverURL, baseDomain string) error {
s := len(serverDomainParts)
b := len(baseDomainParts)
for i := range len(baseDomainParts) {
for i := range baseDomainParts {
if serverDomainParts[s-i-1] != baseDomainParts[b-i-1] {
return nil
}

View File

@@ -282,6 +282,7 @@ func TestReadConfigFromEnv(t *testing.T) {
assert.Equal(t, "trace", viper.GetString("log.level"))
assert.Equal(t, "100.64.0.0/10", viper.GetString("prefixes.v4"))
assert.False(t, viper.GetBool("database.sqlite.write_ahead_log"))
return nil, nil
},
want: nil,

View File

@@ -28,8 +28,10 @@ var (
ErrNodeUserHasNoName = errors.New("node user has no name")
)
type NodeID uint64
type NodeIDs []NodeID
type (
NodeID uint64
NodeIDs []NodeID
)
func (n NodeIDs) Len() int { return len(n) }
func (n NodeIDs) Less(i, j int) bool { return n[i] < n[j] }
@@ -169,6 +171,7 @@ func (node *Node) HasIP(i netip.Addr) bool {
return true
}
}
return false
}
@@ -176,7 +179,7 @@ func (node *Node) HasIP(i netip.Addr) bool {
// and therefore should not be treated as a
// user owned device.
// Currently, this function only handles tags set
// via CLI ("forced tags" and preauthkeys)
// via CLI ("forced tags" and preauthkeys).
func (node *Node) IsTagged() bool {
if len(node.ForcedTags) > 0 {
return true
@@ -199,7 +202,7 @@ func (node *Node) IsTagged() bool {
// HasTag reports if a node has a given tag.
// Currently, this function only handles tags set
// via CLI ("forced tags" and preauthkeys)
// via CLI ("forced tags" and preauthkeys).
func (node *Node) HasTag(tag string) bool {
return slices.Contains(node.Tags(), tag)
}
@@ -577,6 +580,7 @@ func (nodes Nodes) DebugString() string {
sb.WriteString(node.DebugString())
sb.WriteString("\n")
}
return sb.String()
}
@@ -590,6 +594,7 @@ func (node Node) DebugString() string {
fmt.Fprintf(&sb, "\tAnnouncedRoutes: %v\n", node.AnnouncedRoutes())
fmt.Fprintf(&sb, "\tSubnetRoutes: %v\n", node.SubnetRoutes())
sb.WriteString("\n")
return sb.String()
}
@@ -689,7 +694,7 @@ func (v NodeView) Tags() []string {
// and therefore should not be treated as a
// user owned device.
// Currently, this function only handles tags set
// via CLI ("forced tags" and preauthkeys)
// via CLI ("forced tags" and preauthkeys).
func (v NodeView) IsTagged() bool {
if !v.Valid() {
return false
@@ -727,7 +732,7 @@ func (v NodeView) PeerChangeFromMapRequest(req tailcfg.MapRequest) tailcfg.PeerC
// GetFQDN returns the fully qualified domain name for the node.
func (v NodeView) GetFQDN(baseDomain string) (string, error) {
if !v.Valid() {
return "", fmt.Errorf("failed to create valid FQDN: node view is invalid")
return "", errors.New("failed to create valid FQDN: node view is invalid")
}
return v.ж.GetFQDN(baseDomain)
}
@@ -773,4 +778,3 @@ func (v NodeView) IPsAsString() []string {
}
return v.ж.IPsAsString()
}

View File

@@ -2,7 +2,6 @@ package types
import (
"fmt"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"net/netip"
"strings"
"testing"
@@ -10,6 +9,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/util"
"tailscale.com/tailcfg"
"tailscale.com/types/key"

View File

@@ -11,7 +11,7 @@ import (
type PAKError string
func (e PAKError) Error() string { return string(e) }
func (e PAKError) Unwrap() error { return fmt.Errorf("preauth key error: %s", e) }
func (e PAKError) Unwrap() error { return fmt.Errorf("preauth key error: %w", e) }
// PreAuthKey describes a pre-authorization key usable in a particular user.
type PreAuthKey struct {

View File

@@ -1,6 +1,7 @@
package types
import (
"errors"
"testing"
"time"
@@ -109,7 +110,8 @@ func TestCanUsePreAuthKey(t *testing.T) {
if err == nil {
t.Errorf("expected error but got none")
} else {
httpErr, ok := err.(PAKError)
var httpErr PAKError
ok := errors.As(err, &httpErr)
if !ok {
t.Errorf("expected HTTPError but got %T", err)
} else {

View File

@@ -249,7 +249,7 @@ func (c *OIDCClaims) Identifier() string {
// - Remove empty path segments
// - For non-URL identifiers, it joins non-empty segments with a single slash
// - Returns empty string for identifiers with only slashes
// - Normalize URL schemes to lowercase
// - Normalize URL schemes to lowercase.
func CleanIdentifier(identifier string) string {
if identifier == "" {
return identifier
@@ -273,7 +273,7 @@ func CleanIdentifier(identifier string) string {
cleanParts = append(cleanParts, part)
}
}
if len(cleanParts) == 0 {
u.Path = ""
} else {
@@ -281,6 +281,7 @@ func CleanIdentifier(identifier string) string {
}
// Ensure scheme is lowercase
u.Scheme = strings.ToLower(u.Scheme)
return u.String()
}
@@ -297,6 +298,7 @@ func CleanIdentifier(identifier string) string {
if len(cleanParts) == 0 {
return ""
}
return strings.Join(cleanParts, "/")
}

View File

@@ -1,4 +1,6 @@
package types
var Version = "dev"
var GitCommitHash = "dev"
var (
Version = "dev"
GitCommitHash = "dev"
)

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"net/netip"
"regexp"
"strconv"
"strings"
"unicode"
@@ -21,8 +22,10 @@ const (
LabelHostnameLength = 63
)
var invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+")
var invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+")
var (
invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+")
invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+")
)
var ErrInvalidUserName = errors.New("invalid user name")
@@ -141,7 +144,7 @@ func GenerateIPv4DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN {
// here we generate the base domain (e.g., 100.in-addr.arpa., 16.172.in-addr.arpa., etc.)
rdnsSlice := []string{}
for i := lastOctet - 1; i >= 0; i-- {
rdnsSlice = append(rdnsSlice, fmt.Sprintf("%d", netRange.IP[i]))
rdnsSlice = append(rdnsSlice, strconv.FormatUint(uint64(netRange.IP[i]), 10))
}
rdnsSlice = append(rdnsSlice, "in-addr.arpa.")
rdnsBase := strings.Join(rdnsSlice, ".")
@@ -205,7 +208,7 @@ func GenerateIPv6DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN {
makeDomain := func(variablePrefix ...string) (dnsname.FQDN, error) {
prefix := strings.Join(append(variablePrefix, prefixConstantParts...), ".")
return dnsname.ToFQDN(fmt.Sprintf("%s.ip6.arpa", prefix))
return dnsname.ToFQDN(prefix + ".ip6.arpa")
}
var fqdns []dnsname.FQDN

View File

@@ -70,7 +70,7 @@ func (l *DBLogWrapper) Trace(ctx context.Context, begin time.Time, fc func() (sq
"rowsAffected": rowsAffected,
}
if err != nil && !(errors.Is(err, gorm.ErrRecordNotFound) && l.SkipErrRecordNotFound) {
if err != nil && (!errors.Is(err, gorm.ErrRecordNotFound) || !l.SkipErrRecordNotFound) {
l.Logger.Error().Err(err).Fields(fields).Msgf("")
return
}

View File

@@ -58,5 +58,6 @@ var TheInternet = sync.OnceValue(func() *netipx.IPSet {
internetBuilder.RemovePrefix(netip.MustParsePrefix("169.254.0.0/16"))
theInternetSet, _ := internetBuilder.IPSet()
return theInternetSet
})

View File

@@ -53,37 +53,37 @@ func ParseLoginURLFromCLILogin(output string) (*url.URL, error) {
}
type TraceroutePath struct {
// Hop is the current jump in the total traceroute.
Hop int
// Hop is the current jump in the total traceroute.
Hop int
// Hostname is the resolved hostname or IP address identifying the jump
Hostname string
// Hostname is the resolved hostname or IP address identifying the jump
Hostname string
// IP is the IP address of the jump
IP netip.Addr
// IP is the IP address of the jump
IP netip.Addr
// Latencies is a list of the latencies for this jump
Latencies []time.Duration
// Latencies is a list of the latencies for this jump
Latencies []time.Duration
}
type Traceroute struct {
// Hostname is the resolved hostname or IP address identifying the target
Hostname string
// Hostname is the resolved hostname or IP address identifying the target
Hostname string
// IP is the IP address of the target
IP netip.Addr
// IP is the IP address of the target
IP netip.Addr
// Route is the path taken to reach the target if successful. The list is ordered by the path taken.
Route []TraceroutePath
// Route is the path taken to reach the target if successful. The list is ordered by the path taken.
Route []TraceroutePath
// Success indicates if the traceroute was successful.
Success bool
// Success indicates if the traceroute was successful.
Success bool
// Err contains an error if the traceroute was not successful.
Err error
// Err contains an error if the traceroute was not successful.
Err error
}
// ParseTraceroute parses the output of the traceroute command and returns a Traceroute struct
// ParseTraceroute parses the output of the traceroute command and returns a Traceroute struct.
func ParseTraceroute(output string) (Traceroute, error) {
lines := strings.Split(strings.TrimSpace(output), "\n")
if len(lines) < 1 {
@@ -112,7 +112,7 @@ func ParseTraceroute(output string) (Traceroute, error) {
}
// Parse each hop line
hopRegex := regexp.MustCompile(`^\s*(\d+)\s+(?:([^ ]+) \(([^)]+)\)|(\*))(?:\s+(\d+\.\d+) ms)?(?:\s+(\d+\.\d+) ms)?(?:\s+(\d+\.\d+) ms)?`)
hopRegex := regexp.MustCompile("^\\s*(\\d+)\\s+(?:([^ ]+) \\(([^)]+)\\)|(\\*))(?:\\s+(\\d+\\.\\d+) ms)?(?:\\s+(\\d+\\.\\d+) ms)?(?:\\s+(\\d+\\.\\d+) ms)?")
for i := 1; i < len(lines); i++ {
matches := hopRegex.FindStringSubmatch(lines[i])