diff --git a/Dockerfile.debug b/Dockerfile.debug index e5066060..cf55bd74 100644 --- a/Dockerfile.debug +++ b/Dockerfile.debug @@ -8,7 +8,7 @@ ENV GOPATH /go WORKDIR /go/src/headscale RUN apt-get update \ - && apt-get install --no-install-recommends --yes less jq \ + && apt-get install --no-install-recommends --yes less jq sqlite3 \ && rm -rf /var/lib/apt/lists/* \ && apt-get clean RUN mkdir -p /var/run/headscale diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index e5a47953..44faeb91 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -20,9 +20,14 @@ import ( "gorm.io/driver/postgres" "gorm.io/gorm" "gorm.io/gorm/logger" + "gorm.io/gorm/schema" "tailscale.com/util/set" ) +func init() { + schema.RegisterSerializer("text", TextSerialiser{}) +} + var errDatabaseNotSupported = errors.New("database type not supported") // KV is a key-value store in a psql table. For future use... @@ -33,7 +38,8 @@ type KV struct { } type HSDatabase struct { - DB *gorm.DB + DB *gorm.DB + cfg *types.DatabaseConfig baseDomain string } @@ -191,7 +197,7 @@ func NewHeadscaleDatabase( type NodeAux struct { ID uint64 - EnabledRoutes types.IPPrefixes + EnabledRoutes []netip.Prefix `gorm:"serializer:json"` } nodesAux := []NodeAux{} @@ -214,7 +220,7 @@ func NewHeadscaleDatabase( } err = tx.Preload("Node"). - Where("node_id = ? AND prefix = ?", node.ID, types.IPPrefix(prefix)). + Where("node_id = ? AND prefix = ?", node.ID, prefix). First(&types.Route{}). Error if err == nil { @@ -229,7 +235,7 @@ func NewHeadscaleDatabase( NodeID: node.ID, Advertised: true, Enabled: true, - Prefix: types.IPPrefix(prefix), + Prefix: prefix, } if err := tx.Create(&route).Error; err != nil { log.Error().Err(err).Msg("Error creating route") @@ -476,7 +482,8 @@ func NewHeadscaleDatabase( } db := HSDatabase{ - DB: dbConn, + DB: dbConn, + cfg: &cfg, baseDomain: baseDomain, } @@ -676,6 +683,10 @@ func (hsdb *HSDatabase) Close() error { return err } + if hsdb.cfg.Type == types.DatabaseSqlite && hsdb.cfg.Sqlite.WriteAheadLog { + db.Exec("VACUUM") + } + return db.Close() } diff --git a/hscontrol/db/db_test.go b/hscontrol/db/db_test.go index 157ede8b..d92a73e5 100644 --- a/hscontrol/db/db_test.go +++ b/hscontrol/db/db_test.go @@ -13,13 +13,14 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" "github.com/stretchr/testify/assert" "gorm.io/gorm" ) func TestMigrations(t *testing.T) { - ipp := func(p string) types.IPPrefix { - return types.IPPrefix(netip.MustParsePrefix(p)) + ipp := func(p string) netip.Prefix { + return netip.MustParsePrefix(p) } r := func(id uint64, p string, a, e, i bool) types.Route { return types.Route{ @@ -56,9 +57,7 @@ func TestMigrations(t *testing.T) { r(31, "::/0", true, false, false), r(32, "192.168.0.24/32", true, true, true), } - if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), cmp.Comparer(func(x, y types.IPPrefix) bool { - return x == y - })); diff != "" { + if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), util.PrefixComparer); diff != "" { t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff) } }, @@ -103,9 +102,7 @@ func TestMigrations(t *testing.T) { r(13, "::/0", true, true, false), r(13, "10.18.80.2/32", true, true, true), } - if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), cmp.Comparer(func(x, y types.IPPrefix) bool { - return x == y - })); diff != "" { + if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), util.PrefixComparer); diff != "" { t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff) } }, @@ -172,6 +169,29 @@ func TestMigrations(t *testing.T) { } }, }, + { + dbPath: "testdata/0-23-0-to-0-24-0-no-more-special-types.sqlite", + wantFunc: func(t *testing.T, h *HSDatabase) { + nodes, err := Read(h.DB, func(rx *gorm.DB) (types.Nodes, error) { + return ListNodes(rx) + }) + assert.NoError(t, err) + + for _, node := range nodes { + assert.Falsef(t, node.MachineKey.IsZero(), "expected non zero machinekey") + assert.Contains(t, node.MachineKey.String(), "mkey:") + assert.Falsef(t, node.NodeKey.IsZero(), "expected non zero nodekey") + assert.Contains(t, node.NodeKey.String(), "nodekey:") + assert.Falsef(t, node.DiscoKey.IsZero(), "expected non zero discokey") + assert.Contains(t, node.DiscoKey.String(), "discokey:") + assert.NotNil(t, node.IPv4) + assert.NotNil(t, node.IPv4) + assert.Len(t, node.Endpoints, 1) + assert.NotNil(t, node.Hostinfo) + assert.NotNil(t, node.MachineKey) + } + }, + }, } for _, tt := range tests { diff --git a/hscontrol/db/ip_test.go b/hscontrol/db/ip_test.go index b56d2d74..b9a75823 100644 --- a/hscontrol/db/ip_test.go +++ b/hscontrol/db/ip_test.go @@ -1,7 +1,6 @@ package db import ( - "database/sql" "fmt" "net/netip" "strings" @@ -294,15 +293,7 @@ func TestBackfillIPAddresses(t *testing.T) { v4 := fmt.Sprintf("100.64.0.%d", i) v6 := fmt.Sprintf("fd7a:115c:a1e0::%d", i) return &types.Node{ - IPv4DatabaseField: sql.NullString{ - Valid: true, - String: v4, - }, IPv4: nap(v4), - IPv6DatabaseField: sql.NullString{ - Valid: true, - String: v6, - }, IPv6: nap(v6), } } @@ -334,15 +325,7 @@ func TestBackfillIPAddresses(t *testing.T) { want: types.Nodes{ &types.Node{ - IPv4DatabaseField: sql.NullString{ - Valid: true, - String: "100.64.0.1", - }, IPv4: nap("100.64.0.1"), - IPv6DatabaseField: sql.NullString{ - Valid: true, - String: "fd7a:115c:a1e0::1", - }, IPv6: nap("fd7a:115c:a1e0::1"), }, }, @@ -367,15 +350,7 @@ func TestBackfillIPAddresses(t *testing.T) { want: types.Nodes{ &types.Node{ - IPv4DatabaseField: sql.NullString{ - Valid: true, - String: "100.64.0.1", - }, IPv4: nap("100.64.0.1"), - IPv6DatabaseField: sql.NullString{ - Valid: true, - String: "fd7a:115c:a1e0::1", - }, IPv6: nap("fd7a:115c:a1e0::1"), }, }, @@ -400,10 +375,6 @@ func TestBackfillIPAddresses(t *testing.T) { want: types.Nodes{ &types.Node{ - IPv4DatabaseField: sql.NullString{ - Valid: true, - String: "100.64.0.1", - }, IPv4: nap("100.64.0.1"), }, }, @@ -428,10 +399,6 @@ func TestBackfillIPAddresses(t *testing.T) { want: types.Nodes{ &types.Node{ - IPv6DatabaseField: sql.NullString{ - Valid: true, - String: "fd7a:115c:a1e0::1", - }, IPv6: nap("fd7a:115c:a1e0::1"), }, }, @@ -477,13 +444,9 @@ func TestBackfillIPAddresses(t *testing.T) { comps := append(util.Comparers, cmpopts.IgnoreFields(types.Node{}, "ID", - "MachineKeyDatabaseField", - "NodeKeyDatabaseField", - "DiscoKeyDatabaseField", "User", "UserID", "Endpoints", - "HostinfoDatabaseField", "Hostinfo", "Routes", "CreatedAt", diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index 639354b3..a4cd9e0b 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -1,6 +1,7 @@ package db import ( + "encoding/json" "errors" "fmt" "net/netip" @@ -207,21 +208,26 @@ func SetTags( ) error { if len(tags) == 0 { // if no tags are provided, we remove all forced tags - if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", types.StringList{}).Error; err != nil { + if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", "[]").Error; err != nil { return fmt.Errorf("failed to remove tags for node in the database: %w", err) } return nil } - var newTags types.StringList + var newTags []string for _, tag := range tags { if !slices.Contains(newTags, tag) { newTags = append(newTags, tag) } } - if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", newTags).Error; err != nil { + b, err := json.Marshal(newTags) + if err != nil { + return err + } + + if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", string(b)).Error; err != nil { return fmt.Errorf("failed to update tags for node in the database: %w", err) } @@ -569,7 +575,7 @@ func enableRoutes(tx *gorm.DB, for _, prefix := range newRoutes { route := types.Route{} err := tx.Preload("Node"). - Where("node_id = ? AND prefix = ?", node.ID, types.IPPrefix(prefix)). + Where("node_id = ? AND prefix = ?", node.ID, prefix.String()). First(&route).Error if err == nil { route.Enabled = true diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 8451a906..1edaa06e 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -201,7 +201,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { nodeKey := key.NewNode() machineKey := key.NewMachine() - v4 := netip.MustParseAddr(fmt.Sprintf("100.64.0.%v", strconv.Itoa(index+1))) + v4 := netip.MustParseAddr(fmt.Sprintf("100.64.0.%d", index+1)) node := types.Node{ ID: types.NodeID(index), MachineKey: machineKey.Public(), @@ -239,6 +239,8 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { adminNode, err := db.GetNodeByID(1) c.Logf("Node(%v), user: %v", adminNode.Hostname, adminNode.User) + c.Assert(adminNode.IPv4, check.NotNil) + c.Assert(adminNode.IPv6, check.IsNil) c.Assert(err, check.IsNil) testNode, err := db.GetNodeByID(2) @@ -247,9 +249,11 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { adminPeers, err := db.ListPeers(adminNode.ID) c.Assert(err, check.IsNil) + c.Assert(len(adminPeers), check.Equals, 9) testPeers, err := db.ListPeers(testNode.ID) c.Assert(err, check.IsNil) + c.Assert(len(testPeers), check.Equals, 9) adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers) c.Assert(err, check.IsNil) @@ -259,14 +263,14 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules) peersOfTestNode := policy.FilterNodesByACL(testNode, testPeers, testRules) - + c.Log(peersOfAdminNode) c.Log(peersOfTestNode) + c.Assert(len(peersOfTestNode), check.Equals, 9) c.Assert(peersOfTestNode[0].Hostname, check.Equals, "testnode1") c.Assert(peersOfTestNode[1].Hostname, check.Equals, "testnode3") c.Assert(peersOfTestNode[3].Hostname, check.Equals, "testnode5") - c.Log(peersOfAdminNode) c.Assert(len(peersOfAdminNode), check.Equals, 9) c.Assert(peersOfAdminNode[0].Hostname, check.Equals, "testnode2") c.Assert(peersOfAdminNode[2].Hostname, check.Equals, "testnode4") @@ -346,7 +350,7 @@ func (s *Suite) TestSetTags(c *check.C) { c.Assert(err, check.IsNil) node, err = db.getNode("test", "testnode") c.Assert(err, check.IsNil) - c.Assert(node.ForcedTags, check.DeepEquals, types.StringList(sTags)) + c.Assert(node.ForcedTags, check.DeepEquals, sTags) // assign duplicate tags, expect no errors but no doubles in DB eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"} @@ -357,7 +361,7 @@ func (s *Suite) TestSetTags(c *check.C) { c.Assert( node.ForcedTags, check.DeepEquals, - types.StringList([]string{"tag:bar", "tag:test", "tag:unknown"}), + []string{"tag:bar", "tag:test", "tag:unknown"}, ) // test removing tags @@ -365,7 +369,7 @@ func (s *Suite) TestSetTags(c *check.C) { c.Assert(err, check.IsNil) node, err = db.getNode("test", "testnode") c.Assert(err, check.IsNil) - c.Assert(node.ForcedTags, check.DeepEquals, types.StringList([]string{})) + c.Assert(node.ForcedTags, check.DeepEquals, []string{}) } func TestHeadscale_generateGivenName(t *testing.T) { diff --git a/hscontrol/db/preauth_keys.go b/hscontrol/db/preauth_keys.go index 96420211..feacde61 100644 --- a/hscontrol/db/preauth_keys.go +++ b/hscontrol/db/preauth_keys.go @@ -77,7 +77,7 @@ func CreatePreAuthKey( Ephemeral: ephemeral, CreatedAt: &now, Expiration: expiration, - Tags: types.StringList(aclTags), + Tags: aclTags, } if err := tx.Save(&key).Error; err != nil { diff --git a/hscontrol/db/routes.go b/hscontrol/db/routes.go index 0012d64e..fa27ea7c 100644 --- a/hscontrol/db/routes.go +++ b/hscontrol/db/routes.go @@ -49,7 +49,7 @@ func getRoutesByPrefix(tx *gorm.DB, pref netip.Prefix) (types.Routes, error) { err := tx. Preload("Node"). Preload("Node.User"). - Where("prefix = ?", types.IPPrefix(pref)). + Where("prefix = ?", pref.String()). Find(&routes).Error if err != nil { return nil, err @@ -286,7 +286,7 @@ func isUniquePrefix(tx *gorm.DB, route types.Route) bool { var count int64 tx.Model(&types.Route{}). Where("prefix = ? AND node_id != ? AND advertised = ? AND enabled = ?", - route.Prefix, + route.Prefix.String(), route.NodeID, true, true).Count(&count) @@ -297,7 +297,7 @@ func getPrimaryRoute(tx *gorm.DB, prefix netip.Prefix) (*types.Route, error) { var route types.Route err := tx. Preload("Node"). - Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", types.IPPrefix(prefix), true, true, true). + Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", prefix.String(), true, true, true). First(&route).Error if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return nil, err @@ -392,7 +392,7 @@ func SaveNodeRoutes(tx *gorm.DB, node *types.Node) (bool, error) { if !exists { route := types.Route{ NodeID: node.ID.Uint64(), - Prefix: types.IPPrefix(prefix), + Prefix: prefix, Advertised: true, Enabled: false, } diff --git a/hscontrol/db/routes_test.go b/hscontrol/db/routes_test.go index d71df312..0e6535f9 100644 --- a/hscontrol/db/routes_test.go +++ b/hscontrol/db/routes_test.go @@ -290,7 +290,7 @@ func (s *Suite) TestDeleteRoutes(c *check.C) { } var ( - ipp = func(s string) types.IPPrefix { return types.IPPrefix(netip.MustParsePrefix(s)) } + ipp = func(s string) netip.Prefix { return netip.MustParsePrefix(s) } mkNode = func(nid types.NodeID) types.Node { return types.Node{ID: nid} } @@ -301,7 +301,7 @@ var np = func(nid types.NodeID) *types.Node { return &no } -var r = func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary bool) types.Route { +var r = func(id uint, nid types.NodeID, prefix netip.Prefix, enabled, primary bool) types.Route { return types.Route{ Model: gorm.Model{ ID: id, @@ -313,7 +313,7 @@ var r = func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary } } -var rp = func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary bool) *types.Route { +var rp = func(id uint, nid types.NodeID, prefix netip.Prefix, enabled, primary bool) *types.Route { ro := r(id, nid, prefix, enabled, primary) return &ro } @@ -1069,7 +1069,7 @@ func TestFailoverRouteTx(t *testing.T) { } func TestFailoverRoute(t *testing.T) { - r := func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary bool) types.Route { + r := func(id uint, nid types.NodeID, prefix netip.Prefix, enabled, primary bool) types.Route { return types.Route{ Model: gorm.Model{ ID: id, @@ -1082,7 +1082,7 @@ func TestFailoverRoute(t *testing.T) { IsPrimary: primary, } } - rp := func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary bool) *types.Route { + rp := func(id uint, nid types.NodeID, prefix netip.Prefix, enabled, primary bool) *types.Route { ro := r(id, nid, prefix, enabled, primary) return &ro } @@ -1205,13 +1205,6 @@ func TestFailoverRoute(t *testing.T) { }, } - cmps := append( - util.Comparers, - cmp.Comparer(func(x, y types.IPPrefix) bool { - return netip.Prefix(x) == netip.Prefix(y) - }), - ) - for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { gotf := failoverRoute(smap(tt.isConnected), &tt.failingRoute, tt.routes) @@ -1235,7 +1228,7 @@ func TestFailoverRoute(t *testing.T) { "old": gotf.old, } - if diff := cmp.Diff(want, got, cmps...); diff != "" { + if diff := cmp.Diff(want, got, util.Comparers...); diff != "" { t.Fatalf("failoverRoute unexpected result (-want +got):\n%s", diff) } } diff --git a/hscontrol/db/testdata/0-23-0-to-0-24-0-no-more-special-types.sqlite b/hscontrol/db/testdata/0-23-0-to-0-24-0-no-more-special-types.sqlite new file mode 100644 index 00000000..53d6c327 Binary files /dev/null and b/hscontrol/db/testdata/0-23-0-to-0-24-0-no-more-special-types.sqlite differ diff --git a/hscontrol/db/text_serialiser.go b/hscontrol/db/text_serialiser.go new file mode 100644 index 00000000..9c0beef4 --- /dev/null +++ b/hscontrol/db/text_serialiser.go @@ -0,0 +1,99 @@ +package db + +import ( + "context" + "encoding" + "fmt" + "reflect" + + "gorm.io/gorm/schema" +) + +// Got from https://github.com/xdg-go/strum/blob/main/types.go +var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() + +func isTextUnmarshaler(rv reflect.Value) bool { + return rv.Type().Implements(textUnmarshalerType) +} + +func maybeInstantiatePtr(rv reflect.Value) { + if rv.Kind() == reflect.Ptr && rv.IsNil() { + np := reflect.New(rv.Type().Elem()) + rv.Set(np) + } +} + +func decodingError(name string, err error) error { + return fmt.Errorf("error decoding to %s: %w", name, err) +} + +// TextSerialiser implements the Serialiser interface for fields that +// have a type that implements encoding.TextUnmarshaler. +type TextSerialiser struct{} + +func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect.Value, dbValue interface{}) (err error) { + fieldValue := reflect.New(field.FieldType) + + // If the field is a pointer, we need to dereference it to get the actual type + // so we do not end with a second pointer. + if fieldValue.Elem().Kind() == reflect.Ptr { + fieldValue = fieldValue.Elem() + } + + if dbValue != nil { + var bytes []byte + switch v := dbValue.(type) { + case []byte: + bytes = v + case string: + bytes = []byte(v) + default: + return fmt.Errorf("failed to unmarshal text value: %#v", dbValue) + } + + if isTextUnmarshaler(fieldValue) { + maybeInstantiatePtr(fieldValue) + f := fieldValue.MethodByName("UnmarshalText") + args := []reflect.Value{reflect.ValueOf(bytes)} + ret := f.Call(args) + if !ret[0].IsNil() { + return decodingError(field.Name, ret[0].Interface().(error)) + } + + // If the underlying field is to a pointer type, we need to + // assign the value as a pointer to it. + // If it is not a pointer, we need to assign the value to the + // field. + dstField := field.ReflectValueOf(ctx, dst) + if dstField.Kind() == reflect.Ptr { + dstField.Set(fieldValue) + } else { + dstField.Set(fieldValue.Elem()) + } + return nil + } else { + return fmt.Errorf("unsupported type: %T", fieldValue.Interface()) + } + } + + return +} + +func (TextSerialiser) Value(ctx context.Context, field *schema.Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { + switch v := fieldValue.(type) { + case encoding.TextMarshaler: + // If the value is nil, we return nil, however, go nil values are not + // always comparable, particularly when reflection is involved: + // https://dev.to/arxeiss/in-go-nil-is-not-equal-to-nil-sometimes-jn8 + if v == nil || (reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil()) { + return nil, nil + } + b, err := v.MarshalText() + if err != nil { + return nil, err + } + return string(b), nil + default: + return nil, fmt.Errorf("only encoding.TextMarshaler is supported, got %t", v) + } +} diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index 89db69dc..24355993 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -196,19 +196,19 @@ func Test_fullMapResponse(t *testing.T) { Hostinfo: &tailcfg.Hostinfo{}, Routes: []types.Route{ { - Prefix: types.IPPrefix(tsaddr.AllIPv4()), + Prefix: tsaddr.AllIPv4(), Advertised: true, Enabled: true, IsPrimary: false, }, { - Prefix: types.IPPrefix(netip.MustParsePrefix("192.168.0.0/24")), + Prefix: netip.MustParsePrefix("192.168.0.0/24"), Advertised: true, Enabled: true, IsPrimary: true, }, { - Prefix: types.IPPrefix(netip.MustParsePrefix("172.0.0.0/10")), + Prefix: netip.MustParsePrefix("172.0.0.0/10"), Advertised: true, Enabled: false, IsPrimary: true, diff --git a/hscontrol/mapper/tail_test.go b/hscontrol/mapper/tail_test.go index 6e22cdcf..b6692c16 100644 --- a/hscontrol/mapper/tail_test.go +++ b/hscontrol/mapper/tail_test.go @@ -109,19 +109,19 @@ func TestTailNode(t *testing.T) { Hostinfo: &tailcfg.Hostinfo{}, Routes: []types.Route{ { - Prefix: types.IPPrefix(tsaddr.AllIPv4()), + Prefix: tsaddr.AllIPv4(), Advertised: true, Enabled: true, IsPrimary: false, }, { - Prefix: types.IPPrefix(netip.MustParsePrefix("192.168.0.0/24")), + Prefix: netip.MustParsePrefix("192.168.0.0/24"), Advertised: true, Enabled: true, IsPrimary: true, }, { - Prefix: types.IPPrefix(netip.MustParsePrefix("172.0.0.0/10")), + Prefix: netip.MustParsePrefix("172.0.0.0/10"), Advertised: true, Enabled: false, IsPrimary: true, diff --git a/hscontrol/policy/acls.go b/hscontrol/policy/acls.go index f657d26f..7a552456 100644 --- a/hscontrol/policy/acls.go +++ b/hscontrol/policy/acls.go @@ -595,6 +595,11 @@ func (pol *ACLPolicy) ExpandAlias( // excludeCorrectlyTaggedNodes will remove from the list of input nodes the ones // that are correctly tagged since they should not be listed as being in the user // we assume in this function that we only have nodes from 1 user. +// +// TODO(kradalby): It is quite hard to understand what this function is doing, +// it seems like it trying to ensure that we dont include nodes that are tagged +// when we look up the nodes owned by a user. +// This should be refactored to be more clear as part of the Tags work in #1369 func excludeCorrectlyTaggedNodes( aclPolicy *ACLPolicy, nodes types.Nodes, @@ -613,17 +618,16 @@ func excludeCorrectlyTaggedNodes( for _, node := range nodes { found := false - if node.Hostinfo == nil { - continue - } + if node.Hostinfo != nil { + for _, t := range node.Hostinfo.RequestTags { + if slices.Contains(tags, t) { + found = true - for _, t := range node.Hostinfo.RequestTags { - if slices.Contains(tags, t) { - found = true - - break + break + } } } + if len(node.ForcedTags) > 0 { found = true } @@ -981,7 +985,10 @@ func FilterNodesByACL( continue } + log.Printf("Checking if %s can access %s", node.Hostname, peer.Hostname) + if node.CanAccess(filter, nodes[index]) || peer.CanAccess(filter, node) { + log.Printf("CAN ACCESS %s can access %s", node.Hostname, peer.Hostname) result = append(result, peer) } } diff --git a/hscontrol/policy/acls_test.go b/hscontrol/policy/acls_test.go index 20981224..cfcba77a 100644 --- a/hscontrol/policy/acls_test.go +++ b/hscontrol/policy/acls_test.go @@ -2385,7 +2385,7 @@ func TestReduceFilterRules(t *testing.T) { Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")}, }, - ForcedTags: types.StringList{"tag:access-servers"}, + ForcedTags: []string{"tag:access-servers"}, }, peers: types.Nodes{ &types.Node{ @@ -3182,7 +3182,7 @@ func Test_getFilteredByACLPeers(t *testing.T) { Routes: types.Routes{ types.Route{ NodeID: 2, - Prefix: types.IPPrefix(netip.MustParsePrefix("10.33.0.0/16")), + Prefix: netip.MustParsePrefix("10.33.0.0/16"), IsPrimary: true, Enabled: true, }, @@ -3215,7 +3215,7 @@ func Test_getFilteredByACLPeers(t *testing.T) { Routes: types.Routes{ types.Route{ NodeID: 2, - Prefix: types.IPPrefix(netip.MustParsePrefix("10.33.0.0/16")), + Prefix: netip.MustParsePrefix("10.33.0.0/16"), IsPrimary: true, Enabled: true, }, @@ -3225,13 +3225,6 @@ func Test_getFilteredByACLPeers(t *testing.T) { }, } - // TODO(kradalby): Remove when we have gotten rid of IPPrefix type - prefixComparer := cmp.Comparer(func(x, y types.IPPrefix) bool { - return x == y - }) - comparers := append([]cmp.Option{}, util.Comparers...) - comparers = append(comparers, prefixComparer) - for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := FilterNodesByACL( @@ -3239,7 +3232,7 @@ func Test_getFilteredByACLPeers(t *testing.T) { tt.args.nodes, tt.args.rules, ) - if diff := cmp.Diff(tt.want, got, comparers...); diff != "" { + if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" { t.Errorf("FilterNodesByACL() unexpected result (-want +got):\n%s", diff) } }) diff --git a/hscontrol/poll.go b/hscontrol/poll.go index 033639ae..755265f3 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -5,6 +5,7 @@ import ( "fmt" "math/rand/v2" "net/http" + "net/netip" "slices" "strings" "time" @@ -448,13 +449,13 @@ func (m *mapSession) handleEndpointUpdate() { sendUpdate, routesChanged := hostInfoChanged(m.node.Hostinfo, m.req.Hostinfo) // The node might not set NetInfo if it has not changed and if - // the full HostInfo object is overrwritten, the information is lost. + // the full HostInfo object is overwritten, the information is lost. // If there is no NetInfo, keep the previous one. // From 1.66 the client only sends it if changed: // https://github.com/tailscale/tailscale/commit/e1011f138737286ecf5123ff887a7a5800d129a2 // TODO(kradalby): evaulate if we need better comparing of hostinfo // before we take the changes. - if m.req.Hostinfo.NetInfo == nil { + if m.req.Hostinfo.NetInfo == nil && m.node.Hostinfo != nil { m.req.Hostinfo.NetInfo = m.node.Hostinfo.NetInfo } m.node.Hostinfo = m.req.Hostinfo @@ -661,8 +662,15 @@ func hostInfoChanged(old, new *tailcfg.Hostinfo) (bool, bool) { return false, false } + if old == nil && new != nil { + return true, true + } + // Routes - oldRoutes := old.RoutableIPs + oldRoutes := make([]netip.Prefix, 0) + if old != nil { + oldRoutes = old.RoutableIPs + } newRoutes := new.RoutableIPs tsaddr.SortPrefixes(oldRoutes) diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index 35f5e5e4..32ad8a67 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -2,11 +2,7 @@ package types import ( "context" - "database/sql/driver" - "encoding/json" "errors" - "fmt" - "net/netip" "time" "tailscale.com/tailcfg" @@ -21,74 +17,6 @@ const ( var ErrCannotParsePrefix = errors.New("cannot parse prefix") -type IPPrefix netip.Prefix - -func (i *IPPrefix) Scan(destination interface{}) error { - switch value := destination.(type) { - case string: - prefix, err := netip.ParsePrefix(value) - if err != nil { - return err - } - *i = IPPrefix(prefix) - - return nil - default: - return fmt.Errorf("%w: unexpected data type %T", ErrCannotParsePrefix, destination) - } -} - -// Value return json value, implement driver.Valuer interface. -func (i IPPrefix) Value() (driver.Value, error) { - prefixStr := netip.Prefix(i).String() - - return prefixStr, nil -} - -type IPPrefixes []netip.Prefix - -func (i *IPPrefixes) Scan(destination interface{}) error { - switch value := destination.(type) { - case []byte: - return json.Unmarshal(value, i) - - case string: - return json.Unmarshal([]byte(value), i) - - default: - return fmt.Errorf("%w: unexpected data type %T", ErrNodeAddressesInvalid, destination) - } -} - -// Value return json value, implement driver.Valuer interface. -func (i IPPrefixes) Value() (driver.Value, error) { - bytes, err := json.Marshal(i) - - return string(bytes), err -} - -type StringList []string - -func (i *StringList) Scan(destination interface{}) error { - switch value := destination.(type) { - case []byte: - return json.Unmarshal(value, i) - - case string: - return json.Unmarshal([]byte(value), i) - - default: - return fmt.Errorf("%w: unexpected data type %T", ErrNodeAddressesInvalid, destination) - } -} - -// Value return json value, implement driver.Valuer interface. -func (i StringList) Value() (driver.Value, error) { - bytes, err := json.Marshal(i) - - return string(bytes), err -} - type StateUpdateType int func (su StateUpdateType) String() string { diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index 04ca9f8d..0eb937a1 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -1,8 +1,6 @@ package types import ( - "database/sql" - "encoding/json" "errors" "fmt" "net/netip" @@ -15,7 +13,6 @@ import ( "github.com/juanfont/headscale/hscontrol/util" "go4.org/netipx" "google.golang.org/protobuf/types/known/timestamppb" - "gorm.io/gorm" "tailscale.com/tailcfg" "tailscale.com/types/key" ) @@ -51,54 +48,16 @@ func (id NodeID) String() string { type Node struct { ID NodeID `gorm:"primary_key"` - // MachineKeyDatabaseField is the string representation of MachineKey - // it is _only_ used for reading and writing the key to the - // database and should not be used. - // Use MachineKey instead. - MachineKeyDatabaseField string `gorm:"column:machine_key;unique_index"` - MachineKey key.MachinePublic `gorm:"-"` + MachineKey key.MachinePublic `gorm:"serializer:text"` + NodeKey key.NodePublic `gorm:"serializer:text"` + DiscoKey key.DiscoPublic `gorm:"serializer:text"` - // NodeKeyDatabaseField is the string representation of NodeKey - // it is _only_ used for reading and writing the key to the - // database and should not be used. - // Use NodeKey instead. - NodeKeyDatabaseField string `gorm:"column:node_key"` - NodeKey key.NodePublic `gorm:"-"` + Endpoints []netip.AddrPort `gorm:"serializer:json"` - // DiscoKeyDatabaseField is the string representation of DiscoKey - // it is _only_ used for reading and writing the key to the - // database and should not be used. - // Use DiscoKey instead. - DiscoKeyDatabaseField string `gorm:"column:disco_key"` - DiscoKey key.DiscoPublic `gorm:"-"` + Hostinfo *tailcfg.Hostinfo `gorm:"serializer:json"` - // EndpointsDatabaseField is the string list representation of Endpoints - // it is _only_ used for reading and writing the key to the - // database and should not be used. - // Use Endpoints instead. - EndpointsDatabaseField StringList `gorm:"column:endpoints"` - Endpoints []netip.AddrPort `gorm:"-"` - - // EndpointsDatabaseField is the string list representation of Endpoints - // it is _only_ used for reading and writing the key to the - // database and should not be used. - // Use Endpoints instead. - HostinfoDatabaseField string `gorm:"column:host_info"` - Hostinfo *tailcfg.Hostinfo `gorm:"-"` - - // IPv4DatabaseField is the string representation of v4 address, - // it is _only_ used for reading and writing the key to the - // database and should not be used. - // Use V4 instead. - IPv4DatabaseField sql.NullString `gorm:"column:ipv4"` - IPv4 *netip.Addr `gorm:"-"` - - // IPv6DatabaseField is the string representation of v4 address, - // it is _only_ used for reading and writing the key to the - // database and should not be used. - // Use V6 instead. - IPv6DatabaseField sql.NullString `gorm:"column:ipv6"` - IPv6 *netip.Addr `gorm:"-"` + IPv4 *netip.Addr `gorm:"serializer:text"` + IPv6 *netip.Addr `gorm:"serializer:text"` // Hostname represents the name given by the Tailscale // client during registration @@ -116,7 +75,7 @@ type Node struct { RegisterMethod string - ForcedTags StringList + ForcedTags []string `gorm:"serializer:json"` // TODO(kradalby): This seems like irrelevant information? AuthKeyID *uint64 `sql:"DEFAULT:NULL"` @@ -216,16 +175,20 @@ func (node *Node) CanAccess(filter []tailcfg.FilterRule, node2 *Node) bool { src := node.IPs() allowedIPs := node2.IPs() + // TODO(kradalby): Regenerate this everytime the filter change, instead of + // every time we use it. + matchers := make([]matcher.Match, len(filter)) + for i, rule := range filter { + matchers[i] = matcher.MatchFromFilterRule(rule) + } + for _, route := range node2.Routes { if route.Enabled { allowedIPs = append(allowedIPs, netip.Prefix(route.Prefix).Addr()) } } - for _, rule := range filter { - // TODO(kradalby): Cache or pregen this - matcher := matcher.MatchFromFilterRule(rule) - + for _, matcher := range matchers { if !matcher.SrcsContainsIPs(src) { continue } @@ -255,109 +218,6 @@ func (nodes Nodes) FilterByIP(ip netip.Addr) Nodes { return found } -// BeforeSave is a hook that ensures that some values that -// cannot be directly marshalled into database values are stored -// correctly in the database. -// This currently means storing the keys as strings. -func (node *Node) BeforeSave(tx *gorm.DB) error { - node.MachineKeyDatabaseField = node.MachineKey.String() - node.NodeKeyDatabaseField = node.NodeKey.String() - node.DiscoKeyDatabaseField = node.DiscoKey.String() - - var endpoints StringList - for _, addrPort := range node.Endpoints { - endpoints = append(endpoints, addrPort.String()) - } - - node.EndpointsDatabaseField = endpoints - - hi, err := json.Marshal(node.Hostinfo) - if err != nil { - return fmt.Errorf("marshalling Hostinfo to store in db: %w", err) - } - node.HostinfoDatabaseField = string(hi) - - if node.IPv4 != nil { - node.IPv4DatabaseField.String, node.IPv4DatabaseField.Valid = node.IPv4.String(), true - } else { - node.IPv4DatabaseField.String, node.IPv4DatabaseField.Valid = "", false - } - - if node.IPv6 != nil { - node.IPv6DatabaseField.String, node.IPv6DatabaseField.Valid = node.IPv6.String(), true - } else { - node.IPv6DatabaseField.String, node.IPv6DatabaseField.Valid = "", false - } - - return nil -} - -// AfterFind is a hook that ensures that Node objects fields that -// has a different type in the database is unwrapped and populated -// correctly. -// This currently unmarshals all the keys, stored as strings, into -// the proper types. -func (node *Node) AfterFind(tx *gorm.DB) error { - var machineKey key.MachinePublic - if err := machineKey.UnmarshalText([]byte(node.MachineKeyDatabaseField)); err != nil { - return fmt.Errorf("unmarshalling machine key from db: %w", err) - } - node.MachineKey = machineKey - - var nodeKey key.NodePublic - if err := nodeKey.UnmarshalText([]byte(node.NodeKeyDatabaseField)); err != nil { - return fmt.Errorf("unmarshalling node key from db: %w", err) - } - node.NodeKey = nodeKey - - // DiscoKey might be empty if a node has not sent it to headscale. - // This means that this might fail if the disco key is empty. - if node.DiscoKeyDatabaseField != "" { - var discoKey key.DiscoPublic - if err := discoKey.UnmarshalText([]byte(node.DiscoKeyDatabaseField)); err != nil { - return fmt.Errorf("unmarshalling disco key from db: %w", err) - } - node.DiscoKey = discoKey - } - - endpoints := make([]netip.AddrPort, len(node.EndpointsDatabaseField)) - for idx, ep := range node.EndpointsDatabaseField { - addrPort, err := netip.ParseAddrPort(ep) - if err != nil { - return fmt.Errorf("parsing endpoint from db: %w", err) - } - - endpoints[idx] = addrPort - } - node.Endpoints = endpoints - - var hi tailcfg.Hostinfo - if err := json.Unmarshal([]byte(node.HostinfoDatabaseField), &hi); err != nil { - return fmt.Errorf("unmarshalling hostinfo from database: %w", err) - } - node.Hostinfo = &hi - - if node.IPv4DatabaseField.Valid { - ip, err := netip.ParseAddr(node.IPv4DatabaseField.String) - if err != nil { - return fmt.Errorf("parsing IPv4 from database: %w", err) - } - - node.IPv4 = &ip - } - - if node.IPv6DatabaseField.Valid { - ip, err := netip.ParseAddr(node.IPv6DatabaseField.String) - if err != nil { - return fmt.Errorf("parsing IPv6 from database: %w", err) - } - - node.IPv6 = &ip - } - - return nil -} - func (node *Node) Proto() *v1.Node { nodeProto := &v1.Node{ Id: uint64(node.ID), diff --git a/hscontrol/types/routes.go b/hscontrol/types/routes.go index 04118fa6..1f6b8a77 100644 --- a/hscontrol/types/routes.go +++ b/hscontrol/types/routes.go @@ -17,7 +17,7 @@ type Route struct { Node Node // TODO(kradalby): change this custom type to netip.Prefix - Prefix IPPrefix + Prefix netip.Prefix `gorm:"serializer:text"` Advertised bool Enabled bool @@ -31,7 +31,7 @@ func (r *Route) String() string { } func (r *Route) IsExitRoute() bool { - return tsaddr.IsExitRoute(netip.Prefix(r.Prefix)) + return tsaddr.IsExitRoute(r.Prefix) } func (r *Route) IsAnnouncable() bool { @@ -59,8 +59,8 @@ func (rs Routes) Primaries() Routes { return res } -func (rs Routes) PrefixMap() map[IPPrefix][]Route { - res := map[IPPrefix][]Route{} +func (rs Routes) PrefixMap() map[netip.Prefix][]Route { + res := map[netip.Prefix][]Route{} for _, route := range rs { if _, ok := res[route.Prefix]; ok { @@ -80,7 +80,7 @@ func (rs Routes) Proto() []*v1.Route { protoRoute := v1.Route{ Id: uint64(route.ID), Node: route.Node.Proto(), - Prefix: netip.Prefix(route.Prefix).String(), + Prefix: route.Prefix.String(), Advertised: route.Advertised, Enabled: route.Enabled, IsPrimary: route.IsPrimary, diff --git a/hscontrol/types/routes_test.go b/hscontrol/types/routes_test.go index ead4c595..b3600482 100644 --- a/hscontrol/types/routes_test.go +++ b/hscontrol/types/routes_test.go @@ -10,16 +10,11 @@ import ( ) func TestPrefixMap(t *testing.T) { - ipp := func(s string) IPPrefix { return IPPrefix(netip.MustParsePrefix(s)) } - - // TODO(kradalby): Remove when we have gotten rid of IPPrefix type - prefixComparer := cmp.Comparer(func(x, y IPPrefix) bool { - return x == y - }) + ipp := func(s string) netip.Prefix { return netip.MustParsePrefix(s) } tests := []struct { rs Routes - want map[IPPrefix][]Route + want map[netip.Prefix][]Route }{ { rs: Routes{ @@ -27,7 +22,7 @@ func TestPrefixMap(t *testing.T) { Prefix: ipp("10.0.0.0/24"), }, }, - want: map[IPPrefix][]Route{ + want: map[netip.Prefix][]Route{ ipp("10.0.0.0/24"): Routes{ Route{ Prefix: ipp("10.0.0.0/24"), @@ -44,7 +39,7 @@ func TestPrefixMap(t *testing.T) { Prefix: ipp("10.0.1.0/24"), }, }, - want: map[IPPrefix][]Route{ + want: map[netip.Prefix][]Route{ ipp("10.0.0.0/24"): Routes{ Route{ Prefix: ipp("10.0.0.0/24"), @@ -68,7 +63,7 @@ func TestPrefixMap(t *testing.T) { Enabled: false, }, }, - want: map[IPPrefix][]Route{ + want: map[netip.Prefix][]Route{ ipp("10.0.0.0/24"): Routes{ Route{ Prefix: ipp("10.0.0.0/24"), @@ -86,7 +81,7 @@ func TestPrefixMap(t *testing.T) { for idx, tt := range tests { t.Run(fmt.Sprintf("test-%d", idx), func(t *testing.T) { got := tt.rs.PrefixMap() - if diff := cmp.Diff(tt.want, got, prefixComparer, util.MkeyComparer, util.NkeyComparer, util.DkeyComparer); diff != "" { + if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" { t.Errorf("PrefixMap() unexpected result (-want +got):\n%s", diff) } }) diff --git a/integration/hsic/config.go b/integration/hsic/config.go index 244470f2..509052a3 100644 --- a/integration/hsic/config.go +++ b/integration/hsic/config.go @@ -16,6 +16,8 @@ func DefaultConfigEnv() map[string]string { "HEADSCALE_POLICY_PATH": "", "HEADSCALE_DATABASE_TYPE": "sqlite", "HEADSCALE_DATABASE_SQLITE_PATH": "/tmp/integration_test_db.sqlite3", + "HEADSCALE_DATABASE_DEBUG": "1", + "HEADSCALE_DATABASE_GORM_SLOW_THRESHOLD": "1", "HEADSCALE_EPHEMERAL_NODE_INACTIVITY_TIMEOUT": "30m", "HEADSCALE_PREFIXES_V4": "100.64.0.0/10", "HEADSCALE_PREFIXES_V6": "fd7a:115c:a1e0::/48",