From d604663121b7be049d73a6cf47e2e8c0f36456bd Mon Sep 17 00:00:00 2001
From: seiuneko <seiunekosl@gmail.com>
Date: Tue, 3 Dec 2024 18:54:09 +0800
Subject: [PATCH] test: fix TestDERPVerifyEndpoint

- `tailscale debug derp` use random node private key
---
 integration/derp_verify_endpoint_test.go | 87 +++++++++++++++---------
 integration/tailscale.go                 |  2 +
 integration/tsic/tsic.go                 | 29 ++++++++
 3 files changed, 86 insertions(+), 32 deletions(-)

diff --git a/integration/derp_verify_endpoint_test.go b/integration/derp_verify_endpoint_test.go
index adad5b6a..c27e6d0b 100644
--- a/integration/derp_verify_endpoint_test.go
+++ b/integration/derp_verify_endpoint_test.go
@@ -1,11 +1,10 @@
 package integration
 
 import (
-	"encoding/json"
+	"context"
 	"fmt"
 	"net"
 	"strconv"
-	"strings"
 	"testing"
 
 	"github.com/juanfont/headscale/hscontrol/util"
@@ -13,7 +12,11 @@ import (
 	"github.com/juanfont/headscale/integration/hsic"
 	"github.com/juanfont/headscale/integration/integrationutil"
 	"github.com/juanfont/headscale/integration/tsic"
+	"tailscale.com/derp"
+	"tailscale.com/derp/derphttp"
+	"tailscale.com/net/netmon"
 	"tailscale.com/tailcfg"
+	"tailscale.com/types/key"
 )
 
 func TestDERPVerifyEndpoint(t *testing.T) {
@@ -45,23 +48,24 @@ func TestDERPVerifyEndpoint(t *testing.T) {
 	)
 	assertNoErr(t, err)
 
+	derpRegion := tailcfg.DERPRegion{
+		RegionCode: "test-derpverify",
+		RegionName: "TestDerpVerify",
+		Nodes: []*tailcfg.DERPNode{
+			{
+				Name:             "TestDerpVerify",
+				RegionID:         900,
+				HostName:         derper.GetHostname(),
+				STUNPort:         derper.GetSTUNPort(),
+				STUNOnly:         false,
+				DERPPort:         derper.GetDERPPort(),
+				InsecureForTests: true,
+			},
+		},
+	}
 	derpMap := tailcfg.DERPMap{
 		Regions: map[int]*tailcfg.DERPRegion{
-			900: {
-				RegionID:   900,
-				RegionCode: "test-derpverify",
-				RegionName: "TestDerpVerify",
-				Nodes: []*tailcfg.DERPNode{
-					{
-						Name:     "TestDerpVerify",
-						RegionID: 900,
-						HostName: derper.GetHostname(),
-						STUNPort: derper.GetSTUNPort(),
-						STUNOnly: false,
-						DERPPort: derper.GetDERPPort(),
-					},
-				},
-			},
+			900: &derpRegion,
 		},
 	}
 
@@ -76,21 +80,40 @@ func TestDERPVerifyEndpoint(t *testing.T) {
 	allClients, err := scenario.ListTailscaleClients()
 	assertNoErrListClients(t, err)
 
-	for _, client := range allClients {
-		report, err := client.DebugDERPRegion("test-derpverify")
-		assertNoErr(t, err)
-		successful := false
-		for _, line := range report.Info {
-			if strings.Contains(line, "Successfully established a DERP connection with node") {
-				successful = true
+	fakeKey := key.NewNode()
+	DERPVerify(t, fakeKey, derpRegion, false)
 
-				break
-			}
-		}
-		if !successful {
-			stJSON, err := json.Marshal(report)
-			assertNoErr(t, err)
-			t.Errorf("Client %s could not establish a DERP connection: %s", client.Hostname(), string(stJSON))
-		}
+	for _, client := range allClients {
+		nodeKey, err := client.GetNodePrivateKey()
+		assertNoErr(t, err)
+		DERPVerify(t, *nodeKey, derpRegion, true)
+	}
+}
+
+func DERPVerify(
+	t *testing.T,
+	nodeKey key.NodePrivate,
+	region tailcfg.DERPRegion,
+	expectSuccess bool,
+) {
+	IntegrationSkip(t)
+
+	c := derphttp.NewRegionClient(nodeKey, t.Logf, netmon.NewStatic(), func() *tailcfg.DERPRegion {
+		return &region
+	})
+	var result error
+	if err := c.Connect(context.Background()); err != nil {
+		result = fmt.Errorf("client Connect: %w", err)
+	}
+	if m, err := c.Recv(); err != nil {
+		result = fmt.Errorf("client first Recv: %w", err)
+	} else if v, ok := m.(derp.ServerInfoMessage); !ok {
+		result = fmt.Errorf("client first Recv was unexpected type %T", v)
+	}
+
+	if expectSuccess && result != nil {
+		t.Fatalf("DERP verify failed unexpectedly for client %s. Expected success but got error: %v", nodeKey.Public(), result)
+	} else if !expectSuccess && result == nil {
+		t.Fatalf("DERP verify succeeded unexpectedly for client %s. Expected failure but it succeeded.", nodeKey.Public())
 	}
 }
diff --git a/integration/tailscale.go b/integration/tailscale.go
index da9b8754..de8fdb16 100644
--- a/integration/tailscale.go
+++ b/integration/tailscale.go
@@ -9,6 +9,7 @@ import (
 	"github.com/juanfont/headscale/integration/tsic"
 	"tailscale.com/ipn/ipnstate"
 	"tailscale.com/net/netcheck"
+	"tailscale.com/types/key"
 	"tailscale.com/types/netmap"
 )
 
@@ -31,6 +32,7 @@ type TailscaleClient interface {
 	Status(...bool) (*ipnstate.Status, error)
 	Netmap() (*netmap.NetworkMap, error)
 	DebugDERPRegion(region string) (*ipnstate.DebugDERPRegionReport, error)
+	GetNodePrivateKey() (*key.NodePrivate, error)
 	Netcheck() (*netcheck.Report, error)
 	WaitForNeedsLogin() error
 	WaitForRunning() error
diff --git a/integration/tsic/tsic.go b/integration/tsic/tsic.go
index e63a7b6e..5781cd37 100644
--- a/integration/tsic/tsic.go
+++ b/integration/tsic/tsic.go
@@ -24,7 +24,10 @@ import (
 	"github.com/ory/dockertest/v3/docker"
 	"tailscale.com/ipn"
 	"tailscale.com/ipn/ipnstate"
+	"tailscale.com/ipn/store/mem"
 	"tailscale.com/net/netcheck"
+	"tailscale.com/paths"
+	"tailscale.com/types/key"
 	"tailscale.com/types/netmap"
 )
 
@@ -1149,3 +1152,29 @@ func (t *TailscaleInContainer) ReadFile(path string) ([]byte, error) {
 
 	return out.Bytes(), nil
 }
+
+func (t *TailscaleInContainer) GetNodePrivateKey() (*key.NodePrivate, error) {
+	state, err := t.ReadFile(paths.DefaultTailscaledStateFile())
+	if err != nil {
+		return nil, fmt.Errorf("failed to read state file: %w", err)
+	}
+	store := &mem.Store{}
+	if err = store.LoadFromJSON(state); err != nil {
+		return nil, fmt.Errorf("failed to unmarshal state file: %w", err)
+	}
+
+	currentProfileKey, err := store.ReadState(ipn.CurrentProfileStateKey)
+	if err != nil {
+		return nil, fmt.Errorf("failed to read current profile state key: %w", err)
+	}
+	currentProfile, err := store.ReadState(ipn.StateKey(currentProfileKey))
+	if err != nil {
+		return nil, fmt.Errorf("failed to read current profile state: %w", err)
+	}
+
+	p := &ipn.Prefs{}
+	if err = json.Unmarshal(currentProfile, &p); err != nil {
+		return nil, fmt.Errorf("failed to unmarshal current profile state: %w", err)
+	}
+	return &p.Persist.PrivateNodeKey, nil
+}