feat(acls): rewrite functions to be testable

Rewrite some function to get rid of the dependency on Headscale object. This allows us
to write succinct test that are more easy to review and implement.

The improvements of the tests allowed to write the removal of the tagged hosts
from the namespace as specified here: https://tailscale.com/kb/1068/acl-tags/
This commit is contained in:
Adrien Raffin 2022-02-07 16:12:05 +01:00 committed by Adrien Raffin-Caboisse
parent 97eac3b938
commit de59946447
No known key found for this signature in database
GPG Key ID: 7FB60532DEBEAD6A
3 changed files with 646 additions and 75 deletions

187
acls.go
View File

@ -2,7 +2,6 @@ package headscale
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"os" "os"
@ -86,6 +85,11 @@ func (h *Headscale) UpdateACLRules() error {
func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) { func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) {
rules := []tailcfg.FilterRule{} rules := []tailcfg.FilterRule{}
machines, err := h.ListAllMachines()
if err != nil {
return nil, err
}
for index, acl := range h.aclPolicy.ACLs { for index, acl := range h.aclPolicy.ACLs {
if acl.Action != "accept" { if acl.Action != "accept" {
return nil, errInvalidAction return nil, errInvalidAction
@ -93,7 +97,7 @@ func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) {
srcIPs := []string{} srcIPs := []string{}
for innerIndex, user := range acl.Users { for innerIndex, user := range acl.Users {
srcs, err := h.generateACLPolicySrcIP(user) srcs, err := h.generateACLPolicySrcIP(machines, *h.aclPolicy, user)
if err != nil { if err != nil {
log.Error(). log.Error().
Msgf("Error parsing ACL %d, User %d", index, innerIndex) Msgf("Error parsing ACL %d, User %d", index, innerIndex)
@ -105,7 +109,7 @@ func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) {
destPorts := []tailcfg.NetPortRange{} destPorts := []tailcfg.NetPortRange{}
for innerIndex, ports := range acl.Ports { for innerIndex, ports := range acl.Ports {
dests, err := h.generateACLPolicyDestPorts(ports) dests, err := h.generateACLPolicyDestPorts(machines, *h.aclPolicy, ports)
if err != nil { if err != nil {
log.Error(). log.Error().
Msgf("Error parsing ACL %d, Port %d", index, innerIndex) Msgf("Error parsing ACL %d, Port %d", index, innerIndex)
@ -124,11 +128,13 @@ func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) {
return rules, nil return rules, nil
} }
func (h *Headscale) generateACLPolicySrcIP(u string) ([]string, error) { func (h *Headscale) generateACLPolicySrcIP(machines []Machine, aclPolicy ACLPolicy, u string) ([]string, error) {
return h.expandAlias(u) return expandAlias(machines, aclPolicy, u)
} }
func (h *Headscale) generateACLPolicyDestPorts( func (h *Headscale) generateACLPolicyDestPorts(
machines []Machine,
aclPolicy ACLPolicy,
d string, d string,
) ([]tailcfg.NetPortRange, error) { ) ([]tailcfg.NetPortRange, error) {
tokens := strings.Split(d, ":") tokens := strings.Split(d, ":")
@ -149,11 +155,11 @@ func (h *Headscale) generateACLPolicyDestPorts(
alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1]) alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1])
} }
expanded, err := h.expandAlias(alias) expanded, err := expandAlias(machines, aclPolicy, alias)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ports, err := h.expandPorts(tokens[len(tokens)-1]) ports, err := expandPorts(tokens[len(tokens)-1])
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -177,52 +183,40 @@ func (h *Headscale) generateACLPolicyDestPorts(
// - a group // - a group
// - a tag // - a tag
// and transform these in IPAddresses // and transform these in IPAddresses
func (h *Headscale) expandAlias(alias string) ([]string, error) { func expandAlias(machines []Machine, aclPolicy ACLPolicy, alias string) ([]string, error) {
ips := []string{}
if alias == "*" { if alias == "*" {
return []string{"*"}, nil return []string{"*"}, nil
} }
if strings.HasPrefix(alias, "group:") { if strings.HasPrefix(alias, "group:") {
namespaces, err := h.expandGroup(alias) namespaces, err := expandGroup(aclPolicy, alias)
if err != nil { if err != nil {
return nil, err return ips, err
} }
ips := []string{}
for _, n := range namespaces { for _, n := range namespaces {
nodes, err := h.ListMachinesInNamespace(n) nodes := listMachinesInNamespace(machines, n)
if err != nil {
return nil, errInvalidNamespace
}
for _, node := range nodes { for _, node := range nodes {
ips = append(ips, node.IPAddresses.ToStringSlice()...) ips = append(ips, node.IPAddresses.ToStringSlice()...)
} }
} }
return ips, nil return ips, nil
} }
if strings.HasPrefix(alias, "tag:") { if strings.HasPrefix(alias, "tag:") {
var ips []string owners, err := expandTagOwners(aclPolicy, alias)
owners, err := h.expandTagOwners(alias)
if err != nil { if err != nil {
return nil, err return ips, err
} }
for _, namespace := range owners { for _, namespace := range owners {
machines, err := h.ListMachinesInNamespace(namespace) machines := listMachinesInNamespace(machines, namespace)
if err != nil {
if errors.Is(err, errNamespaceNotFound) {
continue
} else {
return nil, err
}
}
for _, machine := range machines { for _, machine := range machines {
if len(machine.HostInfo) == 0 { if len(machine.HostInfo) == 0 {
continue continue
} }
hi, err := machine.GetHostInfo() hi, err := machine.GetHostInfo()
if err != nil { if err != nil {
return nil, err return ips, err
} }
for _, t := range hi.RequestTags { for _, t := range hi.RequestTags {
if alias == t { if alias == t {
@ -234,75 +228,75 @@ func (h *Headscale) expandAlias(alias string) ([]string, error) {
return ips, nil return ips, nil
} }
n, err := h.GetNamespace(alias) // if alias is a namespace
if err == nil { nodes := listMachinesInNamespace(machines, alias)
nodes, err := h.ListMachinesInNamespace(n.Name) nodes, err := excludeCorrectlyTaggedNodes(aclPolicy, nodes, alias)
if err != nil { if err != nil {
return nil, err return ips, err
} }
ips := []string{} for _, n := range nodes {
for _, n := range nodes { ips = append(ips, n.IPAddresses.ToStringSlice()...)
ips = append(ips, n.IPAddresses.ToStringSlice()...) }
} if len(ips) > 0 {
return ips, nil return ips, nil
} }
if h, ok := h.aclPolicy.Hosts[alias]; ok { // if alias is an host
if h, ok := aclPolicy.Hosts[alias]; ok {
return []string{h.String()}, nil return []string{h.String()}, nil
} }
// if alias is an IP
ip, err := netaddr.ParseIP(alias) ip, err := netaddr.ParseIP(alias)
if err == nil { if err == nil {
return []string{ip.String()}, nil return []string{ip.String()}, nil
} }
// if alias is an CIDR
cidr, err := netaddr.ParseIPPrefix(alias) cidr, err := netaddr.ParseIPPrefix(alias)
if err == nil { if err == nil {
return []string{cidr.String()}, nil return []string{cidr.String()}, nil
} }
return nil, errInvalidUserSection return ips, errInvalidUserSection
} }
// expandTagOwners will return a list of namespace. An owner can be either a namespace or a group // excludeCorrectlyTaggedNodes will remove from the list of input nodes the ones
// a group cannot be composed of groups // that are correctly tagged since they should not be listed as being in the namespace
func (h *Headscale) expandTagOwners(owner string) ([]string, error) { // we assume in this function that we only have nodes from 1 namespace.
var owners []string func excludeCorrectlyTaggedNodes(aclPolicy ACLPolicy, nodes []Machine, namespace string) ([]Machine, error) {
ows, ok := h.aclPolicy.TagOwners[owner] out := []Machine{}
if !ok { tags := []string{}
return []string{}, fmt.Errorf("%w. %v isn't owned by a TagOwner. Please add one first. https://tailscale.com/kb/1018/acls/#tag-owners", errInvalidTag, owner) for tag, ns := range aclPolicy.TagOwners {
if containsString(ns, namespace) {
tags = append(tags, tag)
}
} }
for _, ow := range ows { // for each machine if tag is in tags list, don't append it.
if strings.HasPrefix(ow, "group:") { for _, machine := range nodes {
gs, err := h.expandGroup(ow) if len(machine.HostInfo) == 0 {
if err != nil { out = append(out, machine)
return []string{}, err continue
}
hi, err := machine.GetHostInfo()
if err != nil {
return out, err
}
found := false
for _, t := range hi.RequestTags {
if containsString(tags, t) {
found = true
break
} }
owners = append(owners, gs...) }
} else { if !found {
owners = append(owners, ow) out = append(out, machine)
} }
} }
return owners, nil return out, nil
} }
// expandGroup will return the list of namespace inside the group func expandPorts(portsStr string) (*[]tailcfg.PortRange, error) {
// after some validation
func (h *Headscale) expandGroup(group string) ([]string, error) {
gs, ok := h.aclPolicy.Groups[group]
if !ok {
return []string{}, fmt.Errorf("group %v isn't registered. %w", group, errInvalidGroup)
}
for _, g := range gs {
if strings.HasPrefix(g, "group:") {
return []string{}, fmt.Errorf("%w. A group cannot be composed of groups. https://tailscale.com/kb/1018/acls/#groups", errInvalidGroup)
}
}
return gs, nil
}
func (h *Headscale) expandPorts(portsStr string) (*[]tailcfg.PortRange, error) {
if portsStr == "*" { if portsStr == "*" {
return &[]tailcfg.PortRange{ return &[]tailcfg.PortRange{
{First: portRangeBegin, Last: portRangeEnd}, {First: portRangeBegin, Last: portRangeEnd},
@ -344,3 +338,50 @@ func (h *Headscale) expandPorts(portsStr string) (*[]tailcfg.PortRange, error) {
return &ports, nil return &ports, nil
} }
func listMachinesInNamespace(machines []Machine, namespace string) []Machine {
out := []Machine{}
for _, machine := range machines {
if machine.Namespace.Name == namespace {
out = append(out, machine)
}
}
return out
}
// expandTagOwners will return a list of namespace. An owner can be either a namespace or a group
// a group cannot be composed of groups
func expandTagOwners(aclPolicy ACLPolicy, tag string) ([]string, error) {
var owners []string
ows, ok := aclPolicy.TagOwners[tag]
if !ok {
return []string{}, fmt.Errorf("%w. %v isn't owned by a TagOwner. Please add one first. https://tailscale.com/kb/1018/acls/#tag-owners", errInvalidTag, tag)
}
for _, ow := range ows {
if strings.HasPrefix(ow, "group:") {
gs, err := expandGroup(aclPolicy, ow)
if err != nil {
return []string{}, err
}
owners = append(owners, gs...)
} else {
owners = append(owners, ow)
}
}
return owners, nil
}
// expandGroup will return the list of namespace inside the group
// after some validation
func expandGroup(aclPolicy ACLPolicy, group string) ([]string, error) {
gs, ok := aclPolicy.Groups[group]
if !ok {
return []string{}, fmt.Errorf("group %v isn't registered. %w", group, errInvalidGroup)
}
for _, g := range gs {
if strings.HasPrefix(g, "group:") {
return []string{}, fmt.Errorf("%w. A group cannot be composed of groups. https://tailscale.com/kb/1018/acls/#groups", errInvalidGroup)
}
}
return gs, nil
}

View File

@ -2,10 +2,13 @@ package headscale
import ( import (
"errors" "errors"
"reflect"
"testing"
"gopkg.in/check.v1" "gopkg.in/check.v1"
"gorm.io/datatypes" "gorm.io/datatypes"
"inet.af/netaddr" "inet.af/netaddr"
"tailscale.com/tailcfg"
) )
func (s *Suite) TestWrongPath(c *check.C) { func (s *Suite) TestWrongPath(c *check.C) {
@ -267,9 +270,16 @@ func (s *Suite) TestValidTagInvalidNamespace(c *check.C) {
} }
err = app.UpdateACLRules() err = app.UpdateACLRules()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Logf("Rules: %v", app.aclRules)
c.Assert(app.aclRules, check.HasLen, 1) c.Assert(app.aclRules, check.HasLen, 1)
c.Assert(app.aclRules[0].SrcIPs, check.HasLen, 0) c.Assert(app.aclRules[0].SrcIPs, check.HasLen, 1)
c.Assert(app.aclRules[0].SrcIPs[0], check.Equals, "100.64.0.2")
c.Assert(app.aclRules[0].DstPorts, check.HasLen, 2)
c.Assert(app.aclRules[0].DstPorts[0].Ports.First, check.Equals, uint16(80))
c.Assert(app.aclRules[0].DstPorts[0].Ports.Last, check.Equals, uint16(80))
c.Assert(app.aclRules[0].DstPorts[0].IP, check.Equals, "100.64.0.1")
c.Assert(app.aclRules[0].DstPorts[1].Ports.First, check.Equals, uint16(443))
c.Assert(app.aclRules[0].DstPorts[1].Ports.Last, check.Equals, uint16(443))
c.Assert(app.aclRules[0].DstPorts[1].IP, check.Equals, "100.64.0.1")
} }
func (s *Suite) TestPortRange(c *check.C) { func (s *Suite) TestPortRange(c *check.C) {
@ -385,3 +395,510 @@ func (s *Suite) TestPortGroup(c *check.C) {
c.Assert(len(ips), check.Equals, 1) c.Assert(len(ips), check.Equals, 1)
c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String()) c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String())
} }
func Test_expandGroup(t *testing.T) {
type args struct {
aclPolicy ACLPolicy
group string
}
tests := []struct {
name string
args args
want []string
wantErr bool
}{
{
name: "simple test",
args: args{
aclPolicy: ACLPolicy{
Groups: Groups{"group:test": []string{"g1", "foo", "test"}, "group:foo": []string{"foo", "test"}},
},
group: "group:test",
},
want: []string{"g1", "foo", "test"},
wantErr: false,
},
{
name: "InexistantGroup",
args: args{
aclPolicy: ACLPolicy{
Groups: Groups{"group:test": []string{"g1", "foo", "test"}, "group:foo": []string{"foo", "test"}},
},
group: "group:bar",
},
want: []string{},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := expandGroup(tt.args.aclPolicy, tt.args.group)
if (err != nil) != tt.wantErr {
t.Errorf("expandGroup() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("expandGroup() = %v, want %v", got, tt.want)
}
})
}
}
func Test_expandTagOwners(t *testing.T) {
type args struct {
aclPolicy ACLPolicy
tag string
}
tests := []struct {
name string
args args
want []string
wantErr bool
}{
{
name: "simple tag",
args: args{
aclPolicy: ACLPolicy{
TagOwners: TagOwners{"tag:test": []string{"namespace1"}},
},
tag: "tag:test",
},
want: []string{"namespace1"},
wantErr: false,
},
{
name: "tag and group",
args: args{
aclPolicy: ACLPolicy{
Groups: Groups{"group:foo": []string{"n1", "bar"}},
TagOwners: TagOwners{"tag:test": []string{"group:foo"}},
},
tag: "tag:test",
},
want: []string{"n1", "bar"},
wantErr: false,
},
{
name: "namespace and group",
args: args{
aclPolicy: ACLPolicy{
Groups: Groups{"group:foo": []string{"n1", "bar"}},
TagOwners: TagOwners{"tag:test": []string{"group:foo", "home"}},
},
tag: "tag:test",
},
want: []string{"n1", "bar", "home"},
wantErr: false,
},
{
name: "invalid tag",
args: args{
aclPolicy: ACLPolicy{
TagOwners: TagOwners{"tag:foo": []string{"group:foo", "home"}},
},
tag: "tag:test",
},
want: []string{},
wantErr: true,
},
{
name: "invalid group",
args: args{
aclPolicy: ACLPolicy{
Groups: Groups{"group:bar": []string{"n1", "foo"}},
TagOwners: TagOwners{"tag:test": []string{"group:foo", "home"}},
},
tag: "tag:test",
},
want: []string{},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := expandTagOwners(tt.args.aclPolicy, tt.args.tag)
if (err != nil) != tt.wantErr {
t.Errorf("expandTagOwners() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("expandTagOwners() = %v, want %v", got, tt.want)
}
})
}
}
func Test_expandPorts(t *testing.T) {
type args struct {
portsStr string
}
tests := []struct {
name string
args args
want *[]tailcfg.PortRange
wantErr bool
}{
{
name: "wildcard",
args: args{portsStr: "*"},
want: &[]tailcfg.PortRange{
{First: portRangeBegin, Last: portRangeEnd},
},
wantErr: false,
},
{
name: "two ports",
args: args{portsStr: "80,443"},
want: &[]tailcfg.PortRange{
{First: 80, Last: 80},
{First: 443, Last: 443},
},
wantErr: false,
},
{
name: "a range and a port",
args: args{portsStr: "80-1024,443"},
want: &[]tailcfg.PortRange{
{First: 80, Last: 1024},
{First: 443, Last: 443},
},
wantErr: false,
},
{
name: "out of bounds",
args: args{portsStr: "854038"},
want: nil,
wantErr: true,
},
{
name: "wrong port",
args: args{portsStr: "85a38"},
want: nil,
wantErr: true,
},
{
name: "wrong port in first",
args: args{portsStr: "a-80"},
want: nil,
wantErr: true,
},
{
name: "wrong port in last",
args: args{portsStr: "80-85a38"},
want: nil,
wantErr: true,
},
{
name: "wrong port format",
args: args{portsStr: "80-85a38-3"},
want: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := expandPorts(tt.args.portsStr)
if (err != nil) != tt.wantErr {
t.Errorf("expandPorts() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("expandPorts() = %v, want %v", got, tt.want)
}
})
}
}
func Test_listMachinesInNamespace(t *testing.T) {
type args struct {
machines []Machine
namespace string
}
tests := []struct {
name string
args args
want []Machine
}{
{
name: "1 machine in namespace",
args: args{
machines: []Machine{
{Namespace: Namespace{Name: "test"}},
},
namespace: "test",
},
want: []Machine{
{Namespace: Namespace{Name: "test"}},
},
},
{
name: "3 machines, 2 in namespace",
args: args{
machines: []Machine{
{ID: 1, Namespace: Namespace{Name: "test"}},
{ID: 2, Namespace: Namespace{Name: "foo"}},
{ID: 3, Namespace: Namespace{Name: "foo"}},
},
namespace: "foo",
},
want: []Machine{
{ID: 2, Namespace: Namespace{Name: "foo"}},
{ID: 3, Namespace: Namespace{Name: "foo"}},
},
},
{
name: "5 machines, 0 in namespace",
args: args{
machines: []Machine{
{ID: 1, Namespace: Namespace{Name: "test"}},
{ID: 2, Namespace: Namespace{Name: "foo"}},
{ID: 3, Namespace: Namespace{Name: "foo"}},
{ID: 4, Namespace: Namespace{Name: "foo"}},
{ID: 5, Namespace: Namespace{Name: "foo"}},
},
namespace: "bar",
},
want: []Machine{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := listMachinesInNamespace(tt.args.machines, tt.args.namespace); !reflect.DeepEqual(got, tt.want) {
t.Errorf("listMachinesInNamespace() = %v, want %v", got, tt.want)
}
})
}
}
func Test_expandAlias(t *testing.T) {
type args struct {
machines []Machine
aclPolicy ACLPolicy
alias string
}
tests := []struct {
name string
args args
want []string
wantErr bool
}{
{
name: "wildcard",
args: args{
alias: "*",
machines: []Machine{
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.78.84.227")}},
},
aclPolicy: ACLPolicy{},
},
want: []string{"*"},
wantErr: false,
},
{
name: "simple group",
args: args{
alias: "group:foo",
machines: []Machine{
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.3")}, Namespace: Namespace{Name: "bar"}},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "test"}},
},
aclPolicy: ACLPolicy{
Groups: Groups{"group:foo": []string{"foo", "bar"}},
},
},
want: []string{"100.64.0.1", "100.64.0.2", "100.64.0.3"},
wantErr: false,
},
{
name: "wrong group",
args: args{
alias: "group:test",
machines: []Machine{
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.3")}, Namespace: Namespace{Name: "bar"}},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "test"}},
},
aclPolicy: ACLPolicy{
Groups: Groups{"group:foo": []string{"foo", "bar"}},
},
},
want: []string{},
wantErr: true,
},
{
name: "simple ipaddress",
args: args{
alias: "10.0.0.3",
machines: []Machine{},
aclPolicy: ACLPolicy{},
},
want: []string{"10.0.0.3"},
wantErr: false,
},
{
name: "private network",
args: args{
alias: "homeNetwork",
machines: []Machine{},
aclPolicy: ACLPolicy{
Hosts: Hosts{"homeNetwork": netaddr.MustParseIPPrefix("192.168.1.0/24")},
},
},
want: []string{"192.168.1.0/24"},
wantErr: false,
},
{
name: "simple host",
args: args{
alias: "10.0.0.1",
machines: []Machine{},
aclPolicy: ACLPolicy{},
},
want: []string{"10.0.0.1"},
wantErr: false,
},
{
name: "simple CIDR",
args: args{
alias: "10.0.0.0/16",
machines: []Machine{},
aclPolicy: ACLPolicy{},
},
want: []string{"10.0.0.0/16"},
wantErr: false,
},
{
name: "simple tag",
args: args{
alias: "tag:test",
machines: []Machine{
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.3")}, Namespace: Namespace{Name: "bar"}},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "foo"}},
},
aclPolicy: ACLPolicy{
TagOwners: TagOwners{"tag:test": []string{"foo"}},
},
},
want: []string{"100.64.0.1", "100.64.0.2"},
wantErr: false,
},
{
name: "No tag defined",
args: args{
alias: "tag:foo",
machines: []Machine{
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.3")}, Namespace: Namespace{Name: "bar"}},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "test"}},
},
aclPolicy: ACLPolicy{
Groups: Groups{"group:foo": []string{"foo", "bar"}},
TagOwners: TagOwners{"tag:test": []string{"group:foo"}},
},
},
want: []string{},
wantErr: true,
},
{
name: "list host in namespace without correctly tagged servers",
args: args{
alias: "foo",
machines: []Machine{
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.3")}, Namespace: Namespace{Name: "bar"}},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "foo"}},
},
aclPolicy: ACLPolicy{
TagOwners: TagOwners{"tag:test": []string{"foo"}},
},
},
want: []string{"100.64.0.4"},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := expandAlias(tt.args.machines, tt.args.aclPolicy, tt.args.alias)
if (err != nil) != tt.wantErr {
t.Errorf("expandAlias() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("expandAlias() = %v, want %v", got, tt.want)
}
})
}
}
func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
type args struct {
aclPolicy ACLPolicy
nodes []Machine
namespace string
}
tests := []struct {
name string
args args
want []Machine
wantErr bool
}{
{
name: "exclude nodes with valid tags",
args: args{
aclPolicy: ACLPolicy{
TagOwners: TagOwners{"tag:test": []string{"foo"}},
},
nodes: []Machine{
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "foo"}},
},
namespace: "foo",
},
want: []Machine{
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "foo"}},
},
wantErr: false,
},
{
name: "all nodes have invalid tags, don't exclude them",
args: args{
aclPolicy: ACLPolicy{
TagOwners: TagOwners{"tag:foo": []string{"foo"}},
},
nodes: []Machine{
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "foo"}},
},
namespace: "foo",
},
want: []Machine{
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "foo"}},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := excludeCorrectlyTaggedNodes(tt.args.aclPolicy, tt.args.nodes, tt.args.namespace)
if (err != nil) != tt.wantErr {
t.Errorf("excludeCorrectlyTaggedNodes() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("excludeCorrectlyTaggedNodes() = %v, want %v", got, tt.want)
}
})
}
}

View File

@ -119,6 +119,19 @@ func (machine Machine) isExpired() bool {
return time.Now().UTC().After(*machine.Expiry) return time.Now().UTC().After(*machine.Expiry)
} }
func (h *Headscale) ListAllMachines() ([]Machine, error) {
machines := []Machine{}
if err := h.db.Preload("AuthKey").
Preload("AuthKey.Namespace").
Preload("Namespace").
Where("registered").
Find(&machines).Error; err != nil {
return nil, err
}
return machines, nil
}
func containsAddresses(inputs []string, addrs MachineAddresses) bool { func containsAddresses(inputs []string, addrs MachineAddresses) bool {
for _, addr := range addrs.ToStringSlice() { for _, addr := range addrs.ToStringSlice() {
if containsString(inputs, addr) { if containsString(inputs, addr) {