mirror of
https://github.com/juanfont/headscale.git
synced 2025-11-08 05:04:50 -05:00
integration: replace time.Sleep with assert.EventuallyWithT (#2680)
This commit is contained in:
@@ -98,7 +98,6 @@ func (h *Headscale) handleExistingNode(
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
n, policyChanged, err := h.state.SetNodeExpiry(node.ID, requestExpiry)
|
||||
@@ -169,7 +168,6 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
||||
regReq tailcfg.RegisterRequest,
|
||||
machineKey key.MachinePublic,
|
||||
) (*tailcfg.RegisterResponse, error) {
|
||||
|
||||
node, changed, err := h.state.HandleNodeFromPreAuthKey(
|
||||
regReq,
|
||||
machineKey,
|
||||
@@ -178,9 +176,11 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, NewHTTPError(http.StatusUnauthorized, "invalid pre auth key", nil)
|
||||
}
|
||||
if perr, ok := err.(types.PAKError); ok {
|
||||
var perr types.PAKError
|
||||
if errors.As(err, &perr) {
|
||||
return nil, NewHTTPError(http.StatusUnauthorized, perr.Error(), nil)
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
package capver
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"slices"
|
||||
|
||||
xmaps "golang.org/x/exp/maps"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/util/set"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package capver
|
||||
|
||||
//Generated DO NOT EDIT
|
||||
// Generated DO NOT EDIT
|
||||
|
||||
import "tailscale.com/tailcfg"
|
||||
|
||||
@@ -38,17 +38,16 @@ var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{
|
||||
"v1.82.5": 115,
|
||||
}
|
||||
|
||||
|
||||
var capVerToTailscaleVer = map[tailcfg.CapabilityVersion]string{
|
||||
87: "v1.60.0",
|
||||
88: "v1.62.0",
|
||||
90: "v1.64.0",
|
||||
95: "v1.66.0",
|
||||
97: "v1.68.0",
|
||||
102: "v1.70.0",
|
||||
104: "v1.72.0",
|
||||
106: "v1.74.0",
|
||||
109: "v1.78.0",
|
||||
113: "v1.80.0",
|
||||
115: "v1.82.0",
|
||||
87: "v1.60.0",
|
||||
88: "v1.62.0",
|
||||
90: "v1.64.0",
|
||||
95: "v1.66.0",
|
||||
97: "v1.68.0",
|
||||
102: "v1.70.0",
|
||||
104: "v1.72.0",
|
||||
106: "v1.74.0",
|
||||
109: "v1.78.0",
|
||||
113: "v1.80.0",
|
||||
115: "v1.82.0",
|
||||
}
|
||||
|
||||
@@ -764,13 +764,13 @@ AND auth_key_id NOT IN (
|
||||
// Drop all indexes first to avoid conflicts
|
||||
indexesToDrop := []string{
|
||||
"idx_users_deleted_at",
|
||||
"idx_provider_identifier",
|
||||
"idx_provider_identifier",
|
||||
"idx_name_provider_identifier",
|
||||
"idx_name_no_provider_identifier",
|
||||
"idx_api_keys_prefix",
|
||||
"idx_policies_deleted_at",
|
||||
}
|
||||
|
||||
|
||||
for _, index := range indexesToDrop {
|
||||
_ = tx.Exec("DROP INDEX IF EXISTS " + index).Error
|
||||
}
|
||||
@@ -927,6 +927,7 @@ AND auth_key_id NOT IN (
|
||||
}
|
||||
|
||||
log.Info().Msg("Schema recreation completed successfully")
|
||||
|
||||
return nil
|
||||
},
|
||||
Rollback: func(db *gorm.DB) error { return nil },
|
||||
|
||||
@@ -93,7 +93,7 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) {
|
||||
Avoid: false,
|
||||
Nodes: []*tailcfg.DERPNode{
|
||||
{
|
||||
Name: fmt.Sprintf("%d", d.cfg.ServerRegionID),
|
||||
Name: strconv.Itoa(d.cfg.ServerRegionID),
|
||||
RegionID: d.cfg.ServerRegionID,
|
||||
HostName: host,
|
||||
DERPPort: port,
|
||||
|
||||
@@ -103,7 +103,6 @@ func (e *ExtraRecordsMan) Run() {
|
||||
|
||||
return struct{}{}, nil
|
||||
}, backoff.WithBackOff(backoff.NewExponentialBackOff()))
|
||||
|
||||
if err != nil {
|
||||
log.Error().Caller().Err(err).Msgf("extra records filewatcher retrying to find file after delete")
|
||||
continue
|
||||
|
||||
@@ -475,7 +475,10 @@ func (api headscaleV1APIServer) RenameNode(
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
}
|
||||
|
||||
ctx = types.NotifyCtx(ctx, "cli-renamenode", node.Hostname)
|
||||
ctx = types.NotifyCtx(ctx, "cli-renamenode-self", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyByNodeID(ctx, types.UpdateSelf(node.ID), node.ID)
|
||||
|
||||
ctx = types.NotifyCtx(ctx, "cli-renamenode-peers", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
|
||||
|
||||
log.Trace().
|
||||
|
||||
@@ -32,7 +32,7 @@ const (
|
||||
reservedResponseHeaderSize = 4
|
||||
)
|
||||
|
||||
// httpError logs an error and sends an HTTP error response with the given
|
||||
// httpError logs an error and sends an HTTP error response with the given.
|
||||
func httpError(w http.ResponseWriter, err error) {
|
||||
var herr HTTPError
|
||||
if errors.As(err, &herr) {
|
||||
@@ -102,6 +102,7 @@ func (h *Headscale) handleVerifyRequest(
|
||||
resp := &tailcfg.DERPAdmitClientResponse{
|
||||
Allow: nodes.ContainsNodeKey(derpAdmitClientRequest.NodePublic),
|
||||
}
|
||||
|
||||
return json.NewEncoder(writer).Encode(resp)
|
||||
}
|
||||
|
||||
|
||||
@@ -500,7 +500,7 @@ func (m *Mapper) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.
|
||||
}
|
||||
|
||||
// ListNodes queries the database for either all nodes if no parameters are given
|
||||
// or for the given nodes if at least one node ID is given as parameter
|
||||
// or for the given nodes if at least one node ID is given as parameter.
|
||||
func (m *Mapper) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
|
||||
nodes, err := m.state.ListNodes(nodeIDs...)
|
||||
if err != nil {
|
||||
|
||||
@@ -80,7 +80,7 @@ func TestDNSConfigMapResponse(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// mockState is a mock implementation that provides the required methods
|
||||
// mockState is a mock implementation that provides the required methods.
|
||||
type mockState struct {
|
||||
polMan policy.PolicyManager
|
||||
derpMap *tailcfg.DERPMap
|
||||
@@ -133,6 +133,7 @@ func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (typ
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return filtered, nil
|
||||
}
|
||||
// Return all peers except the node itself
|
||||
@@ -142,6 +143,7 @@ func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (typ
|
||||
filtered = append(filtered, peer)
|
||||
}
|
||||
}
|
||||
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
@@ -157,8 +159,10 @@ func (m *mockState) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
return m.nodes, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"tailscale.com/types/views"
|
||||
)
|
||||
|
||||
// NodeCanHaveTagChecker is an interface for checking if a node can have a tag
|
||||
// NodeCanHaveTagChecker is an interface for checking if a node can have a tag.
|
||||
type NodeCanHaveTagChecker interface {
|
||||
NodeCanHaveTag(node types.NodeView, tag string) bool
|
||||
}
|
||||
|
||||
@@ -111,5 +111,6 @@ func (r *respWriterProm) Write(b []byte) (int, error) {
|
||||
}
|
||||
n, err := r.ResponseWriter.Write(b)
|
||||
r.written += int64(n)
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
@@ -50,6 +50,7 @@ func NewNotifier(cfg *types.Config) *Notifier {
|
||||
n.b = b
|
||||
|
||||
go b.doWork()
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
@@ -72,7 +73,7 @@ func (n *Notifier) Close() {
|
||||
n.nodes = make(map[types.NodeID]chan<- types.StateUpdate)
|
||||
}
|
||||
|
||||
// safeCloseChannel closes a channel and panic recovers if already closed
|
||||
// safeCloseChannel closes a channel and panic recovers if already closed.
|
||||
func (n *Notifier) safeCloseChannel(nodeID types.NodeID, c chan<- types.StateUpdate) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
@@ -170,6 +171,7 @@ func (n *Notifier) IsConnected(nodeID types.NodeID) bool {
|
||||
if val, ok := n.connected.Load(nodeID); ok {
|
||||
return val
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -182,7 +184,7 @@ func (n *Notifier) IsLikelyConnected(nodeID types.NodeID) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// LikelyConnectedMap returns a thread safe map of connected nodes
|
||||
// LikelyConnectedMap returns a thread safe map of connected nodes.
|
||||
func (n *Notifier) LikelyConnectedMap() *xsync.MapOf[types.NodeID, bool] {
|
||||
return n.connected
|
||||
}
|
||||
|
||||
@@ -1,17 +1,15 @@
|
||||
package notifier
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sort"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"slices"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
@@ -241,7 +239,7 @@ func TestBatcher(t *testing.T) {
|
||||
defer n.RemoveNode(1, ch)
|
||||
|
||||
for _, u := range tt.updates {
|
||||
n.NotifyAll(context.Background(), u)
|
||||
n.NotifyAll(t.Context(), u)
|
||||
}
|
||||
|
||||
n.b.flush()
|
||||
@@ -270,7 +268,7 @@ func TestBatcher(t *testing.T) {
|
||||
// TestIsLikelyConnectedRaceCondition tests for a race condition in IsLikelyConnected
|
||||
// Multiple goroutines calling AddNode and RemoveNode cause panics when trying to
|
||||
// close a channel that was already closed, which can happen when a node changes
|
||||
// network transport quickly (eg mobile->wifi) and reconnects whilst also disconnecting
|
||||
// network transport quickly (eg mobile->wifi) and reconnects whilst also disconnecting.
|
||||
func TestIsLikelyConnectedRaceCondition(t *testing.T) {
|
||||
// mock config for the notifier
|
||||
cfg := &types.Config{
|
||||
@@ -308,16 +306,17 @@ func TestIsLikelyConnectedRaceCondition(t *testing.T) {
|
||||
for range iterations {
|
||||
// Simulate race by having some goroutines check IsLikelyConnected
|
||||
// while others add/remove the node
|
||||
if routineID%3 == 0 {
|
||||
switch routineID % 3 {
|
||||
case 0:
|
||||
// This goroutine checks connection status
|
||||
isConnected := notifier.IsLikelyConnected(nodeID)
|
||||
if isConnected != true && isConnected != false {
|
||||
errChan <- fmt.Sprintf("Invalid connection status: %v", isConnected)
|
||||
}
|
||||
} else if routineID%3 == 1 {
|
||||
case 1:
|
||||
// This goroutine removes the node
|
||||
notifier.RemoveNode(nodeID, updateChan)
|
||||
} else {
|
||||
default:
|
||||
// This goroutine adds the node back
|
||||
notifier.AddNode(nodeID, updateChan)
|
||||
}
|
||||
|
||||
@@ -84,11 +84,8 @@ func NewAuthProviderOIDC(
|
||||
ClientID: cfg.ClientID,
|
||||
ClientSecret: cfg.ClientSecret,
|
||||
Endpoint: oidcProvider.Endpoint(),
|
||||
RedirectURL: fmt.Sprintf(
|
||||
"%s/oidc/callback",
|
||||
strings.TrimSuffix(serverURL, "/"),
|
||||
),
|
||||
Scopes: cfg.Scope,
|
||||
RedirectURL: strings.TrimSuffix(serverURL, "/") + "/oidc/callback",
|
||||
Scopes: cfg.Scope,
|
||||
}
|
||||
|
||||
registrationCache := zcache.New[string, RegistrationInfo](
|
||||
@@ -131,7 +128,7 @@ func (a *AuthProviderOIDC) RegisterHandler(
|
||||
req *http.Request,
|
||||
) {
|
||||
vars := mux.Vars(req)
|
||||
registrationIdStr, _ := vars["registration_id"]
|
||||
registrationIdStr := vars["registration_id"]
|
||||
|
||||
// We need to make sure we dont open for XSS style injections, if the parameter that
|
||||
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render
|
||||
@@ -232,7 +229,6 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
}
|
||||
|
||||
oauth2Token, err := a.getOauth2Token(req.Context(), code, state)
|
||||
|
||||
if err != nil {
|
||||
httpError(writer, err)
|
||||
return
|
||||
@@ -364,6 +360,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
// Neither node nor machine key was found in the state cache meaning
|
||||
// that we could not reauth nor register the node.
|
||||
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -402,6 +399,7 @@ func (a *AuthProviderOIDC) getOauth2Token(
|
||||
if err != nil {
|
||||
return nil, NewHTTPError(http.StatusForbidden, "invalid code", fmt.Errorf("could not exchange code for token: %w", err))
|
||||
}
|
||||
|
||||
return oauth2Token, err
|
||||
}
|
||||
|
||||
|
||||
@@ -2,9 +2,8 @@ package matcher
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"go4.org/netipx"
|
||||
@@ -28,6 +27,7 @@ func (m Match) DebugString() string {
|
||||
for _, prefix := range m.dests.Prefixes() {
|
||||
sb.WriteString(" " + prefix.String() + "\n")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
@@ -36,6 +36,7 @@ func MatchesFromFilterRules(rules []tailcfg.FilterRule) []Match {
|
||||
for _, rule := range rules {
|
||||
matches = append(matches, MatchFromFilterRule(rule))
|
||||
}
|
||||
|
||||
return matches
|
||||
}
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
|
||||
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"tailscale.com/tailcfg"
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"slices"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/samber/lo"
|
||||
@@ -131,7 +130,7 @@ func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcf
|
||||
// AutoApproveRoutes approves any route that can be autoapproved from
|
||||
// the nodes perspective according to the given policy.
|
||||
// It reports true if any routes were approved.
|
||||
// Note: This function now takes a pointer to the actual node to modify ApprovedRoutes
|
||||
// Note: This function now takes a pointer to the actual node to modify ApprovedRoutes.
|
||||
func AutoApproveRoutes(pm PolicyManager, node *types.Node) bool {
|
||||
if pm == nil {
|
||||
return false
|
||||
|
||||
@@ -7,9 +7,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
@@ -1974,6 +1973,7 @@ func TestSSHPolicyRules(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReduceRoutes(t *testing.T) {
|
||||
type args struct {
|
||||
node *types.Node
|
||||
|
||||
@@ -13,9 +13,7 @@ import (
|
||||
"tailscale.com/types/views"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidAction = errors.New("invalid action")
|
||||
)
|
||||
var ErrInvalidAction = errors.New("invalid action")
|
||||
|
||||
// compileFilterRules takes a set of nodes and an ACLPolicy and generates a
|
||||
// set of Tailscale compatible FilterRules used to allow traffic on clients.
|
||||
@@ -52,7 +50,7 @@ func (pol *Policy) compileFilterRules(
|
||||
|
||||
var destPorts []tailcfg.NetPortRange
|
||||
for _, dest := range acl.Destinations {
|
||||
ips, err := dest.Alias.Resolve(pol, users, nodes)
|
||||
ips, err := dest.Resolve(pol, users, nodes)
|
||||
if err != nil {
|
||||
log.Trace().Err(err).Msgf("resolving destination ips")
|
||||
}
|
||||
@@ -174,5 +172,6 @@ func ipSetToPrefixStringList(ips *netipx.IPSet) []string {
|
||||
for _, pref := range ips.Prefixes() {
|
||||
out = append(out, pref.String())
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -4,19 +4,17 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
|
||||
"slices"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"go4.org/netipx"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/util/deephash"
|
||||
"tailscale.com/types/views"
|
||||
"tailscale.com/util/deephash"
|
||||
)
|
||||
|
||||
type PolicyManager struct {
|
||||
@@ -166,6 +164,7 @@ func (pm *PolicyManager) Filter() ([]tailcfg.FilterRule, []matcher.Match) {
|
||||
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
return pm.filter, pm.matchers
|
||||
}
|
||||
|
||||
@@ -178,6 +177,7 @@ func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
pm.users = users
|
||||
|
||||
return pm.updateLocked()
|
||||
}
|
||||
|
||||
@@ -190,6 +190,7 @@ func (pm *PolicyManager) SetNodes(nodes views.Slice[types.NodeView]) (bool, erro
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
pm.nodes = nodes
|
||||
|
||||
return pm.updateLocked()
|
||||
}
|
||||
|
||||
@@ -249,7 +250,6 @@ func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Pr
|
||||
// cannot just lookup in the prefix map and have to check
|
||||
// if there is a "parent" prefix available.
|
||||
for prefix, approveAddrs := range pm.autoApproveMap {
|
||||
|
||||
// Check if prefix is larger (so containing) and then overlaps
|
||||
// the route to see if the node can approve a subset of an autoapprover
|
||||
if prefix.Bits() <= route.Bits() && prefix.Overlaps(route) {
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
package v2
|
||||
|
||||
import (
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
|
||||
@@ -6,9 +6,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
@@ -72,14 +72,14 @@ func (a AliasWithPorts) MarshalJSON() ([]byte, error) {
|
||||
|
||||
// Check if it's the wildcard port range
|
||||
if len(a.Ports) == 1 && a.Ports[0].First == 0 && a.Ports[0].Last == 65535 {
|
||||
return json.Marshal(fmt.Sprintf("%s:*", alias))
|
||||
return json.Marshal(alias + ":*")
|
||||
}
|
||||
|
||||
// Otherwise, format as "alias:ports"
|
||||
var ports []string
|
||||
for _, port := range a.Ports {
|
||||
if port.First == port.Last {
|
||||
ports = append(ports, fmt.Sprintf("%d", port.First))
|
||||
ports = append(ports, strconv.FormatUint(uint64(port.First), 10))
|
||||
} else {
|
||||
ports = append(ports, fmt.Sprintf("%d-%d", port.First, port.Last))
|
||||
}
|
||||
@@ -133,6 +133,7 @@ func (u *Username) UnmarshalJSON(b []byte) error {
|
||||
if err := u.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -203,7 +204,7 @@ func (u Username) Resolve(_ *Policy, users types.Users, nodes views.Slice[types.
|
||||
return buildIPSetMultiErr(&ips, errs)
|
||||
}
|
||||
|
||||
// Group is a special string which is always prefixed with `group:`
|
||||
// Group is a special string which is always prefixed with `group:`.
|
||||
type Group string
|
||||
|
||||
func (g Group) Validate() error {
|
||||
@@ -218,6 +219,7 @@ func (g *Group) UnmarshalJSON(b []byte) error {
|
||||
if err := g.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -264,7 +266,7 @@ func (g Group) Resolve(p *Policy, users types.Users, nodes views.Slice[types.Nod
|
||||
return buildIPSetMultiErr(&ips, errs)
|
||||
}
|
||||
|
||||
// Tag is a special string which is always prefixed with `tag:`
|
||||
// Tag is a special string which is always prefixed with `tag:`.
|
||||
type Tag string
|
||||
|
||||
func (t Tag) Validate() error {
|
||||
@@ -279,6 +281,7 @@ func (t *Tag) UnmarshalJSON(b []byte) error {
|
||||
if err := t.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -347,6 +350,7 @@ func (h *Host) UnmarshalJSON(b []byte) error {
|
||||
if err := h.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -409,6 +413,7 @@ func (p *Prefix) parseString(addr string) error {
|
||||
}
|
||||
|
||||
*p = Prefix(addrPref)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -417,6 +422,7 @@ func (p *Prefix) parseString(addr string) error {
|
||||
return err
|
||||
}
|
||||
*p = Prefix(pref)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -428,6 +434,7 @@ func (p *Prefix) UnmarshalJSON(b []byte) error {
|
||||
if err := p.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -462,7 +469,7 @@ func appendIfNodeHasIP(nodes views.Slice[types.NodeView], ips *netipx.IPSetBuild
|
||||
}
|
||||
}
|
||||
|
||||
// AutoGroup is a special string which is always prefixed with `autogroup:`
|
||||
// AutoGroup is a special string which is always prefixed with `autogroup:`.
|
||||
type AutoGroup string
|
||||
|
||||
const (
|
||||
@@ -495,6 +502,7 @@ func (ag *AutoGroup) UnmarshalJSON(b []byte) error {
|
||||
if err := ag.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -632,13 +640,14 @@ func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ve.Alias.Validate(); err != nil {
|
||||
if err := ve.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
default:
|
||||
return fmt.Errorf("type %T not supported", vs)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -713,6 +722,7 @@ func (ve *AliasEnc) UnmarshalJSON(b []byte) error {
|
||||
return err
|
||||
}
|
||||
ve.Alias = ptr
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -729,6 +739,7 @@ func (a *Aliases) UnmarshalJSON(b []byte) error {
|
||||
for i, alias := range aliases {
|
||||
(*a)[i] = alias.Alias
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -784,7 +795,7 @@ func buildIPSetMultiErr(ipBuilder *netipx.IPSetBuilder, errs []error) (*netipx.I
|
||||
return ips, multierr.New(append(errs, err)...)
|
||||
}
|
||||
|
||||
// Helper function to unmarshal a JSON string into either an AutoApprover or Owner pointer
|
||||
// Helper function to unmarshal a JSON string into either an AutoApprover or Owner pointer.
|
||||
func unmarshalPointer[T any](
|
||||
b []byte,
|
||||
parseFunc func(string) (T, error),
|
||||
@@ -818,6 +829,7 @@ func (aa *AutoApprovers) UnmarshalJSON(b []byte) error {
|
||||
for i, autoApprover := range autoApprovers {
|
||||
(*aa)[i] = autoApprover.AutoApprover
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -874,6 +886,7 @@ func (ve *AutoApproverEnc) UnmarshalJSON(b []byte) error {
|
||||
return err
|
||||
}
|
||||
ve.AutoApprover = ptr
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -894,6 +907,7 @@ func (ve *OwnerEnc) UnmarshalJSON(b []byte) error {
|
||||
return err
|
||||
}
|
||||
ve.Owner = ptr
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -910,6 +924,7 @@ func (o *Owners) UnmarshalJSON(b []byte) error {
|
||||
for i, owner := range owners {
|
||||
(*o)[i] = owner.Owner
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -941,6 +956,7 @@ func parseOwner(s string) (Owner, error) {
|
||||
case isGroup(s):
|
||||
return ptr.To(Group(s)), nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf(`Invalid Owner %q. An alias must be one of the following types:
|
||||
- user (containing an "@")
|
||||
- group (starting with "group:")
|
||||
@@ -1001,6 +1017,7 @@ func (g *Groups) UnmarshalJSON(b []byte) error {
|
||||
|
||||
(*g)[group] = usernames
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1252,7 +1269,7 @@ type Policy struct {
|
||||
// We use the default JSON marshalling behavior provided by the Go runtime.
|
||||
|
||||
var (
|
||||
// TODO(kradalby): Add these checks for tagOwners and autoApprovers
|
||||
// TODO(kradalby): Add these checks for tagOwners and autoApprovers.
|
||||
autogroupForSrc = []AutoGroup{AutoGroupMember, AutoGroupTagged}
|
||||
autogroupForDst = []AutoGroup{AutoGroupInternet, AutoGroupMember, AutoGroupTagged}
|
||||
autogroupForSSHSrc = []AutoGroup{AutoGroupMember, AutoGroupTagged}
|
||||
@@ -1279,7 +1296,7 @@ func validateAutogroupForSrc(src *AutoGroup) error {
|
||||
}
|
||||
|
||||
if src.Is(AutoGroupInternet) {
|
||||
return fmt.Errorf(`"autogroup:internet" used in source, it can only be used in ACL destinations`)
|
||||
return errors.New(`"autogroup:internet" used in source, it can only be used in ACL destinations`)
|
||||
}
|
||||
|
||||
if !slices.Contains(autogroupForSrc, *src) {
|
||||
@@ -1307,7 +1324,7 @@ func validateAutogroupForSSHSrc(src *AutoGroup) error {
|
||||
}
|
||||
|
||||
if src.Is(AutoGroupInternet) {
|
||||
return fmt.Errorf(`"autogroup:internet" used in SSH source, it can only be used in ACL destinations`)
|
||||
return errors.New(`"autogroup:internet" used in SSH source, it can only be used in ACL destinations`)
|
||||
}
|
||||
|
||||
if !slices.Contains(autogroupForSSHSrc, *src) {
|
||||
@@ -1323,7 +1340,7 @@ func validateAutogroupForSSHDst(dst *AutoGroup) error {
|
||||
}
|
||||
|
||||
if dst.Is(AutoGroupInternet) {
|
||||
return fmt.Errorf(`"autogroup:internet" used in SSH destination, it can only be used in ACL destinations`)
|
||||
return errors.New(`"autogroup:internet" used in SSH destination, it can only be used in ACL destinations`)
|
||||
}
|
||||
|
||||
if !slices.Contains(autogroupForSSHDst, *dst) {
|
||||
@@ -1360,14 +1377,14 @@ func (p *Policy) validate() error {
|
||||
|
||||
for _, acl := range p.ACLs {
|
||||
for _, src := range acl.Sources {
|
||||
switch src.(type) {
|
||||
switch src := src.(type) {
|
||||
case *Host:
|
||||
h := src.(*Host)
|
||||
h := src
|
||||
if !p.Hosts.exist(*h) {
|
||||
errs = append(errs, fmt.Errorf(`Host %q is not defined in the Policy, please define or remove the reference to it`, *h))
|
||||
}
|
||||
case *AutoGroup:
|
||||
ag := src.(*AutoGroup)
|
||||
ag := src
|
||||
|
||||
if err := validateAutogroupSupported(ag); err != nil {
|
||||
errs = append(errs, err)
|
||||
@@ -1379,12 +1396,12 @@ func (p *Policy) validate() error {
|
||||
continue
|
||||
}
|
||||
case *Group:
|
||||
g := src.(*Group)
|
||||
g := src
|
||||
if err := p.Groups.Contains(g); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
case *Tag:
|
||||
tagOwner := src.(*Tag)
|
||||
tagOwner := src
|
||||
if err := p.TagOwners.Contains(tagOwner); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
@@ -1440,9 +1457,9 @@ func (p *Policy) validate() error {
|
||||
}
|
||||
|
||||
for _, src := range ssh.Sources {
|
||||
switch src.(type) {
|
||||
switch src := src.(type) {
|
||||
case *AutoGroup:
|
||||
ag := src.(*AutoGroup)
|
||||
ag := src
|
||||
|
||||
if err := validateAutogroupSupported(ag); err != nil {
|
||||
errs = append(errs, err)
|
||||
@@ -1454,21 +1471,21 @@ func (p *Policy) validate() error {
|
||||
continue
|
||||
}
|
||||
case *Group:
|
||||
g := src.(*Group)
|
||||
g := src
|
||||
if err := p.Groups.Contains(g); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
case *Tag:
|
||||
tagOwner := src.(*Tag)
|
||||
tagOwner := src
|
||||
if err := p.TagOwners.Contains(tagOwner); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, dst := range ssh.Destinations {
|
||||
switch dst.(type) {
|
||||
switch dst := dst.(type) {
|
||||
case *AutoGroup:
|
||||
ag := dst.(*AutoGroup)
|
||||
ag := dst
|
||||
if err := validateAutogroupSupported(ag); err != nil {
|
||||
errs = append(errs, err)
|
||||
continue
|
||||
@@ -1479,7 +1496,7 @@ func (p *Policy) validate() error {
|
||||
continue
|
||||
}
|
||||
case *Tag:
|
||||
tagOwner := dst.(*Tag)
|
||||
tagOwner := dst
|
||||
if err := p.TagOwners.Contains(tagOwner); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
@@ -1489,9 +1506,9 @@ func (p *Policy) validate() error {
|
||||
|
||||
for _, tagOwners := range p.TagOwners {
|
||||
for _, tagOwner := range tagOwners {
|
||||
switch tagOwner.(type) {
|
||||
switch tagOwner := tagOwner.(type) {
|
||||
case *Group:
|
||||
g := tagOwner.(*Group)
|
||||
g := tagOwner
|
||||
if err := p.Groups.Contains(g); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
@@ -1501,14 +1518,14 @@ func (p *Policy) validate() error {
|
||||
|
||||
for _, approvers := range p.AutoApprovers.Routes {
|
||||
for _, approver := range approvers {
|
||||
switch approver.(type) {
|
||||
switch approver := approver.(type) {
|
||||
case *Group:
|
||||
g := approver.(*Group)
|
||||
g := approver
|
||||
if err := p.Groups.Contains(g); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
case *Tag:
|
||||
tagOwner := approver.(*Tag)
|
||||
tagOwner := approver
|
||||
if err := p.TagOwners.Contains(tagOwner); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
@@ -1517,14 +1534,14 @@ func (p *Policy) validate() error {
|
||||
}
|
||||
|
||||
for _, approver := range p.AutoApprovers.ExitNode {
|
||||
switch approver.(type) {
|
||||
switch approver := approver.(type) {
|
||||
case *Group:
|
||||
g := approver.(*Group)
|
||||
g := approver
|
||||
if err := p.Groups.Contains(g); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
case *Tag:
|
||||
tagOwner := approver.(*Tag)
|
||||
tagOwner := approver
|
||||
if err := p.TagOwners.Contains(tagOwner); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
@@ -1536,6 +1553,7 @@ func (p *Policy) validate() error {
|
||||
}
|
||||
|
||||
p.validated = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1589,6 +1607,7 @@ func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error {
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1618,6 +1637,7 @@ func (a *SSHDstAliases) UnmarshalJSON(b []byte) error {
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -5,13 +5,13 @@ import (
|
||||
"net/netip"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/prometheus/common/model"
|
||||
"time"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go4.org/netipx"
|
||||
@@ -68,7 +68,7 @@ func TestMarshalJSON(t *testing.T) {
|
||||
// Marshal the policy to JSON
|
||||
marshalled, err := json.MarshalIndent(policy, "", " ")
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
// Make sure all expected fields are present in the JSON
|
||||
jsonString := string(marshalled)
|
||||
assert.Contains(t, jsonString, "group:example")
|
||||
@@ -79,21 +79,21 @@ func TestMarshalJSON(t *testing.T) {
|
||||
assert.Contains(t, jsonString, "accept")
|
||||
assert.Contains(t, jsonString, "tcp")
|
||||
assert.Contains(t, jsonString, "80")
|
||||
|
||||
|
||||
// Unmarshal back to verify round trip
|
||||
var roundTripped Policy
|
||||
err = json.Unmarshal(marshalled, &roundTripped)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
// Compare the original and round-tripped policies
|
||||
cmps := append(util.Comparers,
|
||||
cmps := append(util.Comparers,
|
||||
cmp.Comparer(func(x, y Prefix) bool {
|
||||
return x == y
|
||||
}),
|
||||
cmpopts.IgnoreUnexported(Policy{}),
|
||||
cmpopts.EquateEmpty(),
|
||||
)
|
||||
|
||||
|
||||
if diff := cmp.Diff(policy, &roundTripped, cmps...); diff != "" {
|
||||
t.Fatalf("round trip policy (-original +roundtripped):\n%s", diff)
|
||||
}
|
||||
@@ -958,13 +958,13 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
cmps := append(util.Comparers,
|
||||
cmps := append(util.Comparers,
|
||||
cmp.Comparer(func(x, y Prefix) bool {
|
||||
return x == y
|
||||
}),
|
||||
cmpopts.IgnoreUnexported(Policy{}),
|
||||
)
|
||||
|
||||
|
||||
// For round-trip testing, we'll normalize the policies before comparing
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -981,6 +981,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
} else if !strings.Contains(err.Error(), tt.wantErr) {
|
||||
t.Fatalf("unmarshalling: got err %v; want error %q", err, tt.wantErr)
|
||||
}
|
||||
|
||||
return // Skip the rest of the test if we expected an error
|
||||
}
|
||||
|
||||
@@ -1001,9 +1002,9 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("round-trip unmarshalling: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// Add EquateEmpty to handle nil vs empty maps/slices
|
||||
roundTripCmps := append(cmps,
|
||||
roundTripCmps := append(cmps,
|
||||
cmpopts.EquateEmpty(),
|
||||
cmpopts.IgnoreUnexported(Policy{}),
|
||||
)
|
||||
@@ -1584,6 +1585,7 @@ func mustIPSet(prefixes ...string) *netipx.IPSet {
|
||||
builder.AddPrefix(mp(p))
|
||||
}
|
||||
ipSet, _ := builder.IPSet()
|
||||
|
||||
return ipSet
|
||||
}
|
||||
|
||||
|
||||
@@ -73,10 +73,10 @@ func TestParsePortRange(t *testing.T) {
|
||||
expected []tailcfg.PortRange
|
||||
err string
|
||||
}{
|
||||
{"80", []tailcfg.PortRange{{80, 80}}, ""},
|
||||
{"80-90", []tailcfg.PortRange{{80, 90}}, ""},
|
||||
{"80,90", []tailcfg.PortRange{{80, 80}, {90, 90}}, ""},
|
||||
{"80-91,92,93-95", []tailcfg.PortRange{{80, 91}, {92, 92}, {93, 95}}, ""},
|
||||
{"80", []tailcfg.PortRange{{First: 80, Last: 80}}, ""},
|
||||
{"80-90", []tailcfg.PortRange{{First: 80, Last: 90}}, ""},
|
||||
{"80,90", []tailcfg.PortRange{{First: 80, Last: 80}, {First: 90, Last: 90}}, ""},
|
||||
{"80-91,92,93-95", []tailcfg.PortRange{{First: 80, Last: 91}, {First: 92, Last: 92}, {First: 93, Last: 95}}, ""},
|
||||
{"*", []tailcfg.PortRange{tailcfg.PortRangeAny}, ""},
|
||||
{"80-", nil, "invalid port range format"},
|
||||
{"-90", nil, "invalid port range format"},
|
||||
|
||||
@@ -158,6 +158,7 @@ func (pr *PrimaryRoutes) PrimaryRoutes(id types.NodeID) []netip.Prefix {
|
||||
}
|
||||
|
||||
tsaddr.SortPrefixes(routes)
|
||||
|
||||
return routes
|
||||
}
|
||||
|
||||
|
||||
@@ -429,6 +429,7 @@ func (s *State) GetNodeViewByID(nodeID types.NodeID) (types.NodeView, error) {
|
||||
if err != nil {
|
||||
return types.NodeView{}, err
|
||||
}
|
||||
|
||||
return node.View(), nil
|
||||
}
|
||||
|
||||
@@ -443,6 +444,7 @@ func (s *State) GetNodeViewByNodeKey(nodeKey key.NodePublic) (types.NodeView, er
|
||||
if err != nil {
|
||||
return types.NodeView{}, err
|
||||
}
|
||||
|
||||
return node.View(), nil
|
||||
}
|
||||
|
||||
@@ -701,7 +703,7 @@ func (s *State) HandleNodeFromPreAuthKey(
|
||||
if !regReq.Expiry.IsZero() && regReq.Expiry.After(time.Now()) {
|
||||
nodeToRegister.Expiry = ®Req.Expiry
|
||||
} else if !regReq.Expiry.IsZero() {
|
||||
// If client is sending an expired time (e.g., after logout),
|
||||
// If client is sending an expired time (e.g., after logout),
|
||||
// don't set expiry so the node won't be considered expired
|
||||
log.Debug().
|
||||
Time("requested_expiry", regReq.Expiry).
|
||||
|
||||
@@ -2,6 +2,7 @@ package hscontrol
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -70,7 +71,7 @@ func runTailSQLService(ctx context.Context, logf logger.Logf, stateDir, dbPath s
|
||||
// When serving TLS, add a redirect from HTTP on port 80 to HTTPS on 443.
|
||||
certDomains := tsNode.CertDomains()
|
||||
if len(certDomains) == 0 {
|
||||
return fmt.Errorf("no cert domains available for HTTPS")
|
||||
return errors.New("no cert domains available for HTTPS")
|
||||
}
|
||||
base := "https://" + certDomains[0]
|
||||
go http.Serve(lst, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -95,5 +96,6 @@ func runTailSQLService(ctx context.Context, logf logger.Logf, stateDir, dbPath s
|
||||
logf("TailSQL started")
|
||||
<-ctx.Done()
|
||||
logf("TailSQL shutting down...")
|
||||
|
||||
return tsNode.Close()
|
||||
}
|
||||
|
||||
@@ -62,7 +62,7 @@ func Apple(url string) *elem.Element {
|
||||
),
|
||||
elem.Pre(nil,
|
||||
elem.Code(nil,
|
||||
elem.Text(fmt.Sprintf("tailscale login --login-server %s", url)),
|
||||
elem.Text("tailscale login --login-server "+url),
|
||||
),
|
||||
),
|
||||
headerTwo("GUI"),
|
||||
@@ -143,10 +143,7 @@ func Apple(url string) *elem.Element {
|
||||
elem.Code(
|
||||
nil,
|
||||
elem.Text(
|
||||
fmt.Sprintf(
|
||||
`defaults write io.tailscale.ipn.macos ControlURL %s`,
|
||||
url,
|
||||
),
|
||||
"defaults write io.tailscale.ipn.macos ControlURL "+url,
|
||||
),
|
||||
),
|
||||
),
|
||||
@@ -155,10 +152,7 @@ func Apple(url string) *elem.Element {
|
||||
elem.Code(
|
||||
nil,
|
||||
elem.Text(
|
||||
fmt.Sprintf(
|
||||
`defaults write io.tailscale.ipn.macsys ControlURL %s`,
|
||||
url,
|
||||
),
|
||||
"defaults write io.tailscale.ipn.macsys ControlURL "+url,
|
||||
),
|
||||
),
|
||||
),
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
package templates
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/chasefleming/elem-go"
|
||||
"github.com/chasefleming/elem-go/attrs"
|
||||
)
|
||||
@@ -31,7 +29,7 @@ func Windows(url string) *elem.Element {
|
||||
),
|
||||
elem.Pre(nil,
|
||||
elem.Code(nil,
|
||||
elem.Text(fmt.Sprintf(`tailscale login --login-server %s`, url)),
|
||||
elem.Text("tailscale login --login-server "+url),
|
||||
),
|
||||
),
|
||||
),
|
||||
|
||||
@@ -180,6 +180,7 @@ func MustRegistrationID() RegistrationID {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return rid
|
||||
}
|
||||
|
||||
|
||||
@@ -339,6 +339,7 @@ func LoadConfig(path string, isFile bool) error {
|
||||
log.Warn().Msg("No config file found, using defaults")
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("fatal error reading config file: %w", err)
|
||||
}
|
||||
|
||||
@@ -843,7 +844,7 @@ func LoadServerConfig() (*Config, error) {
|
||||
}
|
||||
|
||||
if prefix4 == nil && prefix6 == nil {
|
||||
return nil, fmt.Errorf("no IPv4 or IPv6 prefix configured, minimum one prefix is required")
|
||||
return nil, errors.New("no IPv4 or IPv6 prefix configured, minimum one prefix is required")
|
||||
}
|
||||
|
||||
allocStr := viper.GetString("prefixes.allocation")
|
||||
@@ -1020,7 +1021,7 @@ func isSafeServerURL(serverURL, baseDomain string) error {
|
||||
|
||||
s := len(serverDomainParts)
|
||||
b := len(baseDomainParts)
|
||||
for i := range len(baseDomainParts) {
|
||||
for i := range baseDomainParts {
|
||||
if serverDomainParts[s-i-1] != baseDomainParts[b-i-1] {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -282,6 +282,7 @@ func TestReadConfigFromEnv(t *testing.T) {
|
||||
assert.Equal(t, "trace", viper.GetString("log.level"))
|
||||
assert.Equal(t, "100.64.0.0/10", viper.GetString("prefixes.v4"))
|
||||
assert.False(t, viper.GetBool("database.sqlite.write_ahead_log"))
|
||||
|
||||
return nil, nil
|
||||
},
|
||||
want: nil,
|
||||
|
||||
@@ -28,8 +28,10 @@ var (
|
||||
ErrNodeUserHasNoName = errors.New("node user has no name")
|
||||
)
|
||||
|
||||
type NodeID uint64
|
||||
type NodeIDs []NodeID
|
||||
type (
|
||||
NodeID uint64
|
||||
NodeIDs []NodeID
|
||||
)
|
||||
|
||||
func (n NodeIDs) Len() int { return len(n) }
|
||||
func (n NodeIDs) Less(i, j int) bool { return n[i] < n[j] }
|
||||
@@ -169,6 +171,7 @@ func (node *Node) HasIP(i netip.Addr) bool {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -176,7 +179,7 @@ func (node *Node) HasIP(i netip.Addr) bool {
|
||||
// and therefore should not be treated as a
|
||||
// user owned device.
|
||||
// Currently, this function only handles tags set
|
||||
// via CLI ("forced tags" and preauthkeys)
|
||||
// via CLI ("forced tags" and preauthkeys).
|
||||
func (node *Node) IsTagged() bool {
|
||||
if len(node.ForcedTags) > 0 {
|
||||
return true
|
||||
@@ -199,7 +202,7 @@ func (node *Node) IsTagged() bool {
|
||||
|
||||
// HasTag reports if a node has a given tag.
|
||||
// Currently, this function only handles tags set
|
||||
// via CLI ("forced tags" and preauthkeys)
|
||||
// via CLI ("forced tags" and preauthkeys).
|
||||
func (node *Node) HasTag(tag string) bool {
|
||||
return slices.Contains(node.Tags(), tag)
|
||||
}
|
||||
@@ -577,6 +580,7 @@ func (nodes Nodes) DebugString() string {
|
||||
sb.WriteString(node.DebugString())
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
@@ -590,6 +594,7 @@ func (node Node) DebugString() string {
|
||||
fmt.Fprintf(&sb, "\tAnnouncedRoutes: %v\n", node.AnnouncedRoutes())
|
||||
fmt.Fprintf(&sb, "\tSubnetRoutes: %v\n", node.SubnetRoutes())
|
||||
sb.WriteString("\n")
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
@@ -689,7 +694,7 @@ func (v NodeView) Tags() []string {
|
||||
// and therefore should not be treated as a
|
||||
// user owned device.
|
||||
// Currently, this function only handles tags set
|
||||
// via CLI ("forced tags" and preauthkeys)
|
||||
// via CLI ("forced tags" and preauthkeys).
|
||||
func (v NodeView) IsTagged() bool {
|
||||
if !v.Valid() {
|
||||
return false
|
||||
@@ -727,7 +732,7 @@ func (v NodeView) PeerChangeFromMapRequest(req tailcfg.MapRequest) tailcfg.PeerC
|
||||
// GetFQDN returns the fully qualified domain name for the node.
|
||||
func (v NodeView) GetFQDN(baseDomain string) (string, error) {
|
||||
if !v.Valid() {
|
||||
return "", fmt.Errorf("failed to create valid FQDN: node view is invalid")
|
||||
return "", errors.New("failed to create valid FQDN: node view is invalid")
|
||||
}
|
||||
return v.ж.GetFQDN(baseDomain)
|
||||
}
|
||||
@@ -773,4 +778,3 @@ func (v NodeView) IPsAsString() []string {
|
||||
}
|
||||
return v.ж.IPsAsString()
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ package types
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -10,6 +9,7 @@ import (
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
type PAKError string
|
||||
|
||||
func (e PAKError) Error() string { return string(e) }
|
||||
func (e PAKError) Unwrap() error { return fmt.Errorf("preauth key error: %s", e) }
|
||||
func (e PAKError) Unwrap() error { return fmt.Errorf("preauth key error: %w", e) }
|
||||
|
||||
// PreAuthKey describes a pre-authorization key usable in a particular user.
|
||||
type PreAuthKey struct {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -109,7 +110,8 @@ func TestCanUsePreAuthKey(t *testing.T) {
|
||||
if err == nil {
|
||||
t.Errorf("expected error but got none")
|
||||
} else {
|
||||
httpErr, ok := err.(PAKError)
|
||||
var httpErr PAKError
|
||||
ok := errors.As(err, &httpErr)
|
||||
if !ok {
|
||||
t.Errorf("expected HTTPError but got %T", err)
|
||||
} else {
|
||||
|
||||
@@ -249,7 +249,7 @@ func (c *OIDCClaims) Identifier() string {
|
||||
// - Remove empty path segments
|
||||
// - For non-URL identifiers, it joins non-empty segments with a single slash
|
||||
// - Returns empty string for identifiers with only slashes
|
||||
// - Normalize URL schemes to lowercase
|
||||
// - Normalize URL schemes to lowercase.
|
||||
func CleanIdentifier(identifier string) string {
|
||||
if identifier == "" {
|
||||
return identifier
|
||||
@@ -273,7 +273,7 @@ func CleanIdentifier(identifier string) string {
|
||||
cleanParts = append(cleanParts, part)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if len(cleanParts) == 0 {
|
||||
u.Path = ""
|
||||
} else {
|
||||
@@ -281,6 +281,7 @@ func CleanIdentifier(identifier string) string {
|
||||
}
|
||||
// Ensure scheme is lowercase
|
||||
u.Scheme = strings.ToLower(u.Scheme)
|
||||
|
||||
return u.String()
|
||||
}
|
||||
|
||||
@@ -297,6 +298,7 @@ func CleanIdentifier(identifier string) string {
|
||||
if len(cleanParts) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
return strings.Join(cleanParts, "/")
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
package types
|
||||
|
||||
var Version = "dev"
|
||||
var GitCommitHash = "dev"
|
||||
var (
|
||||
Version = "dev"
|
||||
GitCommitHash = "dev"
|
||||
)
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
@@ -21,8 +22,10 @@ const (
|
||||
LabelHostnameLength = 63
|
||||
)
|
||||
|
||||
var invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+")
|
||||
var invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+")
|
||||
var (
|
||||
invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+")
|
||||
invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+")
|
||||
)
|
||||
|
||||
var ErrInvalidUserName = errors.New("invalid user name")
|
||||
|
||||
@@ -141,7 +144,7 @@ func GenerateIPv4DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN {
|
||||
// here we generate the base domain (e.g., 100.in-addr.arpa., 16.172.in-addr.arpa., etc.)
|
||||
rdnsSlice := []string{}
|
||||
for i := lastOctet - 1; i >= 0; i-- {
|
||||
rdnsSlice = append(rdnsSlice, fmt.Sprintf("%d", netRange.IP[i]))
|
||||
rdnsSlice = append(rdnsSlice, strconv.FormatUint(uint64(netRange.IP[i]), 10))
|
||||
}
|
||||
rdnsSlice = append(rdnsSlice, "in-addr.arpa.")
|
||||
rdnsBase := strings.Join(rdnsSlice, ".")
|
||||
@@ -205,7 +208,7 @@ func GenerateIPv6DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN {
|
||||
makeDomain := func(variablePrefix ...string) (dnsname.FQDN, error) {
|
||||
prefix := strings.Join(append(variablePrefix, prefixConstantParts...), ".")
|
||||
|
||||
return dnsname.ToFQDN(fmt.Sprintf("%s.ip6.arpa", prefix))
|
||||
return dnsname.ToFQDN(prefix + ".ip6.arpa")
|
||||
}
|
||||
|
||||
var fqdns []dnsname.FQDN
|
||||
|
||||
@@ -70,7 +70,7 @@ func (l *DBLogWrapper) Trace(ctx context.Context, begin time.Time, fc func() (sq
|
||||
"rowsAffected": rowsAffected,
|
||||
}
|
||||
|
||||
if err != nil && !(errors.Is(err, gorm.ErrRecordNotFound) && l.SkipErrRecordNotFound) {
|
||||
if err != nil && (!errors.Is(err, gorm.ErrRecordNotFound) || !l.SkipErrRecordNotFound) {
|
||||
l.Logger.Error().Err(err).Fields(fields).Msgf("")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -58,5 +58,6 @@ var TheInternet = sync.OnceValue(func() *netipx.IPSet {
|
||||
internetBuilder.RemovePrefix(netip.MustParsePrefix("169.254.0.0/16"))
|
||||
|
||||
theInternetSet, _ := internetBuilder.IPSet()
|
||||
|
||||
return theInternetSet
|
||||
})
|
||||
|
||||
@@ -53,37 +53,37 @@ func ParseLoginURLFromCLILogin(output string) (*url.URL, error) {
|
||||
}
|
||||
|
||||
type TraceroutePath struct {
|
||||
// Hop is the current jump in the total traceroute.
|
||||
Hop int
|
||||
// Hop is the current jump in the total traceroute.
|
||||
Hop int
|
||||
|
||||
// Hostname is the resolved hostname or IP address identifying the jump
|
||||
Hostname string
|
||||
// Hostname is the resolved hostname or IP address identifying the jump
|
||||
Hostname string
|
||||
|
||||
// IP is the IP address of the jump
|
||||
IP netip.Addr
|
||||
// IP is the IP address of the jump
|
||||
IP netip.Addr
|
||||
|
||||
// Latencies is a list of the latencies for this jump
|
||||
Latencies []time.Duration
|
||||
// Latencies is a list of the latencies for this jump
|
||||
Latencies []time.Duration
|
||||
}
|
||||
|
||||
type Traceroute struct {
|
||||
// Hostname is the resolved hostname or IP address identifying the target
|
||||
Hostname string
|
||||
// Hostname is the resolved hostname or IP address identifying the target
|
||||
Hostname string
|
||||
|
||||
// IP is the IP address of the target
|
||||
IP netip.Addr
|
||||
// IP is the IP address of the target
|
||||
IP netip.Addr
|
||||
|
||||
// Route is the path taken to reach the target if successful. The list is ordered by the path taken.
|
||||
Route []TraceroutePath
|
||||
// Route is the path taken to reach the target if successful. The list is ordered by the path taken.
|
||||
Route []TraceroutePath
|
||||
|
||||
// Success indicates if the traceroute was successful.
|
||||
Success bool
|
||||
// Success indicates if the traceroute was successful.
|
||||
Success bool
|
||||
|
||||
// Err contains an error if the traceroute was not successful.
|
||||
Err error
|
||||
// Err contains an error if the traceroute was not successful.
|
||||
Err error
|
||||
}
|
||||
|
||||
// ParseTraceroute parses the output of the traceroute command and returns a Traceroute struct
|
||||
// ParseTraceroute parses the output of the traceroute command and returns a Traceroute struct.
|
||||
func ParseTraceroute(output string) (Traceroute, error) {
|
||||
lines := strings.Split(strings.TrimSpace(output), "\n")
|
||||
if len(lines) < 1 {
|
||||
@@ -112,7 +112,7 @@ func ParseTraceroute(output string) (Traceroute, error) {
|
||||
}
|
||||
|
||||
// Parse each hop line
|
||||
hopRegex := regexp.MustCompile(`^\s*(\d+)\s+(?:([^ ]+) \(([^)]+)\)|(\*))(?:\s+(\d+\.\d+) ms)?(?:\s+(\d+\.\d+) ms)?(?:\s+(\d+\.\d+) ms)?`)
|
||||
hopRegex := regexp.MustCompile("^\\s*(\\d+)\\s+(?:([^ ]+) \\(([^)]+)\\)|(\\*))(?:\\s+(\\d+\\.\\d+) ms)?(?:\\s+(\\d+\\.\\d+) ms)?(?:\\s+(\\d+\\.\\d+) ms)?")
|
||||
|
||||
for i := 1; i < len(lines); i++ {
|
||||
matches := hopRegex.FindStringSubmatch(lines[i])
|
||||
|
||||
Reference in New Issue
Block a user