Work in progress in rule generation

This commit is contained in:
Juan Font 2021-07-03 17:31:32 +02:00
parent bbd6a67c46
commit 136aab9dc8
4 changed files with 169 additions and 38 deletions

101
acls.go
View File

@ -1,30 +1,119 @@
package headscale
import (
"fmt"
"io"
"os"
"strings"
"github.com/tailscale/hujson"
"inet.af/netaddr"
"tailscale.com/tailcfg"
)
const errorInvalidPolicy = Error("invalid policy")
const errorEmptyPolicy = Error("empty policy")
const errorInvalidAction = Error("invalid action")
const errorInvalidUserSection = Error("invalid user section")
const errorInvalidGroup = Error("invalid group")
func (h *Headscale) ParsePolicy(path string) (*ACLPolicy, error) {
func (h *Headscale) LoadPolicy(path string) error {
policyFile, err := os.Open(path)
if err != nil {
return nil, err
return err
}
defer policyFile.Close()
var policy ACLPolicy
b, err := io.ReadAll(policyFile)
if err != nil {
return nil, err
return err
}
err = hujson.Unmarshal(b, &policy)
if policy.IsZero() {
return nil, errorInvalidPolicy
return errorEmptyPolicy
}
return &policy, err
h.aclPolicy = &policy
return err
}
func (h *Headscale) generateACLRules() (*[]tailcfg.FilterRule, error) {
rules := []tailcfg.FilterRule{}
for i, a := range h.aclPolicy.ACLs {
if a.Action != "accept" {
return nil, errorInvalidAction
}
r := tailcfg.FilterRule{}
srcIPs := []string{}
for j, u := range a.Users {
fmt.Printf("acl %d, user %d: ", i, j)
srcs, err := h.generateAclPolicySrcIP(u)
fmt.Printf(" -> %s\n", err)
if err != nil {
return nil, err
}
srcIPs = append(srcIPs, *srcs...)
}
r.SrcIPs = srcIPs
}
return &rules, nil
}
func (h *Headscale) generateAclPolicySrcIP(u string) (*[]string, error) {
if u == "*" {
fmt.Printf("%s -> wildcard", u)
return &[]string{"*"}, nil
}
if strings.HasPrefix(u, "group:") {
fmt.Printf("%s -> group", u)
if _, ok := h.aclPolicy.Groups[u]; !ok {
return nil, errorInvalidGroup
}
return nil, nil
}
if strings.HasPrefix(u, "tag:") {
fmt.Printf("%s -> tag", u)
return nil, nil
}
n, err := h.GetNamespace(u)
if err == nil {
fmt.Printf("%s -> namespace %s", u, n.Name)
nodes, err := h.ListMachinesInNamespace(n.Name)
if err != nil {
return nil, err
}
ips := []string{}
for _, n := range *nodes {
ips = append(ips, n.IPAddress)
}
return &ips, nil
}
if h, ok := h.aclPolicy.Hosts[u]; ok {
fmt.Printf("%s -> host %s", u, h)
return &[]string{h.String()}, nil
}
ip, err := netaddr.ParseIP(u)
if err == nil {
fmt.Printf(" %s -> ip %s", u, ip)
return &[]string{ip.String()}, nil
}
cidr, err := netaddr.ParseIPPrefix(u)
if err == nil {
fmt.Printf("%s -> cidr %s", u, cidr)
return &[]string{cidr.String()}, nil
}
fmt.Printf("%s: cannot be mapped to anything\n", u)
return nil, errorInvalidUserSection
}

View File

@ -5,29 +5,65 @@ import (
)
func (s *Suite) TestWrongPath(c *check.C) {
_, err := h.ParsePolicy("asdfg")
err := h.LoadPolicy("asdfg")
c.Assert(err, check.NotNil)
}
func (s *Suite) TestBrokenHuJson(c *check.C) {
_, err := h.ParsePolicy("./tests/acls/broken.hujson")
err := h.LoadPolicy("./tests/acls/broken.hujson")
c.Assert(err, check.NotNil)
}
func (s *Suite) TestInvalidPolicyHuson(c *check.C) {
_, err := h.ParsePolicy("./tests/acls/invalid.hujson")
err := h.LoadPolicy("./tests/acls/invalid.hujson")
c.Assert(err, check.NotNil)
c.Assert(err, check.Equals, errorInvalidPolicy)
c.Assert(err, check.Equals, errorEmptyPolicy)
}
func (s *Suite) TestValidCheckHosts(c *check.C) {
p, err := h.ParsePolicy("./tests/acls/acl_policy_1.hujson")
func (s *Suite) TestParseHosts(c *check.C) {
var hs Hosts
err := hs.UnmarshalJSON([]byte(`{"example-host-1": "100.100.100.100","example-host-2": "100.100.101.100/24"}`))
c.Assert(hs, check.NotNil)
c.Assert(err, check.IsNil)
c.Assert(p, check.NotNil)
c.Assert(p.IsZero(), check.Equals, false)
hosts, err := p.GetHosts()
c.Assert(err, check.IsNil)
c.Assert(*hosts, check.HasLen, 2)
}
func (s *Suite) TestParseInvalidCIDR(c *check.C) {
var hs Hosts
err := hs.UnmarshalJSON([]byte(`{"example-host-1": "100.100.100.100/42"}`))
c.Assert(hs, check.IsNil)
c.Assert(err, check.NotNil)
}
func (s *Suite) TestCheckLoaded(c *check.C) {
err := h.LoadPolicy("./tests/acls/acl_policy_1.hujson")
c.Assert(err, check.IsNil)
c.Assert(h.aclPolicy, check.NotNil)
}
func (s *Suite) TestValidCheckParsedHosts(c *check.C) {
err := h.LoadPolicy("./tests/acls/acl_policy_1.hujson")
c.Assert(err, check.IsNil)
c.Assert(h.aclPolicy, check.NotNil)
c.Assert(h.aclPolicy.IsZero(), check.Equals, false)
c.Assert(h.aclPolicy.Hosts, check.HasLen, 2)
}
func (s *Suite) TestRuleInvalidGeneration(c *check.C) {
err := h.LoadPolicy("./tests/acls/acl_policy_invalid.hujson")
c.Assert(err, check.IsNil)
rules, err := h.generateACLRules()
c.Assert(err, check.NotNil)
c.Assert(rules, check.IsNil)
}
func (s *Suite) TestRuleGeneration(c *check.C) {
err := h.LoadPolicy("./tests/acls/acl_policy_1.hujson")
c.Assert(err, check.IsNil)
rules, err := h.generateACLRules()
c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil)
}

View File

@ -3,6 +3,7 @@ package headscale
import (
"strings"
"github.com/tailscale/hujson"
"inet.af/netaddr"
)
@ -22,12 +23,9 @@ type ACL struct {
type Groups map[string][]string
type Hosts map[string]string
type Hosts map[string]netaddr.IPPrefix
type TagOwners struct {
TagMontrealWebserver []string `json:"tag:montreal-webserver"`
TagAPIServer []string `json:"tag:api-server"`
}
type TagOwners map[string][]string
type ACLTest struct {
User string `json:"User"`
@ -35,6 +33,27 @@ type ACLTest struct {
Deny []string `json:"Deny,omitempty"`
}
func (h *Hosts) UnmarshalJSON(data []byte) error {
hosts := Hosts{}
hs := make(map[string]string)
err := hujson.Unmarshal(data, &hs)
if err != nil {
return err
}
for k, v := range hs {
if !strings.Contains(v, "/") {
v = v + "/32"
}
prefix, err := netaddr.ParseIPPrefix(v)
if err != nil {
return err
}
hosts[k] = prefix
}
*h = hosts
return nil
}
// IsZero is perhaps a bit naive here
func (p ACLPolicy) IsZero() bool {
if len(p.Groups) == 0 && len(p.Hosts) == 0 && len(p.ACLs) == 0 {
@ -42,18 +61,3 @@ func (p ACLPolicy) IsZero() bool {
}
return false
}
func (p ACLPolicy) GetHosts() (*map[string]netaddr.IPPrefix, error) {
hosts := make(map[string]netaddr.IPPrefix)
for k, v := range p.Hosts {
if !strings.Contains(v, "/") {
v = v + "/32"
}
prefix, err := netaddr.ParseIPPrefix(v)
if err != nil {
return nil, err
}
hosts[k] = prefix
}
return &hosts, nil
}

2
app.go
View File

@ -49,6 +49,8 @@ type Headscale struct {
publicKey *wgkey.Key
privateKey *wgkey.Private
aclPolicy *ACLPolicy
pollMu sync.Mutex
clientsPolling map[uint64]chan []byte // this is by all means a hackity hack
}