mirror of
https://github.com/juanfont/headscale.git
synced 2024-12-25 13:45:52 -05:00
Add tests to verify "Hosts" aliases in ACL (#1304)
This commit is contained in:
parent
681c86cc95
commit
ceeef40cdf
60
acls.go
60
acls.go
@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/tailscale/hujson"
|
||||
"go4.org/netipx"
|
||||
"gopkg.in/yaml.v3"
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/tailcfg"
|
||||
@ -165,16 +166,22 @@ func generateACLPeerCacheMap(rules []tailcfg.FilterRule) map[string]map[string]s
|
||||
aclCachePeerMap := make(map[string]map[string]struct{})
|
||||
for _, rule := range rules {
|
||||
for _, srcIP := range rule.SrcIPs {
|
||||
if data, ok := aclCachePeerMap[srcIP]; ok {
|
||||
for _, dstPort := range rule.DstPorts {
|
||||
data[dstPort.IP] = struct{}{}
|
||||
for _, ip := range expandACLPeerAddr(srcIP) {
|
||||
if data, ok := aclCachePeerMap[ip]; ok {
|
||||
for _, dstPort := range rule.DstPorts {
|
||||
for _, dstIP := range expandACLPeerAddr(dstPort.IP) {
|
||||
data[dstIP] = struct{}{}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
dstPortsMap := make(map[string]struct{}, len(rule.DstPorts))
|
||||
for _, dstPort := range rule.DstPorts {
|
||||
for _, dstIP := range expandACLPeerAddr(dstPort.IP) {
|
||||
dstPortsMap[dstIP] = struct{}{}
|
||||
}
|
||||
}
|
||||
aclCachePeerMap[ip] = dstPortsMap
|
||||
}
|
||||
} else {
|
||||
dstPortsMap := make(map[string]struct{}, len(rule.DstPorts))
|
||||
for _, dstPort := range rule.DstPorts {
|
||||
dstPortsMap[dstPort.IP] = struct{}{}
|
||||
}
|
||||
aclCachePeerMap[srcIP] = dstPortsMap
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -184,6 +191,41 @@ func generateACLPeerCacheMap(rules []tailcfg.FilterRule) map[string]map[string]s
|
||||
return aclCachePeerMap
|
||||
}
|
||||
|
||||
// expandACLPeerAddr takes a "tailcfg.FilterRule" "IP" and expands it into
|
||||
// something our cache logic can look up, which is "*" or single IP addresses.
|
||||
// This is probably quite inefficient, but it is a result of
|
||||
// "make it work, then make it fast", and a lot of the ACL stuff does not
|
||||
// work, but people have tried to make it fast.
|
||||
func expandACLPeerAddr(srcIP string) []string {
|
||||
if ip, err := netip.ParseAddr(srcIP); err == nil {
|
||||
return []string{ip.String()}
|
||||
}
|
||||
|
||||
if cidr, err := netip.ParsePrefix(srcIP); err == nil {
|
||||
addrs := []string{}
|
||||
|
||||
ipRange := netipx.RangeOfPrefix(cidr)
|
||||
|
||||
from := ipRange.From()
|
||||
too := ipRange.To()
|
||||
|
||||
if from == too {
|
||||
return []string{from.String()}
|
||||
}
|
||||
|
||||
for from != too {
|
||||
addrs = append(addrs, from.String())
|
||||
|
||||
from = from.Next()
|
||||
}
|
||||
|
||||
return addrs
|
||||
}
|
||||
|
||||
// probably "*" or other string based "IP"
|
||||
return []string{srcIP}
|
||||
}
|
||||
|
||||
func generateACLRules(
|
||||
machines []Machine,
|
||||
aclPolicy ACLPolicy,
|
||||
|
64
acls_test.go
64
acls_test.go
@ -1556,3 +1556,67 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_expandACLPeerAddr(t *testing.T) {
|
||||
type args struct {
|
||||
srcIP string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
name: "asterix",
|
||||
args: args{
|
||||
srcIP: "*",
|
||||
},
|
||||
want: []string{"*"},
|
||||
},
|
||||
{
|
||||
name: "ip",
|
||||
args: args{
|
||||
srcIP: "10.0.0.1",
|
||||
},
|
||||
want: []string{"10.0.0.1"},
|
||||
},
|
||||
{
|
||||
name: "ip/32",
|
||||
args: args{
|
||||
srcIP: "10.0.0.1/32",
|
||||
},
|
||||
want: []string{"10.0.0.1"},
|
||||
},
|
||||
{
|
||||
name: "ip/30",
|
||||
args: args{
|
||||
srcIP: "10.0.0.1/30",
|
||||
},
|
||||
want: []string{
|
||||
"10.0.0.0",
|
||||
"10.0.0.1",
|
||||
"10.0.0.2",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ip/28",
|
||||
args: args{
|
||||
srcIP: "192.168.0.128/28",
|
||||
},
|
||||
want: []string{
|
||||
"192.168.0.128", "192.168.0.129", "192.168.0.130",
|
||||
"192.168.0.131", "192.168.0.132", "192.168.0.133",
|
||||
"192.168.0.134", "192.168.0.135", "192.168.0.136",
|
||||
"192.168.0.137", "192.168.0.138", "192.168.0.139",
|
||||
"192.168.0.140", "192.168.0.141", "192.168.0.142",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := expandACLPeerAddr(tt.args.srcIP); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("expandACLPeerAddr() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package integration
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@ -439,3 +440,214 @@ func TestACLAllowStarDst(t *testing.T) {
|
||||
err = scenario.Shutdown()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// This test aims to cover cases where individual hosts are allowed and denied
|
||||
// access based on their assigned hostname
|
||||
// https://github.com/juanfont/headscale/issues/941
|
||||
|
||||
// ACL = [{
|
||||
// "DstPorts": [{
|
||||
// "Bits": null,
|
||||
// "IP": "100.64.0.3/32",
|
||||
// "Ports": {
|
||||
// "First": 0,
|
||||
// "Last": 65535
|
||||
// }
|
||||
// }],
|
||||
// "SrcIPs": ["*"]
|
||||
// }, {
|
||||
//
|
||||
// "DstPorts": [{
|
||||
// "Bits": null,
|
||||
// "IP": "100.64.0.2/32",
|
||||
// "Ports": {
|
||||
// "First": 0,
|
||||
// "Last": 65535
|
||||
// }
|
||||
// }],
|
||||
// "SrcIPs": ["100.64.0.1/32"]
|
||||
// }]
|
||||
//
|
||||
// ACL Cache Map= {
|
||||
// "*": {
|
||||
// "100.64.0.3/32": {}
|
||||
// },
|
||||
// "100.64.0.1/32": {
|
||||
// "100.64.0.2/32": {}
|
||||
// }
|
||||
// }
|
||||
func TestACLNamedHostsCanReach(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
|
||||
scenario := aclScenario(t,
|
||||
headscale.ACLPolicy{
|
||||
Hosts: headscale.Hosts{
|
||||
"test1": netip.MustParsePrefix("100.64.0.1/32"),
|
||||
"test2": netip.MustParsePrefix("100.64.0.2/32"),
|
||||
"test3": netip.MustParsePrefix("100.64.0.3/32"),
|
||||
},
|
||||
ACLs: []headscale.ACL{
|
||||
// Everyone can curl test3
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"*"},
|
||||
Destinations: []string{"test3:*"},
|
||||
},
|
||||
// test1 can curl test2
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"test1"},
|
||||
Destinations: []string{"test2:*"},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
// Since user/users dont matter here, we basically expect that some clients
|
||||
// will be assigned these ips and that we can pick them up for our own use.
|
||||
test1ip := netip.MustParseAddr("100.64.0.1")
|
||||
test1, err := scenario.FindTailscaleClientByIP(test1ip)
|
||||
assert.NoError(t, err)
|
||||
|
||||
test1fqdn, err := test1.FQDN()
|
||||
assert.NoError(t, err)
|
||||
test1ipURL := fmt.Sprintf("http://%s/etc/hostname", test1ip.String())
|
||||
test1fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test1fqdn)
|
||||
|
||||
test2ip := netip.MustParseAddr("100.64.0.2")
|
||||
test2, err := scenario.FindTailscaleClientByIP(test2ip)
|
||||
assert.NoError(t, err)
|
||||
|
||||
test2fqdn, err := test2.FQDN()
|
||||
assert.NoError(t, err)
|
||||
test2ipURL := fmt.Sprintf("http://%s/etc/hostname", test2ip.String())
|
||||
test2fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test2fqdn)
|
||||
|
||||
test3ip := netip.MustParseAddr("100.64.0.3")
|
||||
test3, err := scenario.FindTailscaleClientByIP(test3ip)
|
||||
assert.NoError(t, err)
|
||||
|
||||
test3fqdn, err := test3.FQDN()
|
||||
assert.NoError(t, err)
|
||||
test3ipURL := fmt.Sprintf("http://%s/etc/hostname", test3ip.String())
|
||||
test3fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test3fqdn)
|
||||
|
||||
// test1 can query test3
|
||||
result, err := test1.Curl(test3ipURL)
|
||||
assert.Len(t, result, 13)
|
||||
assert.NoError(t, err)
|
||||
|
||||
result, err = test1.Curl(test3fqdnURL)
|
||||
assert.Len(t, result, 13)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// test2 can query test3
|
||||
result, err = test2.Curl(test3ipURL)
|
||||
assert.Len(t, result, 13)
|
||||
assert.NoError(t, err)
|
||||
|
||||
result, err = test2.Curl(test3fqdnURL)
|
||||
assert.Len(t, result, 13)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// test3 cannot query test1
|
||||
result, err = test3.Curl(test1ipURL)
|
||||
assert.Empty(t, result)
|
||||
assert.Error(t, err)
|
||||
|
||||
result, err = test3.Curl(test1fqdnURL)
|
||||
assert.Empty(t, result)
|
||||
assert.Error(t, err)
|
||||
|
||||
// test3 cannot query test2
|
||||
result, err = test3.Curl(test2ipURL)
|
||||
assert.Empty(t, result)
|
||||
assert.Error(t, err)
|
||||
|
||||
result, err = test3.Curl(test2fqdnURL)
|
||||
assert.Empty(t, result)
|
||||
assert.Error(t, err)
|
||||
|
||||
// test1 can query test2
|
||||
result, err = test1.Curl(test2ipURL)
|
||||
assert.Len(t, result, 13)
|
||||
assert.NoError(t, err)
|
||||
|
||||
result, err = test1.Curl(test2fqdnURL)
|
||||
assert.Len(t, result, 13)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// test2 cannot query test1
|
||||
result, err = test2.Curl(test1ipURL)
|
||||
assert.Empty(t, result)
|
||||
assert.Error(t, err)
|
||||
|
||||
result, err = test2.Curl(test1fqdnURL)
|
||||
assert.Empty(t, result)
|
||||
assert.Error(t, err)
|
||||
|
||||
err = scenario.Shutdown()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestACLNamedHostsCanReachBySubnet is the same as
|
||||
// TestACLNamedHostsCanReach, but it tests if we expand a
|
||||
// full CIDR correctly. All routes should work.
|
||||
func TestACLNamedHostsCanReachBySubnet(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
|
||||
scenario := aclScenario(t,
|
||||
headscale.ACLPolicy{
|
||||
Hosts: headscale.Hosts{
|
||||
"all": netip.MustParsePrefix("100.64.0.0/24"),
|
||||
},
|
||||
ACLs: []headscale.ACL{
|
||||
// Everyone can curl test3
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"*"},
|
||||
Destinations: []string{"all:*"},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
user1Clients, err := scenario.ListTailscaleClients("user1")
|
||||
assert.NoError(t, err)
|
||||
|
||||
user2Clients, err := scenario.ListTailscaleClients("user2")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test that user1 can visit all user2
|
||||
for _, client := range user1Clients {
|
||||
for _, peer := range user2Clients {
|
||||
fqdn, err := peer.FQDN()
|
||||
assert.NoError(t, err)
|
||||
|
||||
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
|
||||
t.Logf("url from %s to %s", client.Hostname(), url)
|
||||
|
||||
result, err := client.Curl(url)
|
||||
assert.Len(t, result, 13)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Test that user2 can visit all user1
|
||||
for _, client := range user2Clients {
|
||||
for _, peer := range user1Clients {
|
||||
fqdn, err := peer.FQDN()
|
||||
assert.NoError(t, err)
|
||||
|
||||
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
|
||||
t.Logf("url from %s to %s", client.Hostname(), url)
|
||||
|
||||
result, err := client.Curl(url)
|
||||
assert.Len(t, result, 13)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
err = scenario.Shutdown()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
17
machine.go
17
machine.go
@ -170,13 +170,14 @@ func (h *Headscale) filterMachinesByACL(currentMachine *Machine, peers Machines)
|
||||
// filterMachinesByACL returns the list of peers authorized to be accessed from a given machine.
|
||||
func filterMachinesByACL(
|
||||
machine *Machine,
|
||||
machines []Machine,
|
||||
machines Machines,
|
||||
lock *sync.RWMutex,
|
||||
aclPeerCacheMap map[string]map[string]struct{},
|
||||
) Machines {
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", machine.Hostname).
|
||||
Str("self", machine.Hostname).
|
||||
Str("input", machines.String()).
|
||||
Msg("Finding peers filtered by ACLs")
|
||||
|
||||
peers := make(map[uint64]Machine)
|
||||
@ -263,7 +264,7 @@ func filterMachinesByACL(
|
||||
|
||||
lock.RUnlock()
|
||||
|
||||
authorizedPeers := make([]Machine, 0, len(peers))
|
||||
authorizedPeers := make(Machines, 0, len(peers))
|
||||
for _, m := range peers {
|
||||
authorizedPeers = append(authorizedPeers, m)
|
||||
}
|
||||
@ -274,8 +275,9 @@ func filterMachinesByACL(
|
||||
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", machine.Hostname).
|
||||
Msgf("Found some machines: %v", machines)
|
||||
Str("self", machine.Hostname).
|
||||
Str("peers", authorizedPeers.String()).
|
||||
Msg("Authorized peers")
|
||||
|
||||
return authorizedPeers
|
||||
}
|
||||
@ -335,8 +337,9 @@ func (h *Headscale) getPeers(machine *Machine) (Machines, error) {
|
||||
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", machine.Hostname).
|
||||
Msgf("Found total peers: %s", peers.String())
|
||||
Str("self", machine.Hostname).
|
||||
Str("peers", peers.String()).
|
||||
Msg("Peers returned to caller")
|
||||
|
||||
return peers, nil
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user