Rule generation kinda working, missing tests

This commit is contained in:
Juan Font 2021-07-04 12:35:18 +02:00
parent 136aab9dc8
commit 07e95393b3
2 changed files with 185 additions and 22 deletions

190
acls.go
View File

@ -1,11 +1,15 @@
package headscale package headscale
import ( import (
"encoding/json"
"fmt" "fmt"
"io" "io"
"log"
"os" "os"
"strconv"
"strings" "strings"
"github.com/davecgh/go-spew/spew"
"github.com/tailscale/hujson" "github.com/tailscale/hujson"
"inet.af/netaddr" "inet.af/netaddr"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
@ -15,6 +19,9 @@ const errorEmptyPolicy = Error("empty policy")
const errorInvalidAction = Error("invalid action") const errorInvalidAction = Error("invalid action")
const errorInvalidUserSection = Error("invalid user section") const errorInvalidUserSection = Error("invalid user section")
const errorInvalidGroup = Error("invalid group") const errorInvalidGroup = Error("invalid group")
const errorInvalidTag = Error("invalid tag")
const errorInvalidNamespace = Error("invalid namespace")
const errorInvalidPortFormat = Error("invalid port format")
func (h *Headscale) LoadPolicy(path string) error { func (h *Headscale) LoadPolicy(path string) error {
policyFile, err := os.Open(path) policyFile, err := os.Open(path)
@ -59,33 +66,143 @@ func (h *Headscale) generateACLRules() (*[]tailcfg.FilterRule, error) {
} }
r.SrcIPs = srcIPs r.SrcIPs = srcIPs
destPorts := []tailcfg.NetPortRange{}
for j, d := range a.Ports {
fmt.Printf("acl %d, port %d: ", i, j)
dests, err := h.generateAclPolicyDestPorts(d)
fmt.Printf(" -> %s\n", err)
if err != nil {
return nil, err
}
destPorts = append(destPorts, *dests...)
}
rules = append(rules, tailcfg.FilterRule{
SrcIPs: srcIPs,
DstPorts: destPorts,
})
} }
// fmt.Println(rules)
spew.Dump(rules)
return &rules, nil return &rules, nil
} }
func (h *Headscale) generateAclPolicySrcIP(u string) (*[]string, error) { func (h *Headscale) generateAclPolicySrcIP(u string) (*[]string, error) {
if u == "*" { return h.expandAlias(u)
fmt.Printf("%s -> wildcard", u) }
func (h *Headscale) generateAclPolicyDestPorts(d string) (*[]tailcfg.NetPortRange, error) {
tokens := strings.Split(d, ":")
if len(tokens) < 2 || len(tokens) > 3 {
return nil, errorInvalidPortFormat
}
var alias string
// We can have here stuff like:
// git-server:*
// 192.168.1.0/24:22
// tag:montreal-webserver:80,443
// tag:api-server:443
// example-host-1:*
if len(tokens) == 2 {
alias = tokens[0]
} else {
alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1])
}
expanded, err := h.expandAlias(alias)
if err != nil {
return nil, err
}
ports, err := h.expandPorts(tokens[len(tokens)-1])
if err != nil {
return nil, err
}
dests := []tailcfg.NetPortRange{}
for _, d := range *expanded {
for _, p := range *ports {
pr := tailcfg.NetPortRange{
IP: d,
Ports: p,
}
dests = append(dests, pr)
}
}
return &dests, nil
}
func (h *Headscale) expandAlias(s string) (*[]string, error) {
if s == "*" {
fmt.Printf("%s -> wildcard", s)
return &[]string{"*"}, nil return &[]string{"*"}, nil
} }
if strings.HasPrefix(u, "group:") { if strings.HasPrefix(s, "group:") {
fmt.Printf("%s -> group", u) fmt.Printf("%s -> group", s)
if _, ok := h.aclPolicy.Groups[u]; !ok { if _, ok := h.aclPolicy.Groups[s]; !ok {
return nil, errorInvalidGroup return nil, errorInvalidGroup
} }
return nil, nil ips := []string{}
for _, n := range h.aclPolicy.Groups[s] {
nodes, err := h.ListMachinesInNamespace(n)
if err != nil {
return nil, errorInvalidNamespace
}
for _, node := range *nodes {
ips = append(ips, node.IPAddress)
}
}
return &ips, nil
} }
if strings.HasPrefix(u, "tag:") { if strings.HasPrefix(s, "tag:") {
fmt.Printf("%s -> tag", u) fmt.Printf("%s -> tag", s)
return nil, nil if _, ok := h.aclPolicy.TagOwners[s]; !ok {
return nil, errorInvalidTag
}
// This will have HORRIBLE performance.
// We need to change the data model to better store tags
db, err := h.db()
if err != nil {
log.Printf("Cannot open DB: %s", err)
return nil, err
}
machines := []Machine{}
if err = db.Where("registered").Find(&machines).Error; err != nil {
log.Printf("Error accessing db: %s", err)
return nil, err
}
ips := []string{}
for _, m := range machines {
hostinfo := tailcfg.Hostinfo{}
if len(m.HostInfo) != 0 {
hi, err := m.HostInfo.MarshalJSON()
if err != nil {
return nil, err
}
err = json.Unmarshal(hi, &hostinfo)
if err != nil {
return nil, err
}
// FIXME: Check TagOwners allows this
for _, t := range hostinfo.RequestTags {
if s[4:] == t {
ips = append(ips, m.IPAddress)
break
}
}
}
}
return &ips, nil
} }
n, err := h.GetNamespace(u) n, err := h.GetNamespace(s)
if err == nil { if err == nil {
fmt.Printf("%s -> namespace %s", u, n.Name) fmt.Printf("%s -> namespace %s", s, n.Name)
nodes, err := h.ListMachinesInNamespace(n.Name) nodes, err := h.ListMachinesInNamespace(n.Name)
if err != nil { if err != nil {
return nil, err return nil, err
@ -97,23 +214,60 @@ func (h *Headscale) generateAclPolicySrcIP(u string) (*[]string, error) {
return &ips, nil return &ips, nil
} }
if h, ok := h.aclPolicy.Hosts[u]; ok { if h, ok := h.aclPolicy.Hosts[s]; ok {
fmt.Printf("%s -> host %s", u, h) fmt.Printf("%s -> host %s", s, h)
return &[]string{h.String()}, nil return &[]string{h.String()}, nil
} }
ip, err := netaddr.ParseIP(u) ip, err := netaddr.ParseIP(s)
if err == nil { if err == nil {
fmt.Printf(" %s -> ip %s", u, ip) fmt.Printf(" %s -> ip %s", s, ip)
return &[]string{ip.String()}, nil return &[]string{ip.String()}, nil
} }
cidr, err := netaddr.ParseIPPrefix(u) cidr, err := netaddr.ParseIPPrefix(s)
if err == nil { if err == nil {
fmt.Printf("%s -> cidr %s", u, cidr) fmt.Printf("%s -> cidr %s", s, cidr)
return &[]string{cidr.String()}, nil return &[]string{cidr.String()}, nil
} }
fmt.Printf("%s: cannot be mapped to anything\n", u) fmt.Printf("%s: cannot be mapped to anything\n", s)
return nil, errorInvalidUserSection return nil, errorInvalidUserSection
} }
func (h *Headscale) expandPorts(s string) (*[]tailcfg.PortRange, error) {
if s == "*" {
return &[]tailcfg.PortRange{{First: 0, Last: 65535}}, nil
}
ports := []tailcfg.PortRange{}
for _, p := range strings.Split(s, ",") {
rang := strings.Split(p, "-")
if len(rang) == 1 {
pi, err := strconv.ParseUint(rang[0], 10, 16)
if err != nil {
return nil, err
}
ports = append(ports, tailcfg.PortRange{
First: uint16(pi),
Last: uint16(pi),
})
} else if len(rang) == 2 {
start, err := strconv.ParseUint(rang[0], 10, 16)
if err != nil {
return nil, err
}
last, err := strconv.ParseUint(rang[1], 10, 16)
if err != nil {
return nil, err
}
ports = append(ports, tailcfg.PortRange{
First: uint16(start),
Last: uint16(last),
})
} else {
return nil, errorInvalidPortFormat
}
}
return &ports, nil
}

View File

@ -58,12 +58,21 @@ func (s *Suite) TestRuleInvalidGeneration(c *check.C) {
c.Assert(rules, check.IsNil) c.Assert(rules, check.IsNil)
} }
func (s *Suite) TestRuleGeneration(c *check.C) { func (s *Suite) TestBasicRule(c *check.C) {
err := h.LoadPolicy("./tests/acls/acl_policy_1.hujson") err := h.LoadPolicy("./tests/acls/acl_policy_basic_1.hujson")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
rules, err := h.generateACLRules() rules, err := h.generateACLRules()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil) c.Assert(rules, check.IsNil)
} }
// func (s *Suite) TestRuleGeneration(c *check.C) {
// err := h.LoadPolicy("./tests/acls/acl_policy_1.hujson")
// c.Assert(err, check.IsNil)
// rules, err := h.generateACLRules()
// c.Assert(err, check.IsNil)
// c.Assert(rules, check.NotNil)
// }