mirror of
https://github.com/juanfont/headscale.git
synced 2025-11-09 21:49:39 -05:00
bunch of qol (#2748)
This commit is contained in:
@@ -24,6 +24,7 @@ type Batcher interface {
|
||||
ConnectedMap() *xsync.Map[types.NodeID, bool]
|
||||
AddWork(c change.ChangeSet)
|
||||
MapResponseFromChange(id types.NodeID, c change.ChangeSet) (*tailcfg.MapResponse, error)
|
||||
DebugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error)
|
||||
}
|
||||
|
||||
func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *LockFreeBatcher {
|
||||
|
||||
@@ -489,3 +489,7 @@ func (nc *nodeConn) send(data *tailcfg.MapResponse) error {
|
||||
nc.updateCount.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) DebugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) {
|
||||
return b.mapper.debugMapResponses()
|
||||
}
|
||||
|
||||
@@ -237,7 +237,6 @@ func (b *MapResponseBuilder) WithPeerChangedPatch(changes []*tailcfg.PeerChange)
|
||||
|
||||
// WithPeersRemoved adds removed peer IDs
|
||||
func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapResponseBuilder {
|
||||
|
||||
var tailscaleIDs []tailcfg.NodeID
|
||||
for _, id := range removedIDs {
|
||||
tailscaleIDs = append(tailscaleIDs, id.NodeID())
|
||||
@@ -247,12 +246,16 @@ func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapRe
|
||||
}
|
||||
|
||||
// Build finalizes the response and returns marshaled bytes
|
||||
func (b *MapResponseBuilder) Build(messages ...string) (*tailcfg.MapResponse, error) {
|
||||
func (b *MapResponseBuilder) Build() (*tailcfg.MapResponse, error) {
|
||||
if len(b.errs) > 0 {
|
||||
return nil, multierr.New(b.errs...)
|
||||
}
|
||||
if debugDumpMapResponsePath != "" {
|
||||
writeDebugMapResponse(b.resp, b.nodeID)
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
writeDebugMapResponse(b.resp, node)
|
||||
}
|
||||
|
||||
return b.resp, nil
|
||||
|
||||
@@ -18,17 +18,17 @@ func TestMapResponseBuilder_Basic(t *testing.T) {
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID)
|
||||
|
||||
|
||||
// Test basic builder creation
|
||||
assert.NotNil(t, builder)
|
||||
assert.Equal(t, nodeID, builder.nodeID)
|
||||
@@ -45,13 +45,13 @@ func TestMapResponseBuilder_WithCapabilityVersion(t *testing.T) {
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
capVer := tailcfg.CapabilityVersion(42)
|
||||
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithCapabilityVersion(capVer)
|
||||
|
||||
|
||||
assert.Equal(t, capVer, builder.capVer)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
@@ -62,18 +62,18 @@ func TestMapResponseBuilder_WithDomain(t *testing.T) {
|
||||
ServerURL: "https://test.example.com",
|
||||
BaseDomain: domain,
|
||||
}
|
||||
|
||||
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithDomain()
|
||||
|
||||
|
||||
assert.Equal(t, domain, builder.resp.Domain)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
@@ -85,12 +85,12 @@ func TestMapResponseBuilder_WithCollectServicesDisabled(t *testing.T) {
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithCollectServicesDisabled()
|
||||
|
||||
|
||||
value, isSet := builder.resp.CollectServices.Get()
|
||||
assert.True(t, isSet)
|
||||
assert.False(t, value)
|
||||
@@ -99,22 +99,22 @@ func TestMapResponseBuilder_WithCollectServicesDisabled(t *testing.T) {
|
||||
|
||||
func TestMapResponseBuilder_WithDebugConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
name string
|
||||
logTailEnabled bool
|
||||
expected bool
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "LogTail enabled",
|
||||
name: "LogTail enabled",
|
||||
logTailEnabled: true,
|
||||
expected: false, // DisableLogTail should be false when LogTail is enabled
|
||||
expected: false, // DisableLogTail should be false when LogTail is enabled
|
||||
},
|
||||
{
|
||||
name: "LogTail disabled",
|
||||
name: "LogTail disabled",
|
||||
logTailEnabled: false,
|
||||
expected: true, // DisableLogTail should be true when LogTail is disabled
|
||||
expected: true, // DisableLogTail should be true when LogTail is disabled
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := &types.Config{
|
||||
@@ -127,12 +127,12 @@ func TestMapResponseBuilder_WithDebugConfig(t *testing.T) {
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithDebugConfig()
|
||||
|
||||
|
||||
require.NotNil(t, builder.resp.Debug)
|
||||
assert.Equal(t, tt.expected, builder.resp.Debug.DisableLogTail)
|
||||
assert.False(t, builder.hasErrors())
|
||||
@@ -147,22 +147,22 @@ func TestMapResponseBuilder_WithPeerChangedPatch(t *testing.T) {
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
changes := []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 123,
|
||||
NodeID: 123,
|
||||
DERPRegion: 1,
|
||||
},
|
||||
{
|
||||
NodeID: 456,
|
||||
NodeID: 456,
|
||||
DERPRegion: 2,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithPeerChangedPatch(changes)
|
||||
|
||||
|
||||
assert.Equal(t, changes, builder.resp.PeersChangedPatch)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
@@ -174,14 +174,14 @@ func TestMapResponseBuilder_WithPeersRemoved(t *testing.T) {
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
removedID1 := types.NodeID(123)
|
||||
removedID2 := types.NodeID(456)
|
||||
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithPeersRemoved(removedID1, removedID2)
|
||||
|
||||
|
||||
expected := []tailcfg.NodeID{
|
||||
removedID1.NodeID(),
|
||||
removedID2.NodeID(),
|
||||
@@ -197,25 +197,25 @@ func TestMapResponseBuilder_ErrorHandling(t *testing.T) {
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
|
||||
// Simulate an error in the builder
|
||||
builder := m.NewMapResponseBuilder(nodeID)
|
||||
builder.addError(assert.AnError)
|
||||
|
||||
|
||||
// All subsequent calls should continue to work and accumulate errors
|
||||
result := builder.
|
||||
WithDomain().
|
||||
WithCollectServicesDisabled().
|
||||
WithDebugConfig()
|
||||
|
||||
|
||||
assert.True(t, result.hasErrors())
|
||||
assert.Len(t, result.errs, 1)
|
||||
assert.Equal(t, assert.AnError, result.errs[0])
|
||||
|
||||
|
||||
// Build should return the error
|
||||
data, err := result.Build("none")
|
||||
data, err := result.Build()
|
||||
assert.Nil(t, data)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
@@ -229,22 +229,22 @@ func TestMapResponseBuilder_ChainedCalls(t *testing.T) {
|
||||
Enabled: false,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
capVer := tailcfg.CapabilityVersion(99)
|
||||
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithCapabilityVersion(capVer).
|
||||
WithDomain().
|
||||
WithCollectServicesDisabled().
|
||||
WithDebugConfig()
|
||||
|
||||
|
||||
// Verify all fields are set correctly
|
||||
assert.Equal(t, capVer, builder.capVer)
|
||||
assert.Equal(t, domain, builder.resp.Domain)
|
||||
@@ -263,16 +263,16 @@ func TestMapResponseBuilder_MultipleWithPeersRemoved(t *testing.T) {
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
removedID1 := types.NodeID(100)
|
||||
removedID2 := types.NodeID(200)
|
||||
|
||||
|
||||
// Test calling WithPeersRemoved multiple times
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithPeersRemoved(removedID1).
|
||||
WithPeersRemoved(removedID2)
|
||||
|
||||
|
||||
// Second call should overwrite the first
|
||||
expected := []tailcfg.NodeID{removedID2.NodeID()}
|
||||
assert.Equal(t, expected, builder.resp.PeersRemoved)
|
||||
@@ -286,12 +286,12 @@ func TestMapResponseBuilder_EmptyPeerChangedPatch(t *testing.T) {
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithPeerChangedPatch([]*tailcfg.PeerChange{})
|
||||
|
||||
|
||||
assert.Empty(t, builder.resp.PeersChangedPatch)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
@@ -303,12 +303,12 @@ func TestMapResponseBuilder_NilPeerChangedPatch(t *testing.T) {
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithPeerChangedPatch(nil)
|
||||
|
||||
|
||||
assert.Nil(t, builder.resp.PeersChangedPatch)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
@@ -320,28 +320,28 @@ func TestMapResponseBuilder_MultipleErrors(t *testing.T) {
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
|
||||
// Create a builder and add multiple errors
|
||||
builder := m.NewMapResponseBuilder(nodeID)
|
||||
builder.addError(assert.AnError)
|
||||
builder.addError(assert.AnError)
|
||||
builder.addError(nil) // This should be ignored
|
||||
|
||||
|
||||
// All subsequent calls should continue to work
|
||||
result := builder.
|
||||
WithDomain().
|
||||
WithCollectServicesDisabled()
|
||||
|
||||
|
||||
assert.True(t, result.hasErrors())
|
||||
assert.Len(t, result.errs, 2) // nil error should be ignored
|
||||
|
||||
|
||||
// Build should return a multierr
|
||||
data, err := result.Build("none")
|
||||
data, err := result.Build()
|
||||
assert.Nil(t, data)
|
||||
assert.Error(t, err)
|
||||
|
||||
|
||||
// The error should contain information about multiple errors
|
||||
assert.Contains(t, err.Error(), "multiple errors")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"os"
|
||||
"path"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -154,7 +155,7 @@ func (m *mapper) fullMapResponse(
|
||||
WithUserProfiles(peers).
|
||||
WithPacketFilters().
|
||||
WithPeers(peers).
|
||||
Build(messages...)
|
||||
Build()
|
||||
}
|
||||
|
||||
func (m *mapper) derpMapResponse(
|
||||
@@ -207,36 +208,15 @@ func (m *mapper) peerRemovedResponse(
|
||||
|
||||
func writeDebugMapResponse(
|
||||
resp *tailcfg.MapResponse,
|
||||
nodeID types.NodeID,
|
||||
messages ...string,
|
||||
node *types.Node,
|
||||
) {
|
||||
data := map[string]any{
|
||||
"Messages": messages,
|
||||
"MapResponse": resp,
|
||||
}
|
||||
|
||||
responseType := "keepalive"
|
||||
|
||||
switch {
|
||||
case len(resp.Peers) > 0:
|
||||
responseType = "full"
|
||||
case resp.Peers == nil && resp.PeersChanged == nil && resp.PeersChangedPatch == nil && resp.DERPMap == nil && !resp.KeepAlive:
|
||||
responseType = "self"
|
||||
case len(resp.PeersChanged) > 0:
|
||||
responseType = "changed"
|
||||
case len(resp.PeersChangedPatch) > 0:
|
||||
responseType = "patch"
|
||||
case len(resp.PeersRemoved) > 0:
|
||||
responseType = "removed"
|
||||
}
|
||||
|
||||
body, err := json.MarshalIndent(data, "", " ")
|
||||
body, err := json.MarshalIndent(resp, "", " ")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
perms := fs.FileMode(debugMapResponsePerm)
|
||||
mPath := path.Join(debugDumpMapResponsePath, nodeID.String())
|
||||
mPath := path.Join(debugDumpMapResponsePath, fmt.Sprintf("%d", node.ID))
|
||||
err = os.MkdirAll(mPath, perms)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@@ -246,7 +226,7 @@ func writeDebugMapResponse(
|
||||
|
||||
mapResponsePath := path.Join(
|
||||
mPath,
|
||||
fmt.Sprintf("%s-%s.json", now, responseType),
|
||||
fmt.Sprintf("%s.json", now),
|
||||
)
|
||||
|
||||
log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath)
|
||||
@@ -279,3 +259,62 @@ func (m *mapper) listPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.
|
||||
// netip.Prefixes that are allowed for that node. It is used to filter routes
|
||||
// from the primary route manager to the node.
|
||||
type routeFilterFunc func(id types.NodeID) []netip.Prefix
|
||||
|
||||
func (m *mapper) debugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) {
|
||||
if debugDumpMapResponsePath == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
nodes, err := os.ReadDir(debugDumpMapResponsePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make(map[types.NodeID][]tailcfg.MapResponse)
|
||||
for _, node := range nodes {
|
||||
if !node.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
nodeIDu, err := strconv.ParseUint(node.Name(), 10, 64)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Parsing node ID from dir %s", node.Name())
|
||||
continue
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(nodeIDu)
|
||||
|
||||
files, err := os.ReadDir(path.Join(debugDumpMapResponsePath, node.Name()))
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Reading dir %s", node.Name())
|
||||
continue
|
||||
}
|
||||
|
||||
slices.SortStableFunc(files, func(a, b fs.DirEntry) int {
|
||||
return strings.Compare(a.Name(), b.Name())
|
||||
})
|
||||
|
||||
for _, file := range files {
|
||||
if file.IsDir() || !strings.HasSuffix(file.Name(), ".json") {
|
||||
continue
|
||||
}
|
||||
|
||||
body, err := os.ReadFile(path.Join(debugDumpMapResponsePath, node.Name(), file.Name()))
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Reading file %s", file.Name())
|
||||
continue
|
||||
}
|
||||
|
||||
var resp tailcfg.MapResponse
|
||||
err = json.Unmarshal(body, &resp)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Unmarshalling file %s", file.Name())
|
||||
continue
|
||||
}
|
||||
|
||||
result[nodeID] = append(result[nodeID], resp)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user