From de5994644790b6948d1d39b12a59dfb9931df797 Mon Sep 17 00:00:00 2001 From: Adrien Raffin Date: Mon, 7 Feb 2022 16:12:05 +0100 Subject: [PATCH] feat(acls): rewrite functions to be testable Rewrite some function to get rid of the dependency on Headscale object. This allows us to write succinct test that are more easy to review and implement. The improvements of the tests allowed to write the removal of the tagged hosts from the namespace as specified here: https://tailscale.com/kb/1068/acl-tags/ --- acls.go | 187 ++++++++++-------- acls_test.go | 521 ++++++++++++++++++++++++++++++++++++++++++++++++++- machine.go | 13 ++ 3 files changed, 646 insertions(+), 75 deletions(-) diff --git a/acls.go b/acls.go index c86e315f..9dd1260d 100644 --- a/acls.go +++ b/acls.go @@ -2,7 +2,6 @@ package headscale import ( "encoding/json" - "errors" "fmt" "io" "os" @@ -86,6 +85,11 @@ func (h *Headscale) UpdateACLRules() error { func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) { rules := []tailcfg.FilterRule{} + machines, err := h.ListAllMachines() + if err != nil { + return nil, err + } + for index, acl := range h.aclPolicy.ACLs { if acl.Action != "accept" { return nil, errInvalidAction @@ -93,7 +97,7 @@ func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) { srcIPs := []string{} for innerIndex, user := range acl.Users { - srcs, err := h.generateACLPolicySrcIP(user) + srcs, err := h.generateACLPolicySrcIP(machines, *h.aclPolicy, user) if err != nil { log.Error(). Msgf("Error parsing ACL %d, User %d", index, innerIndex) @@ -105,7 +109,7 @@ func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) { destPorts := []tailcfg.NetPortRange{} for innerIndex, ports := range acl.Ports { - dests, err := h.generateACLPolicyDestPorts(ports) + dests, err := h.generateACLPolicyDestPorts(machines, *h.aclPolicy, ports) if err != nil { log.Error(). Msgf("Error parsing ACL %d, Port %d", index, innerIndex) @@ -124,11 +128,13 @@ func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) { return rules, nil } -func (h *Headscale) generateACLPolicySrcIP(u string) ([]string, error) { - return h.expandAlias(u) +func (h *Headscale) generateACLPolicySrcIP(machines []Machine, aclPolicy ACLPolicy, u string) ([]string, error) { + return expandAlias(machines, aclPolicy, u) } func (h *Headscale) generateACLPolicyDestPorts( + machines []Machine, + aclPolicy ACLPolicy, d string, ) ([]tailcfg.NetPortRange, error) { tokens := strings.Split(d, ":") @@ -149,11 +155,11 @@ func (h *Headscale) generateACLPolicyDestPorts( alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1]) } - expanded, err := h.expandAlias(alias) + expanded, err := expandAlias(machines, aclPolicy, alias) if err != nil { return nil, err } - ports, err := h.expandPorts(tokens[len(tokens)-1]) + ports, err := expandPorts(tokens[len(tokens)-1]) if err != nil { return nil, err } @@ -177,52 +183,40 @@ func (h *Headscale) generateACLPolicyDestPorts( // - a group // - a tag // and transform these in IPAddresses -func (h *Headscale) expandAlias(alias string) ([]string, error) { +func expandAlias(machines []Machine, aclPolicy ACLPolicy, alias string) ([]string, error) { + ips := []string{} if alias == "*" { return []string{"*"}, nil } if strings.HasPrefix(alias, "group:") { - namespaces, err := h.expandGroup(alias) + namespaces, err := expandGroup(aclPolicy, alias) if err != nil { - return nil, err + return ips, err } - ips := []string{} for _, n := range namespaces { - nodes, err := h.ListMachinesInNamespace(n) - if err != nil { - return nil, errInvalidNamespace - } + nodes := listMachinesInNamespace(machines, n) for _, node := range nodes { ips = append(ips, node.IPAddresses.ToStringSlice()...) } } - return ips, nil } if strings.HasPrefix(alias, "tag:") { - var ips []string - owners, err := h.expandTagOwners(alias) + owners, err := expandTagOwners(aclPolicy, alias) if err != nil { - return nil, err + return ips, err } for _, namespace := range owners { - machines, err := h.ListMachinesInNamespace(namespace) - if err != nil { - if errors.Is(err, errNamespaceNotFound) { - continue - } else { - return nil, err - } - } + machines := listMachinesInNamespace(machines, namespace) for _, machine := range machines { if len(machine.HostInfo) == 0 { continue } hi, err := machine.GetHostInfo() if err != nil { - return nil, err + return ips, err } for _, t := range hi.RequestTags { if alias == t { @@ -234,75 +228,75 @@ func (h *Headscale) expandAlias(alias string) ([]string, error) { return ips, nil } - n, err := h.GetNamespace(alias) - if err == nil { - nodes, err := h.ListMachinesInNamespace(n.Name) - if err != nil { - return nil, err - } - ips := []string{} - for _, n := range nodes { - ips = append(ips, n.IPAddresses.ToStringSlice()...) - } - + // if alias is a namespace + nodes := listMachinesInNamespace(machines, alias) + nodes, err := excludeCorrectlyTaggedNodes(aclPolicy, nodes, alias) + if err != nil { + return ips, err + } + for _, n := range nodes { + ips = append(ips, n.IPAddresses.ToStringSlice()...) + } + if len(ips) > 0 { return ips, nil } - if h, ok := h.aclPolicy.Hosts[alias]; ok { + // if alias is an host + if h, ok := aclPolicy.Hosts[alias]; ok { return []string{h.String()}, nil } + // if alias is an IP ip, err := netaddr.ParseIP(alias) if err == nil { return []string{ip.String()}, nil } + // if alias is an CIDR cidr, err := netaddr.ParseIPPrefix(alias) if err == nil { return []string{cidr.String()}, nil } - return nil, errInvalidUserSection + return ips, errInvalidUserSection } -// expandTagOwners will return a list of namespace. An owner can be either a namespace or a group -// a group cannot be composed of groups -func (h *Headscale) expandTagOwners(owner string) ([]string, error) { - var owners []string - ows, ok := h.aclPolicy.TagOwners[owner] - if !ok { - return []string{}, fmt.Errorf("%w. %v isn't owned by a TagOwner. Please add one first. https://tailscale.com/kb/1018/acls/#tag-owners", errInvalidTag, owner) +// excludeCorrectlyTaggedNodes will remove from the list of input nodes the ones +// that are correctly tagged since they should not be listed as being in the namespace +// we assume in this function that we only have nodes from 1 namespace. +func excludeCorrectlyTaggedNodes(aclPolicy ACLPolicy, nodes []Machine, namespace string) ([]Machine, error) { + out := []Machine{} + tags := []string{} + for tag, ns := range aclPolicy.TagOwners { + if containsString(ns, namespace) { + tags = append(tags, tag) + } } - for _, ow := range ows { - if strings.HasPrefix(ow, "group:") { - gs, err := h.expandGroup(ow) - if err != nil { - return []string{}, err + // for each machine if tag is in tags list, don't append it. + for _, machine := range nodes { + if len(machine.HostInfo) == 0 { + out = append(out, machine) + continue + } + hi, err := machine.GetHostInfo() + if err != nil { + return out, err + } + found := false + for _, t := range hi.RequestTags { + if containsString(tags, t) { + found = true + break } - owners = append(owners, gs...) - } else { - owners = append(owners, ow) + } + if !found { + out = append(out, machine) } } - return owners, nil + return out, nil } -// expandGroup will return the list of namespace inside the group -// after some validation -func (h *Headscale) expandGroup(group string) ([]string, error) { - gs, ok := h.aclPolicy.Groups[group] - if !ok { - return []string{}, fmt.Errorf("group %v isn't registered. %w", group, errInvalidGroup) - } - for _, g := range gs { - if strings.HasPrefix(g, "group:") { - return []string{}, fmt.Errorf("%w. A group cannot be composed of groups. https://tailscale.com/kb/1018/acls/#groups", errInvalidGroup) - } - } - return gs, nil -} - -func (h *Headscale) expandPorts(portsStr string) (*[]tailcfg.PortRange, error) { +func expandPorts(portsStr string) (*[]tailcfg.PortRange, error) { if portsStr == "*" { return &[]tailcfg.PortRange{ {First: portRangeBegin, Last: portRangeEnd}, @@ -344,3 +338,50 @@ func (h *Headscale) expandPorts(portsStr string) (*[]tailcfg.PortRange, error) { return &ports, nil } + +func listMachinesInNamespace(machines []Machine, namespace string) []Machine { + out := []Machine{} + for _, machine := range machines { + if machine.Namespace.Name == namespace { + out = append(out, machine) + } + } + return out +} + +// expandTagOwners will return a list of namespace. An owner can be either a namespace or a group +// a group cannot be composed of groups +func expandTagOwners(aclPolicy ACLPolicy, tag string) ([]string, error) { + var owners []string + ows, ok := aclPolicy.TagOwners[tag] + if !ok { + return []string{}, fmt.Errorf("%w. %v isn't owned by a TagOwner. Please add one first. https://tailscale.com/kb/1018/acls/#tag-owners", errInvalidTag, tag) + } + for _, ow := range ows { + if strings.HasPrefix(ow, "group:") { + gs, err := expandGroup(aclPolicy, ow) + if err != nil { + return []string{}, err + } + owners = append(owners, gs...) + } else { + owners = append(owners, ow) + } + } + return owners, nil +} + +// expandGroup will return the list of namespace inside the group +// after some validation +func expandGroup(aclPolicy ACLPolicy, group string) ([]string, error) { + gs, ok := aclPolicy.Groups[group] + if !ok { + return []string{}, fmt.Errorf("group %v isn't registered. %w", group, errInvalidGroup) + } + for _, g := range gs { + if strings.HasPrefix(g, "group:") { + return []string{}, fmt.Errorf("%w. A group cannot be composed of groups. https://tailscale.com/kb/1018/acls/#groups", errInvalidGroup) + } + } + return gs, nil +} diff --git a/acls_test.go b/acls_test.go index f2fb6a03..8bedb47d 100644 --- a/acls_test.go +++ b/acls_test.go @@ -2,10 +2,13 @@ package headscale import ( "errors" + "reflect" + "testing" "gopkg.in/check.v1" "gorm.io/datatypes" "inet.af/netaddr" + "tailscale.com/tailcfg" ) func (s *Suite) TestWrongPath(c *check.C) { @@ -267,9 +270,16 @@ func (s *Suite) TestValidTagInvalidNamespace(c *check.C) { } err = app.UpdateACLRules() c.Assert(err, check.IsNil) - c.Logf("Rules: %v", app.aclRules) c.Assert(app.aclRules, check.HasLen, 1) - c.Assert(app.aclRules[0].SrcIPs, check.HasLen, 0) + c.Assert(app.aclRules[0].SrcIPs, check.HasLen, 1) + c.Assert(app.aclRules[0].SrcIPs[0], check.Equals, "100.64.0.2") + c.Assert(app.aclRules[0].DstPorts, check.HasLen, 2) + c.Assert(app.aclRules[0].DstPorts[0].Ports.First, check.Equals, uint16(80)) + c.Assert(app.aclRules[0].DstPorts[0].Ports.Last, check.Equals, uint16(80)) + c.Assert(app.aclRules[0].DstPorts[0].IP, check.Equals, "100.64.0.1") + c.Assert(app.aclRules[0].DstPorts[1].Ports.First, check.Equals, uint16(443)) + c.Assert(app.aclRules[0].DstPorts[1].Ports.Last, check.Equals, uint16(443)) + c.Assert(app.aclRules[0].DstPorts[1].IP, check.Equals, "100.64.0.1") } func (s *Suite) TestPortRange(c *check.C) { @@ -385,3 +395,510 @@ func (s *Suite) TestPortGroup(c *check.C) { c.Assert(len(ips), check.Equals, 1) c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String()) } + +func Test_expandGroup(t *testing.T) { + type args struct { + aclPolicy ACLPolicy + group string + } + tests := []struct { + name string + args args + want []string + wantErr bool + }{ + { + name: "simple test", + args: args{ + aclPolicy: ACLPolicy{ + Groups: Groups{"group:test": []string{"g1", "foo", "test"}, "group:foo": []string{"foo", "test"}}, + }, + group: "group:test", + }, + want: []string{"g1", "foo", "test"}, + wantErr: false, + }, + { + name: "InexistantGroup", + args: args{ + aclPolicy: ACLPolicy{ + Groups: Groups{"group:test": []string{"g1", "foo", "test"}, "group:foo": []string{"foo", "test"}}, + }, + group: "group:bar", + }, + want: []string{}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := expandGroup(tt.args.aclPolicy, tt.args.group) + if (err != nil) != tt.wantErr { + t.Errorf("expandGroup() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("expandGroup() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_expandTagOwners(t *testing.T) { + type args struct { + aclPolicy ACLPolicy + tag string + } + tests := []struct { + name string + args args + want []string + wantErr bool + }{ + { + name: "simple tag", + args: args{ + aclPolicy: ACLPolicy{ + TagOwners: TagOwners{"tag:test": []string{"namespace1"}}, + }, + tag: "tag:test", + }, + want: []string{"namespace1"}, + wantErr: false, + }, + { + name: "tag and group", + args: args{ + aclPolicy: ACLPolicy{ + Groups: Groups{"group:foo": []string{"n1", "bar"}}, + TagOwners: TagOwners{"tag:test": []string{"group:foo"}}, + }, + tag: "tag:test", + }, + want: []string{"n1", "bar"}, + wantErr: false, + }, + { + name: "namespace and group", + args: args{ + aclPolicy: ACLPolicy{ + Groups: Groups{"group:foo": []string{"n1", "bar"}}, + TagOwners: TagOwners{"tag:test": []string{"group:foo", "home"}}, + }, + tag: "tag:test", + }, + want: []string{"n1", "bar", "home"}, + wantErr: false, + }, + { + name: "invalid tag", + args: args{ + aclPolicy: ACLPolicy{ + TagOwners: TagOwners{"tag:foo": []string{"group:foo", "home"}}, + }, + tag: "tag:test", + }, + want: []string{}, + wantErr: true, + }, + { + name: "invalid group", + args: args{ + aclPolicy: ACLPolicy{ + Groups: Groups{"group:bar": []string{"n1", "foo"}}, + TagOwners: TagOwners{"tag:test": []string{"group:foo", "home"}}, + }, + tag: "tag:test", + }, + want: []string{}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := expandTagOwners(tt.args.aclPolicy, tt.args.tag) + if (err != nil) != tt.wantErr { + t.Errorf("expandTagOwners() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("expandTagOwners() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_expandPorts(t *testing.T) { + type args struct { + portsStr string + } + tests := []struct { + name string + args args + want *[]tailcfg.PortRange + wantErr bool + }{ + { + name: "wildcard", + args: args{portsStr: "*"}, + want: &[]tailcfg.PortRange{ + {First: portRangeBegin, Last: portRangeEnd}, + }, + wantErr: false, + }, + { + name: "two ports", + args: args{portsStr: "80,443"}, + want: &[]tailcfg.PortRange{ + {First: 80, Last: 80}, + {First: 443, Last: 443}, + }, + wantErr: false, + }, + { + name: "a range and a port", + args: args{portsStr: "80-1024,443"}, + want: &[]tailcfg.PortRange{ + {First: 80, Last: 1024}, + {First: 443, Last: 443}, + }, + wantErr: false, + }, + { + name: "out of bounds", + args: args{portsStr: "854038"}, + want: nil, + wantErr: true, + }, + { + name: "wrong port", + args: args{portsStr: "85a38"}, + want: nil, + wantErr: true, + }, + { + name: "wrong port in first", + args: args{portsStr: "a-80"}, + want: nil, + wantErr: true, + }, + { + name: "wrong port in last", + args: args{portsStr: "80-85a38"}, + want: nil, + wantErr: true, + }, + { + name: "wrong port format", + args: args{portsStr: "80-85a38-3"}, + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := expandPorts(tt.args.portsStr) + if (err != nil) != tt.wantErr { + t.Errorf("expandPorts() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("expandPorts() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_listMachinesInNamespace(t *testing.T) { + type args struct { + machines []Machine + namespace string + } + tests := []struct { + name string + args args + want []Machine + }{ + { + name: "1 machine in namespace", + args: args{ + machines: []Machine{ + {Namespace: Namespace{Name: "test"}}, + }, + namespace: "test", + }, + want: []Machine{ + {Namespace: Namespace{Name: "test"}}, + }, + }, + { + name: "3 machines, 2 in namespace", + args: args{ + machines: []Machine{ + {ID: 1, Namespace: Namespace{Name: "test"}}, + {ID: 2, Namespace: Namespace{Name: "foo"}}, + {ID: 3, Namespace: Namespace{Name: "foo"}}, + }, + namespace: "foo", + }, + want: []Machine{ + {ID: 2, Namespace: Namespace{Name: "foo"}}, + {ID: 3, Namespace: Namespace{Name: "foo"}}, + }, + }, + { + name: "5 machines, 0 in namespace", + args: args{ + machines: []Machine{ + {ID: 1, Namespace: Namespace{Name: "test"}}, + {ID: 2, Namespace: Namespace{Name: "foo"}}, + {ID: 3, Namespace: Namespace{Name: "foo"}}, + {ID: 4, Namespace: Namespace{Name: "foo"}}, + {ID: 5, Namespace: Namespace{Name: "foo"}}, + }, + namespace: "bar", + }, + want: []Machine{}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := listMachinesInNamespace(tt.args.machines, tt.args.namespace); !reflect.DeepEqual(got, tt.want) { + t.Errorf("listMachinesInNamespace() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_expandAlias(t *testing.T) { + type args struct { + machines []Machine + aclPolicy ACLPolicy + alias string + } + tests := []struct { + name string + args args + want []string + wantErr bool + }{ + { + name: "wildcard", + args: args{ + alias: "*", + machines: []Machine{ + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.78.84.227")}}, + }, + aclPolicy: ACLPolicy{}, + }, + want: []string{"*"}, + wantErr: false, + }, + { + name: "simple group", + args: args{ + alias: "group:foo", + machines: []Machine{ + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.3")}, Namespace: Namespace{Name: "bar"}}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "test"}}, + }, + aclPolicy: ACLPolicy{ + Groups: Groups{"group:foo": []string{"foo", "bar"}}, + }, + }, + want: []string{"100.64.0.1", "100.64.0.2", "100.64.0.3"}, + wantErr: false, + }, + { + name: "wrong group", + args: args{ + alias: "group:test", + machines: []Machine{ + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.3")}, Namespace: Namespace{Name: "bar"}}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "test"}}, + }, + aclPolicy: ACLPolicy{ + Groups: Groups{"group:foo": []string{"foo", "bar"}}, + }, + }, + want: []string{}, + wantErr: true, + }, + { + name: "simple ipaddress", + args: args{ + alias: "10.0.0.3", + machines: []Machine{}, + aclPolicy: ACLPolicy{}, + }, + want: []string{"10.0.0.3"}, + wantErr: false, + }, + { + name: "private network", + args: args{ + alias: "homeNetwork", + machines: []Machine{}, + aclPolicy: ACLPolicy{ + Hosts: Hosts{"homeNetwork": netaddr.MustParseIPPrefix("192.168.1.0/24")}, + }, + }, + want: []string{"192.168.1.0/24"}, + wantErr: false, + }, + { + name: "simple host", + args: args{ + alias: "10.0.0.1", + machines: []Machine{}, + aclPolicy: ACLPolicy{}, + }, + want: []string{"10.0.0.1"}, + wantErr: false, + }, + { + name: "simple CIDR", + args: args{ + alias: "10.0.0.0/16", + machines: []Machine{}, + aclPolicy: ACLPolicy{}, + }, + want: []string{"10.0.0.0/16"}, + wantErr: false, + }, + { + name: "simple tag", + args: args{ + alias: "tag:test", + machines: []Machine{ + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.3")}, Namespace: Namespace{Name: "bar"}}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "foo"}}, + }, + aclPolicy: ACLPolicy{ + TagOwners: TagOwners{"tag:test": []string{"foo"}}, + }, + }, + want: []string{"100.64.0.1", "100.64.0.2"}, + wantErr: false, + }, + { + name: "No tag defined", + args: args{ + alias: "tag:foo", + machines: []Machine{ + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.3")}, Namespace: Namespace{Name: "bar"}}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "test"}}, + }, + aclPolicy: ACLPolicy{ + Groups: Groups{"group:foo": []string{"foo", "bar"}}, + TagOwners: TagOwners{"tag:test": []string{"group:foo"}}, + }, + }, + want: []string{}, + wantErr: true, + }, + { + name: "list host in namespace without correctly tagged servers", + args: args{ + alias: "foo", + machines: []Machine{ + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.3")}, Namespace: Namespace{Name: "bar"}}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "foo"}}, + }, + aclPolicy: ACLPolicy{ + TagOwners: TagOwners{"tag:test": []string{"foo"}}, + }, + }, + want: []string{"100.64.0.4"}, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := expandAlias(tt.args.machines, tt.args.aclPolicy, tt.args.alias) + if (err != nil) != tt.wantErr { + t.Errorf("expandAlias() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("expandAlias() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_excludeCorrectlyTaggedNodes(t *testing.T) { + type args struct { + aclPolicy ACLPolicy + nodes []Machine + namespace string + } + tests := []struct { + name string + args args + want []Machine + wantErr bool + }{ + { + name: "exclude nodes with valid tags", + args: args{ + aclPolicy: ACLPolicy{ + TagOwners: TagOwners{"tag:test": []string{"foo"}}, + }, + nodes: []Machine{ + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "foo"}}, + }, + namespace: "foo", + }, + want: []Machine{ + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "foo"}}, + }, + wantErr: false, + }, + { + name: "all nodes have invalid tags, don't exclude them", + args: args{ + aclPolicy: ACLPolicy{ + TagOwners: TagOwners{"tag:foo": []string{"foo"}}, + }, + nodes: []Machine{ + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "foo"}}, + }, + namespace: "foo", + }, + want: []Machine{ + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")}, + {IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "foo"}}, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := excludeCorrectlyTaggedNodes(tt.args.aclPolicy, tt.args.nodes, tt.args.namespace) + if (err != nil) != tt.wantErr { + t.Errorf("excludeCorrectlyTaggedNodes() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("excludeCorrectlyTaggedNodes() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/machine.go b/machine.go index 231c179d..01fc9276 100644 --- a/machine.go +++ b/machine.go @@ -119,6 +119,19 @@ func (machine Machine) isExpired() bool { return time.Now().UTC().After(*machine.Expiry) } +func (h *Headscale) ListAllMachines() ([]Machine, error) { + machines := []Machine{} + if err := h.db.Preload("AuthKey"). + Preload("AuthKey.Namespace"). + Preload("Namespace"). + Where("registered"). + Find(&machines).Error; err != nil { + return nil, err + } + + return machines, nil +} + func containsAddresses(inputs []string, addrs MachineAddresses) bool { for _, addr := range addrs.ToStringSlice() { if containsString(inputs, addr) {