From a52f1df1806538368bd671b198fe1e975806ade5 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 20 May 2025 13:57:26 +0200 Subject: [PATCH] policy: remove v1 code (#2600) * policy: remove v1 code Signed-off-by: Kristoffer Dalby * db: update test with v1 removal Signed-off-by: Kristoffer Dalby * integration: start moving to v2 policy Signed-off-by: Kristoffer Dalby * policy: add ssh unmarshal tests Signed-off-by: Kristoffer Dalby * changelog: add entry Signed-off-by: Kristoffer Dalby * policy: remove v1 comment Signed-off-by: Kristoffer Dalby * integration: remove comment out case Signed-off-by: Kristoffer Dalby * cleanup skipv1 Signed-off-by: Kristoffer Dalby * policy: remove v1 prefix workaround Signed-off-by: Kristoffer Dalby * policy: add all node ips if prefix/host is ts ip Signed-off-by: Kristoffer Dalby --------- Signed-off-by: Kristoffer Dalby --- CHANGELOG.md | 5 + hscontrol/db/node_test.go | 3 +- hscontrol/mapper/mapper_test.go | 2 +- hscontrol/policy/pm.go | 25 +- hscontrol/policy/policy_test.go | 41 +- hscontrol/policy/route_approval_test.go | 17 +- hscontrol/policy/v1/acls.go | 996 -------- hscontrol/policy/v1/acls_test.go | 2797 ----------------------- hscontrol/policy/v1/acls_types.go | 123 - hscontrol/policy/v1/policy.go | 188 -- hscontrol/policy/v1/policy_test.go | 180 -- hscontrol/policy/v2/types.go | 433 +++- hscontrol/policy/v2/types_test.go | 258 ++- integration/acl_test.go | 466 ++-- integration/cli_test.go | 98 +- integration/control.go | 4 +- integration/hsic/hsic.go | 17 +- integration/route_test.go | 233 +- integration/scenario.go | 5 - integration/ssh_test.go | 132 +- integration/utils.go | 72 +- 21 files changed, 1258 insertions(+), 4837 deletions(-) delete mode 100644 hscontrol/policy/v1/acls.go delete mode 100644 hscontrol/policy/v1/acls_test.go delete mode 100644 hscontrol/policy/v1/acls_types.go delete mode 100644 hscontrol/policy/v1/policy.go delete mode 100644 hscontrol/policy/v1/policy_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 91a23a05..43c9f2a3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,11 @@ - Policy: Zero or empty destination port is no longer allowed [#2606](https://github.com/juanfont/headscale/pull/2606) +### Changes + +- Remove policy v1 code + [#2600](https://github.com/juanfont/headscale/pull/2600) + ## 0.26.0 (2025-05-14) ### BREAKING diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index fd9313e1..56c967f1 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -435,8 +435,7 @@ func TestAutoApproveRoutes(t *testing.T) { for _, tt := range tests { pmfs := policy.PolicyManagerFuncsForTest([]byte(tt.acl)) for i, pmf := range pmfs { - version := i + 1 - t.Run(fmt.Sprintf("%s-policyv%d", tt.name, version), func(t *testing.T) { + t.Run(fmt.Sprintf("%s-policy-index%d", tt.name, i), func(t *testing.T) { adb, err := newSQLiteTestDB() require.NoError(t, err) diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index dfce60bb..8d2c60bb 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -263,7 +263,7 @@ func Test_fullMapResponse(t *testing.T) { // { // name: "empty-node", // node: types.Node{}, - // pol: &policyv1.ACLPolicy{}, + // pol: &policyv2.Policy{}, // dnsConfig: &tailcfg.DNSConfig{}, // baseDomain: "", // want: nil, diff --git a/hscontrol/policy/pm.go b/hscontrol/policy/pm.go index b90d2efc..c4758929 100644 --- a/hscontrol/policy/pm.go +++ b/hscontrol/policy/pm.go @@ -5,17 +5,11 @@ import ( "github.com/juanfont/headscale/hscontrol/policy/matcher" - policyv1 "github.com/juanfont/headscale/hscontrol/policy/v1" policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" "github.com/juanfont/headscale/hscontrol/types" - "tailscale.com/envknob" "tailscale.com/tailcfg" ) -var ( - polv1 = envknob.Bool("HEADSCALE_POLICY_V1") -) - type PolicyManager interface { // Filter returns the current filter rules for the entire tailnet and the associated matchers. Filter() ([]tailcfg.FilterRule, []matcher.Match) @@ -33,21 +27,13 @@ type PolicyManager interface { DebugString() string } -// NewPolicyManager returns a new policy manager, the version is determined by -// the environment flag "HEADSCALE_POLICY_V1". +// NewPolicyManager returns a new policy manager. func NewPolicyManager(pol []byte, users []types.User, nodes types.Nodes) (PolicyManager, error) { var polMan PolicyManager var err error - if polv1 { - polMan, err = policyv1.NewPolicyManager(pol, users, nodes) - if err != nil { - return nil, err - } - } else { - polMan, err = policyv2.NewPolicyManager(pol, users, nodes) - if err != nil { - return nil, err - } + polMan, err = policyv2.NewPolicyManager(pol, users, nodes) + if err != nil { + return nil, err } return polMan, err @@ -73,9 +59,6 @@ func PolicyManagersForTest(pol []byte, users []types.User, nodes types.Nodes) ([ func PolicyManagerFuncsForTest(pol []byte) []func([]types.User, types.Nodes) (PolicyManager, error) { var polmanFuncs []func([]types.User, types.Nodes) (PolicyManager, error) - polmanFuncs = append(polmanFuncs, func(u []types.User, n types.Nodes) (PolicyManager, error) { - return policyv1.NewPolicyManager(pol, u, n) - }) polmanFuncs = append(polmanFuncs, func(u []types.User, n types.Nodes) (PolicyManager, error) { return policyv2.NewPolicyManager(pol, u, n) }) diff --git a/hscontrol/policy/policy_test.go b/hscontrol/policy/policy_test.go index 00c00f78..83d69eb8 100644 --- a/hscontrol/policy/policy_test.go +++ b/hscontrol/policy/policy_test.go @@ -490,18 +490,6 @@ func TestReduceFilterRules(t *testing.T) { {IP: "16.0.0.0/4", Ports: tailcfg.PortRangeAny}, {IP: "32.0.0.0/3", Ports: tailcfg.PortRangeAny}, {IP: "64.0.0.0/2", Ports: tailcfg.PortRangeAny}, - // This should not be included I believe, seems like - // this is a bug in the v1 code. - // For example: - // If a src or dst includes "64.0.0.0/2:*", it will include 100.64/16 range, which - // means that it will need to fetch the IPv6 addrs of the node to include the full range. - // Clearly, if a user sets the dst to be "64.0.0.0/2:*", it is likely more of a exit node - // and this would be strange behaviour. - // TODO(kradalby): Remove before launch. - {IP: "fd7a:115c:a1e0::1/128", Ports: tailcfg.PortRangeAny}, - {IP: "fd7a:115c:a1e0::2/128", Ports: tailcfg.PortRangeAny}, - {IP: "fd7a:115c:a1e0::100/128", Ports: tailcfg.PortRangeAny}, - // End {IP: "128.0.0.0/3", Ports: tailcfg.PortRangeAny}, {IP: "160.0.0.0/5", Ports: tailcfg.PortRangeAny}, {IP: "168.0.0.0/6", Ports: tailcfg.PortRangeAny}, @@ -824,8 +812,7 @@ func TestReduceFilterRules(t *testing.T) { for _, tt := range tests { for idx, pmf := range PolicyManagerFuncsForTest([]byte(tt.pol)) { - version := idx + 1 - t.Run(fmt.Sprintf("%s-v%d", tt.name, version), func(t *testing.T) { + t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) { var pm PolicyManager var err error pm, err = pmf(users, append(tt.peers, tt.node)) @@ -1644,10 +1631,6 @@ func TestSSHPolicyRules(t *testing.T) { wantSSH *tailcfg.SSHPolicy expectErr bool errorMessage string - - // There are some tests that will not pass on V1 since we do not - // have the same kind of error handling as V2, so we skip them. - skipV1 bool }{ { name: "group-to-user", @@ -1681,10 +1664,6 @@ func TestSSHPolicyRules(t *testing.T) { }, }, }}, - - // It looks like the group implementation in v1 is broken, so - // we skip this test for v1 and not let it hold up v2 replacing it. - skipV1: true, }, { name: "group-to-tag", @@ -1722,10 +1701,6 @@ func TestSSHPolicyRules(t *testing.T) { }, }, }}, - - // It looks like the group implementation in v1 is broken, so - // we skip this test for v1 and not let it hold up v2 replacing it. - skipV1: true, }, { name: "tag-to-user", @@ -1826,10 +1801,6 @@ func TestSSHPolicyRules(t *testing.T) { }, }, }}, - - // It looks like the group implementation in v1 is broken, so - // we skip this test for v1 and not let it hold up v2 replacing it. - skipV1: true, }, { name: "check-period-specified", @@ -1901,7 +1872,6 @@ func TestSSHPolicyRules(t *testing.T) { }`, expectErr: true, errorMessage: `SSH action "invalid" is not valid, must be accept or check`, - skipV1: true, }, { name: "invalid-check-period", @@ -1920,7 +1890,6 @@ func TestSSHPolicyRules(t *testing.T) { }`, expectErr: true, errorMessage: "not a valid duration string", - skipV1: true, }, { name: "multiple-ssh-users-with-autogroup", @@ -1972,18 +1941,12 @@ func TestSSHPolicyRules(t *testing.T) { }`, expectErr: true, errorMessage: "autogroup \"autogroup:invalid\" is not supported", - skipV1: true, }, } for _, tt := range tests { for idx, pmf := range PolicyManagerFuncsForTest([]byte(tt.policy)) { - version := idx + 1 - t.Run(fmt.Sprintf("%s-v%d", tt.name, version), func(t *testing.T) { - if version == 1 && tt.skipV1 { - t.Skip() - } - + t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) { var pm PolicyManager var err error pm, err = pmf(users, append(tt.peers, &tt.targetNode)) diff --git a/hscontrol/policy/route_approval_test.go b/hscontrol/policy/route_approval_test.go index 90d5f98e..19d61d82 100644 --- a/hscontrol/policy/route_approval_test.go +++ b/hscontrol/policy/route_approval_test.go @@ -60,7 +60,6 @@ func TestNodeCanApproveRoute(t *testing.T) { route netip.Prefix policy string canApprove bool - skipV1 bool }{ { name: "allow-all-routes-for-admin-user", @@ -766,10 +765,10 @@ func TestNodeCanApproveRoute(t *testing.T) { canApprove: false, }, { - name: "empty-policy", - node: normalNode, - route: p("192.168.1.0/24"), - policy: `{"acls":[{"action":"accept","src":["*"],"dst":["*:*"]}]}`, + name: "empty-policy", + node: normalNode, + route: p("192.168.1.0/24"), + policy: `{"acls":[{"action":"accept","src":["*"],"dst":["*:*"]}]}`, canApprove: false, }, } @@ -789,13 +788,7 @@ func TestNodeCanApproveRoute(t *testing.T) { } for i, pm := range policyManagers { - versionNum := i + 1 - if versionNum == 1 && tt.skipV1 { - // Skip V1 policy manager for specific tests - continue - } - - t.Run(fmt.Sprintf("PolicyV%d", versionNum), func(t *testing.T) { + t.Run(fmt.Sprintf("policy-index%d", i), func(t *testing.T) { result := pm.NodeCanApproveRoute(&tt.node, tt.route) if diff := cmp.Diff(tt.canApprove, result); diff != "" { diff --git a/hscontrol/policy/v1/acls.go b/hscontrol/policy/v1/acls.go deleted file mode 100644 index 9ab1b244..00000000 --- a/hscontrol/policy/v1/acls.go +++ /dev/null @@ -1,996 +0,0 @@ -package v1 - -import ( - "encoding/json" - "errors" - "fmt" - "io" - "net/netip" - "os" - "slices" - "strconv" - "strings" - "time" - - "github.com/juanfont/headscale/hscontrol/types" - "github.com/juanfont/headscale/hscontrol/util" - "github.com/rs/zerolog/log" - "github.com/tailscale/hujson" - "go4.org/netipx" - "tailscale.com/tailcfg" -) - -var ( - ErrEmptyPolicy = errors.New("empty policy") - ErrInvalidAction = errors.New("invalid action") - ErrInvalidGroup = errors.New("invalid group") - ErrInvalidTag = errors.New("invalid tag") - ErrInvalidPortFormat = errors.New("invalid port format") - ErrWildcardIsNeeded = errors.New("wildcard as port is required for the protocol") -) - -const ( - portRangeBegin = 0 - portRangeEnd = 65535 - expectedTokenItems = 2 -) - -// For some reason golang.org/x/net/internal/iana is an internal package. -const ( - protocolICMP = 1 // Internet Control Message - protocolIGMP = 2 // Internet Group Management - protocolIPv4 = 4 // IPv4 encapsulation - protocolTCP = 6 // Transmission Control - protocolEGP = 8 // Exterior Gateway Protocol - protocolIGP = 9 // any private interior gateway (used by Cisco for their IGRP) - protocolUDP = 17 // User Datagram - protocolGRE = 47 // Generic Routing Encapsulation - protocolESP = 50 // Encap Security Payload - protocolAH = 51 // Authentication Header - protocolIPv6ICMP = 58 // ICMP for IPv6 - protocolSCTP = 132 // Stream Control Transmission Protocol - ProtocolFC = 133 // Fibre Channel -) - -// LoadACLPolicyFromPath loads the ACL policy from the specify path, and generates the ACL rules. -func LoadACLPolicyFromPath(path string) (*ACLPolicy, error) { - log.Debug(). - Str("func", "LoadACLPolicy"). - Str("path", path). - Msg("Loading ACL policy from path") - - policyFile, err := os.Open(path) - if err != nil { - return nil, err - } - defer policyFile.Close() - - policyBytes, err := io.ReadAll(policyFile) - if err != nil { - return nil, err - } - - log.Debug(). - Str("path", path). - Bytes("file", policyBytes). - Msg("Loading ACLs") - - return LoadACLPolicyFromBytes(policyBytes) -} - -func LoadACLPolicyFromBytes(acl []byte) (*ACLPolicy, error) { - var policy ACLPolicy - - ast, err := hujson.Parse(acl) - if err != nil { - return nil, fmt.Errorf("parsing hujson, err: %w", err) - } - - ast.Standardize() - acl = ast.Pack() - - if err := json.Unmarshal(acl, &policy); err != nil { - return nil, fmt.Errorf("unmarshalling policy, err: %w", err) - } - - if policy.IsZero() { - return nil, ErrEmptyPolicy - } - - return &policy, nil -} - -func GenerateFilterAndSSHRulesForTests( - policy *ACLPolicy, - node *types.Node, - peers types.Nodes, - users []types.User, -) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) { - // If there is no policy defined, we default to allow all - if policy == nil { - return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil - } - - rules, err := policy.CompileFilterRules(users, append(peers, node)) - if err != nil { - return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err - } - - log.Trace().Interface("ACL", rules).Str("node", node.GivenName).Msg("ACL rules") - - sshPolicy, err := policy.CompileSSHPolicy(node, users, peers) - if err != nil { - return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err - } - - return rules, sshPolicy, nil -} - -// CompileFilterRules takes a set of nodes and an ACLPolicy and generates a -// set of Tailscale compatible FilterRules used to allow traffic on clients. -func (pol *ACLPolicy) CompileFilterRules( - users []types.User, - nodes types.Nodes, -) ([]tailcfg.FilterRule, error) { - if pol == nil { - return tailcfg.FilterAllowAll, nil - } - - var rules []tailcfg.FilterRule - - for index, acl := range pol.ACLs { - if acl.Action != "accept" { - return nil, ErrInvalidAction - } - - var srcIPs []string - for srcIndex, src := range acl.Sources { - srcs, err := pol.expandSource(src, users, nodes) - if err != nil { - return nil, fmt.Errorf( - "parsing policy, acl index: %d->%d: %w", - index, - srcIndex, - err, - ) - } - srcIPs = append(srcIPs, srcs...) - } - - protocols, isWildcard, err := parseProtocol(acl.Protocol) - if err != nil { - return nil, fmt.Errorf("parsing policy, protocol err: %w ", err) - } - - destPorts := []tailcfg.NetPortRange{} - for _, dest := range acl.Destinations { - alias, port, err := parseDestination(dest) - if err != nil { - return nil, err - } - - expanded, err := pol.ExpandAlias( - nodes, - users, - alias, - ) - if err != nil { - return nil, err - } - - ports, err := expandPorts(port, isWildcard) - if err != nil { - return nil, err - } - - var dests []tailcfg.NetPortRange - for _, dest := range expanded.Prefixes() { - for _, port := range *ports { - pr := tailcfg.NetPortRange{ - IP: dest.String(), - Ports: port, - } - dests = append(dests, pr) - } - } - destPorts = append(destPorts, dests...) - } - - rules = append(rules, tailcfg.FilterRule{ - SrcIPs: srcIPs, - DstPorts: destPorts, - IPProto: protocols, - }) - } - - return rules, nil -} - -func (pol *ACLPolicy) CompileSSHPolicy( - node *types.Node, - users []types.User, - peers types.Nodes, -) (*tailcfg.SSHPolicy, error) { - if pol == nil { - return nil, nil - } - - var rules []*tailcfg.SSHRule - - acceptAction := tailcfg.SSHAction{ - Message: "", - Reject: false, - Accept: true, - SessionDuration: 0, - AllowAgentForwarding: true, - HoldAndDelegate: "", - AllowLocalPortForwarding: true, - } - - rejectAction := tailcfg.SSHAction{ - Message: "", - Reject: true, - Accept: false, - SessionDuration: 0, - AllowAgentForwarding: false, - HoldAndDelegate: "", - AllowLocalPortForwarding: false, - } - - for index, sshACL := range pol.SSHs { - var dest netipx.IPSetBuilder - for _, src := range sshACL.Destinations { - expanded, err := pol.ExpandAlias(append(peers, node), users, src) - if err != nil { - return nil, err - } - dest.AddSet(expanded) - } - - destSet, err := dest.IPSet() - if err != nil { - return nil, err - } - - if !node.InIPSet(destSet) { - continue - } - - action := rejectAction - switch sshACL.Action { - case "accept": - action = acceptAction - case "check": - checkAction, err := sshCheckAction(sshACL.CheckPeriod) - if err != nil { - return nil, fmt.Errorf( - "parsing SSH policy, parsing check duration, index: %d: %w", - index, - err, - ) - } else { - action = *checkAction - } - default: - return nil, fmt.Errorf( - "parsing SSH policy, unknown action %q, index: %d: %w", - sshACL.Action, - index, - err, - ) - } - - var principals []*tailcfg.SSHPrincipal - for innerIndex, srcToken := range sshACL.Sources { - if isWildcard(srcToken) { - principals = []*tailcfg.SSHPrincipal{{ - Any: true, - }} - break - } - - // If the token is a group, expand the users and validate - // them. Then use the .Username() to get the login name - // that corresponds with the User info in the netmap. - if isGroup(srcToken) { - usersFromGroup, err := pol.expandUsersFromGroup(srcToken) - if err != nil { - return nil, fmt.Errorf("parsing SSH policy, expanding user from group, index: %d->%d: %w", index, innerIndex, err) - } - - for _, userStr := range usersFromGroup { - user, err := findUserFromToken(users, userStr) - if err != nil { - log.Trace().Err(err).Msg("user not found") - continue - } - - principals = append(principals, &tailcfg.SSHPrincipal{ - UserLogin: user.Username(), - }) - } - - continue - } - - // Try to check if the token is a user, if it is, then we - // can use the .Username() to get the login name that - // corresponds with the User info in the netmap. - // TODO(kradalby): This is a bit of a hack, and it should go - // away with the new policy where users can be reliably determined. - if user, err := findUserFromToken(users, srcToken); err == nil { - principals = append(principals, &tailcfg.SSHPrincipal{ - UserLogin: user.Username(), - }) - continue - } - - // This is kind of then non-ideal scenario where we dont really know - // what to do with the token, so we expand it to IP addresses of nodes. - // The pro here is that we have a pretty good lockdown on the mapping - // between users and node, but it can explode if a user owns many nodes. - ips, err := pol.ExpandAlias( - peers, - users, - srcToken, - ) - if err != nil { - return nil, fmt.Errorf("parsing SSH policy, expanding alias, index: %d->%d: %w", index, innerIndex, err) - } - for addr := range util.IPSetAddrIter(ips) { - principals = append(principals, &tailcfg.SSHPrincipal{ - NodeIP: addr.String(), - }) - } - } - - userMap := make(map[string]string, len(sshACL.Users)) - for _, user := range sshACL.Users { - userMap[user] = "=" - } - rules = append(rules, &tailcfg.SSHRule{ - Principals: principals, - SSHUsers: userMap, - Action: &action, - }) - } - - return &tailcfg.SSHPolicy{ - Rules: rules, - }, nil -} - -func sshCheckAction(duration string) (*tailcfg.SSHAction, error) { - sessionLength, err := time.ParseDuration(duration) - if err != nil { - return nil, err - } - - return &tailcfg.SSHAction{ - Message: "", - Reject: false, - Accept: true, - SessionDuration: sessionLength, - AllowAgentForwarding: true, - HoldAndDelegate: "", - AllowLocalPortForwarding: true, - }, nil -} - -func parseDestination(dest string) (string, string, error) { - var tokens []string - - // Check if there is a IPv4/6:Port combination, IPv6 has more than - // three ":". - tokens = strings.Split(dest, ":") - if len(tokens) < expectedTokenItems || len(tokens) > 3 { - port := tokens[len(tokens)-1] - - maybeIPv6Str := strings.TrimSuffix(dest, ":"+port) - log.Trace().Str("maybeIPv6Str", maybeIPv6Str).Msg("") - - filteredMaybeIPv6Str := maybeIPv6Str - if strings.Contains(maybeIPv6Str, "/") { - networkParts := strings.Split(maybeIPv6Str, "/") - filteredMaybeIPv6Str = networkParts[0] - } - - if maybeIPv6, err := netip.ParseAddr(filteredMaybeIPv6Str); err != nil && !maybeIPv6.Is6() { - log.Trace().Err(err).Msg("trying to parse as IPv6") - - return "", "", fmt.Errorf( - "failed to parse destination, tokens %v: %w", - tokens, - ErrInvalidPortFormat, - ) - } else { - tokens = []string{maybeIPv6Str, port} - } - } - - var alias string - // We can have here stuff like: - // git-server:* - // 192.168.1.0/24:22 - // fd7a:115c:a1e0::2:22 - // fd7a:115c:a1e0::2/128:22 - // tag:montreal-webserver:80,443 - // tag:api-server:443 - // example-host-1:* - if len(tokens) == expectedTokenItems { - alias = tokens[0] - } else { - alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1]) - } - - return alias, tokens[len(tokens)-1], nil -} - -// parseProtocol reads the proto field of the ACL and generates a list of -// protocols that will be allowed, following the IANA IP protocol number -// https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml -// -// If the ACL proto field is empty, it allows ICMPv4, ICMPv6, TCP, and UDP, -// as per Tailscale behaviour (see tailcfg.FilterRule). -// -// Also returns a boolean indicating if the protocol -// requires all the destinations to use wildcard as port number (only TCP, -// UDP and SCTP support specifying ports). -func parseProtocol(protocol string) ([]int, bool, error) { - switch protocol { - case "": - return nil, false, nil - case "igmp": - return []int{protocolIGMP}, true, nil - case "ipv4", "ip-in-ip": - return []int{protocolIPv4}, true, nil - case "tcp": - return []int{protocolTCP}, false, nil - case "egp": - return []int{protocolEGP}, true, nil - case "igp": - return []int{protocolIGP}, true, nil - case "udp": - return []int{protocolUDP}, false, nil - case "gre": - return []int{protocolGRE}, true, nil - case "esp": - return []int{protocolESP}, true, nil - case "ah": - return []int{protocolAH}, true, nil - case "sctp": - return []int{protocolSCTP}, false, nil - case "icmp": - return []int{protocolICMP, protocolIPv6ICMP}, true, nil - - default: - protocolNumber, err := strconv.Atoi(protocol) - if err != nil { - return nil, false, fmt.Errorf("parsing protocol number: %w", err) - } - needsWildcard := protocolNumber != protocolTCP && - protocolNumber != protocolUDP && - protocolNumber != protocolSCTP - - return []int{protocolNumber}, needsWildcard, nil - } -} - -// expandSource returns a set of Source IPs that would be associated -// with the given src alias. -func (pol *ACLPolicy) expandSource( - src string, - users []types.User, - nodes types.Nodes, -) ([]string, error) { - ipSet, err := pol.ExpandAlias(nodes, users, src) - if err != nil { - return []string{}, err - } - - var prefixes []string - for _, prefix := range ipSet.Prefixes() { - prefixes = append(prefixes, prefix.String()) - } - - return prefixes, nil -} - -// expandalias has an input of either -// - a user -// - a group -// - a tag -// - a host -// - an ip -// - a cidr -// - an autogroup -// and transform these in IPAddresses. -func (pol *ACLPolicy) ExpandAlias( - nodes types.Nodes, - users []types.User, - alias string, -) (*netipx.IPSet, error) { - if isWildcard(alias) { - return util.ParseIPSet("*", nil) - } - - build := netipx.IPSetBuilder{} - - log.Debug(). - Str("alias", alias). - Msg("Expanding") - - // if alias is a group - if isGroup(alias) { - return pol.expandIPsFromGroup(alias, users, nodes) - } - - // if alias is a tag - if isTag(alias) { - return pol.expandIPsFromTag(alias, users, nodes) - } - - if isAutoGroup(alias) { - return expandAutoGroup(alias) - } - - // if alias is a user - if ips, err := pol.expandIPsFromUser(alias, users, nodes); ips != nil { - return ips, err - } - - // if alias is an host - // Note, this is recursive. - if h, ok := pol.Hosts[alias]; ok { - log.Trace().Str("host", h.String()).Msg("ExpandAlias got hosts entry") - - return pol.ExpandAlias(nodes, users, h.String()) - } - - // if alias is an IP - if ip, err := netip.ParseAddr(alias); err == nil { - return pol.expandIPsFromSingleIP(ip, nodes) - } - - // if alias is an IP Prefix (CIDR) - if prefix, err := netip.ParsePrefix(alias); err == nil { - return pol.expandIPsFromIPPrefix(prefix, nodes) - } - - log.Warn().Msgf("No IPs found with the alias %v", alias) - - return build.IPSet() -} - -// excludeCorrectlyTaggedNodes will remove from the list of input nodes the ones -// that are correctly tagged since they should not be listed as being in the user -// we assume in this function that we only have nodes from 1 user. -// -// TODO(kradalby): It is quite hard to understand what this function is doing, -// it seems like it trying to ensure that we dont include nodes that are tagged -// when we look up the nodes owned by a user. -// This should be refactored to be more clear as part of the Tags work in #1369. -func excludeCorrectlyTaggedNodes( - aclPolicy *ACLPolicy, - nodes types.Nodes, - user string, -) types.Nodes { - var out types.Nodes - var tags []string - for tag := range aclPolicy.TagOwners { - owners, _ := expandOwnersFromTag(aclPolicy, user) - ns := append(owners, user) - if slices.Contains(ns, user) { - tags = append(tags, tag) - } - } - // for each node if tag is in tags list, don't append it. - for _, node := range nodes { - found := false - - if node.Hostinfo != nil { - for _, t := range node.Hostinfo.RequestTags { - if slices.Contains(tags, t) { - found = true - - break - } - } - } - - if len(node.ForcedTags) > 0 { - found = true - } - if !found { - out = append(out, node) - } - } - - return out -} - -func expandPorts(portsStr string, isWild bool) (*[]tailcfg.PortRange, error) { - if isWildcard(portsStr) { - return &[]tailcfg.PortRange{ - {First: portRangeBegin, Last: portRangeEnd}, - }, nil - } - - if isWild { - return nil, ErrWildcardIsNeeded - } - - var ports []tailcfg.PortRange - for _, portStr := range strings.Split(portsStr, ",") { - log.Trace().Msgf("parsing portstring: %s", portStr) - rang := strings.Split(portStr, "-") - switch len(rang) { - case 1: - port, err := strconv.ParseUint(rang[0], util.Base10, util.BitSize16) - if err != nil { - return nil, err - } - ports = append(ports, tailcfg.PortRange{ - First: uint16(port), - Last: uint16(port), - }) - - case expectedTokenItems: - start, err := strconv.ParseUint(rang[0], util.Base10, util.BitSize16) - if err != nil { - return nil, err - } - last, err := strconv.ParseUint(rang[1], util.Base10, util.BitSize16) - if err != nil { - return nil, err - } - ports = append(ports, tailcfg.PortRange{ - First: uint16(start), - Last: uint16(last), - }) - - default: - return nil, ErrInvalidPortFormat - } - } - - return &ports, nil -} - -// expandOwnersFromTag will return a list of user. An owner can be either a user or a group -// a group cannot be composed of groups. -func expandOwnersFromTag( - pol *ACLPolicy, - tag string, -) ([]string, error) { - noTagErr := fmt.Errorf( - "%w. %v isn't owned by a TagOwner. Please add one first. https://tailscale.com/kb/1018/acls/#tag-owners", - ErrInvalidTag, - tag, - ) - if pol == nil { - return []string{}, noTagErr - } - var owners []string - ows, ok := pol.TagOwners[tag] - if !ok { - return []string{}, noTagErr - } - for _, owner := range ows { - if isGroup(owner) { - gs, err := pol.expandUsersFromGroup(owner) - if err != nil { - return []string{}, err - } - owners = append(owners, gs...) - } else { - owners = append(owners, owner) - } - } - - return owners, nil -} - -// expandUsersFromGroup will return the list of user inside the group -// after some validation. -func (pol *ACLPolicy) expandUsersFromGroup( - group string, -) ([]string, error) { - var users []string - log.Trace().Caller().Interface("pol", pol).Msg("test") - aclGroups, ok := pol.Groups[group] - if !ok { - return []string{}, fmt.Errorf( - "group %v isn't registered. %w", - group, - ErrInvalidGroup, - ) - } - for _, group := range aclGroups { - if isGroup(group) { - return []string{}, fmt.Errorf( - "%w. A group cannot be composed of groups. https://tailscale.com/kb/1018/acls/#groups", - ErrInvalidGroup, - ) - } - users = append(users, group) - } - - return users, nil -} - -func (pol *ACLPolicy) expandIPsFromGroup( - group string, - users []types.User, - nodes types.Nodes, -) (*netipx.IPSet, error) { - var build netipx.IPSetBuilder - - userTokens, err := pol.expandUsersFromGroup(group) - if err != nil { - return &netipx.IPSet{}, err - } - for _, user := range userTokens { - filteredNodes := filterNodesByUser(nodes, users, user) - for _, node := range filteredNodes { - node.AppendToIPSet(&build) - } - } - - return build.IPSet() -} - -func (pol *ACLPolicy) expandIPsFromTag( - alias string, - users []types.User, - nodes types.Nodes, -) (*netipx.IPSet, error) { - var build netipx.IPSetBuilder - - // check for forced tags - for _, node := range nodes { - if slices.Contains(node.ForcedTags, alias) { - node.AppendToIPSet(&build) - } - } - - // find tag owners - owners, err := expandOwnersFromTag(pol, alias) - if err != nil { - if errors.Is(err, ErrInvalidTag) { - ipSet, _ := build.IPSet() - if len(ipSet.Prefixes()) == 0 { - return ipSet, fmt.Errorf( - "%w. %v isn't owned by a TagOwner and no forced tags are defined", - ErrInvalidTag, - alias, - ) - } - - return build.IPSet() - } else { - return nil, err - } - } - - // filter out nodes per tag owner - for _, user := range owners { - nodes := filterNodesByUser(nodes, users, user) - for _, node := range nodes { - if node.Hostinfo == nil { - continue - } - - if slices.Contains(node.Hostinfo.RequestTags, alias) { - node.AppendToIPSet(&build) - } - } - } - - return build.IPSet() -} - -func (pol *ACLPolicy) expandIPsFromUser( - user string, - users []types.User, - nodes types.Nodes, -) (*netipx.IPSet, error) { - var build netipx.IPSetBuilder - - filteredNodes := filterNodesByUser(nodes, users, user) - filteredNodes = excludeCorrectlyTaggedNodes(pol, filteredNodes, user) - - // shortcurcuit if we have no nodes to get ips from. - if len(filteredNodes) == 0 { - return nil, nil // nolint - } - - for _, node := range filteredNodes { - node.AppendToIPSet(&build) - } - - return build.IPSet() -} - -func (pol *ACLPolicy) expandIPsFromSingleIP( - ip netip.Addr, - nodes types.Nodes, -) (*netipx.IPSet, error) { - log.Trace().Str("ip", ip.String()).Msg("ExpandAlias got ip") - - matches := nodes.FilterByIP(ip) - - var build netipx.IPSetBuilder - build.Add(ip) - - for _, node := range matches { - node.AppendToIPSet(&build) - } - - return build.IPSet() -} - -func (pol *ACLPolicy) expandIPsFromIPPrefix( - prefix netip.Prefix, - nodes types.Nodes, -) (*netipx.IPSet, error) { - log.Trace().Str("prefix", prefix.String()).Msg("expandAlias got prefix") - var build netipx.IPSetBuilder - build.AddPrefix(prefix) - - // This is suboptimal and quite expensive, but if we only add the prefix, we will miss all the relevant IPv6 - // addresses for the hosts that belong to tailscale. This doesn't really affect stuff like subnet routers. - for _, node := range nodes { - for _, ip := range node.IPs() { - // log.Trace(). - // Msgf("checking if node ip (%s) is part of prefix (%s): %v, is single ip prefix (%v), addr: %s", ip.String(), prefix.String(), prefix.Contains(ip), prefix.IsSingleIP(), prefix.Addr().String()) - if prefix.Contains(ip) { - node.AppendToIPSet(&build) - } - } - } - - return build.IPSet() -} - -func expandAutoGroup(alias string) (*netipx.IPSet, error) { - switch { - case strings.HasPrefix(alias, "autogroup:internet"): - return util.TheInternet(), nil - - default: - return nil, fmt.Errorf("unknown autogroup %q", alias) - } -} - -func isWildcard(str string) bool { - return str == "*" -} - -func isGroup(str string) bool { - return strings.HasPrefix(str, "group:") -} - -func isTag(str string) bool { - return strings.HasPrefix(str, "tag:") -} - -func isAutoGroup(str string) bool { - return strings.HasPrefix(str, "autogroup:") -} - -// TagsOfNode will return the tags of the current node. -// Invalid tags are tags added by a user on a node, and that user doesn't have authority to add this tag. -// Valid tags are tags added by a user that is allowed in the ACL policy to add this tag. -func (pol *ACLPolicy) TagsOfNode( - users []types.User, - node *types.Node, -) ([]string, []string) { - var validTags []string - var invalidTags []string - - // TODO(kradalby): Why is this sometimes nil? coming from tailNode? - if node == nil { - return validTags, invalidTags - } - - validTagMap := make(map[string]bool) - invalidTagMap := make(map[string]bool) - if node.Hostinfo != nil { - for _, tag := range node.Hostinfo.RequestTags { - owners, err := expandOwnersFromTag(pol, tag) - if errors.Is(err, ErrInvalidTag) { - invalidTagMap[tag] = true - - continue - } - var found bool - for _, owner := range owners { - user, err := findUserFromToken(users, owner) - if err != nil { - log.Trace().Caller().Err(err).Msg("could not determine user to filter tags by") - } - - if node.User.ID == user.ID { - found = true - } - } - if found { - validTagMap[tag] = true - } else { - invalidTagMap[tag] = true - } - } - for tag := range invalidTagMap { - invalidTags = append(invalidTags, tag) - } - for tag := range validTagMap { - validTags = append(validTags, tag) - } - } - - return validTags, invalidTags -} - -// filterNodesByUser returns a list of nodes that match the given userToken from a -// policy. -// Matching nodes are determined by first matching the user token to a user by checking: -// - If it is an ID that mactches the user database ID -// - It is the Provider Identifier from OIDC -// - It matches the username or email of a user -// -// If the token matches more than one user, zero nodes will returned. -func filterNodesByUser(nodes types.Nodes, users []types.User, userToken string) types.Nodes { - var out types.Nodes - - user, err := findUserFromToken(users, userToken) - if err != nil { - log.Trace().Caller().Err(err).Msg("could not determine user to filter nodes by") - return out - } - - for _, node := range nodes { - if node.User.ID == user.ID { - out = append(out, node) - } - } - - return out -} - -var ( - ErrorNoUserMatching = errors.New("no user matching") - ErrorMultipleUserMatching = errors.New("multiple users matching") -) - -// findUserFromToken finds and returns a user based on the given token, prioritizing matches by ProviderIdentifier, followed by email or name. -// If no matching user is found, it returns an error of type ErrorNoUserMatching. -// If multiple users match the token, it returns an error indicating multiple matches. -func findUserFromToken(users []types.User, token string) (types.User, error) { - var potentialUsers []types.User - - // This adds the v2 support to looking up users with the new required - // policyv2 format where usernames have @ at the end if they are not emails. - token = strings.TrimSuffix(token, "@") - - for _, user := range users { - if user.ProviderIdentifier.Valid && user.ProviderIdentifier.String == token { - // Prioritize ProviderIdentifier match and exit early - return user, nil - } - - if user.Email == token || user.Name == token { - potentialUsers = append(potentialUsers, user) - } - } - - if len(potentialUsers) == 0 { - return types.User{}, fmt.Errorf("user with token %q not found: %w", token, ErrorNoUserMatching) - } - - if len(potentialUsers) > 1 { - return types.User{}, fmt.Errorf("multiple users with token %q found: %w", token, ErrorNoUserMatching) - } - - return potentialUsers[0], nil -} diff --git a/hscontrol/policy/v1/acls_test.go b/hscontrol/policy/v1/acls_test.go deleted file mode 100644 index f2871064..00000000 --- a/hscontrol/policy/v1/acls_test.go +++ /dev/null @@ -1,2797 +0,0 @@ -package v1 - -import ( - "database/sql" - "errors" - "math/rand/v2" - "net/netip" - "slices" - "sort" - "testing" - - "github.com/google/go-cmp/cmp" - "github.com/juanfont/headscale/hscontrol/types" - "github.com/juanfont/headscale/hscontrol/util" - "github.com/rs/zerolog/log" - "github.com/stretchr/testify/require" - "go4.org/netipx" - "gopkg.in/check.v1" - "gorm.io/gorm" - "tailscale.com/tailcfg" -) - -var iap = func(ipStr string) *netip.Addr { - ip := netip.MustParseAddr(ipStr) - return &ip -} - -func Test(t *testing.T) { - check.TestingT(t) -} - -var _ = check.Suite(&Suite{}) - -type Suite struct{} - -func (s *Suite) TestWrongPath(c *check.C) { - _, err := LoadACLPolicyFromPath("asdfg") - c.Assert(err, check.NotNil) -} - -func TestParsing(t *testing.T) { - tests := []struct { - name string - format string - acl string - want []tailcfg.FilterRule - wantErr bool - }{ - { - name: "invalid-hujson", - format: "hujson", - acl: ` -{ - `, - want: []tailcfg.FilterRule{}, - wantErr: true, - }, - { - name: "valid-hujson-invalid-content", - format: "hujson", - acl: ` -{ - "valid_json": true, - "but_a_policy_though": false -} - `, - want: []tailcfg.FilterRule{}, - wantErr: true, - }, - { - name: "invalid-cidr", - format: "hujson", - acl: ` -{"example-host-1": "100.100.100.100/42"} - `, - want: []tailcfg.FilterRule{}, - wantErr: true, - }, - { - name: "basic-rule", - format: "hujson", - acl: ` -{ - "hosts": { - "host-1": "100.100.100.100", - "subnet-1": "100.100.101.100/24", - }, - - "acls": [ - { - "action": "accept", - "src": [ - "subnet-1", - "192.168.1.0/24" - ], - "dst": [ - "*:22,3389", - "host-1:*", - ], - }, - ], -} - `, - want: []tailcfg.FilterRule{ - { - SrcIPs: []string{"100.100.101.0/24", "192.168.1.0/24"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "0.0.0.0/0", Ports: tailcfg.PortRange{First: 22, Last: 22}}, - {IP: "0.0.0.0/0", Ports: tailcfg.PortRange{First: 3389, Last: 3389}}, - {IP: "::/0", Ports: tailcfg.PortRange{First: 22, Last: 22}}, - {IP: "::/0", Ports: tailcfg.PortRange{First: 3389, Last: 3389}}, - {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, - }, - }, - }, - wantErr: false, - }, - { - name: "parse-protocol", - format: "hujson", - acl: ` -{ - "hosts": { - "host-1": "100.100.100.100", - "subnet-1": "100.100.101.100/24", - }, - - "acls": [ - { - "Action": "accept", - "src": [ - "*", - ], - "proto": "tcp", - "dst": [ - "host-1:*", - ], - }, - { - "Action": "accept", - "src": [ - "*", - ], - "proto": "udp", - "dst": [ - "host-1:53", - ], - }, - { - "Action": "accept", - "src": [ - "*", - ], - "proto": "icmp", - "dst": [ - "host-1:*", - ], - }, - ], -}`, - want: []tailcfg.FilterRule{ - { - SrcIPs: []string{"0.0.0.0/0", "::/0"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, - }, - IPProto: []int{protocolTCP}, - }, - { - SrcIPs: []string{"0.0.0.0/0", "::/0"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "100.100.100.100/32", Ports: tailcfg.PortRange{First: 53, Last: 53}}, - }, - IPProto: []int{protocolUDP}, - }, - { - SrcIPs: []string{"0.0.0.0/0", "::/0"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, - }, - IPProto: []int{protocolICMP, protocolIPv6ICMP}, - }, - }, - wantErr: false, - }, - { - name: "port-wildcard", - format: "hujson", - acl: ` -{ - "hosts": { - "host-1": "100.100.100.100", - "subnet-1": "100.100.101.100/24", - }, - - "acls": [ - { - "Action": "accept", - "src": [ - "*", - ], - "dst": [ - "host-1:*", - ], - }, - ], -} -`, - want: []tailcfg.FilterRule{ - { - SrcIPs: []string{"0.0.0.0/0", "::/0"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, - }, - }, - }, - wantErr: false, - }, - { - name: "port-range", - format: "hujson", - acl: ` -{ - "hosts": { - "host-1": "100.100.100.100", - "subnet-1": "100.100.101.100/24", - }, - - "acls": [ - { - "action": "accept", - "src": [ - "subnet-1", - ], - "dst": [ - "host-1:5400-5500", - ], - }, - ], -} -`, - want: []tailcfg.FilterRule{ - { - SrcIPs: []string{"100.100.101.0/24"}, - DstPorts: []tailcfg.NetPortRange{ - { - IP: "100.100.100.100/32", - Ports: tailcfg.PortRange{First: 5400, Last: 5500}, - }, - }, - }, - }, - wantErr: false, - }, - { - name: "port-group", - format: "hujson", - acl: ` -{ - "groups": { - "group:example": [ - "testuser", - ], - }, - - "hosts": { - "host-1": "100.100.100.100", - "subnet-1": "100.100.101.100/24", - }, - - "acls": [ - { - "action": "accept", - "src": [ - "group:example", - ], - "dst": [ - "host-1:*", - ], - }, - ], -} -`, - want: []tailcfg.FilterRule{ - { - SrcIPs: []string{"200.200.200.200/32"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, - }, - }, - }, - wantErr: false, - }, - { - name: "port-user", - format: "hujson", - acl: ` -{ - "hosts": { - "host-1": "100.100.100.100", - "subnet-1": "100.100.101.100/24", - }, - - "acls": [ - { - "action": "accept", - "src": [ - "testuser", - ], - "dst": [ - "host-1:*", - ], - }, - ], -} -`, - want: []tailcfg.FilterRule{ - { - SrcIPs: []string{"200.200.200.200/32"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, - }, - }, - }, - wantErr: false, - }, - { - name: "ipv6", - format: "hujson", - acl: ` -{ - "hosts": { - "host-1": "100.100.100.100/32", - "subnet-1": "100.100.101.100/24", - }, - - "acls": [ - { - "action": "accept", - "src": [ - "*", - ], - "dst": [ - "host-1:*", - ], - }, - ], -} -`, - want: []tailcfg.FilterRule{ - { - SrcIPs: []string{"0.0.0.0/0", "::/0"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, - }, - }, - }, - wantErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - pol, err := LoadACLPolicyFromBytes([]byte(tt.acl)) - - if tt.wantErr && err == nil { - t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr) - - return - } else if !tt.wantErr && err != nil { - t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr) - - return - } - - if err != nil { - return - } - - user := types.User{ - Model: gorm.Model{ID: 1}, - Name: "testuser", - } - rules, err := pol.CompileFilterRules( - []types.User{ - user, - }, - types.Nodes{ - &types.Node{ - IPv4: iap("100.100.100.100"), - }, - &types.Node{ - IPv4: iap("200.200.200.200"), - User: user, - Hostinfo: &tailcfg.Hostinfo{}, - }, - }) - - if (err != nil) != tt.wantErr { - t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr) - - return - } - - if diff := cmp.Diff(tt.want, rules); diff != "" { - t.Errorf("parsing() unexpected result (-want +got):\n%s", diff) - } - }) - } -} - -func (s *Suite) TestRuleInvalidGeneration(c *check.C) { - acl := []byte(` -{ - // Declare static groups of users beyond those in the identity service. - "groups": { - "group:example": [ - "user1@example.com", - "user2@example.com", - ], - }, - // Declare hostname aliases to use in place of IP addresses or subnets. - "hosts": { - "example-host-1": "100.100.100.100", - "example-host-2": "100.100.101.100/24", - }, - // Define who is allowed to use which tags. - "tagOwners": { - // Everyone in the montreal-admins or global-admins group are - // allowed to tag servers as montreal-webserver. - "tag:montreal-webserver": [ - "group:montreal-admins", - "group:global-admins", - ], - // Only a few admins are allowed to create API servers. - "tag:api-server": [ - "group:global-admins", - "example-host-1", - ], - }, - // Access control lists. - "acls": [ - // Engineering users, plus the president, can access port 22 (ssh) - // and port 3389 (remote desktop protocol) on all servers, and all - // ports on git-server or ci-server. - { - "action": "accept", - "src": [ - "group:engineering", - "president@example.com" - ], - "dst": [ - "*:22,3389", - "git-server:*", - "ci-server:*" - ], - }, - // Allow engineer users to access any port on a device tagged with - // tag:production. - { - "action": "accept", - "src": [ - "group:engineers" - ], - "dst": [ - "tag:production:*" - ], - }, - // Allow servers in the my-subnet host and 192.168.1.0/24 to access hosts - // on both networks. - { - "action": "accept", - "src": [ - "my-subnet", - "192.168.1.0/24" - ], - "dst": [ - "my-subnet:*", - "192.168.1.0/24:*" - ], - }, - // Allow every user of your network to access anything on the network. - // Comment out this section if you want to define specific ACL - // restrictions above. - { - "action": "accept", - "src": [ - "*" - ], - "dst": [ - "*:*" - ], - }, - // All users in Montreal are allowed to access the Montreal web - // servers. - { - "action": "accept", - "src": [ - "group:montreal-users" - ], - "dst": [ - "tag:montreal-webserver:80,443" - ], - }, - // Montreal web servers are allowed to make outgoing connections to - // the API servers, but only on https port 443. - // In contrast, this doesn't grant API servers the right to initiate - // any connections. - { - "action": "accept", - "src": [ - "tag:montreal-webserver" - ], - "dst": [ - "tag:api-server:443" - ], - }, - ], - // Declare tests to check functionality of ACL rules - "tests": [ - { - "src": "user1@example.com", - "accept": [ - "example-host-1:22", - "example-host-2:80" - ], - "deny": [ - "example-host-2:100" - ], - }, - { - "src": "user2@example.com", - "accept": [ - "100.60.3.4:22" - ], - }, - ], -} - `) - pol, err := LoadACLPolicyFromBytes(acl) - c.Assert(pol.ACLs, check.HasLen, 6) - c.Assert(err, check.IsNil) - - rules, err := pol.CompileFilterRules([]types.User{}, types.Nodes{}) - c.Assert(err, check.NotNil) - c.Assert(rules, check.IsNil) -} - -// TODO(kradalby): Make tests values safe, independent and descriptive. -func (s *Suite) TestInvalidAction(c *check.C) { - pol := &ACLPolicy{ - ACLs: []ACL{ - { - Action: "invalidAction", - Sources: []string{"*"}, - Destinations: []string{"*:*"}, - }, - }, - } - _, _, err := GenerateFilterAndSSHRulesForTests( - pol, - &types.Node{}, - types.Nodes{}, - []types.User{}, - ) - c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true) -} - -func (s *Suite) TestInvalidGroupInGroup(c *check.C) { - // this ACL is wrong because the group in Sources sections doesn't exist - pol := &ACLPolicy{ - Groups: Groups{ - "group:test": []string{"foo"}, - "group:error": []string{"foo", "group:test"}, - }, - ACLs: []ACL{ - { - Action: "accept", - Sources: []string{"group:error"}, - Destinations: []string{"*:*"}, - }, - }, - } - _, _, err := GenerateFilterAndSSHRulesForTests( - pol, - &types.Node{}, - types.Nodes{}, - []types.User{}, - ) - c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true) -} - -func (s *Suite) TestInvalidTagOwners(c *check.C) { - // this ACL is wrong because no tagOwners own the requested tag for the server - pol := &ACLPolicy{ - ACLs: []ACL{ - { - Action: "accept", - Sources: []string{"tag:foo"}, - Destinations: []string{"*:*"}, - }, - }, - } - - _, _, err := GenerateFilterAndSSHRulesForTests( - pol, - &types.Node{}, - types.Nodes{}, - []types.User{}, - ) - c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true) -} - -func Test_expandGroup(t *testing.T) { - type field struct { - pol ACLPolicy - } - type args struct { - group string - stripEmail bool - } - tests := []struct { - name string - field field - args args - want []string - wantErr bool - }{ - { - name: "simple test", - field: field{ - pol: ACLPolicy{ - Groups: Groups{ - "group:test": []string{"user1", "user2", "user3"}, - "group:foo": []string{"user2", "user3"}, - }, - }, - }, - args: args{ - group: "group:test", - }, - want: []string{"user1", "user2", "user3"}, - wantErr: false, - }, - { - name: "InexistentGroup", - field: field{ - pol: ACLPolicy{ - Groups: Groups{ - "group:test": []string{"user1", "user2", "user3"}, - "group:foo": []string{"user2", "user3"}, - }, - }, - }, - args: args{ - group: "group:undefined", - }, - want: []string{}, - wantErr: true, - }, - { - name: "Expand emails in group", - field: field{ - pol: ACLPolicy{ - Groups: Groups{ - "group:admin": []string{ - "joe.bar@gmail.com", - "john.doe@yahoo.fr", - }, - }, - }, - }, - args: args{ - group: "group:admin", - }, - want: []string{"joe.bar@gmail.com", "john.doe@yahoo.fr"}, - wantErr: false, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - got, err := test.field.pol.expandUsersFromGroup( - test.args.group, - ) - - if (err != nil) != test.wantErr { - t.Errorf("expandGroup() error = %v, wantErr %v", err, test.wantErr) - - return - } - - if diff := cmp.Diff(test.want, got); diff != "" { - t.Errorf("expandGroup() unexpected result (-want +got):\n%s", diff) - } - }) - } -} - -func Test_expandTagOwners(t *testing.T) { - type args struct { - aclPolicy *ACLPolicy - tag string - } - tests := []struct { - name string - args args - want []string - wantErr bool - }{ - { - name: "simple tag expansion", - args: args{ - aclPolicy: &ACLPolicy{ - TagOwners: TagOwners{"tag:test": []string{"user1"}}, - }, - tag: "tag:test", - }, - want: []string{"user1"}, - wantErr: false, - }, - { - name: "expand with tag and group", - args: args{ - aclPolicy: &ACLPolicy{ - Groups: Groups{"group:foo": []string{"user1", "user2"}}, - TagOwners: TagOwners{"tag:test": []string{"group:foo"}}, - }, - tag: "tag:test", - }, - want: []string{"user1", "user2"}, - wantErr: false, - }, - { - name: "expand with user and group", - args: args{ - aclPolicy: &ACLPolicy{ - Groups: Groups{"group:foo": []string{"user1", "user2"}}, - TagOwners: TagOwners{"tag:test": []string{"group:foo", "user3"}}, - }, - tag: "tag:test", - }, - want: []string{"user1", "user2", "user3"}, - wantErr: false, - }, - { - name: "invalid tag", - args: args{ - aclPolicy: &ACLPolicy{ - TagOwners: TagOwners{"tag:foo": []string{"group:foo", "user1"}}, - }, - tag: "tag:test", - }, - want: []string{}, - wantErr: true, - }, - { - name: "invalid group", - args: args{ - aclPolicy: &ACLPolicy{ - Groups: Groups{"group:bar": []string{"user1", "user2"}}, - TagOwners: TagOwners{"tag:test": []string{"group:foo", "user2"}}, - }, - tag: "tag:test", - }, - want: []string{}, - wantErr: true, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - got, err := expandOwnersFromTag( - test.args.aclPolicy, - test.args.tag, - ) - if (err != nil) != test.wantErr { - t.Errorf("expandTagOwners() error = %v, wantErr %v", err, test.wantErr) - - return - } - if diff := cmp.Diff(test.want, got); diff != "" { - t.Errorf("expandTagOwners() = (-want +got):\n%s", diff) - } - }) - } -} - -func Test_expandPorts(t *testing.T) { - type args struct { - portsStr string - needsWildcard bool - } - tests := []struct { - name string - args args - want *[]tailcfg.PortRange - wantErr bool - }{ - { - name: "wildcard", - args: args{portsStr: "*", needsWildcard: true}, - want: &[]tailcfg.PortRange{ - {First: portRangeBegin, Last: portRangeEnd}, - }, - wantErr: false, - }, - { - name: "needs wildcard but does not require it", - args: args{portsStr: "*", needsWildcard: false}, - want: &[]tailcfg.PortRange{ - {First: portRangeBegin, Last: portRangeEnd}, - }, - wantErr: false, - }, - { - name: "needs wildcard but gets port", - args: args{portsStr: "80,443", needsWildcard: true}, - want: nil, - wantErr: true, - }, - { - name: "two Destinations", - args: args{portsStr: "80,443", needsWildcard: false}, - want: &[]tailcfg.PortRange{ - {First: 80, Last: 80}, - {First: 443, Last: 443}, - }, - wantErr: false, - }, - { - name: "a range and a port", - args: args{portsStr: "80-1024,443", needsWildcard: false}, - want: &[]tailcfg.PortRange{ - {First: 80, Last: 1024}, - {First: 443, Last: 443}, - }, - wantErr: false, - }, - { - name: "out of bounds", - args: args{portsStr: "854038", needsWildcard: false}, - want: nil, - wantErr: true, - }, - { - name: "wrong port", - args: args{portsStr: "85a38", needsWildcard: false}, - want: nil, - wantErr: true, - }, - { - name: "wrong port in first", - args: args{portsStr: "a-80", needsWildcard: false}, - want: nil, - wantErr: true, - }, - { - name: "wrong port in last", - args: args{portsStr: "80-85a38", needsWildcard: false}, - want: nil, - wantErr: true, - }, - { - name: "wrong port format", - args: args{portsStr: "80-85a38-3", needsWildcard: false}, - want: nil, - wantErr: true, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - got, err := expandPorts(test.args.portsStr, test.args.needsWildcard) - if (err != nil) != test.wantErr { - t.Errorf("expandPorts() error = %v, wantErr %v", err, test.wantErr) - - return - } - if diff := cmp.Diff(test.want, got); diff != "" { - t.Errorf("expandPorts() = (-want +got):\n%s", diff) - } - }) - } -} - -func Test_filterNodesByUser(t *testing.T) { - users := []types.User{ - {Model: gorm.Model{ID: 1}, Name: "marc"}, - {Model: gorm.Model{ID: 2}, Name: "joe", Email: "joe@headscale.net"}, - { - Model: gorm.Model{ID: 3}, - Name: "mikael", - Email: "mikael@headscale.net", - ProviderIdentifier: sql.NullString{String: "http://oidc.org/1234", Valid: true}, - }, - {Model: gorm.Model{ID: 4}, Name: "mikael2", Email: "mikael@headscale.net"}, - {Model: gorm.Model{ID: 5}, Name: "mikael", Email: "mikael2@headscale.net"}, - {Model: gorm.Model{ID: 6}, Name: "http://oidc.org/1234", Email: "mikael@headscale.net"}, - {Model: gorm.Model{ID: 7}, Name: "1"}, - {Model: gorm.Model{ID: 8}, Name: "alex", Email: "alex@headscale.net"}, - {Model: gorm.Model{ID: 9}, Name: "alex@headscale.net"}, - {Model: gorm.Model{ID: 10}, Email: "http://oidc.org/1234"}, - } - - type args struct { - nodes types.Nodes - user string - } - tests := []struct { - name string - args args - want types.Nodes - }{ - { - name: "1 node in user", - args: args{ - nodes: types.Nodes{ - &types.Node{User: users[1]}, - }, - user: "joe", - }, - want: types.Nodes{ - &types.Node{User: users[1]}, - }, - }, - { - name: "3 nodes, 2 in user", - args: args{ - nodes: types.Nodes{ - &types.Node{ID: 1, User: users[1]}, - &types.Node{ID: 2, User: users[0]}, - &types.Node{ID: 3, User: users[0]}, - }, - user: "marc", - }, - want: types.Nodes{ - &types.Node{ID: 2, User: users[0]}, - &types.Node{ID: 3, User: users[0]}, - }, - }, - { - name: "5 nodes, 0 in user", - args: args{ - nodes: types.Nodes{ - &types.Node{ID: 1, User: users[1]}, - &types.Node{ID: 2, User: users[0]}, - &types.Node{ID: 3, User: users[0]}, - &types.Node{ID: 4, User: users[0]}, - &types.Node{ID: 5, User: users[0]}, - }, - user: "mickael", - }, - want: nil, - }, - { - name: "match-by-provider-ident", - args: args{ - nodes: types.Nodes{ - &types.Node{ID: 1, User: users[1]}, - &types.Node{ID: 2, User: users[2]}, - }, - user: "http://oidc.org/1234", - }, - want: types.Nodes{ - &types.Node{ID: 2, User: users[2]}, - }, - }, - { - name: "match-by-email", - args: args{ - nodes: types.Nodes{ - &types.Node{ID: 1, User: users[1]}, - &types.Node{ID: 2, User: users[2]}, - &types.Node{ID: 8, User: users[7]}, - }, - user: "joe@headscale.net", - }, - want: types.Nodes{ - &types.Node{ID: 1, User: users[1]}, - }, - }, - { - name: "multi-match-is-zero", - args: args{ - nodes: types.Nodes{ - &types.Node{ID: 1, User: users[1]}, - &types.Node{ID: 2, User: users[2]}, - &types.Node{ID: 3, User: users[3]}, - }, - user: "mikael@headscale.net", - }, - want: nil, - }, - { - name: "multi-email-first-match-is-zero", - args: args{ - nodes: types.Nodes{ - // First match email, then provider id - &types.Node{ID: 3, User: users[3]}, - &types.Node{ID: 2, User: users[2]}, - }, - user: "mikael@headscale.net", - }, - want: nil, - }, - { - name: "multi-username-first-match-is-zero", - args: args{ - nodes: types.Nodes{ - // First match username, then provider id - &types.Node{ID: 4, User: users[3]}, - &types.Node{ID: 2, User: users[2]}, - }, - user: "mikael", - }, - want: nil, - }, - { - name: "all-users-duplicate-username-random-order", - args: args{ - nodes: types.Nodes{ - &types.Node{ID: 1, User: users[0]}, - &types.Node{ID: 2, User: users[1]}, - &types.Node{ID: 3, User: users[2]}, - &types.Node{ID: 4, User: users[3]}, - &types.Node{ID: 5, User: users[4]}, - }, - user: "mikael", - }, - want: nil, - }, - { - name: "all-users-unique-username-random-order", - args: args{ - nodes: types.Nodes{ - &types.Node{ID: 1, User: users[0]}, - &types.Node{ID: 2, User: users[1]}, - &types.Node{ID: 3, User: users[2]}, - &types.Node{ID: 4, User: users[3]}, - &types.Node{ID: 5, User: users[4]}, - }, - user: "marc", - }, - want: types.Nodes{ - &types.Node{ID: 1, User: users[0]}, - }, - }, - { - name: "all-users-no-username-random-order", - args: args{ - nodes: types.Nodes{ - &types.Node{ID: 1, User: users[0]}, - &types.Node{ID: 2, User: users[1]}, - &types.Node{ID: 3, User: users[2]}, - &types.Node{ID: 4, User: users[3]}, - &types.Node{ID: 5, User: users[4]}, - }, - user: "not-working", - }, - want: nil, - }, - { - name: "all-users-duplicate-email-random-order", - args: args{ - nodes: types.Nodes{ - &types.Node{ID: 1, User: users[0]}, - &types.Node{ID: 2, User: users[1]}, - &types.Node{ID: 3, User: users[2]}, - &types.Node{ID: 4, User: users[3]}, - &types.Node{ID: 5, User: users[4]}, - }, - user: "mikael@headscale.net", - }, - want: nil, - }, - { - name: "all-users-duplicate-email-random-order", - args: args{ - nodes: types.Nodes{ - &types.Node{ID: 1, User: users[0]}, - &types.Node{ID: 2, User: users[1]}, - &types.Node{ID: 3, User: users[2]}, - &types.Node{ID: 4, User: users[3]}, - &types.Node{ID: 5, User: users[4]}, - &types.Node{ID: 8, User: users[7]}, - }, - user: "joe@headscale.net", - }, - want: types.Nodes{ - &types.Node{ID: 2, User: users[1]}, - }, - }, - { - name: "email-as-username-duplicate", - args: args{ - nodes: types.Nodes{ - &types.Node{ID: 1, User: users[7]}, - &types.Node{ID: 2, User: users[8]}, - }, - user: "alex@headscale.net", - }, - want: nil, - }, - { - name: "all-users-no-email-random-order", - args: args{ - nodes: types.Nodes{ - &types.Node{ID: 1, User: users[0]}, - &types.Node{ID: 2, User: users[1]}, - &types.Node{ID: 3, User: users[2]}, - &types.Node{ID: 4, User: users[3]}, - &types.Node{ID: 5, User: users[4]}, - }, - user: "not-working@headscale.net", - }, - want: nil, - }, - { - name: "all-users-provider-id-random-order", - args: args{ - nodes: types.Nodes{ - &types.Node{ID: 1, User: users[0]}, - &types.Node{ID: 2, User: users[1]}, - &types.Node{ID: 3, User: users[2]}, - &types.Node{ID: 4, User: users[3]}, - &types.Node{ID: 5, User: users[4]}, - &types.Node{ID: 6, User: users[5]}, - }, - user: "http://oidc.org/1234", - }, - want: types.Nodes{ - &types.Node{ID: 3, User: users[2]}, - }, - }, - { - name: "all-users-no-provider-id-random-order", - args: args{ - nodes: types.Nodes{ - &types.Node{ID: 1, User: users[0]}, - &types.Node{ID: 2, User: users[1]}, - &types.Node{ID: 3, User: users[2]}, - &types.Node{ID: 4, User: users[3]}, - &types.Node{ID: 5, User: users[4]}, - &types.Node{ID: 6, User: users[5]}, - }, - user: "http://oidc.org/4321", - }, - want: nil, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - for range 1000 { - ns := test.args.nodes - rand.Shuffle(len(ns), func(i, j int) { - ns[i], ns[j] = ns[j], ns[i] - }) - us := users - rand.Shuffle(len(us), func(i, j int) { - us[i], us[j] = us[j], us[i] - }) - got := filterNodesByUser(ns, us, test.args.user) - sort.Slice(got, func(i, j int) bool { - return got[i].ID < got[j].ID - }) - - if diff := cmp.Diff(test.want, got, util.Comparers...); diff != "" { - t.Errorf("filterNodesByUser() = (-want +got):\n%s", diff) - } - } - }) - } -} - -func Test_expandAlias(t *testing.T) { - set := func(ips []string, prefixes []string) *netipx.IPSet { - var builder netipx.IPSetBuilder - - for _, ip := range ips { - builder.Add(netip.MustParseAddr(ip)) - } - - for _, pre := range prefixes { - builder.AddPrefix(netip.MustParsePrefix(pre)) - } - - s, _ := builder.IPSet() - - return s - } - - users := []types.User{ - {Model: gorm.Model{ID: 1}, Name: "joe"}, - {Model: gorm.Model{ID: 2}, Name: "marc"}, - {Model: gorm.Model{ID: 3}, Name: "mickael"}, - } - - type field struct { - pol ACLPolicy - } - type args struct { - nodes types.Nodes - aclPolicy ACLPolicy - alias string - } - tests := []struct { - name string - field field - args args - want *netipx.IPSet - wantErr bool - }{ - { - name: "wildcard", - field: field{ - pol: ACLPolicy{}, - }, - args: args{ - alias: "*", - nodes: types.Nodes{ - &types.Node{ - IPv4: iap("100.64.0.1"), - }, - &types.Node{ - IPv4: iap("100.78.84.227"), - }, - }, - }, - want: set([]string{}, []string{ - "0.0.0.0/0", - "::/0", - }), - wantErr: false, - }, - { - name: "simple group", - field: field{ - pol: ACLPolicy{ - Groups: Groups{"group:accountant": []string{"joe", "marc"}}, - }, - }, - args: args{ - alias: "group:accountant", - nodes: types.Nodes{ - &types.Node{ - IPv4: iap("100.64.0.1"), - User: users[0], - }, - &types.Node{ - IPv4: iap("100.64.0.2"), - User: users[0], - }, - &types.Node{ - IPv4: iap("100.64.0.3"), - User: users[1], - }, - &types.Node{ - IPv4: iap("100.64.0.4"), - User: users[2], - }, - }, - }, - want: set([]string{ - "100.64.0.1", "100.64.0.2", "100.64.0.3", - }, []string{}), - wantErr: false, - }, - { - name: "wrong group", - field: field{ - pol: ACLPolicy{ - Groups: Groups{"group:accountant": []string{"joe", "marc"}}, - }, - }, - args: args{ - alias: "group:hr", - nodes: types.Nodes{ - &types.Node{ - IPv4: iap("100.64.0.1"), - User: users[0], - }, - &types.Node{ - IPv4: iap("100.64.0.2"), - User: users[0], - }, - &types.Node{ - IPv4: iap("100.64.0.3"), - User: users[1], - }, - &types.Node{ - IPv4: iap("100.64.0.4"), - User: users[2], - }, - }, - }, - want: set([]string{}, []string{}), - wantErr: true, - }, - { - name: "simple ipaddress", - field: field{ - pol: ACLPolicy{}, - }, - args: args{ - alias: "10.0.0.3", - nodes: types.Nodes{}, - }, - want: set([]string{ - "10.0.0.3", - }, []string{}), - wantErr: false, - }, - { - name: "simple host by ip passed through", - field: field{ - pol: ACLPolicy{}, - }, - args: args{ - alias: "10.0.0.1", - nodes: types.Nodes{}, - }, - want: set([]string{ - "10.0.0.1", - }, []string{}), - wantErr: false, - }, - { - name: "simple host by ipv4 single ipv4", - field: field{ - pol: ACLPolicy{}, - }, - args: args{ - alias: "10.0.0.1", - nodes: types.Nodes{ - &types.Node{ - IPv4: iap("10.0.0.1"), - User: types.User{Name: "mickael"}, - }, - }, - }, - want: set([]string{ - "10.0.0.1", - }, []string{}), - wantErr: false, - }, - { - name: "simple host by ipv4 single dual stack", - field: field{ - pol: ACLPolicy{}, - }, - args: args{ - alias: "10.0.0.1", - nodes: types.Nodes{ - &types.Node{ - IPv4: iap("10.0.0.1"), - IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"), - User: types.User{Name: "mickael"}, - }, - }, - }, - want: set([]string{ - "10.0.0.1", "fd7a:115c:a1e0:ab12:4843:2222:6273:2222", - }, []string{}), - wantErr: false, - }, - { - name: "simple host by ipv6 single dual stack", - field: field{ - pol: ACLPolicy{}, - }, - args: args{ - alias: "fd7a:115c:a1e0:ab12:4843:2222:6273:2222", - nodes: types.Nodes{ - &types.Node{ - IPv4: iap("10.0.0.1"), - IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"), - User: types.User{Name: "mickael"}, - }, - }, - }, - want: set([]string{ - "fd7a:115c:a1e0:ab12:4843:2222:6273:2222", "10.0.0.1", - }, []string{}), - wantErr: false, - }, - { - name: "simple host by hostname alias", - field: field{ - pol: ACLPolicy{ - Hosts: Hosts{ - "testy": netip.MustParsePrefix("10.0.0.132/32"), - }, - }, - }, - args: args{ - alias: "testy", - nodes: types.Nodes{}, - }, - want: set([]string{}, []string{"10.0.0.132/32"}), - wantErr: false, - }, - { - name: "private network", - field: field{ - pol: ACLPolicy{ - Hosts: Hosts{ - "homeNetwork": netip.MustParsePrefix("192.168.1.0/24"), - }, - }, - }, - args: args{ - alias: "homeNetwork", - nodes: types.Nodes{}, - }, - want: set([]string{}, []string{"192.168.1.0/24"}), - wantErr: false, - }, - { - name: "simple CIDR", - field: field{ - pol: ACLPolicy{}, - }, - args: args{ - alias: "10.0.0.0/16", - nodes: types.Nodes{}, - aclPolicy: ACLPolicy{}, - }, - want: set([]string{}, []string{"10.0.0.0/16"}), - wantErr: false, - }, - { - name: "simple tag", - field: field{ - pol: ACLPolicy{ - TagOwners: TagOwners{"tag:hr-webserver": []string{"joe"}}, - }, - }, - args: args{ - alias: "tag:hr-webserver", - nodes: types.Nodes{ - &types.Node{ - IPv4: iap("100.64.0.1"), - User: users[0], - Hostinfo: &tailcfg.Hostinfo{ - OS: "centos", - Hostname: "foo", - RequestTags: []string{"tag:hr-webserver"}, - }, - }, - &types.Node{ - IPv4: iap("100.64.0.2"), - User: users[0], - Hostinfo: &tailcfg.Hostinfo{ - OS: "centos", - Hostname: "foo", - RequestTags: []string{"tag:hr-webserver"}, - }, - }, - &types.Node{ - IPv4: iap("100.64.0.3"), - User: users[1], - }, - &types.Node{ - IPv4: iap("100.64.0.4"), - User: users[0], - }, - }, - }, - want: set([]string{ - "100.64.0.1", "100.64.0.2", - }, []string{}), - wantErr: false, - }, - { - name: "No tag defined", - field: field{ - pol: ACLPolicy{ - Groups: Groups{"group:accountant": []string{"joe", "marc"}}, - TagOwners: TagOwners{ - "tag:accountant-webserver": []string{"group:accountant"}, - }, - }, - }, - args: args{ - alias: "tag:hr-webserver", - nodes: types.Nodes{ - &types.Node{ - IPv4: iap("100.64.0.1"), - User: types.User{Name: "joe"}, - }, - &types.Node{ - IPv4: iap("100.64.0.2"), - User: types.User{Name: "joe"}, - }, - &types.Node{ - IPv4: iap("100.64.0.3"), - User: types.User{Name: "marc"}, - }, - &types.Node{ - IPv4: iap("100.64.0.4"), - User: types.User{Name: "mickael"}, - }, - }, - }, - want: set([]string{}, []string{}), - wantErr: true, - }, - { - name: "Forced tag defined", - field: field{ - pol: ACLPolicy{}, - }, - args: args{ - alias: "tag:hr-webserver", - nodes: types.Nodes{ - &types.Node{ - IPv4: iap("100.64.0.1"), - User: users[0], - ForcedTags: []string{"tag:hr-webserver"}, - }, - &types.Node{ - IPv4: iap("100.64.0.2"), - User: users[0], - ForcedTags: []string{"tag:hr-webserver"}, - }, - &types.Node{ - IPv4: iap("100.64.0.3"), - User: users[1], - }, - &types.Node{ - IPv4: iap("100.64.0.4"), - User: users[2], - }, - }, - }, - want: set([]string{"100.64.0.1", "100.64.0.2"}, []string{}), - wantErr: false, - }, - { - name: "Forced tag with legitimate tagOwner", - field: field{ - pol: ACLPolicy{ - TagOwners: TagOwners{ - "tag:hr-webserver": []string{"joe"}, - }, - }, - }, - args: args{ - alias: "tag:hr-webserver", - nodes: types.Nodes{ - &types.Node{ - IPv4: iap("100.64.0.1"), - User: users[0], - ForcedTags: []string{"tag:hr-webserver"}, - }, - &types.Node{ - IPv4: iap("100.64.0.2"), - User: users[0], - Hostinfo: &tailcfg.Hostinfo{ - OS: "centos", - Hostname: "foo", - RequestTags: []string{"tag:hr-webserver"}, - }, - }, - &types.Node{ - IPv4: iap("100.64.0.3"), - User: users[1], - }, - &types.Node{ - IPv4: iap("100.64.0.4"), - User: users[2], - }, - }, - }, - want: set([]string{"100.64.0.1", "100.64.0.2"}, []string{}), - wantErr: false, - }, - { - name: "list host in user without correctly tagged servers", - field: field{ - pol: ACLPolicy{ - TagOwners: TagOwners{"tag:accountant-webserver": []string{"joe"}}, - }, - }, - args: args{ - alias: "joe", - nodes: types.Nodes{ - &types.Node{ - IPv4: iap("100.64.0.1"), - User: types.User{Name: "joe"}, - Hostinfo: &tailcfg.Hostinfo{ - OS: "centos", - Hostname: "foo", - RequestTags: []string{"tag:accountant-webserver"}, - }, - }, - &types.Node{ - IPv4: iap("100.64.0.2"), - User: types.User{Name: "joe"}, - Hostinfo: &tailcfg.Hostinfo{ - OS: "centos", - Hostname: "foo", - RequestTags: []string{"tag:accountant-webserver"}, - }, - }, - &types.Node{ - IPv4: iap("100.64.0.3"), - User: users[1], - Hostinfo: &tailcfg.Hostinfo{}, - }, - &types.Node{ - IPv4: iap("100.64.0.4"), - User: users[0], - Hostinfo: &tailcfg.Hostinfo{}, - }, - }, - }, - want: set([]string{"100.64.0.4"}, []string{}), - wantErr: false, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - got, err := test.field.pol.ExpandAlias( - test.args.nodes, - users, - test.args.alias, - ) - if (err != nil) != test.wantErr { - t.Errorf("expandAlias() error = %v, wantErr %v", err, test.wantErr) - - return - } - if diff := cmp.Diff(test.want, got); diff != "" { - t.Errorf("expandAlias() unexpected result (-want +got):\n%s", diff) - } - }) - } -} - -func Test_excludeCorrectlyTaggedNodes(t *testing.T) { - type args struct { - aclPolicy *ACLPolicy - nodes types.Nodes - user string - } - tests := []struct { - name string - args args - want types.Nodes - wantErr bool - }{ - { - name: "exclude nodes with valid tags", - args: args{ - aclPolicy: &ACLPolicy{ - TagOwners: TagOwners{"tag:accountant-webserver": []string{"joe"}}, - }, - nodes: types.Nodes{ - &types.Node{ - IPv4: iap("100.64.0.1"), - User: types.User{Name: "joe"}, - Hostinfo: &tailcfg.Hostinfo{ - OS: "centos", - Hostname: "foo", - RequestTags: []string{"tag:accountant-webserver"}, - }, - }, - &types.Node{ - IPv4: iap("100.64.0.2"), - User: types.User{Name: "joe"}, - Hostinfo: &tailcfg.Hostinfo{ - OS: "centos", - Hostname: "foo", - RequestTags: []string{"tag:accountant-webserver"}, - }, - }, - &types.Node{ - IPv4: iap("100.64.0.4"), - User: types.User{Name: "joe"}, - Hostinfo: &tailcfg.Hostinfo{}, - }, - }, - user: "joe", - }, - want: types.Nodes{ - &types.Node{ - IPv4: iap("100.64.0.4"), - User: types.User{Name: "joe"}, - Hostinfo: &tailcfg.Hostinfo{}, - }, - }, - }, - { - name: "exclude nodes with valid tags, and owner is in a group", - args: args{ - aclPolicy: &ACLPolicy{ - Groups: Groups{ - "group:accountant": []string{"joe", "bar"}, - }, - TagOwners: TagOwners{ - "tag:accountant-webserver": []string{"group:accountant"}, - }, - }, - nodes: types.Nodes{ - &types.Node{ - IPv4: iap("100.64.0.1"), - User: types.User{Name: "joe"}, - Hostinfo: &tailcfg.Hostinfo{ - OS: "centos", - Hostname: "foo", - RequestTags: []string{"tag:accountant-webserver"}, - }, - }, - &types.Node{ - IPv4: iap("100.64.0.2"), - User: types.User{Name: "joe"}, - Hostinfo: &tailcfg.Hostinfo{ - OS: "centos", - Hostname: "foo", - RequestTags: []string{"tag:accountant-webserver"}, - }, - }, - &types.Node{ - IPv4: iap("100.64.0.4"), - User: types.User{Name: "joe"}, - Hostinfo: &tailcfg.Hostinfo{}, - }, - }, - user: "joe", - }, - want: types.Nodes{ - &types.Node{ - IPv4: iap("100.64.0.4"), - User: types.User{Name: "joe"}, - Hostinfo: &tailcfg.Hostinfo{}, - }, - }, - }, - { - name: "exclude nodes with valid tags and with forced tags", - args: args{ - aclPolicy: &ACLPolicy{ - TagOwners: TagOwners{"tag:accountant-webserver": []string{"joe"}}, - }, - nodes: types.Nodes{ - &types.Node{ - IPv4: iap("100.64.0.1"), - User: types.User{Name: "joe"}, - Hostinfo: &tailcfg.Hostinfo{ - OS: "centos", - Hostname: "foo", - RequestTags: []string{"tag:accountant-webserver"}, - }, - }, - &types.Node{ - IPv4: iap("100.64.0.2"), - User: types.User{Name: "joe"}, - ForcedTags: []string{"tag:accountant-webserver"}, - Hostinfo: &tailcfg.Hostinfo{}, - }, - &types.Node{ - IPv4: iap("100.64.0.4"), - User: types.User{Name: "joe"}, - Hostinfo: &tailcfg.Hostinfo{}, - }, - }, - user: "joe", - }, - want: types.Nodes{ - &types.Node{ - IPv4: iap("100.64.0.4"), - User: types.User{Name: "joe"}, - Hostinfo: &tailcfg.Hostinfo{}, - }, - }, - }, - { - name: "all nodes have invalid tags, don't exclude them", - args: args{ - aclPolicy: &ACLPolicy{ - TagOwners: TagOwners{"tag:accountant-webserver": []string{"joe"}}, - }, - nodes: types.Nodes{ - &types.Node{ - IPv4: iap("100.64.0.1"), - User: types.User{Name: "joe"}, - Hostinfo: &tailcfg.Hostinfo{ - OS: "centos", - Hostname: "hr-web1", - RequestTags: []string{"tag:hr-webserver"}, - }, - }, - &types.Node{ - IPv4: iap("100.64.0.2"), - User: types.User{Name: "joe"}, - Hostinfo: &tailcfg.Hostinfo{ - OS: "centos", - Hostname: "hr-web2", - RequestTags: []string{"tag:hr-webserver"}, - }, - }, - &types.Node{ - IPv4: iap("100.64.0.4"), - User: types.User{Name: "joe"}, - Hostinfo: &tailcfg.Hostinfo{}, - }, - }, - user: "joe", - }, - want: types.Nodes{ - &types.Node{ - IPv4: iap("100.64.0.1"), - User: types.User{Name: "joe"}, - Hostinfo: &tailcfg.Hostinfo{ - OS: "centos", - Hostname: "hr-web1", - RequestTags: []string{"tag:hr-webserver"}, - }, - }, - &types.Node{ - IPv4: iap("100.64.0.2"), - User: types.User{Name: "joe"}, - Hostinfo: &tailcfg.Hostinfo{ - OS: "centos", - Hostname: "hr-web2", - RequestTags: []string{"tag:hr-webserver"}, - }, - }, - &types.Node{ - IPv4: iap("100.64.0.4"), - User: types.User{Name: "joe"}, - Hostinfo: &tailcfg.Hostinfo{}, - }, - }, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - got := excludeCorrectlyTaggedNodes( - test.args.aclPolicy, - test.args.nodes, - test.args.user, - ) - if diff := cmp.Diff(test.want, got, util.Comparers...); diff != "" { - t.Errorf("excludeCorrectlyTaggedNodes() (-want +got):\n%s", diff) - } - }) - } -} - -func TestACLPolicy_generateFilterRules(t *testing.T) { - type field struct { - pol ACLPolicy - } - type args struct { - nodes types.Nodes - } - tests := []struct { - name string - field field - args args - want []tailcfg.FilterRule - wantErr bool - }{ - { - name: "no-policy", - field: field{}, - args: args{}, - want: nil, - wantErr: false, - }, - { - name: "allow-all", - field: field{ - pol: ACLPolicy{ - ACLs: []ACL{ - { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"*:*"}, - }, - }, - }, - }, - args: args{ - nodes: types.Nodes{ - &types.Node{ - IPv4: iap("100.64.0.1"), - IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), - }, - }, - }, - want: []tailcfg.FilterRule{ - { - SrcIPs: []string{"0.0.0.0/0", "::/0"}, - DstPorts: []tailcfg.NetPortRange{ - { - IP: "0.0.0.0/0", - Ports: tailcfg.PortRange{ - First: 0, - Last: 65535, - }, - }, - { - IP: "::/0", - Ports: tailcfg.PortRange{ - First: 0, - Last: 65535, - }, - }, - }, - }, - }, - wantErr: false, - }, - { - name: "host1-can-reach-host2-full", - field: field{ - pol: ACLPolicy{ - ACLs: []ACL{ - { - Action: "accept", - Sources: []string{"100.64.0.2"}, - Destinations: []string{"100.64.0.1:*"}, - }, - }, - }, - }, - args: args{ - nodes: types.Nodes{ - &types.Node{ - IPv4: iap("100.64.0.1"), - IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), - User: types.User{Name: "mickael"}, - }, - &types.Node{ - IPv4: iap("100.64.0.2"), - IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"), - User: types.User{Name: "mickael"}, - }, - }, - }, - want: []tailcfg.FilterRule{ - { - SrcIPs: []string{ - "100.64.0.2/32", - "fd7a:115c:a1e0:ab12:4843:2222:6273:2222/128", - }, - DstPorts: []tailcfg.NetPortRange{ - { - IP: "100.64.0.1/32", - Ports: tailcfg.PortRange{ - First: 0, - Last: 65535, - }, - }, - { - IP: "fd7a:115c:a1e0:ab12:4843:2222:6273:2221/128", - Ports: tailcfg.PortRange{ - First: 0, - Last: 65535, - }, - }, - }, - }, - }, - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := tt.field.pol.CompileFilterRules( - []types.User{}, - tt.args.nodes, - ) - if (err != nil) != tt.wantErr { - t.Errorf("ACLgenerateFilterRules() error = %v, wantErr %v", err, tt.wantErr) - - return - } - - if diff := cmp.Diff(tt.want, got); diff != "" { - log.Trace().Interface("got", got).Msg("result") - t.Errorf("ACLgenerateFilterRules() unexpected result (-want +got):\n%s", diff) - } - }) - } -} - -// tsExitNodeDest is the list of destination IP ranges that are allowed when -// you dump the filter list from a Tailscale node connected to Tailscale SaaS. -var tsExitNodeDest = []tailcfg.NetPortRange{ - { - IP: "0.0.0.0-9.255.255.255", - Ports: tailcfg.PortRangeAny, - }, - { - IP: "11.0.0.0-100.63.255.255", - Ports: tailcfg.PortRangeAny, - }, - { - IP: "100.128.0.0-169.253.255.255", - Ports: tailcfg.PortRangeAny, - }, - { - IP: "169.255.0.0-172.15.255.255", - Ports: tailcfg.PortRangeAny, - }, - { - IP: "172.32.0.0-192.167.255.255", - Ports: tailcfg.PortRangeAny, - }, - { - IP: "192.169.0.0-255.255.255.255", - Ports: tailcfg.PortRangeAny, - }, - { - IP: "2000::-3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff", - Ports: tailcfg.PortRangeAny, - }, -} - -func Test_getTags(t *testing.T) { - users := []types.User{ - { - Model: gorm.Model{ID: 1}, - Name: "joe", - }, - } - type args struct { - aclPolicy *ACLPolicy - node *types.Node - } - tests := []struct { - name string - args args - wantInvalid []string - wantValid []string - }{ - { - name: "valid tag one nodes", - args: args{ - aclPolicy: &ACLPolicy{ - TagOwners: TagOwners{ - "tag:valid": []string{"joe"}, - }, - }, - node: &types.Node{ - User: users[0], - Hostinfo: &tailcfg.Hostinfo{ - RequestTags: []string{"tag:valid"}, - }, - }, - }, - wantValid: []string{"tag:valid"}, - wantInvalid: nil, - }, - { - name: "invalid tag and valid tag one nodes", - args: args{ - aclPolicy: &ACLPolicy{ - TagOwners: TagOwners{ - "tag:valid": []string{"joe"}, - }, - }, - node: &types.Node{ - User: users[0], - Hostinfo: &tailcfg.Hostinfo{ - RequestTags: []string{"tag:valid", "tag:invalid"}, - }, - }, - }, - wantValid: []string{"tag:valid"}, - wantInvalid: []string{"tag:invalid"}, - }, - { - name: "multiple invalid and identical tags, should return only one invalid tag", - args: args{ - aclPolicy: &ACLPolicy{ - TagOwners: TagOwners{ - "tag:valid": []string{"joe"}, - }, - }, - node: &types.Node{ - User: users[0], - Hostinfo: &tailcfg.Hostinfo{ - RequestTags: []string{ - "tag:invalid", - "tag:valid", - "tag:invalid", - }, - }, - }, - }, - wantValid: []string{"tag:valid"}, - wantInvalid: []string{"tag:invalid"}, - }, - { - name: "only invalid tags", - args: args{ - aclPolicy: &ACLPolicy{ - TagOwners: TagOwners{ - "tag:valid": []string{"joe"}, - }, - }, - node: &types.Node{ - User: users[0], - Hostinfo: &tailcfg.Hostinfo{ - RequestTags: []string{"tag:invalid", "very-invalid"}, - }, - }, - }, - wantValid: nil, - wantInvalid: []string{"tag:invalid", "very-invalid"}, - }, - { - name: "empty ACLPolicy should return empty tags and should not panic", - args: args{ - aclPolicy: &ACLPolicy{}, - node: &types.Node{ - User: users[0], - Hostinfo: &tailcfg.Hostinfo{ - RequestTags: []string{"tag:invalid", "very-invalid"}, - }, - }, - }, - wantValid: nil, - wantInvalid: []string{"tag:invalid", "very-invalid"}, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - gotValid, gotInvalid := test.args.aclPolicy.TagsOfNode( - users, - test.args.node, - ) - for _, valid := range gotValid { - if !slices.Contains(test.wantValid, valid) { - t.Errorf( - "valids: getTags() = %v, want %v", - gotValid, - test.wantValid, - ) - - break - } - } - for _, invalid := range gotInvalid { - if !slices.Contains(test.wantInvalid, invalid) { - t.Errorf( - "invalids: getTags() = %v, want %v", - gotInvalid, - test.wantInvalid, - ) - - break - } - } - }) - } -} - -func TestParseDestination(t *testing.T) { - tests := []struct { - dest string - wantAlias string - wantPort string - }{ - { - dest: "git-server:*", - wantAlias: "git-server", - wantPort: "*", - }, - { - dest: "192.168.1.0/24:22", - wantAlias: "192.168.1.0/24", - wantPort: "22", - }, - { - dest: "192.168.1.1:22", - wantAlias: "192.168.1.1", - wantPort: "22", - }, - { - dest: "fd7a:115c:a1e0::2:22", - wantAlias: "fd7a:115c:a1e0::2", - wantPort: "22", - }, - { - dest: "fd7a:115c:a1e0::2/128:22", - wantAlias: "fd7a:115c:a1e0::2/128", - wantPort: "22", - }, - { - dest: "tag:montreal-webserver:80,443", - wantAlias: "tag:montreal-webserver", - wantPort: "80,443", - }, - { - dest: "tag:api-server:443", - wantAlias: "tag:api-server", - wantPort: "443", - }, - { - dest: "example-host-1:*", - wantAlias: "example-host-1", - wantPort: "*", - }, - } - - for _, tt := range tests { - t.Run(tt.dest, func(t *testing.T) { - alias, port, _ := parseDestination(tt.dest) - - if alias != tt.wantAlias { - t.Errorf("unexpected alias: want(%s) != got(%s)", tt.wantAlias, alias) - } - - if port != tt.wantPort { - t.Errorf("unexpected port: want(%s) != got(%s)", tt.wantPort, port) - } - }) - } -} - -// this test should validate that we can expand a group in a TagOWner section and -// match properly the IP's of the related hosts. The owner is valid and the tag is also valid. -// the tag is matched in the Sources section. -func TestValidExpandTagOwnersInSources(t *testing.T) { - hostInfo := tailcfg.Hostinfo{ - OS: "centos", - Hostname: "testnodes", - RequestTags: []string{"tag:test"}, - } - - user := types.User{ - Model: gorm.Model{ID: 1}, - Name: "user1", - } - - node := &types.Node{ - ID: 0, - Hostname: "testnodes", - IPv4: iap("100.64.0.1"), - UserID: 0, - User: user, - RegisterMethod: util.RegisterMethodAuthKey, - Hostinfo: &hostInfo, - } - - pol := &ACLPolicy{ - Groups: Groups{"group:test": []string{"user1", "user2"}}, - TagOwners: TagOwners{"tag:test": []string{"user3", "group:test"}}, - ACLs: []ACL{ - { - Action: "accept", - Sources: []string{"tag:test"}, - Destinations: []string{"*:*"}, - }, - }, - } - - got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}, []types.User{user}) - require.NoError(t, err) - - want := []tailcfg.FilterRule{ - { - SrcIPs: []string{"100.64.0.1/32"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "0.0.0.0/0", Ports: tailcfg.PortRange{Last: 65535}}, - {IP: "::/0", Ports: tailcfg.PortRange{Last: 65535}}, - }, - }, - } - - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("TestValidExpandTagOwnersInSources() unexpected result (-want +got):\n%s", diff) - } -} - -// need a test with: -// tag on a host that isn't owned by a tag owners. So the user -// of the host should be valid. -func TestInvalidTagValidUser(t *testing.T) { - hostInfo := tailcfg.Hostinfo{ - OS: "centos", - Hostname: "testnodes", - RequestTags: []string{"tag:foo"}, - } - - node := &types.Node{ - ID: 1, - Hostname: "testnodes", - IPv4: iap("100.64.0.1"), - UserID: 1, - User: types.User{ - Model: gorm.Model{ID: 1}, - Name: "user1", - }, - RegisterMethod: util.RegisterMethodAuthKey, - Hostinfo: &hostInfo, - } - - pol := &ACLPolicy{ - TagOwners: TagOwners{"tag:test": []string{"user1"}}, - ACLs: []ACL{ - { - Action: "accept", - Sources: []string{"user1"}, - Destinations: []string{"*:*"}, - }, - }, - } - - got, _, err := GenerateFilterAndSSHRulesForTests( - pol, - node, - types.Nodes{}, - []types.User{node.User}, - ) - require.NoError(t, err) - - want := []tailcfg.FilterRule{ - { - SrcIPs: []string{"100.64.0.1/32"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "0.0.0.0/0", Ports: tailcfg.PortRange{Last: 65535}}, - {IP: "::/0", Ports: tailcfg.PortRange{Last: 65535}}, - }, - }, - } - - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("TestInvalidTagValidUser() unexpected result (-want +got):\n%s", diff) - } -} - -// this test should validate that we can expand a group in a TagOWner section and -// match properly the IP's of the related hosts. The owner is valid and the tag is also valid. -// the tag is matched in the Destinations section. -func TestValidExpandTagOwnersInDestinations(t *testing.T) { - hostInfo := tailcfg.Hostinfo{ - OS: "centos", - Hostname: "testnodes", - RequestTags: []string{"tag:test"}, - } - - node := &types.Node{ - ID: 1, - Hostname: "testnodes", - IPv4: iap("100.64.0.1"), - UserID: 1, - User: types.User{ - Model: gorm.Model{ID: 1}, - Name: "user1", - }, - RegisterMethod: util.RegisterMethodAuthKey, - Hostinfo: &hostInfo, - } - - pol := &ACLPolicy{ - Groups: Groups{"group:test": []string{"user1", "user2"}}, - TagOwners: TagOwners{"tag:test": []string{"user3", "group:test"}}, - ACLs: []ACL{ - { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"tag:test:*"}, - }, - }, - } - - // rules, _, err := GenerateFilterRules(pol, &node, peers, false) - // c.Assert(err, check.IsNil) - // - // c.Assert(rules, check.HasLen, 1) - // c.Assert(rules[0].DstPorts, check.HasLen, 1) - // c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32") - - got, _, err := GenerateFilterAndSSHRulesForTests( - pol, - node, - types.Nodes{}, - []types.User{node.User}, - ) - require.NoError(t, err) - - want := []tailcfg.FilterRule{ - { - SrcIPs: []string{"0.0.0.0/0", "::/0"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "100.64.0.1/32", Ports: tailcfg.PortRange{Last: 65535}}, - }, - }, - } - - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf( - "TestValidExpandTagOwnersInDestinations() unexpected result (-want +got):\n%s", - diff, - ) - } -} - -// tag on a host is owned by a tag owner, the tag is valid. -// an ACL rule is matching the tag to a user. It should not be valid since the -// host should be tied to the tag now. -func TestValidTagInvalidUser(t *testing.T) { - hostInfo := tailcfg.Hostinfo{ - OS: "centos", - Hostname: "webserver", - RequestTags: []string{"tag:webapp"}, - } - user := types.User{ - Model: gorm.Model{ID: 1}, - Name: "user1", - } - - node := &types.Node{ - ID: 1, - Hostname: "webserver", - IPv4: iap("100.64.0.1"), - UserID: 1, - User: user, - RegisterMethod: util.RegisterMethodAuthKey, - Hostinfo: &hostInfo, - } - - hostInfo2 := tailcfg.Hostinfo{ - OS: "debian", - Hostname: "Hostname", - } - - nodes2 := &types.Node{ - ID: 2, - Hostname: "user", - IPv4: iap("100.64.0.2"), - UserID: 1, - User: user, - RegisterMethod: util.RegisterMethodAuthKey, - Hostinfo: &hostInfo2, - } - - pol := &ACLPolicy{ - TagOwners: TagOwners{"tag:webapp": []string{"user1"}}, - ACLs: []ACL{ - { - Action: "accept", - Sources: []string{"user1"}, - Destinations: []string{"tag:webapp:80,443"}, - }, - }, - } - - got, _, err := GenerateFilterAndSSHRulesForTests( - pol, - node, - types.Nodes{nodes2}, - []types.User{user}, - ) - require.NoError(t, err) - - want := []tailcfg.FilterRule{ - { - SrcIPs: []string{"100.64.0.2/32"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "100.64.0.1/32", Ports: tailcfg.PortRange{First: 80, Last: 80}}, - {IP: "100.64.0.1/32", Ports: tailcfg.PortRange{First: 443, Last: 443}}, - }, - }, - } - - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("TestValidTagInvalidUser() unexpected result (-want +got):\n%s", diff) - } -} - -func TestFindUserByToken(t *testing.T) { - tests := []struct { - name string - users []types.User - token string - want types.User - wantErr bool - }{ - { - name: "exact match by ProviderIdentifier", - users: []types.User{ - {ProviderIdentifier: sql.NullString{Valid: true, String: "token1"}}, - {Email: "user2@example.com"}, - }, - token: "token1", - want: types.User{ProviderIdentifier: sql.NullString{Valid: true, String: "token1"}}, - wantErr: false, - }, - { - name: "no matches found", - users: []types.User{ - {Email: "user1@example.com"}, - {Name: "username"}, - }, - token: "nonexistent-token", - want: types.User{}, - wantErr: true, - }, - { - name: "multiple matches by email and name", - users: []types.User{ - {Email: "token2", Name: "notoken"}, - {Name: "token2", Email: "notoken@example.com"}, - }, - token: "token2", - want: types.User{}, - wantErr: true, - }, - { - name: "match by email", - users: []types.User{ - {Email: "token3@example.com"}, - {ProviderIdentifier: sql.NullString{Valid: true, String: "othertoken"}}, - }, - token: "token3@example.com", - want: types.User{Email: "token3@example.com"}, - wantErr: false, - }, - { - name: "match by name", - users: []types.User{ - {Name: "token4"}, - {Email: "user5@example.com"}, - }, - token: "token4", - want: types.User{Name: "token4"}, - wantErr: false, - }, - { - name: "provider identifier takes precedence over email and name matches", - users: []types.User{ - {ProviderIdentifier: sql.NullString{Valid: true, String: "token5"}}, - {Email: "token5@example.com", Name: "token5"}, - }, - token: "token5", - want: types.User{ProviderIdentifier: sql.NullString{Valid: true, String: "token5"}}, - wantErr: false, - }, - { - name: "empty token finds no users", - users: []types.User{ - {Email: "user6@example.com"}, - {Name: "username6"}, - }, - token: "", - want: types.User{}, - wantErr: true, - }, - // Test case 1: Duplicate Emails with Unique ProviderIdentifiers - { - name: "duplicate emails with unique provider identifiers", - users: []types.User{ - {ProviderIdentifier: sql.NullString{Valid: true, String: "pid1"}, Email: "user@example.com"}, - {ProviderIdentifier: sql.NullString{Valid: true, String: "pid2"}, Email: "user@example.com"}, - }, - token: "user@example.com", - want: types.User{}, - wantErr: true, - }, - - // Test case 2: Duplicate Names with Unique ProviderIdentifiers - { - name: "duplicate names with unique provider identifiers", - users: []types.User{ - {ProviderIdentifier: sql.NullString{Valid: true, String: "pid3"}, Name: "John Doe"}, - {ProviderIdentifier: sql.NullString{Valid: true, String: "pid4"}, Name: "John Doe"}, - }, - token: "John Doe", - want: types.User{}, - wantErr: true, - }, - - // Test case 3: Duplicate Emails and Names with Unique ProviderIdentifiers - { - name: "duplicate emails and names with unique provider identifiers", - users: []types.User{ - {ProviderIdentifier: sql.NullString{Valid: true, String: "pid5"}, Email: "user@example.com", Name: "John Doe"}, - {ProviderIdentifier: sql.NullString{Valid: true, String: "pid6"}, Email: "user@example.com", Name: "John Doe"}, - }, - token: "user@example.com", - want: types.User{}, - wantErr: true, - }, - - // Test case 4: Unique Names without ProviderIdentifiers - { - name: "unique names without provider identifiers", - users: []types.User{ - {ProviderIdentifier: sql.NullString{Valid: false, String: ""}, Name: "John Doe", Email: "johndoe@example.com"}, - {ProviderIdentifier: sql.NullString{Valid: false, String: ""}, Name: "Jane Smith", Email: "janesmith@example.com"}, - }, - token: "John Doe", - want: types.User{ProviderIdentifier: sql.NullString{Valid: false, String: ""}, Name: "John Doe", Email: "johndoe@example.com"}, - wantErr: false, - }, - - // Test case 5: Duplicate Emails without ProviderIdentifiers but Unique Names - { - name: "duplicate emails without provider identifiers but unique names", - users: []types.User{ - {ProviderIdentifier: sql.NullString{Valid: false, String: ""}, Name: "John Doe", Email: "user@example.com"}, - {ProviderIdentifier: sql.NullString{Valid: false, String: ""}, Name: "Jane Smith", Email: "user@example.com"}, - }, - token: "John Doe", - want: types.User{ProviderIdentifier: sql.NullString{Valid: false, String: ""}, Name: "John Doe", Email: "user@example.com"}, - wantErr: false, - }, - - // Test case 6: Duplicate Names and Emails without ProviderIdentifiers - { - name: "duplicate names and emails without provider identifiers", - users: []types.User{ - {ProviderIdentifier: sql.NullString{Valid: false, String: ""}, Name: "John Doe", Email: "user@example.com"}, - {ProviderIdentifier: sql.NullString{Valid: false, String: ""}, Name: "John Doe", Email: "user@example.com"}, - }, - token: "John Doe", - want: types.User{}, - wantErr: true, - }, - - // Test case 7: Multiple Users with the Same Email but Different Names and Unique ProviderIdentifiers - { - name: "multiple users with same email, different names, unique provider identifiers", - users: []types.User{ - {ProviderIdentifier: sql.NullString{Valid: true, String: "pid7"}, Email: "user@example.com", Name: "John Doe"}, - {ProviderIdentifier: sql.NullString{Valid: true, String: "pid8"}, Email: "user@example.com", Name: "Jane Smith"}, - }, - token: "user@example.com", - want: types.User{}, - wantErr: true, - }, - - // Test case 8: Multiple Users with the Same Name but Different Emails and Unique ProviderIdentifiers - { - name: "multiple users with same name, different emails, unique provider identifiers", - users: []types.User{ - {ProviderIdentifier: sql.NullString{Valid: true, String: "pid9"}, Email: "johndoe@example.com", Name: "John Doe"}, - {ProviderIdentifier: sql.NullString{Valid: true, String: "pid10"}, Email: "janedoe@example.com", Name: "John Doe"}, - }, - token: "John Doe", - want: types.User{}, - wantErr: true, - }, - - // Test case 9: Multiple Users with Same Email and Name but Unique ProviderIdentifiers - { - name: "multiple users with same email and name, unique provider identifiers", - users: []types.User{ - {ProviderIdentifier: sql.NullString{Valid: true, String: "pid11"}, Email: "user@example.com", Name: "John Doe"}, - {ProviderIdentifier: sql.NullString{Valid: true, String: "pid12"}, Email: "user@example.com", Name: "John Doe"}, - }, - token: "user@example.com", - want: types.User{}, - wantErr: true, - }, - - // Test case 10: Multiple Users without ProviderIdentifiers but with Unique Names and Emails - { - name: "multiple users without provider identifiers, unique names and emails", - users: []types.User{ - {ProviderIdentifier: sql.NullString{Valid: false, String: ""}, Name: "John Doe", Email: "johndoe@example.com"}, - {ProviderIdentifier: sql.NullString{Valid: false, String: ""}, Name: "Jane Smith", Email: "janesmith@example.com"}, - }, - token: "John Doe", - want: types.User{ProviderIdentifier: sql.NullString{Valid: false, String: ""}, Name: "John Doe", Email: "johndoe@example.com"}, - wantErr: false, - }, - - // Test case 11: Multiple Users without ProviderIdentifiers and Duplicate Emails but Unique Names - { - name: "multiple users without provider identifiers, duplicate emails but unique names", - users: []types.User{ - {ProviderIdentifier: sql.NullString{Valid: false, String: ""}, Name: "John Doe", Email: "user@example.com"}, - {ProviderIdentifier: sql.NullString{Valid: false, String: ""}, Name: "Jane Smith", Email: "user@example.com"}, - }, - token: "John Doe", - want: types.User{ProviderIdentifier: sql.NullString{Valid: false, String: ""}, Name: "John Doe", Email: "user@example.com"}, - wantErr: false, - }, - - // Test case 12: Multiple Users without ProviderIdentifiers and Duplicate Names but Unique Emails - { - name: "multiple users without provider identifiers, duplicate names but unique emails", - users: []types.User{ - {ProviderIdentifier: sql.NullString{Valid: false, String: ""}, Name: "John Doe", Email: "johndoe@example.com"}, - {ProviderIdentifier: sql.NullString{Valid: false, String: ""}, Name: "John Doe", Email: "janedoe@example.com"}, - }, - token: "John Doe", - want: types.User{}, - wantErr: true, - }, - - // Test case 13: Multiple Users without ProviderIdentifiers and Duplicate Both Names and Emails - { - name: "multiple users without provider identifiers, duplicate names and emails", - users: []types.User{ - {ProviderIdentifier: sql.NullString{Valid: false, String: ""}, Name: "John Doe", Email: "user@example.com"}, - {ProviderIdentifier: sql.NullString{Valid: false, String: ""}, Name: "John Doe", Email: "user@example.com"}, - }, - token: "John Doe", - want: types.User{}, - wantErr: true, - }, - - // Test case 14: Multiple Users with Same Email Without ProviderIdentifiers - { - name: "multiple users with same email without provider identifiers", - users: []types.User{ - {ProviderIdentifier: sql.NullString{Valid: false, String: ""}, Name: "John Doe", Email: "user@example.com"}, - {ProviderIdentifier: sql.NullString{Valid: false, String: ""}, Name: "Jane Smith", Email: "user@example.com"}, - }, - token: "user@example.com", - want: types.User{}, - wantErr: true, - }, - - // Test case 15: Multiple Users with Same Name Without ProviderIdentifiers - { - name: "multiple users with same name without provider identifiers", - users: []types.User{ - {ProviderIdentifier: sql.NullString{Valid: false, String: ""}, Name: "John Doe", Email: "johndoe@example.com"}, - {ProviderIdentifier: sql.NullString{Valid: false, String: ""}, Name: "John Doe", Email: "janedoe@example.com"}, - }, - token: "John Doe", - want: types.User{}, - wantErr: true, - }, - { - name: "Name field used as email address match", - users: []types.User{ - {ProviderIdentifier: sql.NullString{Valid: true, String: "pid3"}, Name: "user@example.com", Email: "another@example.com"}, - }, - token: "user@example.com", - want: types.User{ProviderIdentifier: sql.NullString{Valid: true, String: "pid3"}, Name: "user@example.com", Email: "another@example.com"}, - wantErr: false, - }, - { - name: "multiple users with same name as email and unique provider identifiers", - users: []types.User{ - {ProviderIdentifier: sql.NullString{Valid: true, String: "pid4"}, Name: "user@example.com", Email: "user1@example.com"}, - {ProviderIdentifier: sql.NullString{Valid: true, String: "pid5"}, Name: "user@example.com", Email: "user2@example.com"}, - }, - token: "user@example.com", - want: types.User{}, - wantErr: true, - }, - { - name: "no provider identifier and duplicate names as emails", - users: []types.User{ - {ProviderIdentifier: sql.NullString{Valid: false, String: ""}, Name: "user@example.com", Email: "another1@example.com"}, - {ProviderIdentifier: sql.NullString{Valid: false, String: ""}, Name: "user@example.com", Email: "another2@example.com"}, - }, - token: "user@example.com", - want: types.User{}, - wantErr: true, - }, - { - name: "name as email with multiple matches when provider identifier is not set", - users: []types.User{ - {ProviderIdentifier: sql.NullString{Valid: false, String: ""}, Name: "user@example.com", Email: "another1@example.com"}, - {ProviderIdentifier: sql.NullString{Valid: false, String: ""}, Name: "user@example.com", Email: "another2@example.com"}, - }, - token: "user@example.com", - want: types.User{}, - wantErr: true, - }, - { - name: "test-v2-format-working", - users: []types.User{ - {ProviderIdentifier: sql.NullString{Valid: false, String: ""}, Name: "user1", Email: "another1@example.com"}, - {ProviderIdentifier: sql.NullString{Valid: false, String: ""}, Name: "user2", Email: "another2@example.com"}, - }, - token: "user2", - want: types.User{ProviderIdentifier: sql.NullString{Valid: false, String: ""}, Name: "user2", Email: "another2@example.com"}, - wantErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - gotUser, err := findUserFromToken(tt.users, tt.token) - if (err != nil) != tt.wantErr { - t.Errorf("findUserFromToken() error = %v, wantErr %v", err, tt.wantErr) - return - } - if diff := cmp.Diff(tt.want, gotUser, util.Comparers...); diff != "" { - t.Errorf("findUserFromToken() unexpected result (-want +got):\n%s", diff) - } - }) - } -} diff --git a/hscontrol/policy/v1/acls_types.go b/hscontrol/policy/v1/acls_types.go deleted file mode 100644 index c7c59328..00000000 --- a/hscontrol/policy/v1/acls_types.go +++ /dev/null @@ -1,123 +0,0 @@ -package v1 - -import ( - "encoding/json" - "net/netip" - "strings" - - "github.com/tailscale/hujson" -) - -// ACLPolicy represents a Tailscale ACL Policy. -type ACLPolicy struct { - Groups Groups `json:"groups"` - Hosts Hosts `json:"hosts"` - TagOwners TagOwners `json:"tagOwners"` - ACLs []ACL `json:"acls"` - Tests []ACLTest `json:"tests"` - AutoApprovers AutoApprovers `json:"autoApprovers"` - SSHs []SSH `json:"ssh"` -} - -// ACL is a basic rule for the ACL Policy. -type ACL struct { - Action string `json:"action"` - Protocol string `json:"proto"` - Sources []string `json:"src"` - Destinations []string `json:"dst"` -} - -// Groups references a series of alias in the ACL rules. -type Groups map[string][]string - -// Hosts are alias for IP addresses or subnets. -type Hosts map[string]netip.Prefix - -// TagOwners specify what users (users?) are allow to use certain tags. -type TagOwners map[string][]string - -// ACLTest is not implemented, but should be used to check if a certain rule is allowed. -type ACLTest struct { - Source string `json:"src"` - Accept []string `json:"accept"` - Deny []string `json:"deny,omitempty"` -} - -// AutoApprovers specify which users, groups or tags have their advertised routes -// or exit node status automatically enabled. -type AutoApprovers struct { - Routes map[string][]string `json:"routes"` - ExitNode []string `json:"exitNode"` -} - -// SSH controls who can ssh into which machines. -type SSH struct { - Action string `json:"action"` - Sources []string `json:"src"` - Destinations []string `json:"dst"` - Users []string `json:"users"` - CheckPeriod string `json:"checkPeriod,omitempty"` -} - -// UnmarshalJSON allows to parse the Hosts directly into netip objects. -func (hosts *Hosts) UnmarshalJSON(data []byte) error { - newHosts := Hosts{} - hostIPPrefixMap := make(map[string]string) - ast, err := hujson.Parse(data) - if err != nil { - return err - } - ast.Standardize() - data = ast.Pack() - err = json.Unmarshal(data, &hostIPPrefixMap) - if err != nil { - return err - } - for host, prefixStr := range hostIPPrefixMap { - if !strings.Contains(prefixStr, "/") { - prefixStr += "/32" - } - prefix, err := netip.ParsePrefix(prefixStr) - if err != nil { - return err - } - newHosts[host] = prefix - } - *hosts = newHosts - - return nil -} - -// IsZero is perhaps a bit naive here. -func (pol ACLPolicy) IsZero() bool { - if len(pol.Groups) == 0 && len(pol.Hosts) == 0 && len(pol.ACLs) == 0 && len(pol.SSHs) == 0 { - return true - } - - return false -} - -// GetRouteApprovers returns the list of autoApproving users, groups or tags for a given IPPrefix. -func (autoApprovers *AutoApprovers) GetRouteApprovers( - prefix netip.Prefix, -) ([]string, error) { - if prefix.Bits() == 0 { - return autoApprovers.ExitNode, nil // 0.0.0.0/0, ::/0 or equivalent - } - - approverAliases := make([]string, 0) - - for autoApprovedPrefix, autoApproverAliases := range autoApprovers.Routes { - autoApprovedPrefix, err := netip.ParsePrefix(autoApprovedPrefix) - if err != nil { - return nil, err - } - - if prefix.Bits() >= autoApprovedPrefix.Bits() && - autoApprovedPrefix.Contains(prefix.Masked().Addr()) { - approverAliases = append(approverAliases, autoApproverAliases...) - } - } - - return approverAliases, nil -} diff --git a/hscontrol/policy/v1/policy.go b/hscontrol/policy/v1/policy.go deleted file mode 100644 index c2e9520a..00000000 --- a/hscontrol/policy/v1/policy.go +++ /dev/null @@ -1,188 +0,0 @@ -package v1 - -import ( - "fmt" - "github.com/juanfont/headscale/hscontrol/policy/matcher" - "io" - "net/netip" - "os" - "sync" - - "slices" - - "github.com/juanfont/headscale/hscontrol/types" - "github.com/rs/zerolog/log" - "tailscale.com/tailcfg" - "tailscale.com/util/deephash" -) - -func NewPolicyManagerFromPath(path string, users []types.User, nodes types.Nodes) (*PolicyManager, error) { - policyFile, err := os.Open(path) - if err != nil { - return nil, err - } - defer policyFile.Close() - - policyBytes, err := io.ReadAll(policyFile) - if err != nil { - return nil, err - } - - return NewPolicyManager(policyBytes, users, nodes) -} - -func NewPolicyManager(polB []byte, users []types.User, nodes types.Nodes) (*PolicyManager, error) { - var pol *ACLPolicy - var err error - if polB != nil && len(polB) > 0 { - pol, err = LoadACLPolicyFromBytes(polB) - if err != nil { - return nil, fmt.Errorf("parsing policy: %w", err) - } - } - - pm := PolicyManager{ - pol: pol, - users: users, - nodes: nodes, - } - - _, err = pm.updateLocked() - if err != nil { - return nil, err - } - - return &pm, nil -} - -type PolicyManager struct { - mu sync.Mutex - pol *ACLPolicy - polHash deephash.Sum - - users []types.User - nodes types.Nodes - - filter []tailcfg.FilterRule - filterHash deephash.Sum -} - -// updateLocked updates the filter rules based on the current policy and nodes. -// It must be called with the lock held. -func (pm *PolicyManager) updateLocked() (bool, error) { - filter, err := pm.pol.CompileFilterRules(pm.users, pm.nodes) - if err != nil { - return false, fmt.Errorf("compiling filter rules: %w", err) - } - - polHash := deephash.Hash(pm.pol) - filterHash := deephash.Hash(&filter) - - if polHash == pm.polHash && filterHash == pm.filterHash { - return false, nil - } - - pm.filter = filter - pm.filterHash = filterHash - pm.polHash = polHash - - return true, nil -} - -func (pm *PolicyManager) Filter() ([]tailcfg.FilterRule, []matcher.Match) { - pm.mu.Lock() - defer pm.mu.Unlock() - return pm.filter, matcher.MatchesFromFilterRules(pm.filter) -} - -func (pm *PolicyManager) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) { - pm.mu.Lock() - defer pm.mu.Unlock() - - return pm.pol.CompileSSHPolicy(node, pm.users, pm.nodes) -} - -func (pm *PolicyManager) SetPolicy(polB []byte) (bool, error) { - if len(polB) == 0 { - return false, nil - } - - pol, err := LoadACLPolicyFromBytes(polB) - if err != nil { - return false, fmt.Errorf("parsing policy: %w", err) - } - - pm.mu.Lock() - defer pm.mu.Unlock() - - pm.pol = pol - - return pm.updateLocked() -} - -// SetUsers updates the users in the policy manager and updates the filter rules. -func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) { - pm.mu.Lock() - defer pm.mu.Unlock() - - pm.users = users - return pm.updateLocked() -} - -// SetNodes updates the nodes in the policy manager and updates the filter rules. -func (pm *PolicyManager) SetNodes(nodes types.Nodes) (bool, error) { - pm.mu.Lock() - defer pm.mu.Unlock() - pm.nodes = nodes - return pm.updateLocked() -} - -func (pm *PolicyManager) NodeCanHaveTag(node *types.Node, tag string) bool { - if pm == nil || pm.pol == nil { - return false - } - - pm.mu.Lock() - defer pm.mu.Unlock() - - tags, invalid := pm.pol.TagsOfNode(pm.users, node) - log.Debug().Strs("authorised_tags", tags).Strs("unauthorised_tags", invalid).Uint64("node.id", node.ID.Uint64()).Msg("tags provided by policy") - - return slices.Contains(tags, tag) -} - -func (pm *PolicyManager) NodeCanApproveRoute(node *types.Node, route netip.Prefix) bool { - if pm == nil || pm.pol == nil { - return false - } - - pm.mu.Lock() - defer pm.mu.Unlock() - - approvers, _ := pm.pol.AutoApprovers.GetRouteApprovers(route) - - for _, approvedAlias := range approvers { - if approvedAlias == node.User.Username() { - return true - } else { - ips, err := pm.pol.ExpandAlias(pm.nodes, pm.users, approvedAlias) - if err != nil { - return false - } - - // approvedIPs should contain all of node's IPs if it matches the rule, so check for first - if ips != nil && ips.Contains(*node.IPv4) { - return true - } - } - } - return false -} - -func (pm *PolicyManager) Version() int { - return 1 -} - -func (pm *PolicyManager) DebugString() string { - return "not implemented for v1" -} diff --git a/hscontrol/policy/v1/policy_test.go b/hscontrol/policy/v1/policy_test.go deleted file mode 100644 index c9f98079..00000000 --- a/hscontrol/policy/v1/policy_test.go +++ /dev/null @@ -1,180 +0,0 @@ -package v1 - -import ( - "github.com/juanfont/headscale/hscontrol/policy/matcher" - "testing" - - "github.com/google/go-cmp/cmp" - "github.com/juanfont/headscale/hscontrol/types" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "gorm.io/gorm" - "tailscale.com/tailcfg" -) - -func TestPolicySetChange(t *testing.T) { - users := []types.User{ - { - Model: gorm.Model{ID: 1}, - Name: "testuser", - }, - } - tests := []struct { - name string - users []types.User - nodes types.Nodes - policy []byte - wantUsersChange bool - wantNodesChange bool - wantPolicyChange bool - wantFilter []tailcfg.FilterRule - wantMatchers []matcher.Match - }{ - { - name: "set-nodes", - nodes: types.Nodes{ - { - IPv4: iap("100.64.0.2"), - User: users[0], - }, - }, - wantNodesChange: false, - wantFilter: []tailcfg.FilterRule{ - { - DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}}, - }, - }, - wantMatchers: []matcher.Match{ - matcher.MatchFromStrings([]string{}, []string{"100.64.0.1/32"}), - }, - }, - { - name: "set-users", - users: users, - wantUsersChange: false, - wantFilter: []tailcfg.FilterRule{ - { - DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}}, - }, - }, - wantMatchers: []matcher.Match{ - matcher.MatchFromStrings([]string{}, []string{"100.64.0.1/32"}), - }, - }, - { - name: "set-users-and-node", - users: users, - nodes: types.Nodes{ - { - IPv4: iap("100.64.0.2"), - User: users[0], - }, - }, - wantUsersChange: false, - wantNodesChange: true, - wantFilter: []tailcfg.FilterRule{ - { - SrcIPs: []string{"100.64.0.2/32"}, - DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}}, - }, - }, - wantMatchers: []matcher.Match{ - matcher.MatchFromStrings([]string{"100.64.0.2/32"}, []string{"100.64.0.1/32"}), - }, - }, - { - name: "set-policy", - policy: []byte(` -{ -"acls": [ - { - "action": "accept", - "src": [ - "100.64.0.61", - ], - "dst": [ - "100.64.0.62:*", - ], - }, - ], -} - `), - wantPolicyChange: true, - wantFilter: []tailcfg.FilterRule{ - { - SrcIPs: []string{"100.64.0.61/32"}, - DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.62/32", Ports: tailcfg.PortRangeAny}}, - }, - }, - wantMatchers: []matcher.Match{ - matcher.MatchFromStrings([]string{"100.64.0.61/32"}, []string{"100.64.0.62/32"}), - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - pol := ` -{ - "groups": { - "group:example": [ - "testuser", - ], - }, - - "hosts": { - "host-1": "100.64.0.1", - "subnet-1": "100.100.101.100/24", - }, - - "acls": [ - { - "action": "accept", - "src": [ - "group:example", - ], - "dst": [ - "host-1:*", - ], - }, - ], -} -` - pm, err := NewPolicyManager([]byte(pol), []types.User{}, types.Nodes{}) - require.NoError(t, err) - - if tt.policy != nil { - change, err := pm.SetPolicy(tt.policy) - require.NoError(t, err) - - assert.Equal(t, tt.wantPolicyChange, change) - } - - if tt.users != nil { - change, err := pm.SetUsers(tt.users) - require.NoError(t, err) - - assert.Equal(t, tt.wantUsersChange, change) - } - - if tt.nodes != nil { - change, err := pm.SetNodes(tt.nodes) - require.NoError(t, err) - - assert.Equal(t, tt.wantNodesChange, change) - } - - filter, matchers := pm.Filter() - if diff := cmp.Diff(tt.wantFilter, filter); diff != "" { - t.Errorf("TestPolicySetChange() unexpected filter (-want +got):\n%s", diff) - } - if diff := cmp.Diff( - tt.wantMatchers, - matchers, - cmp.AllowUnexported(matcher.Match{}), - ); diff != "" { - t.Errorf("TestPolicySetChange() unexpected matchers (-want +got):\n%s", diff) - } - }) - } -} diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go index 580a1980..941a645b 100644 --- a/hscontrol/policy/v2/types.go +++ b/hscontrol/policy/v2/types.go @@ -33,6 +33,60 @@ func (a Asterix) String() string { return "*" } +// MarshalJSON marshals the Asterix to JSON. +func (a Asterix) MarshalJSON() ([]byte, error) { + return []byte(`"*"`), nil +} + +// MarshalJSON marshals the AliasWithPorts to JSON. +func (a AliasWithPorts) MarshalJSON() ([]byte, error) { + if a.Alias == nil { + return []byte(`""`), nil + } + + var alias string + switch v := a.Alias.(type) { + case *Username: + alias = string(*v) + case *Group: + alias = string(*v) + case *Tag: + alias = string(*v) + case *Host: + alias = string(*v) + case *Prefix: + alias = v.String() + case *AutoGroup: + alias = string(*v) + case Asterix: + alias = "*" + default: + return nil, fmt.Errorf("unknown alias type: %T", v) + } + + // If no ports are specified + if len(a.Ports) == 0 { + return json.Marshal(alias) + } + + // Check if it's the wildcard port range + if len(a.Ports) == 1 && a.Ports[0].First == 0 && a.Ports[0].Last == 65535 { + return json.Marshal(fmt.Sprintf("%s:*", alias)) + } + + // Otherwise, format as "alias:ports" + var ports []string + for _, port := range a.Ports { + if port.First == port.Last { + ports = append(ports, fmt.Sprintf("%d", port.First)) + } else { + ports = append(ports, fmt.Sprintf("%d-%d", port.First, port.Last)) + } + } + + return json.Marshal(fmt.Sprintf("%s:%s", alias, strings.Join(ports, ","))) +} + func (a Asterix) UnmarshalJSON(b []byte) error { return nil } @@ -63,6 +117,16 @@ func (u *Username) String() string { return string(*u) } +// MarshalJSON marshals the Username to JSON. +func (u Username) MarshalJSON() ([]byte, error) { + return json.Marshal(string(u)) +} + +// MarshalJSON marshals the Prefix to JSON. +func (p Prefix) MarshalJSON() ([]byte, error) { + return json.Marshal(p.String()) +} + func (u *Username) UnmarshalJSON(b []byte) error { *u = Username(strings.Trim(string(b), `"`)) if err := u.Validate(); err != nil { @@ -163,10 +227,25 @@ func (g Group) CanBeAutoApprover() bool { return true } +// String returns the string representation of the Group. func (g Group) String() string { return string(g) } +func (h Host) String() string { + return string(h) +} + +// MarshalJSON marshals the Host to JSON. +func (h Host) MarshalJSON() ([]byte, error) { + return json.Marshal(string(h)) +} + +// MarshalJSON marshals the Group to JSON. +func (g Group) MarshalJSON() ([]byte, error) { + return json.Marshal(string(g)) +} + func (g Group) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) { var ips netipx.IPSetBuilder var errs []error @@ -244,6 +323,11 @@ func (t Tag) String() string { return string(t) } +// MarshalJSON marshals the Tag to JSON. +func (t Tag) MarshalJSON() ([]byte, error) { + return json.Marshal(string(t)) +} + // Host is a string that represents a hostname. type Host string @@ -279,7 +363,7 @@ func (h Host) Resolve(p *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSe // If the IP is a single host, look for a node to ensure we add all the IPs of // the node to the IPSet. - // appendIfNodeHasIP(nodes, &ips, pref) + appendIfNodeHasIP(nodes, &ips, netip.Prefix(pref)) // TODO(kradalby): I am a bit unsure what is the correct way to do this, // should a host with a non single IP be able to resolve the full host (inc all IPs). @@ -355,30 +439,25 @@ func (p Prefix) Resolve(_ *Policy, _ types.Users, nodes types.Nodes) (*netipx.IP ips.AddPrefix(netip.Prefix(p)) // If the IP is a single host, look for a node to ensure we add all the IPs of // the node to the IPSet. - // appendIfNodeHasIP(nodes, &ips, pref) - - // TODO(kradalby): I am a bit unsure what is the correct way to do this, - // should a host with a non single IP be able to resolve the full host (inc all IPs). - // Currently this is done because the old implementation did this, we might want to - // drop it before releasing. - // For example: - // If a src or dst includes "64.0.0.0/2:*", it will include 100.64/16 range, which - // means that it will need to fetch the IPv6 addrs of the node to include the full range. - // Clearly, if a user sets the dst to be "64.0.0.0/2:*", it is likely more of a exit node - // and this would be strange behaviour. - ipsTemp, err := ips.IPSet() - if err != nil { - errs = append(errs, err) - } - for _, node := range nodes { - if node.InIPSet(ipsTemp) { - node.AppendToIPSet(&ips) - } - } + appendIfNodeHasIP(nodes, &ips, netip.Prefix(p)) return buildIPSetMultiErr(&ips, errs) } +// appendIfNodeHasIP appends the IPs of the nodes to the IPSet if the node has the +// IP address in the prefix. +func appendIfNodeHasIP(nodes types.Nodes, ips *netipx.IPSetBuilder, pref netip.Prefix) { + if !pref.IsSingleIP() && !tsaddr.IsTailscaleIP(pref.Addr()) { + return + } + + for _, node := range nodes { + if node.HasIP(pref.Addr()) { + node.AppendToIPSet(ips) + } + } +} + // AutoGroup is a special string which is always prefixed with `autogroup:` type AutoGroup string @@ -415,6 +494,11 @@ func (ag *AutoGroup) UnmarshalJSON(b []byte) error { return nil } +// MarshalJSON marshals the AutoGroup to JSON. +func (ag AutoGroup) MarshalJSON() ([]byte, error) { + return json.Marshal(string(ag)) +} + func (ag AutoGroup) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) { var build netipx.IPSetBuilder @@ -644,6 +728,37 @@ func (a *Aliases) UnmarshalJSON(b []byte) error { return nil } +// MarshalJSON marshals the Aliases to JSON. +func (a Aliases) MarshalJSON() ([]byte, error) { + if a == nil { + return []byte("[]"), nil + } + + aliases := make([]string, len(a)) + for i, alias := range a { + switch v := alias.(type) { + case *Username: + aliases[i] = string(*v) + case *Group: + aliases[i] = string(*v) + case *Tag: + aliases[i] = string(*v) + case *Host: + aliases[i] = string(*v) + case *Prefix: + aliases[i] = v.String() + case *AutoGroup: + aliases[i] = string(*v) + case Asterix: + aliases[i] = "*" + default: + return nil, fmt.Errorf("unknown alias type: %T", v) + } + } + + return json.Marshal(aliases) +} + func (a Aliases) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) { var ips netipx.IPSetBuilder var errs []error @@ -702,6 +817,29 @@ func (aa *AutoApprovers) UnmarshalJSON(b []byte) error { return nil } +// MarshalJSON marshals the AutoApprovers to JSON. +func (aa AutoApprovers) MarshalJSON() ([]byte, error) { + if aa == nil { + return []byte("[]"), nil + } + + approvers := make([]string, len(aa)) + for i, approver := range aa { + switch v := approver.(type) { + case *Username: + approvers[i] = string(*v) + case *Tag: + approvers[i] = string(*v) + case *Group: + approvers[i] = string(*v) + default: + return nil, fmt.Errorf("unknown auto approver type: %T", v) + } + } + + return json.Marshal(approvers) +} + func parseAutoApprover(s string) (AutoApprover, error) { switch { case isUser(s): @@ -771,6 +909,27 @@ func (o *Owners) UnmarshalJSON(b []byte) error { return nil } +// MarshalJSON marshals the Owners to JSON. +func (o Owners) MarshalJSON() ([]byte, error) { + if o == nil { + return []byte("[]"), nil + } + + owners := make([]string, len(o)) + for i, owner := range o { + switch v := owner.(type) { + case *Username: + owners[i] = string(*v) + case *Group: + owners[i] = string(*v) + default: + return nil, fmt.Errorf("unknown owner type: %T", v) + } + } + + return json.Marshal(owners) +} + func parseOwner(s string) (Owner, error) { switch { case isUser(s): @@ -857,22 +1016,64 @@ func (h *Hosts) UnmarshalJSON(b []byte) error { return err } - var pref Prefix - err := pref.parseString(value) - if err != nil { - return fmt.Errorf("Hostname %q contains an invalid IP address: %q", key, value) + var prefix Prefix + if err := prefix.parseString(value); err != nil { + return fmt.Errorf(`Hostname "%s" contains an invalid IP address: "%s"`, key, value) } - (*h)[host] = pref + (*h)[host] = prefix } + return nil } +// MarshalJSON marshals the Hosts to JSON. +func (h Hosts) MarshalJSON() ([]byte, error) { + if h == nil { + return []byte("{}"), nil + } + + rawHosts := make(map[string]string) + for host, prefix := range h { + rawHosts[string(host)] = prefix.String() + } + + return json.Marshal(rawHosts) +} + func (h Hosts) exist(name Host) bool { _, ok := h[name] return ok } +// MarshalJSON marshals the TagOwners to JSON. +func (to TagOwners) MarshalJSON() ([]byte, error) { + if to == nil { + return []byte("{}"), nil + } + + rawTagOwners := make(map[string][]string) + for tag, owners := range to { + tagStr := string(tag) + ownerStrs := make([]string, len(owners)) + + for i, owner := range owners { + switch v := owner.(type) { + case *Username: + ownerStrs[i] = string(*v) + case *Group: + ownerStrs[i] = string(*v) + default: + return nil, fmt.Errorf("unknown owner type: %T", v) + } + } + + rawTagOwners[tagStr] = ownerStrs + } + + return json.Marshal(rawTagOwners) +} + // TagOwners are a map of Tag to a list of the UserEntities that own the tag. type TagOwners map[Tag]Owners @@ -926,8 +1127,32 @@ func resolveTagOwners(p *Policy, users types.Users, nodes types.Nodes) (map[Tag] } type AutoApproverPolicy struct { - Routes map[netip.Prefix]AutoApprovers `json:"routes"` - ExitNode AutoApprovers `json:"exitNode"` + Routes map[netip.Prefix]AutoApprovers `json:"routes,omitempty"` + ExitNode AutoApprovers `json:"exitNode,omitempty"` +} + +// MarshalJSON marshals the AutoApproverPolicy to JSON. +func (ap AutoApproverPolicy) MarshalJSON() ([]byte, error) { + // Marshal empty policies as empty object + if ap.Routes == nil && ap.ExitNode == nil { + return []byte("{}"), nil + } + + type Alias AutoApproverPolicy + + // Create a new object to avoid marshalling nil slices as null instead of empty arrays + obj := Alias(ap) + + // Initialize empty maps/slices to ensure they're marshalled as empty objects/arrays instead of null + if obj.Routes == nil { + obj.Routes = make(map[netip.Prefix]AutoApprovers) + } + + if obj.ExitNode == nil { + obj.ExitNode = AutoApprovers{} + } + + return json.Marshal(&obj) } // resolveAutoApprovers resolves the AutoApprovers to a map of netip.Prefix to netipx.IPSet. @@ -1011,14 +1236,17 @@ type Policy struct { // callers using it should panic if not validated bool `json:"-"` - Groups Groups `json:"groups"` - Hosts Hosts `json:"hosts"` - TagOwners TagOwners `json:"tagOwners"` - ACLs []ACL `json:"acls"` - AutoApprovers AutoApproverPolicy `json:"autoApprovers"` - SSHs []SSH `json:"ssh"` + Groups Groups `json:"groups,omitempty"` + Hosts Hosts `json:"hosts,omitempty"` + TagOwners TagOwners `json:"tagOwners,omitempty"` + ACLs []ACL `json:"acls,omitempty"` + AutoApprovers AutoApproverPolicy `json:"autoApprovers,omitempty"` + SSHs []SSH `json:"ssh,omitempty"` } +// MarshalJSON is deliberately not implemented for Policy. +// We use the default JSON marshalling behavior provided by the Go runtime. + var ( // TODO(kradalby): Add these checks for tagOwners and autoApprovers autogroupForSrc = []AutoGroup{AutoGroupMember, AutoGroupTagged} @@ -1320,6 +1548,24 @@ type SSH struct { // It can be a list of usernames, groups, tags or autogroups. type SSHSrcAliases []Alias +// MarshalJSON marshals the Groups to JSON. +func (g Groups) MarshalJSON() ([]byte, error) { + if g == nil { + return []byte("{}"), nil + } + + raw := make(map[string][]string) + for group, usernames := range g { + users := make([]string, len(usernames)) + for i, username := range usernames { + users[i] = string(username) + } + raw[string(group)] = users + } + + return json.Marshal(raw) +} + func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error { var aliases []AliasEnc err := json.Unmarshal(b, &aliases) @@ -1333,12 +1579,98 @@ func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error { case *Username, *Group, *Tag, *AutoGroup: (*a)[i] = alias.Alias default: - return fmt.Errorf("type %T not supported", alias.Alias) + return fmt.Errorf( + "alias %T is not supported for SSH source", + alias.Alias, + ) } } return nil } +func (a *SSHDstAliases) UnmarshalJSON(b []byte) error { + var aliases []AliasEnc + err := json.Unmarshal(b, &aliases) + if err != nil { + return err + } + + *a = make([]Alias, len(aliases)) + for i, alias := range aliases { + switch alias.Alias.(type) { + case *Username, *Tag, *AutoGroup, *Host, + // Asterix and Group is actually not supposed to be supported, + // however we do not support autogroups at the moment + // so we will leave it in as there is no other option + // to dynamically give all access + // https://tailscale.com/kb/1193/tailscale-ssh#dst + // TODO(kradalby): remove this when we support autogroup:tagged and autogroup:member + Asterix: + (*a)[i] = alias.Alias + default: + return fmt.Errorf( + "alias %T is not supported for SSH destination", + alias.Alias, + ) + } + } + return nil +} + +// MarshalJSON marshals the SSHDstAliases to JSON. +func (a SSHDstAliases) MarshalJSON() ([]byte, error) { + if a == nil { + return []byte("[]"), nil + } + + aliases := make([]string, len(a)) + for i, alias := range a { + switch v := alias.(type) { + case *Username: + aliases[i] = string(*v) + case *Tag: + aliases[i] = string(*v) + case *AutoGroup: + aliases[i] = string(*v) + case *Host: + aliases[i] = string(*v) + case Asterix: + aliases[i] = "*" + default: + return nil, fmt.Errorf("unknown SSH destination alias type: %T", v) + } + } + + return json.Marshal(aliases) +} + +// MarshalJSON marshals the SSHSrcAliases to JSON. +func (a SSHSrcAliases) MarshalJSON() ([]byte, error) { + if a == nil { + return []byte("[]"), nil + } + + aliases := make([]string, len(a)) + for i, alias := range a { + switch v := alias.(type) { + case *Username: + aliases[i] = string(*v) + case *Group: + aliases[i] = string(*v) + case *Tag: + aliases[i] = string(*v) + case *AutoGroup: + aliases[i] = string(*v) + case Asterix: + aliases[i] = "*" + default: + return nil, fmt.Errorf("unknown SSH source alias type: %T", v) + } + } + + return json.Marshal(aliases) +} + func (a SSHSrcAliases) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) { var ips netipx.IPSetBuilder var errs []error @@ -1359,38 +1691,17 @@ func (a SSHSrcAliases) Resolve(p *Policy, users types.Users, nodes types.Nodes) // It can be a list of usernames, tags or autogroups. type SSHDstAliases []Alias -func (a *SSHDstAliases) UnmarshalJSON(b []byte) error { - var aliases []AliasEnc - err := json.Unmarshal(b, &aliases) - if err != nil { - return err - } - - *a = make([]Alias, len(aliases)) - for i, alias := range aliases { - switch alias.Alias.(type) { - case *Username, *Tag, *AutoGroup, - // Asterix and Group is actually not supposed to be supported, - // however we do not support autogroups at the moment - // so we will leave it in as there is no other option - // to dynamically give all access - // https://tailscale.com/kb/1193/tailscale-ssh#dst - // TODO(kradalby): remove this when we support autogroup:tagged and autogroup:member - Asterix: - (*a)[i] = alias.Alias - default: - return fmt.Errorf("type %T not supported", alias.Alias) - } - } - return nil -} - type SSHUser string func (u SSHUser) String() string { return string(u) } +// MarshalJSON marshals the SSHUser to JSON. +func (u SSHUser) MarshalJSON() ([]byte, error) { + return json.Marshal(string(u)) +} + // unmarshalPolicy takes a byte slice and unmarshals it into a Policy struct. // In addition to unmarshalling, it will also validate the policy. // This is the only entrypoint of reading a policy from a file or other source. diff --git a/hscontrol/policy/v2/types_test.go b/hscontrol/policy/v2/types_test.go index 3e9de7d7..ac2fc3b1 100644 --- a/hscontrol/policy/v2/types_test.go +++ b/hscontrol/policy/v2/types_test.go @@ -10,6 +10,9 @@ import ( "github.com/google/go-cmp/cmp/cmpopts" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" + "github.com/prometheus/common/model" + "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go4.org/netipx" xmaps "golang.org/x/exp/maps" @@ -19,6 +22,83 @@ import ( "tailscale.com/types/ptr" ) +// TestUnmarshalPolicy tests the unmarshalling of JSON into Policy objects and the marshalling +// back to JSON (round-trip). It ensures that: +// 1. JSON can be correctly unmarshalled into a Policy object +// 2. A Policy object can be correctly marshalled back to JSON +// 3. The unmarshalled Policy matches the expected Policy +// 4. The marshalled and then unmarshalled Policy is semantically equivalent to the original +// (accounting for nil vs empty map/slice differences) +// +// This test also verifies that all the required struct fields are properly marshalled and +// unmarshalled, maintaining semantic equivalence through a complete JSON round-trip. + +// TestMarshalJSON tests explicit marshalling of Policy objects to JSON. +// This test ensures our custom MarshalJSON methods properly encode +// the various data structures used in the Policy. +func TestMarshalJSON(t *testing.T) { + // Create a complex test policy + policy := &Policy{ + Groups: Groups{ + Group("group:example"): []Username{Username("user@example.com")}, + }, + Hosts: Hosts{ + "host-1": Prefix(mp("100.100.100.100/32")), + }, + TagOwners: TagOwners{ + Tag("tag:test"): Owners{up("user@example.com")}, + }, + ACLs: []ACL{ + { + Action: "accept", + Protocol: "tcp", + Sources: Aliases{ + ptr.To(Username("user@example.com")), + }, + Destinations: []AliasWithPorts{ + { + Alias: ptr.To(Username("other@example.com")), + Ports: []tailcfg.PortRange{{First: 80, Last: 80}}, + }, + }, + }, + }, + } + + // Marshal the policy to JSON + marshalled, err := json.MarshalIndent(policy, "", " ") + require.NoError(t, err) + + // Make sure all expected fields are present in the JSON + jsonString := string(marshalled) + assert.Contains(t, jsonString, "group:example") + assert.Contains(t, jsonString, "user@example.com") + assert.Contains(t, jsonString, "host-1") + assert.Contains(t, jsonString, "100.100.100.100/32") + assert.Contains(t, jsonString, "tag:test") + assert.Contains(t, jsonString, "accept") + assert.Contains(t, jsonString, "tcp") + assert.Contains(t, jsonString, "80") + + // Unmarshal back to verify round trip + var roundTripped Policy + err = json.Unmarshal(marshalled, &roundTripped) + require.NoError(t, err) + + // Compare the original and round-tripped policies + cmps := append(util.Comparers, + cmp.Comparer(func(x, y Prefix) bool { + return x == y + }), + cmpopts.IgnoreUnexported(Policy{}), + cmpopts.EquateEmpty(), + ) + + if diff := cmp.Diff(policy, &roundTripped, cmps...); diff != "" { + t.Fatalf("round trip policy (-original +roundtripped):\n%s", diff) + } +} + func TestUnmarshalPolicy(t *testing.T) { tests := []struct { name string @@ -511,6 +591,138 @@ func TestUnmarshalPolicy(t *testing.T) { `, wantErr: `"autogroup:internet" used in SSH destination, it can only be used in ACL destinations`, }, + { + name: "ssh-basic", + input: ` +{ + "groups": { + "group:admins": ["admin@example.com"] + }, + "tagOwners": { + "tag:servers": ["group:admins"] + }, + "ssh": [ + { + "action": "accept", + "src": [ + "group:admins" + ], + "dst": [ + "tag:servers" + ], + "users": ["root", "admin"] + } + ] +} +`, + want: &Policy{ + Groups: Groups{ + Group("group:admins"): []Username{Username("admin@example.com")}, + }, + TagOwners: TagOwners{ + Tag("tag:servers"): Owners{gp("group:admins")}, + }, + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{ + gp("group:admins"), + }, + Destinations: SSHDstAliases{ + tp("tag:servers"), + }, + Users: []SSHUser{ + SSHUser("root"), + SSHUser("admin"), + }, + }, + }, + }, + }, + { + name: "ssh-with-tag-and-user", + input: ` +{ + "tagOwners": { + "tag:web": ["admin@example.com"] + }, + "ssh": [ + { + "action": "accept", + "src": [ + "tag:web" + ], + "dst": [ + "admin@example.com" + ], + "users": ["*"] + } + ] +} +`, + want: &Policy{ + TagOwners: TagOwners{ + Tag("tag:web"): Owners{ptr.To(Username("admin@example.com"))}, + }, + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{ + tp("tag:web"), + }, + Destinations: SSHDstAliases{ + ptr.To(Username("admin@example.com")), + }, + Users: []SSHUser{ + SSHUser("*"), + }, + }, + }, + }, + }, + { + name: "ssh-with-check-period", + input: ` +{ + "groups": { + "group:admins": ["admin@example.com"] + }, + "ssh": [ + { + "action": "accept", + "src": [ + "group:admins" + ], + "dst": [ + "admin@example.com" + ], + "users": ["root"], + "checkPeriod": "24h" + } + ] +} +`, + want: &Policy{ + Groups: Groups{ + Group("group:admins"): []Username{Username("admin@example.com")}, + }, + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{ + gp("group:admins"), + }, + Destinations: SSHDstAliases{ + ptr.To(Username("admin@example.com")), + }, + Users: []SSHUser{ + SSHUser("root"), + }, + CheckPeriod: model.Duration(24 * time.Hour), + }, + }, + }, + }, { name: "group-must-be-defined-acl-src", input: ` @@ -746,29 +958,61 @@ func TestUnmarshalPolicy(t *testing.T) { }, } - cmps := append(util.Comparers, cmp.Comparer(func(x, y Prefix) bool { - return x == y - })) - cmps = append(cmps, cmpopts.IgnoreUnexported(Policy{})) + cmps := append(util.Comparers, + cmp.Comparer(func(x, y Prefix) bool { + return x == y + }), + cmpopts.IgnoreUnexported(Policy{}), + ) + + // For round-trip testing, we'll normalize the policies before comparing for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + // Test unmarshalling policy, err := unmarshalPolicy([]byte(tt.input)) if tt.wantErr == "" { if err != nil { - t.Fatalf("got %v; want no error", err) + t.Fatalf("unmarshalling: got %v; want no error", err) } } else { if err == nil { - t.Fatalf("got nil; want error %q", tt.wantErr) + t.Fatalf("unmarshalling: got nil; want error %q", tt.wantErr) } else if !strings.Contains(err.Error(), tt.wantErr) { - t.Fatalf("got err %v; want error %q", err, tt.wantErr) + t.Fatalf("unmarshalling: got err %v; want error %q", err, tt.wantErr) } + return // Skip the rest of the test if we expected an error } if diff := cmp.Diff(tt.want, policy, cmps...); diff != "" { t.Fatalf("unexpected policy (-want +got):\n%s", diff) } + + // Test round-trip marshalling/unmarshalling + if policy != nil { + // Marshal the policy back to JSON + marshalled, err := json.MarshalIndent(policy, "", " ") + if err != nil { + t.Fatalf("marshalling: %v", err) + } + + // Unmarshal it again + roundTripped, err := unmarshalPolicy(marshalled) + if err != nil { + t.Fatalf("round-trip unmarshalling: %v", err) + } + + // Add EquateEmpty to handle nil vs empty maps/slices + roundTripCmps := append(cmps, + cmpopts.EquateEmpty(), + cmpopts.IgnoreUnexported(Policy{}), + ) + + // Compare using the enhanced comparers for round-trip testing + if diff := cmp.Diff(policy, roundTripped, roundTripCmps...); diff != "" { + t.Fatalf("round trip policy (-original +roundtripped):\n%s", diff) + } + } }) } } diff --git a/integration/acl_test.go b/integration/acl_test.go index 116f298d..193b6669 100644 --- a/integration/acl_test.go +++ b/integration/acl_test.go @@ -7,50 +7,53 @@ import ( "testing" "github.com/google/go-cmp/cmp" - policyv1 "github.com/juanfont/headscale/hscontrol/policy/v1" + "github.com/google/go-cmp/cmp/cmpopts" + policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "tailscale.com/tailcfg" + "tailscale.com/types/ptr" ) -var veryLargeDestination = []string{ - "0.0.0.0/5:*", - "8.0.0.0/7:*", - "11.0.0.0/8:*", - "12.0.0.0/6:*", - "16.0.0.0/4:*", - "32.0.0.0/3:*", - "64.0.0.0/2:*", - "128.0.0.0/3:*", - "160.0.0.0/5:*", - "168.0.0.0/6:*", - "172.0.0.0/12:*", - "172.32.0.0/11:*", - "172.64.0.0/10:*", - "172.128.0.0/9:*", - "173.0.0.0/8:*", - "174.0.0.0/7:*", - "176.0.0.0/4:*", - "192.0.0.0/9:*", - "192.128.0.0/11:*", - "192.160.0.0/13:*", - "192.169.0.0/16:*", - "192.170.0.0/15:*", - "192.172.0.0/14:*", - "192.176.0.0/12:*", - "192.192.0.0/10:*", - "193.0.0.0/8:*", - "194.0.0.0/7:*", - "196.0.0.0/6:*", - "200.0.0.0/5:*", - "208.0.0.0/4:*", +var veryLargeDestination = []policyv2.AliasWithPorts{ + aliasWithPorts(prefixp("0.0.0.0/5"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("8.0.0.0/7"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("11.0.0.0/8"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("12.0.0.0/6"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("16.0.0.0/4"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("32.0.0.0/3"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("64.0.0.0/2"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("128.0.0.0/3"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("160.0.0.0/5"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("168.0.0.0/6"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("172.0.0.0/12"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("172.32.0.0/11"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("172.64.0.0/10"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("172.128.0.0/9"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("173.0.0.0/8"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("174.0.0.0/7"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("176.0.0.0/4"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("192.0.0.0/9"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("192.128.0.0/11"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("192.160.0.0/13"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("192.169.0.0/16"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("192.170.0.0/15"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("192.172.0.0/14"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("192.176.0.0/12"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("192.192.0.0/10"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("193.0.0.0/8"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("194.0.0.0/7"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("196.0.0.0/6"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("200.0.0.0/5"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("208.0.0.0/4"), tailcfg.PortRangeAny), } func aclScenario( t *testing.T, - policy *policyv1.ACLPolicy, + policy *policyv2.Policy, clientsPerUser int, ) *Scenario { t.Helper() @@ -108,19 +111,21 @@ func TestACLHostsInNetMapTable(t *testing.T) { // they can access minus one (them self). tests := map[string]struct { users ScenarioSpec - policy policyv1.ACLPolicy + policy policyv2.Policy want map[string]int }{ // Test that when we have no ACL, each client netmap has // the amount of peers of the total amount of clients "base-acls": { users: spec, - policy: policyv1.ACLPolicy{ - ACLs: []policyv1.ACL{ + policy: policyv2.Policy{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"*:*"}, + Action: "accept", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, }, }, }, want: map[string]int{ @@ -133,17 +138,21 @@ func TestACLHostsInNetMapTable(t *testing.T) { // their own user. "two-isolated-users": { users: spec, - policy: policyv1.ACLPolicy{ - ACLs: []policyv1.ACL{ + policy: policyv2.Policy{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"user1@"}, - Destinations: []string{"user1@:*"}, + Action: "accept", + Sources: []policyv2.Alias{usernamep("user1@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user1@"), tailcfg.PortRangeAny), + }, }, { - Action: "accept", - Sources: []string{"user2@"}, - Destinations: []string{"user2@:*"}, + Action: "accept", + Sources: []policyv2.Alias{usernamep("user2@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user2@"), tailcfg.PortRangeAny), + }, }, }, }, want: map[string]int{ @@ -156,27 +165,35 @@ func TestACLHostsInNetMapTable(t *testing.T) { // in the netmap. "two-restricted-present-in-netmap": { users: spec, - policy: policyv1.ACLPolicy{ - ACLs: []policyv1.ACL{ + policy: policyv2.Policy{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"user1@"}, - Destinations: []string{"user1@:22"}, + Action: "accept", + Sources: []policyv2.Alias{usernamep("user1@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user1@"), tailcfg.PortRange{First: 22, Last: 22}), + }, }, { - Action: "accept", - Sources: []string{"user2@"}, - Destinations: []string{"user2@:22"}, + Action: "accept", + Sources: []policyv2.Alias{usernamep("user2@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user2@"), tailcfg.PortRange{First: 22, Last: 22}), + }, }, { - Action: "accept", - Sources: []string{"user1@"}, - Destinations: []string{"user2@:22"}, + Action: "accept", + Sources: []policyv2.Alias{usernamep("user1@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user2@"), tailcfg.PortRange{First: 22, Last: 22}), + }, }, { - Action: "accept", - Sources: []string{"user2@"}, - Destinations: []string{"user1@:22"}, + Action: "accept", + Sources: []policyv2.Alias{usernamep("user2@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user1@"), tailcfg.PortRange{First: 22, Last: 22}), + }, }, }, }, want: map[string]int{ @@ -190,22 +207,28 @@ func TestACLHostsInNetMapTable(t *testing.T) { // need them present on the other side for the "return path". "two-ns-one-isolated": { users: spec, - policy: policyv1.ACLPolicy{ - ACLs: []policyv1.ACL{ + policy: policyv2.Policy{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"user1@"}, - Destinations: []string{"user1@:*"}, + Action: "accept", + Sources: []policyv2.Alias{usernamep("user1@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user1@"), tailcfg.PortRangeAny), + }, }, { - Action: "accept", - Sources: []string{"user2@"}, - Destinations: []string{"user2@:*"}, + Action: "accept", + Sources: []policyv2.Alias{usernamep("user2@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user2@"), tailcfg.PortRangeAny), + }, }, { - Action: "accept", - Sources: []string{"user1@"}, - Destinations: []string{"user2@:*"}, + Action: "accept", + Sources: []policyv2.Alias{usernamep("user1@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user2@"), tailcfg.PortRangeAny), + }, }, }, }, want: map[string]int{ @@ -215,22 +238,37 @@ func TestACLHostsInNetMapTable(t *testing.T) { }, "very-large-destination-prefix-1372": { users: spec, - policy: policyv1.ACLPolicy{ - ACLs: []policyv1.ACL{ + policy: policyv2.Policy{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"user1@"}, - Destinations: append([]string{"user1@:*"}, veryLargeDestination...), + Action: "accept", + Sources: []policyv2.Alias{usernamep("user1@")}, + Destinations: append( + []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user1@"), tailcfg.PortRangeAny), + }, + veryLargeDestination..., + ), }, { - Action: "accept", - Sources: []string{"user2@"}, - Destinations: append([]string{"user2@:*"}, veryLargeDestination...), + Action: "accept", + Sources: []policyv2.Alias{usernamep("user2@")}, + Destinations: append( + []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user2@"), tailcfg.PortRangeAny), + }, + veryLargeDestination..., + ), }, { - Action: "accept", - Sources: []string{"user1@"}, - Destinations: append([]string{"user2@:*"}, veryLargeDestination...), + Action: "accept", + Sources: []policyv2.Alias{usernamep("user1@")}, + Destinations: append( + []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user2@"), tailcfg.PortRangeAny), + }, + veryLargeDestination..., + ), }, }, }, want: map[string]int{ @@ -240,12 +278,15 @@ func TestACLHostsInNetMapTable(t *testing.T) { }, "ipv6-acls-1470": { users: spec, - policy: policyv1.ACLPolicy{ - ACLs: []policyv1.ACL{ + policy: policyv2.Policy{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"0.0.0.0/0:*", "::/0:*"}, + Action: "accept", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(prefixp("0.0.0.0/0"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("::/0"), tailcfg.PortRangeAny), + }, }, }, }, want: map[string]int{ @@ -295,12 +336,14 @@ func TestACLAllowUser80Dst(t *testing.T) { IntegrationSkip(t) scenario := aclScenario(t, - &policyv1.ACLPolicy{ - ACLs: []policyv1.ACL{ + &policyv2.Policy{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"user1@"}, - Destinations: []string{"user2@:80"}, + Action: "accept", + Sources: []policyv2.Alias{usernamep("user1@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user2@"), tailcfg.PortRange{First: 80, Last: 80}), + }, }, }, }, @@ -349,15 +392,17 @@ func TestACLDenyAllPort80(t *testing.T) { IntegrationSkip(t) scenario := aclScenario(t, - &policyv1.ACLPolicy{ - Groups: map[string][]string{ - "group:integration-acl-test": {"user1@", "user2@"}, + &policyv2.Policy{ + Groups: policyv2.Groups{ + policyv2.Group("group:integration-acl-test"): []policyv2.Username{policyv2.Username("user1@"), policyv2.Username("user2@")}, }, - ACLs: []policyv1.ACL{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"group:integration-acl-test"}, - Destinations: []string{"*:22"}, + Action: "accept", + Sources: []policyv2.Alias{groupp("group:integration-acl-test")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRange{First: 22, Last: 22}), + }, }, }, }, @@ -396,12 +441,14 @@ func TestACLAllowUserDst(t *testing.T) { IntegrationSkip(t) scenario := aclScenario(t, - &policyv1.ACLPolicy{ - ACLs: []policyv1.ACL{ + &policyv2.Policy{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"user1@"}, - Destinations: []string{"user2@:*"}, + Action: "accept", + Sources: []policyv2.Alias{usernamep("user1@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user2@"), tailcfg.PortRangeAny), + }, }, }, }, @@ -452,12 +499,14 @@ func TestACLAllowStarDst(t *testing.T) { IntegrationSkip(t) scenario := aclScenario(t, - &policyv1.ACLPolicy{ - ACLs: []policyv1.ACL{ + &policyv2.Policy{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"user1@"}, - Destinations: []string{"*:*"}, + Action: "accept", + Sources: []policyv2.Alias{usernamep("user1@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, }, }, }, @@ -509,16 +558,18 @@ func TestACLNamedHostsCanReachBySubnet(t *testing.T) { IntegrationSkip(t) scenario := aclScenario(t, - &policyv1.ACLPolicy{ - Hosts: policyv1.Hosts{ - "all": netip.MustParsePrefix("100.64.0.0/24"), + &policyv2.Policy{ + Hosts: policyv2.Hosts{ + "all": policyv2.Prefix(netip.MustParsePrefix("100.64.0.0/24")), }, - ACLs: []policyv1.ACL{ + ACLs: []policyv2.ACL{ // Everyone can curl test3 { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"all:*"}, + Action: "accept", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(hostp("all"), tailcfg.PortRangeAny), + }, }, }, }, @@ -606,50 +657,58 @@ func TestACLNamedHostsCanReach(t *testing.T) { IntegrationSkip(t) tests := map[string]struct { - policy policyv1.ACLPolicy + policy policyv2.Policy }{ "ipv4": { - policy: policyv1.ACLPolicy{ - Hosts: policyv1.Hosts{ - "test1": netip.MustParsePrefix("100.64.0.1/32"), - "test2": netip.MustParsePrefix("100.64.0.2/32"), - "test3": netip.MustParsePrefix("100.64.0.3/32"), + policy: policyv2.Policy{ + Hosts: policyv2.Hosts{ + "test1": policyv2.Prefix(netip.MustParsePrefix("100.64.0.1/32")), + "test2": policyv2.Prefix(netip.MustParsePrefix("100.64.0.2/32")), + "test3": policyv2.Prefix(netip.MustParsePrefix("100.64.0.3/32")), }, - ACLs: []policyv1.ACL{ + ACLs: []policyv2.ACL{ // Everyone can curl test3 { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"test3:*"}, + Action: "accept", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(hostp("test3"), tailcfg.PortRangeAny), + }, }, // test1 can curl test2 { - Action: "accept", - Sources: []string{"test1"}, - Destinations: []string{"test2:*"}, + Action: "accept", + Sources: []policyv2.Alias{hostp("test1")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(hostp("test2"), tailcfg.PortRangeAny), + }, }, }, }, }, "ipv6": { - policy: policyv1.ACLPolicy{ - Hosts: policyv1.Hosts{ - "test1": netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), - "test2": netip.MustParsePrefix("fd7a:115c:a1e0::2/128"), - "test3": netip.MustParsePrefix("fd7a:115c:a1e0::3/128"), + policy: policyv2.Policy{ + Hosts: policyv2.Hosts{ + "test1": policyv2.Prefix(netip.MustParsePrefix("fd7a:115c:a1e0::1/128")), + "test2": policyv2.Prefix(netip.MustParsePrefix("fd7a:115c:a1e0::2/128")), + "test3": policyv2.Prefix(netip.MustParsePrefix("fd7a:115c:a1e0::3/128")), }, - ACLs: []policyv1.ACL{ + ACLs: []policyv2.ACL{ // Everyone can curl test3 { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"test3:*"}, + Action: "accept", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(hostp("test3"), tailcfg.PortRangeAny), + }, }, // test1 can curl test2 { - Action: "accept", - Sources: []string{"test1"}, - Destinations: []string{"test2:*"}, + Action: "accept", + Sources: []policyv2.Alias{hostp("test1")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(hostp("test2"), tailcfg.PortRangeAny), + }, }, }, }, @@ -855,71 +914,81 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) { IntegrationSkip(t) tests := map[string]struct { - policy policyv1.ACLPolicy + policy policyv2.Policy }{ "ipv4": { - policy: policyv1.ACLPolicy{ - ACLs: []policyv1.ACL{ + policy: policyv2.Policy{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"100.64.0.1"}, - Destinations: []string{"100.64.0.2:*"}, + Action: "accept", + Sources: []policyv2.Alias{prefixp("100.64.0.1/32")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(prefixp("100.64.0.2/32"), tailcfg.PortRangeAny), + }, }, }, }, }, "ipv6": { - policy: policyv1.ACLPolicy{ - ACLs: []policyv1.ACL{ + policy: policyv2.Policy{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"fd7a:115c:a1e0::1"}, - Destinations: []string{"fd7a:115c:a1e0::2:*"}, + Action: "accept", + Sources: []policyv2.Alias{prefixp("fd7a:115c:a1e0::1/128")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(prefixp("fd7a:115c:a1e0::2/128"), tailcfg.PortRangeAny), + }, }, }, }, }, "hostv4cidr": { - policy: policyv1.ACLPolicy{ - Hosts: policyv1.Hosts{ - "test1": netip.MustParsePrefix("100.64.0.1/32"), - "test2": netip.MustParsePrefix("100.64.0.2/32"), + policy: policyv2.Policy{ + Hosts: policyv2.Hosts{ + "test1": policyv2.Prefix(netip.MustParsePrefix("100.64.0.1/32")), + "test2": policyv2.Prefix(netip.MustParsePrefix("100.64.0.2/32")), }, - ACLs: []policyv1.ACL{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"test1"}, - Destinations: []string{"test2:*"}, + Action: "accept", + Sources: []policyv2.Alias{hostp("test1")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(hostp("test2"), tailcfg.PortRangeAny), + }, }, }, }, }, "hostv6cidr": { - policy: policyv1.ACLPolicy{ - Hosts: policyv1.Hosts{ - "test1": netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), - "test2": netip.MustParsePrefix("fd7a:115c:a1e0::2/128"), + policy: policyv2.Policy{ + Hosts: policyv2.Hosts{ + "test1": policyv2.Prefix(netip.MustParsePrefix("fd7a:115c:a1e0::1/128")), + "test2": policyv2.Prefix(netip.MustParsePrefix("fd7a:115c:a1e0::2/128")), }, - ACLs: []policyv1.ACL{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"test1"}, - Destinations: []string{"test2:*"}, + Action: "accept", + Sources: []policyv2.Alias{hostp("test1")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(hostp("test2"), tailcfg.PortRangeAny), + }, }, }, }, }, "group": { - policy: policyv1.ACLPolicy{ - Groups: map[string][]string{ - "group:one": {"user1@"}, - "group:two": {"user2@"}, + policy: policyv2.Policy{ + Groups: policyv2.Groups{ + policyv2.Group("group:one"): []policyv2.Username{policyv2.Username("user1@")}, + policyv2.Group("group:two"): []policyv2.Username{policyv2.Username("user2@")}, }, - ACLs: []policyv1.ACL{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"group:one"}, - Destinations: []string{"group:two:*"}, + Action: "accept", + Sources: []policyv2.Alias{groupp("group:one")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(groupp("group:two"), tailcfg.PortRangeAny), + }, }, }, }, @@ -1073,15 +1142,17 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) { headscale, err := scenario.Headscale() require.NoError(t, err) - p := policyv1.ACLPolicy{ - ACLs: []policyv1.ACL{ + p := policyv2.Policy{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"user1@"}, - Destinations: []string{"user2@:*"}, + Action: "accept", + Sources: []policyv2.Alias{usernamep("user1@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user2@"), tailcfg.PortRangeAny), + }, }, }, - Hosts: policyv1.Hosts{}, + Hosts: policyv2.Hosts{}, } err = headscale.SetPolicy(&p) @@ -1089,7 +1160,7 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) { // Get the current policy and check // if it is the same as the one we set. - var output *policyv1.ACLPolicy + var output *policyv2.Policy err = executeAndUnmarshal( headscale, []string{ @@ -1105,7 +1176,7 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) { assert.Len(t, output.ACLs, 1) - if diff := cmp.Diff(p, *output); diff != "" { + if diff := cmp.Diff(p, *output, cmpopts.IgnoreUnexported(policyv2.Policy{}), cmpopts.EquateEmpty()); diff != "" { t.Errorf("unexpected policy(-want +got):\n%s", diff) } @@ -1145,12 +1216,14 @@ func TestACLAutogroupMember(t *testing.T) { t.Parallel() scenario := aclScenario(t, - &policyv1.ACLPolicy{ - ACLs: []policyv1.ACL{ + &policyv2.Policy{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"autogroup:member"}, - Destinations: []string{"autogroup:member:*"}, + Action: "accept", + Sources: []policyv2.Alias{ptr.To(policyv2.AutoGroupMember)}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(ptr.To(policyv2.AutoGroupMember), tailcfg.PortRangeAny), + }, }, }, }, @@ -1201,15 +1274,18 @@ func TestACLAutogroupTagged(t *testing.T) { t.Parallel() scenario := aclScenario(t, - &policyv1.ACLPolicy{ - ACLs: []policyv1.ACL{ + &policyv2.Policy{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"autogroup:tagged"}, - Destinations: []string{"autogroup:tagged:*"}, + Action: "accept", + Sources: []policyv2.Alias{ptr.To(policyv2.AutoGroupTagged)}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(ptr.To(policyv2.AutoGroupTagged), tailcfg.PortRangeAny), + }, }, }, }, + 2, ) defer scenario.ShutdownAssertNoPanics(t) diff --git a/integration/cli_test.go b/integration/cli_test.go index 435b7e55..2cff0500 100644 --- a/integration/cli_test.go +++ b/integration/cli_test.go @@ -12,12 +12,13 @@ import ( tcmp "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" - policyv1 "github.com/juanfont/headscale/hscontrol/policy/v1" + policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "tailscale.com/tailcfg" "golang.org/x/exp/slices" ) @@ -912,13 +913,15 @@ func TestNodeTagCommand(t *testing.T) { ) } + + func TestNodeAdvertiseTagCommand(t *testing.T) { IntegrationSkip(t) t.Parallel() tests := []struct { name string - policy *policyv1.ACLPolicy + policy *policyv2.Policy wantTag bool }{ { @@ -927,51 +930,60 @@ func TestNodeAdvertiseTagCommand(t *testing.T) { }, { name: "with-policy-email", - policy: &policyv1.ACLPolicy{ - ACLs: []policyv1.ACL{ + policy: &policyv2.Policy{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"*:*"}, + Action: "accept", + Protocol: "tcp", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, }, }, - TagOwners: map[string][]string{ - "tag:test": {"user1@test.no"}, + TagOwners: policyv2.TagOwners{ + policyv2.Tag("tag:test"): policyv2.Owners{usernameOwner("user1@test.no")}, }, }, wantTag: true, }, { name: "with-policy-username", - policy: &policyv1.ACLPolicy{ - ACLs: []policyv1.ACL{ + policy: &policyv2.Policy{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"*:*"}, + Action: "accept", + Protocol: "tcp", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, }, }, - TagOwners: map[string][]string{ - "tag:test": {"user1@"}, + TagOwners: policyv2.TagOwners{ + policyv2.Tag("tag:test"): policyv2.Owners{usernameOwner("user1@")}, }, }, wantTag: true, }, { name: "with-policy-groups", - policy: &policyv1.ACLPolicy{ - Groups: policyv1.Groups{ - "group:admins": []string{"user1@"}, + policy: &policyv2.Policy{ + Groups: policyv2.Groups{ + policyv2.Group("group:admins"): []policyv2.Username{policyv2.Username("user1@")}, }, - ACLs: []policyv1.ACL{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"*:*"}, + Action: "accept", + Protocol: "tcp", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, }, }, - TagOwners: map[string][]string{ - "tag:test": {"group:admins"}, + TagOwners: policyv2.TagOwners{ + policyv2.Tag("tag:test"): policyv2.Owners{groupOwner("group:admins")}, }, }, wantTag: true, @@ -1746,16 +1758,19 @@ func TestPolicyCommand(t *testing.T) { headscale, err := scenario.Headscale() assertNoErr(t, err) - p := policyv1.ACLPolicy{ - ACLs: []policyv1.ACL{ + p := policyv2.Policy{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"*:*"}, + Action: "accept", + Protocol: "tcp", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, }, }, - TagOwners: map[string][]string{ - "tag:exists": {"user1@"}, + TagOwners: policyv2.TagOwners{ + policyv2.Tag("tag:exists"): policyv2.Owners{usernameOwner("user1@")}, }, } @@ -1782,7 +1797,7 @@ func TestPolicyCommand(t *testing.T) { // Get the current policy and check // if it is the same as the one we set. - var output *policyv1.ACLPolicy + var output *policyv2.Policy err = executeAndUnmarshal( headscale, []string{ @@ -1825,18 +1840,21 @@ func TestPolicyBrokenConfigCommand(t *testing.T) { headscale, err := scenario.Headscale() assertNoErr(t, err) - p := policyv1.ACLPolicy{ - ACLs: []policyv1.ACL{ + p := policyv2.Policy{ + ACLs: []policyv2.ACL{ { // This is an unknown action, so it will return an error // and the config will not be applied. - Action: "unknown-action", - Sources: []string{"*"}, - Destinations: []string{"*:*"}, + Action: "unknown-action", + Protocol: "tcp", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, }, }, - TagOwners: map[string][]string{ - "tag:exists": {"user1@"}, + TagOwners: policyv2.TagOwners{ + policyv2.Tag("tag:exists"): policyv2.Owners{usernameOwner("user1@")}, }, } diff --git a/integration/control.go b/integration/control.go index 22e7552b..df1d5d13 100644 --- a/integration/control.go +++ b/integration/control.go @@ -4,7 +4,7 @@ import ( "net/netip" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" - policyv1 "github.com/juanfont/headscale/hscontrol/policy/v1" + policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" "github.com/ory/dockertest/v3" ) @@ -28,5 +28,5 @@ type ControlServer interface { ApproveRoutes(uint64, []netip.Prefix) (*v1.Node, error) GetCert() []byte GetHostname() string - SetPolicy(*policyv1.ACLPolicy) error + SetPolicy(*policyv2.Policy) error } diff --git a/integration/hsic/hsic.go b/integration/hsic/hsic.go index e6762cf0..35550c65 100644 --- a/integration/hsic/hsic.go +++ b/integration/hsic/hsic.go @@ -19,7 +19,7 @@ import ( "github.com/davecgh/go-spew/spew" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" - policyv1 "github.com/juanfont/headscale/hscontrol/policy/v1" + policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/integration/dockertestutil" @@ -65,7 +65,7 @@ type HeadscaleInContainer struct { extraPorts []string caCerts [][]byte hostPortBindings map[string][]string - aclPolicy *policyv1.ACLPolicy + aclPolicy *policyv2.Policy env map[string]string tlsCert []byte tlsKey []byte @@ -80,7 +80,7 @@ type Option = func(c *HeadscaleInContainer) // WithACLPolicy adds a hscontrol.ACLPolicy policy to the // HeadscaleInContainer instance. -func WithACLPolicy(acl *policyv1.ACLPolicy) Option { +func WithACLPolicy(acl *policyv2.Policy) Option { return func(hsic *HeadscaleInContainer) { if acl == nil { return @@ -188,13 +188,6 @@ func WithPostgres() Option { } } -// WithPolicyV1 tells the integration test to use the old v1 filter. -func WithPolicyV1() Option { - return func(hsic *HeadscaleInContainer) { - hsic.env["HEADSCALE_POLICY_V1"] = "1" - } -} - // WithPolicy sets the policy mode for headscale func WithPolicyMode(mode types.PolicyMode) Option { return func(hsic *HeadscaleInContainer) { @@ -889,7 +882,7 @@ func (t *HeadscaleInContainer) MapUsers() (map[string]*v1.User, error) { return userMap, nil } -func (h *HeadscaleInContainer) SetPolicy(pol *policyv1.ACLPolicy) error { +func (h *HeadscaleInContainer) SetPolicy(pol *policyv2.Policy) error { err := h.writePolicy(pol) if err != nil { return fmt.Errorf("writing policy file: %w", err) @@ -930,7 +923,7 @@ func (h *HeadscaleInContainer) reloadDatabasePolicy() error { return nil } -func (h *HeadscaleInContainer) writePolicy(pol *policyv1.ACLPolicy) error { +func (h *HeadscaleInContainer) writePolicy(pol *policyv2.Policy) error { pBytes, err := json.Marshal(pol) if err != nil { return fmt.Errorf("marshalling pol: %w", err) diff --git a/integration/route_test.go b/integration/route_test.go index 5a85f436..053b4582 100644 --- a/integration/route_test.go +++ b/integration/route_test.go @@ -5,6 +5,7 @@ import ( "fmt" "net/netip" "sort" + "strings" "testing" "time" @@ -13,7 +14,7 @@ import ( cmpdiff "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" - policyv1 "github.com/juanfont/headscale/hscontrol/policy/v1" + policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/integration/hsic" @@ -22,6 +23,7 @@ import ( "github.com/stretchr/testify/require" "tailscale.com/ipn/ipnstate" "tailscale.com/net/tsaddr" + "tailscale.com/tailcfg" "tailscale.com/types/ipproto" "tailscale.com/types/views" "tailscale.com/util/must" @@ -793,26 +795,25 @@ func TestSubnetRouteACL(t *testing.T) { err = scenario.CreateHeadscaleEnv([]tsic.Option{ tsic.WithAcceptRoutes(), }, hsic.WithTestName("clienableroute"), hsic.WithACLPolicy( - &policyv1.ACLPolicy{ - Groups: policyv1.Groups{ - "group:admins": {user + "@"}, + &policyv2.Policy{ + Groups: policyv2.Groups{ + policyv2.Group("group:admins"): []policyv2.Username{policyv2.Username(user + "@")}, }, - ACLs: []policyv1.ACL{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"group:admins"}, - Destinations: []string{"group:admins:*"}, + Action: "accept", + Sources: []policyv2.Alias{groupp("group:admins")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(groupp("group:admins"), tailcfg.PortRangeAny), + }, }, { - Action: "accept", - Sources: []string{"group:admins"}, - Destinations: []string{"10.33.0.0/16:*"}, + Action: "accept", + Sources: []policyv2.Alias{groupp("group:admins")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(prefixp("10.33.0.0/16"), tailcfg.PortRangeAny), + }, }, - // { - // Action: "accept", - // Sources: []string{"group:admins"}, - // Destinations: []string{"0.0.0.0/0:*"}, - // }, }, }, )) @@ -1384,29 +1385,31 @@ func TestAutoApproveMultiNetwork(t *testing.T) { tests := []struct { name string - pol *policyv1.ACLPolicy + pol *policyv2.Policy approver string spec ScenarioSpec withURL bool }{ { name: "authkey-tag", - pol: &policyv1.ACLPolicy{ - ACLs: []policyv1.ACL{ + pol: &policyv2.Policy{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"*:*"}, + Action: "accept", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, }, }, - TagOwners: map[string][]string{ - "tag:approve": {"user1@"}, + TagOwners: policyv2.TagOwners{ + policyv2.Tag("tag:approve"): policyv2.Owners{usernameOwner("user1@")}, }, - AutoApprovers: policyv1.AutoApprovers{ - Routes: map[string][]string{ - bigRoute.String(): {"tag:approve"}, + AutoApprovers: policyv2.AutoApproverPolicy{ + Routes: map[netip.Prefix]policyv2.AutoApprovers{ + bigRoute: {tagApprover("tag:approve")}, }, - ExitNode: []string{"tag:approve"}, + ExitNode: policyv2.AutoApprovers{tagApprover("tag:approve")}, }, }, approver: "tag:approve", @@ -1427,19 +1430,21 @@ func TestAutoApproveMultiNetwork(t *testing.T) { }, { name: "authkey-user", - pol: &policyv1.ACLPolicy{ - ACLs: []policyv1.ACL{ + pol: &policyv2.Policy{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"*:*"}, + Action: "accept", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, }, }, - AutoApprovers: policyv1.AutoApprovers{ - Routes: map[string][]string{ - bigRoute.String(): {"user1@"}, + AutoApprovers: policyv2.AutoApproverPolicy{ + Routes: map[netip.Prefix]policyv2.AutoApprovers{ + bigRoute: {usernameApprover("user1@")}, }, - ExitNode: []string{"user1@"}, + ExitNode: policyv2.AutoApprovers{usernameApprover("user1@")}, }, }, approver: "user1@", @@ -1460,22 +1465,24 @@ func TestAutoApproveMultiNetwork(t *testing.T) { }, { name: "authkey-group", - pol: &policyv1.ACLPolicy{ - ACLs: []policyv1.ACL{ + pol: &policyv2.Policy{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"*:*"}, + Action: "accept", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, }, }, - Groups: policyv1.Groups{ - "group:approve": []string{"user1@"}, + Groups: policyv2.Groups{ + policyv2.Group("group:approve"): []policyv2.Username{policyv2.Username("user1@")}, }, - AutoApprovers: policyv1.AutoApprovers{ - Routes: map[string][]string{ - bigRoute.String(): {"group:approve"}, + AutoApprovers: policyv2.AutoApproverPolicy{ + Routes: map[netip.Prefix]policyv2.AutoApprovers{ + bigRoute: {groupApprover("group:approve")}, }, - ExitNode: []string{"group:approve"}, + ExitNode: policyv2.AutoApprovers{groupApprover("group:approve")}, }, }, approver: "group:approve", @@ -1496,19 +1503,21 @@ func TestAutoApproveMultiNetwork(t *testing.T) { }, { name: "webauth-user", - pol: &policyv1.ACLPolicy{ - ACLs: []policyv1.ACL{ + pol: &policyv2.Policy{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"*:*"}, + Action: "accept", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, }, }, - AutoApprovers: policyv1.AutoApprovers{ - Routes: map[string][]string{ - bigRoute.String(): {"user1@"}, + AutoApprovers: policyv2.AutoApproverPolicy{ + Routes: map[netip.Prefix]policyv2.AutoApprovers{ + bigRoute: {usernameApprover("user1@")}, }, - ExitNode: []string{"user1@"}, + ExitNode: policyv2.AutoApprovers{usernameApprover("user1@")}, }, }, approver: "user1@", @@ -1530,22 +1539,24 @@ func TestAutoApproveMultiNetwork(t *testing.T) { }, { name: "webauth-tag", - pol: &policyv1.ACLPolicy{ - ACLs: []policyv1.ACL{ + pol: &policyv2.Policy{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"*:*"}, + Action: "accept", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, }, }, - TagOwners: map[string][]string{ - "tag:approve": {"user1@"}, + TagOwners: policyv2.TagOwners{ + policyv2.Tag("tag:approve"): policyv2.Owners{usernameOwner("user1@")}, }, - AutoApprovers: policyv1.AutoApprovers{ - Routes: map[string][]string{ - bigRoute.String(): {"tag:approve"}, + AutoApprovers: policyv2.AutoApproverPolicy{ + Routes: map[netip.Prefix]policyv2.AutoApprovers{ + bigRoute: {tagApprover("tag:approve")}, }, - ExitNode: []string{"tag:approve"}, + ExitNode: policyv2.AutoApprovers{tagApprover("tag:approve")}, }, }, approver: "tag:approve", @@ -1567,22 +1578,24 @@ func TestAutoApproveMultiNetwork(t *testing.T) { }, { name: "webauth-group", - pol: &policyv1.ACLPolicy{ - ACLs: []policyv1.ACL{ + pol: &policyv2.Policy{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"*:*"}, + Action: "accept", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, }, }, - Groups: policyv1.Groups{ - "group:approve": []string{"user1@"}, + Groups: policyv2.Groups{ + policyv2.Group("group:approve"): []policyv2.Username{policyv2.Username("user1@")}, }, - AutoApprovers: policyv1.AutoApprovers{ - Routes: map[string][]string{ - bigRoute.String(): {"group:approve"}, + AutoApprovers: policyv2.AutoApproverPolicy{ + Routes: map[netip.Prefix]policyv2.AutoApprovers{ + bigRoute: {groupApprover("group:approve")}, }, - ExitNode: []string{"group:approve"}, + ExitNode: policyv2.AutoApprovers{groupApprover("group:approve")}, }, }, approver: "group:approve", @@ -1657,7 +1670,20 @@ func TestAutoApproveMultiNetwork(t *testing.T) { assert.NotNil(t, headscale) // Set the route of usernet1 to be autoapproved - tt.pol.AutoApprovers.Routes[route.String()] = []string{tt.approver} + var approvers policyv2.AutoApprovers + switch { + case strings.HasPrefix(tt.approver, "tag:"): + approvers = append(approvers, tagApprover(tt.approver)) + case strings.HasPrefix(tt.approver, "group:"): + approvers = append(approvers, groupApprover(tt.approver)) + default: + approvers = append(approvers, usernameApprover(tt.approver)) + } + if tt.pol.AutoApprovers.Routes == nil { + tt.pol.AutoApprovers.Routes = make(map[netip.Prefix]policyv2.AutoApprovers) + } + prefix := *route + tt.pol.AutoApprovers.Routes[prefix] = approvers err = headscale.SetPolicy(tt.pol) require.NoError(t, err) @@ -1767,7 +1793,8 @@ func TestAutoApproveMultiNetwork(t *testing.T) { assertTracerouteViaIP(t, tr, routerUsernet1.MustIPv4()) // Remove the auto approval from the policy, any routes already enabled should be allowed. - delete(tt.pol.AutoApprovers.Routes, route.String()) + prefix = *route + delete(tt.pol.AutoApprovers.Routes, prefix) err = headscale.SetPolicy(tt.pol) require.NoError(t, err) @@ -1831,7 +1858,20 @@ func TestAutoApproveMultiNetwork(t *testing.T) { // Add the route back to the auto approver in the policy, the route should // now become available again. - tt.pol.AutoApprovers.Routes[route.String()] = []string{tt.approver} + var newApprovers policyv2.AutoApprovers + switch { + case strings.HasPrefix(tt.approver, "tag:"): + newApprovers = append(newApprovers, tagApprover(tt.approver)) + case strings.HasPrefix(tt.approver, "group:"): + newApprovers = append(newApprovers, groupApprover(tt.approver)) + default: + newApprovers = append(newApprovers, usernameApprover(tt.approver)) + } + if tt.pol.AutoApprovers.Routes == nil { + tt.pol.AutoApprovers.Routes = make(map[netip.Prefix]policyv2.AutoApprovers) + } + prefix = *route + tt.pol.AutoApprovers.Routes[prefix] = newApprovers err = headscale.SetPolicy(tt.pol) require.NoError(t, err) @@ -2070,7 +2110,9 @@ func TestSubnetRouteACLFiltering(t *testing.T) { "src": [ "node" ], - "dst": [] + "dst": [ + "*:*" + ] } ] }`) @@ -2090,8 +2132,7 @@ func TestSubnetRouteACLFiltering(t *testing.T) { weburl := fmt.Sprintf("http://%s/etc/hostname", webip) t.Logf("webservice: %s, %s", webip.String(), weburl) - // Create ACL policy - aclPolicy := &policyv1.ACLPolicy{} + aclPolicy := &policyv2.Policy{} err = json.Unmarshal([]byte(aclPolicyStr), aclPolicy) require.NoError(t, err) @@ -2121,24 +2162,23 @@ func TestSubnetRouteACLFiltering(t *testing.T) { routerClient := allClients[0] nodeClient := allClients[1] - aclPolicy.Hosts = policyv1.Hosts{ - routerUser: must.Get(routerClient.MustIPv4().Prefix(32)), - nodeUser: must.Get(nodeClient.MustIPv4().Prefix(32)), + aclPolicy.Hosts = policyv2.Hosts{ + policyv2.Host(routerUser): policyv2.Prefix(must.Get(routerClient.MustIPv4().Prefix(32))), + policyv2.Host(nodeUser): policyv2.Prefix(must.Get(nodeClient.MustIPv4().Prefix(32))), } - aclPolicy.ACLs[1].Destinations = []string{ - route.String() + ":*", + aclPolicy.ACLs[1].Destinations = []policyv2.AliasWithPorts{ + aliasWithPorts(prefixp(route.String()), tailcfg.PortRangeAny), } - require.NoError(t, headscale.SetPolicy(aclPolicy)) // Set up the subnet routes for the router - routes := []string{ - route.String(), // This should be accessible by the client - "10.10.11.0/24", // These should NOT be accessible - "10.10.12.0/24", + routes := []netip.Prefix{ + *route, // This should be accessible by the client + netip.MustParsePrefix("10.10.11.0/24"), // These should NOT be accessible + netip.MustParsePrefix("10.10.12.0/24"), } - routeArg := "--advertise-routes=" + routes[0] + "," + routes[1] + "," + routes[2] + routeArg := "--advertise-routes=" + routes[0].String() + "," + routes[1].String() + "," + routes[2].String() command := []string{ "tailscale", "set", @@ -2208,5 +2248,4 @@ func TestSubnetRouteACLFiltering(t *testing.T) { tr, err := nodeClient.Traceroute(webip) require.NoError(t, err) assertTracerouteViaIP(t, tr, routerClient.MustIPv4()) - } diff --git a/integration/scenario.go b/integration/scenario.go index 7d4d62d1..507c248d 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -47,7 +47,6 @@ const ( ) var usePostgresForTest = envknob.Bool("HEADSCALE_INTEGRATION_POSTGRES") -var usePolicyV1ForTest = envknob.Bool("HEADSCALE_POLICY_V1") var ( errNoHeadscaleAvailable = errors.New("no headscale available") @@ -414,10 +413,6 @@ func (s *Scenario) Headscale(opts ...hsic.Option) (ControlServer, error) { opts = append(opts, hsic.WithPostgres()) } - if usePolicyV1ForTest { - opts = append(opts, hsic.WithPolicyV1()) - } - headscale, err := hsic.New(s.pool, s.Networks(), opts...) if err != nil { return nil, fmt.Errorf("failed to create headscale container: %w", err) diff --git a/integration/ssh_test.go b/integration/ssh_test.go index 25ede0c4..0bbd8711 100644 --- a/integration/ssh_test.go +++ b/integration/ssh_test.go @@ -7,10 +7,11 @@ import ( "testing" "time" - policyv1 "github.com/juanfont/headscale/hscontrol/policy/v1" + policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" "github.com/stretchr/testify/assert" + "tailscale.com/tailcfg" ) func isSSHNoAccessStdError(stderr string) bool { @@ -48,7 +49,7 @@ var retry = func(times int, sleepInterval time.Duration, return result, stderr, err } -func sshScenario(t *testing.T, policy *policyv1.ACLPolicy, clientsPerUser int) *Scenario { +func sshScenario(t *testing.T, policy *policyv2.Policy, clientsPerUser int) *Scenario { t.Helper() spec := ScenarioSpec{ @@ -92,23 +93,26 @@ func TestSSHOneUserToAll(t *testing.T) { t.Parallel() scenario := sshScenario(t, - &policyv1.ACLPolicy{ - Groups: map[string][]string{ - "group:integration-test": {"user1@"}, + &policyv2.Policy{ + Groups: policyv2.Groups{ + policyv2.Group("group:integration-test"): []policyv2.Username{policyv2.Username("user1@")}, }, - ACLs: []policyv1.ACL{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"*:*"}, + Action: "accept", + Protocol: "tcp", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, }, }, - SSHs: []policyv1.SSH{ + SSHs: []policyv2.SSH{ { Action: "accept", - Sources: []string{"group:integration-test"}, - Destinations: []string{"*"}, - Users: []string{"ssh-it-user"}, + Sources: policyv2.SSHSrcAliases{groupp("group:integration-test")}, + Destinations: policyv2.SSHDstAliases{wildcard()}, + Users: []policyv2.SSHUser{policyv2.SSHUser("ssh-it-user")}, }, }, }, @@ -157,23 +161,26 @@ func TestSSHMultipleUsersAllToAll(t *testing.T) { t.Parallel() scenario := sshScenario(t, - &policyv1.ACLPolicy{ - Groups: map[string][]string{ - "group:integration-test": {"user1@", "user2@"}, + &policyv2.Policy{ + Groups: policyv2.Groups{ + policyv2.Group("group:integration-test"): []policyv2.Username{policyv2.Username("user1@"), policyv2.Username("user2@")}, }, - ACLs: []policyv1.ACL{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"*:*"}, + Action: "accept", + Protocol: "tcp", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, }, }, - SSHs: []policyv1.SSH{ + SSHs: []policyv2.SSH{ { Action: "accept", - Sources: []string{"group:integration-test"}, - Destinations: []string{"user1@", "user2@"}, - Users: []string{"ssh-it-user"}, + Sources: policyv2.SSHSrcAliases{groupp("group:integration-test")}, + Destinations: policyv2.SSHDstAliases{usernamep("user1@"), usernamep("user2@")}, + Users: []policyv2.SSHUser{policyv2.SSHUser("ssh-it-user")}, }, }, }, @@ -210,18 +217,21 @@ func TestSSHNoSSHConfigured(t *testing.T) { t.Parallel() scenario := sshScenario(t, - &policyv1.ACLPolicy{ - Groups: map[string][]string{ - "group:integration-test": {"user1@"}, + &policyv2.Policy{ + Groups: policyv2.Groups{ + policyv2.Group("group:integration-test"): []policyv2.Username{policyv2.Username("user1@")}, }, - ACLs: []policyv1.ACL{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"*:*"}, + Action: "accept", + Protocol: "tcp", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, }, }, - SSHs: []policyv1.SSH{}, + SSHs: []policyv2.SSH{}, }, len(MustTestVersions), ) @@ -252,23 +262,26 @@ func TestSSHIsBlockedInACL(t *testing.T) { t.Parallel() scenario := sshScenario(t, - &policyv1.ACLPolicy{ - Groups: map[string][]string{ - "group:integration-test": {"user1@"}, + &policyv2.Policy{ + Groups: policyv2.Groups{ + policyv2.Group("group:integration-test"): []policyv2.Username{policyv2.Username("user1@")}, }, - ACLs: []policyv1.ACL{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"*:80"}, + Action: "accept", + Protocol: "tcp", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRange{First: 80, Last: 80}), + }, }, }, - SSHs: []policyv1.SSH{ + SSHs: []policyv2.SSH{ { Action: "accept", - Sources: []string{"group:integration-test"}, - Destinations: []string{"user1@"}, - Users: []string{"ssh-it-user"}, + Sources: policyv2.SSHSrcAliases{groupp("group:integration-test")}, + Destinations: policyv2.SSHDstAliases{usernamep("user1@")}, + Users: []policyv2.SSHUser{policyv2.SSHUser("ssh-it-user")}, }, }, }, @@ -301,30 +314,33 @@ func TestSSHUserOnlyIsolation(t *testing.T) { t.Parallel() scenario := sshScenario(t, - &policyv1.ACLPolicy{ - Groups: map[string][]string{ - "group:ssh1": {"user1@"}, - "group:ssh2": {"user2@"}, + &policyv2.Policy{ + Groups: policyv2.Groups{ + policyv2.Group("group:ssh1"): []policyv2.Username{policyv2.Username("user1@")}, + policyv2.Group("group:ssh2"): []policyv2.Username{policyv2.Username("user2@")}, }, - ACLs: []policyv1.ACL{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"*:*"}, + Action: "accept", + Protocol: "tcp", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, }, }, - SSHs: []policyv1.SSH{ + SSHs: []policyv2.SSH{ { Action: "accept", - Sources: []string{"group:ssh1"}, - Destinations: []string{"user1@"}, - Users: []string{"ssh-it-user"}, + Sources: policyv2.SSHSrcAliases{groupp("group:ssh1")}, + Destinations: policyv2.SSHDstAliases{usernamep("user1@")}, + Users: []policyv2.SSHUser{policyv2.SSHUser("ssh-it-user")}, }, { Action: "accept", - Sources: []string{"group:ssh2"}, - Destinations: []string{"user2@"}, - Users: []string{"ssh-it-user"}, + Sources: policyv2.SSHSrcAliases{groupp("group:ssh2")}, + Destinations: policyv2.SSHDstAliases{usernamep("user2@")}, + Users: []policyv2.SSHUser{policyv2.SSHUser("ssh-it-user")}, }, }, }, diff --git a/integration/utils.go b/integration/utils.go index 440fa663..18721cad 100644 --- a/integration/utils.go +++ b/integration/utils.go @@ -5,15 +5,19 @@ import ( "bytes" "fmt" "io" + "net/netip" "strings" "sync" "testing" "time" "github.com/cenkalti/backoff/v4" + policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" "github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/integration/tsic" "github.com/stretchr/testify/assert" + "tailscale.com/tailcfg" + "tailscale.com/types/ptr" ) const ( @@ -419,10 +423,76 @@ func countMatchingLines(in io.Reader, predicate func(string) bool) (int, error) // return peer // } // } -// } +// } // // return nil // } + +// Helper functions for creating typed policy entities + +// wildcard returns a wildcard alias (*). +func wildcard() policyv2.Alias { + return policyv2.Wildcard +} + +// usernamep returns a pointer to a Username as an Alias. +func usernamep(name string) policyv2.Alias { + return ptr.To(policyv2.Username(name)) +} + +// hostp returns a pointer to a Host. +func hostp(name string) policyv2.Alias { + return ptr.To(policyv2.Host(name)) +} + +// groupp returns a pointer to a Group as an Alias. +func groupp(name string) policyv2.Alias { + return ptr.To(policyv2.Group(name)) +} + +// tagp returns a pointer to a Tag as an Alias. +func tagp(name string) policyv2.Alias { + return ptr.To(policyv2.Tag(name)) +} + +// prefixp returns a pointer to a Prefix from a CIDR string. +func prefixp(cidr string) policyv2.Alias { + prefix := netip.MustParsePrefix(cidr) + return ptr.To(policyv2.Prefix(prefix)) +} + +// aliasWithPorts creates an AliasWithPorts structure from an alias and ports. +func aliasWithPorts(alias policyv2.Alias, ports ...tailcfg.PortRange) policyv2.AliasWithPorts { + return policyv2.AliasWithPorts{ + Alias: alias, + Ports: ports, + } +} + +// usernameOwner returns a Username as an Owner for use in TagOwners. +func usernameOwner(name string) policyv2.Owner { + return ptr.To(policyv2.Username(name)) +} + +// groupOwner returns a Group as an Owner for use in TagOwners. +func groupOwner(name string) policyv2.Owner { + return ptr.To(policyv2.Group(name)) +} + +// usernameApprover returns a Username as an AutoApprover. +func usernameApprover(name string) policyv2.AutoApprover { + return ptr.To(policyv2.Username(name)) +} + +// groupApprover returns a Group as an AutoApprover. +func groupApprover(name string) policyv2.AutoApprover { + return ptr.To(policyv2.Group(name)) +} + +// tagApprover returns a Tag as an AutoApprover. +func tagApprover(name string) policyv2.AutoApprover { + return ptr.To(policyv2.Tag(name)) +} // // // findPeerByHostname takes a hostname and a map of peers from status.Peer, and returns a *ipnstate.PeerStatus // // if there is a peer with the given hostname. If no peer is found, nil is returned.