From d00251c63e3eb8c93c0ea21acda1040af5971331 Mon Sep 17 00:00:00 2001 From: Adrien Raffin-Caboisse Date: Sun, 20 Feb 2022 21:24:02 +0100 Subject: [PATCH] fix(acls,machines): apply code review suggestions --- acls.go | 8 ++++---- acls_test.go | 2 +- machine.go | 37 ++++++++++++++++++++----------------- 3 files changed, 25 insertions(+), 22 deletions(-) diff --git a/acls.go b/acls.go index d5d39889..f9ed09da 100644 --- a/acls.go +++ b/acls.go @@ -204,7 +204,7 @@ func expandAlias( return ips, err } for _, n := range namespaces { - nodes := listMachinesInNamespace(machines, n) + nodes := filterMachinesByNamespace(machines, n) for _, node := range nodes { ips = append(ips, node.IPAddresses.ToStringSlice()...) } @@ -219,7 +219,7 @@ func expandAlias( return ips, err } for _, namespace := range owners { - machines := listMachinesInNamespace(machines, namespace) + machines := filterMachinesByNamespace(machines, namespace) for _, machine := range machines { if len(machine.HostInfo) == 0 { continue @@ -240,7 +240,7 @@ func expandAlias( } // if alias is a namespace - nodes := listMachinesInNamespace(machines, alias) + nodes := filterMachinesByNamespace(machines, alias) nodes, err := excludeCorrectlyTaggedNodes(aclPolicy, nodes, alias) if err != nil { return ips, err @@ -357,7 +357,7 @@ func expandPorts(portsStr string) (*[]tailcfg.PortRange, error) { return &ports, nil } -func listMachinesInNamespace(machines []Machine, namespace string) []Machine { +func filterMachinesByNamespace(machines []Machine, namespace string) []Machine { out := []Machine{} for _, machine := range machines { if machine.Namespace.Name == namespace { diff --git a/acls_test.go b/acls_test.go index e68a01f6..32dd5726 100644 --- a/acls_test.go +++ b/acls_test.go @@ -687,7 +687,7 @@ func Test_listMachinesInNamespace(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - if got := listMachinesInNamespace(test.args.machines, test.args.namespace); !reflect.DeepEqual( + if got := filterMachinesByNamespace(test.args.machines, test.args.namespace); !reflect.DeepEqual( got, test.want, ) { diff --git a/machine.go b/machine.go index ba677f15..4c984d64 100644 --- a/machine.go +++ b/machine.go @@ -142,6 +142,16 @@ func containsAddresses(inputs []string, addrs MachineAddresses) bool { return false } +// matchSourceAndDestinationWithRule will check if source is authorized to communicate with destination through +// the given rule. +func matchSourceAndDestinationWithRule(rule tailcfg.FilterRule, source Machine, destination Machine) bool { + var dst []string + for _, d := range rule.DstPorts { + dst = append(dst, d.IP) + } + return (containsAddresses(rule.SrcIPs, source.IPAddresses) && containsAddresses(dst, destination.IPAddresses)) || containsString(dst, "*") +} + // getFilteredByACLPeerss should return the list of peers authorized to be accessed from machine. func (h *Headscale) getFilteredByACLPeers(machine *Machine) (Machines, error) { log.Trace(). @@ -149,14 +159,12 @@ func (h *Headscale) getFilteredByACLPeers(machine *Machine) (Machines, error) { Str("machine", machine.Name). Msg("Finding peers filtered by ACLs") - machines := Machines{} - if err := h.db.Preload("Namespace").Where("machine_key <> ? AND registered", - machine.MachineKey).Find(&machines).Error; err != nil { - log.Error().Err(err).Msg("Error accessing db") - + machines, err := h.ListAllMachines() + if err != nil { + log.Error().Err(err).Msg("Error retrieving list of machines") return Machines{}, err } - mMachines := make(map[uint64]Machine) + peers := make(map[uint64]Machine) // Aclfilter peers here. We are itering through machines in all namespaces and search through the computed aclRules // for match between rule SrcIPs and DstPorts. If the rule is a match we allow the machine to be viewable. @@ -175,21 +183,16 @@ func (h *Headscale) getFilteredByACLPeers(machine *Machine) (Machines, error) { // In order to do this we would need to be able to identify that node A want to talk to node B but that Node B doesn't know // how to talk to node A and then add the peering resource. - for _, mchn := range machines { + for _, peer := range machines { for _, rule := range h.aclRules { - var dst []string - for _, d := range rule.DstPorts { - dst = append(dst, d.IP) - } - if (containsAddresses(rule.SrcIPs, machine.IPAddresses) && (containsAddresses(dst, mchn.IPAddresses) || containsString(dst, "*"))) || - (containsAddresses(rule.SrcIPs, mchn.IPAddresses) && containsAddresses(dst, machine.IPAddresses)) { - mMachines[mchn.ID] = mchn + if matchSourceAndDestinationWithRule(rule, *machine, peer) || matchSourceAndDestinationWithRule(rule, peer, *machine) { + peers[peer.ID] = peer } } } - authorizedMachines := make([]Machine, 0, len(mMachines)) - for _, m := range mMachines { + authorizedMachines := make([]Machine, 0, len(peers)) + for _, m := range peers { authorizedMachines = append(authorizedMachines, m) } sort.Slice( @@ -200,7 +203,7 @@ func (h *Headscale) getFilteredByACLPeers(machine *Machine) (Machines, error) { log.Trace(). Caller(). Str("machine", machine.Name). - Msgf("Found some machines: %s", machines.String()) + Msgf("Found some machines: %v", machines) return authorizedMachines, nil }