Initial work eliminating one/two letter variables
This commit is contained in:
parent
53ed749f45
commit
471c0b4993
|
@ -28,6 +28,9 @@ linters:
|
|||
# In progress
|
||||
- gocritic
|
||||
|
||||
# TODO: approve: ok, db, id
|
||||
- varnamelen
|
||||
|
||||
# We should strive to enable these:
|
||||
- testpackage
|
||||
- stylecheck
|
||||
|
@ -39,7 +42,6 @@ linters:
|
|||
- gosec
|
||||
- forbidigo
|
||||
- dupl
|
||||
- varnamelen
|
||||
- makezero
|
||||
- paralleltest
|
||||
|
||||
|
|
74
acls.go
74
acls.go
|
@ -41,18 +41,18 @@ func (h *Headscale) LoadACLPolicy(path string) error {
|
|||
defer policyFile.Close()
|
||||
|
||||
var policy ACLPolicy
|
||||
b, err := io.ReadAll(policyFile)
|
||||
policyBytes, err := io.ReadAll(policyFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ast, err := hujson.Parse(b)
|
||||
ast, err := hujson.Parse(policyBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ast.Standardize()
|
||||
b = ast.Pack()
|
||||
err = json.Unmarshal(b, &policy)
|
||||
policyBytes = ast.Pack()
|
||||
err = json.Unmarshal(policyBytes, &policy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -73,32 +73,32 @@ func (h *Headscale) LoadACLPolicy(path string) error {
|
|||
func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) {
|
||||
rules := []tailcfg.FilterRule{}
|
||||
|
||||
for i, a := range h.aclPolicy.ACLs {
|
||||
if a.Action != "accept" {
|
||||
for index, acl := range h.aclPolicy.ACLs {
|
||||
if acl.Action != "accept" {
|
||||
return nil, errorInvalidAction
|
||||
}
|
||||
|
||||
r := tailcfg.FilterRule{}
|
||||
filterRule := tailcfg.FilterRule{}
|
||||
|
||||
srcIPs := []string{}
|
||||
for j, u := range a.Users {
|
||||
srcs, err := h.generateACLPolicySrcIP(u)
|
||||
for innerIndex, user := range acl.Users {
|
||||
srcs, err := h.generateACLPolicySrcIP(user)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Msgf("Error parsing ACL %d, User %d", i, j)
|
||||
Msgf("Error parsing ACL %d, User %d", index, innerIndex)
|
||||
|
||||
return nil, err
|
||||
}
|
||||
srcIPs = append(srcIPs, srcs...)
|
||||
}
|
||||
r.SrcIPs = srcIPs
|
||||
filterRule.SrcIPs = srcIPs
|
||||
|
||||
destPorts := []tailcfg.NetPortRange{}
|
||||
for j, d := range a.Ports {
|
||||
dests, err := h.generateACLPolicyDestPorts(d)
|
||||
for innerIndex, ports := range acl.Ports {
|
||||
dests, err := h.generateACLPolicyDestPorts(ports)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Msgf("Error parsing ACL %d, Port %d", i, j)
|
||||
Msgf("Error parsing ACL %d, Port %d", index, innerIndex)
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
@ -162,17 +162,17 @@ func (h *Headscale) generateACLPolicyDestPorts(
|
|||
return dests, nil
|
||||
}
|
||||
|
||||
func (h *Headscale) expandAlias(s string) ([]string, error) {
|
||||
if s == "*" {
|
||||
func (h *Headscale) expandAlias(alias string) ([]string, error) {
|
||||
if alias == "*" {
|
||||
return []string{"*"}, nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(s, "group:") {
|
||||
if _, ok := h.aclPolicy.Groups[s]; !ok {
|
||||
if strings.HasPrefix(alias, "group:") {
|
||||
if _, ok := h.aclPolicy.Groups[alias]; !ok {
|
||||
return nil, errorInvalidGroup
|
||||
}
|
||||
ips := []string{}
|
||||
for _, n := range h.aclPolicy.Groups[s] {
|
||||
for _, n := range h.aclPolicy.Groups[alias] {
|
||||
nodes, err := h.ListMachinesInNamespace(n)
|
||||
if err != nil {
|
||||
return nil, errorInvalidNamespace
|
||||
|
@ -185,8 +185,8 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
|
|||
return ips, nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(s, "tag:") {
|
||||
if _, ok := h.aclPolicy.TagOwners[s]; !ok {
|
||||
if strings.HasPrefix(alias, "tag:") {
|
||||
if _, ok := h.aclPolicy.TagOwners[alias]; !ok {
|
||||
return nil, errorInvalidTag
|
||||
}
|
||||
|
||||
|
@ -197,10 +197,10 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
|
|||
return nil, err
|
||||
}
|
||||
ips := []string{}
|
||||
for _, m := range machines {
|
||||
for _, machine := range machines {
|
||||
hostinfo := tailcfg.Hostinfo{}
|
||||
if len(m.HostInfo) != 0 {
|
||||
hi, err := m.HostInfo.MarshalJSON()
|
||||
if len(machine.HostInfo) != 0 {
|
||||
hi, err := machine.HostInfo.MarshalJSON()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -211,8 +211,8 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
|
|||
|
||||
// FIXME: Check TagOwners allows this
|
||||
for _, t := range hostinfo.RequestTags {
|
||||
if s[4:] == t {
|
||||
ips = append(ips, m.IPAddress)
|
||||
if alias[4:] == t {
|
||||
ips = append(ips, machine.IPAddress)
|
||||
|
||||
break
|
||||
}
|
||||
|
@ -223,7 +223,7 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
|
|||
return ips, nil
|
||||
}
|
||||
|
||||
n, err := h.GetNamespace(s)
|
||||
n, err := h.GetNamespace(alias)
|
||||
if err == nil {
|
||||
nodes, err := h.ListMachinesInNamespace(n.Name)
|
||||
if err != nil {
|
||||
|
@ -237,16 +237,16 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
|
|||
return ips, nil
|
||||
}
|
||||
|
||||
if h, ok := h.aclPolicy.Hosts[s]; ok {
|
||||
if h, ok := h.aclPolicy.Hosts[alias]; ok {
|
||||
return []string{h.String()}, nil
|
||||
}
|
||||
|
||||
ip, err := netaddr.ParseIP(s)
|
||||
ip, err := netaddr.ParseIP(alias)
|
||||
if err == nil {
|
||||
return []string{ip.String()}, nil
|
||||
}
|
||||
|
||||
cidr, err := netaddr.ParseIPPrefix(s)
|
||||
cidr, err := netaddr.ParseIPPrefix(alias)
|
||||
if err == nil {
|
||||
return []string{cidr.String()}, nil
|
||||
}
|
||||
|
@ -254,25 +254,25 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
|
|||
return nil, errorInvalidUserSection
|
||||
}
|
||||
|
||||
func (h *Headscale) expandPorts(s string) (*[]tailcfg.PortRange, error) {
|
||||
if s == "*" {
|
||||
func (h *Headscale) expandPorts(portsStr string) (*[]tailcfg.PortRange, error) {
|
||||
if portsStr == "*" {
|
||||
return &[]tailcfg.PortRange{
|
||||
{First: PORT_RANGE_BEGIN, Last: PORT_RANGE_END},
|
||||
}, nil
|
||||
}
|
||||
|
||||
ports := []tailcfg.PortRange{}
|
||||
for _, p := range strings.Split(s, ",") {
|
||||
rang := strings.Split(p, "-")
|
||||
for _, portStr := range strings.Split(portsStr, ",") {
|
||||
rang := strings.Split(portStr, "-")
|
||||
switch len(rang) {
|
||||
case 1:
|
||||
pi, err := strconv.ParseUint(rang[0], BASE_10, BIT_SIZE_16)
|
||||
port, err := strconv.ParseUint(rang[0], BASE_10, BIT_SIZE_16)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ports = append(ports, tailcfg.PortRange{
|
||||
First: uint16(pi),
|
||||
Last: uint16(pi),
|
||||
First: uint16(port),
|
||||
Last: uint16(port),
|
||||
})
|
||||
|
||||
case EXPECTED_TOKEN_ITEMS:
|
||||
|
|
|
@ -41,37 +41,37 @@ type ACLTest struct {
|
|||
}
|
||||
|
||||
// UnmarshalJSON allows to parse the Hosts directly into netaddr objects.
|
||||
func (h *Hosts) UnmarshalJSON(data []byte) error {
|
||||
hosts := Hosts{}
|
||||
hs := make(map[string]string)
|
||||
func (hosts *Hosts) UnmarshalJSON(data []byte) error {
|
||||
newHosts := Hosts{}
|
||||
hostIpPrefixMap := make(map[string]string)
|
||||
ast, err := hujson.Parse(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ast.Standardize()
|
||||
data = ast.Pack()
|
||||
err = json.Unmarshal(data, &hs)
|
||||
err = json.Unmarshal(data, &hostIpPrefixMap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for k, v := range hs {
|
||||
if !strings.Contains(v, "/") {
|
||||
v += "/32"
|
||||
for host, prefixStr := range hostIpPrefixMap {
|
||||
if !strings.Contains(prefixStr, "/") {
|
||||
prefixStr += "/32"
|
||||
}
|
||||
prefix, err := netaddr.ParseIPPrefix(v)
|
||||
prefix, err := netaddr.ParseIPPrefix(prefixStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hosts[k] = prefix
|
||||
newHosts[host] = prefix
|
||||
}
|
||||
*h = hosts
|
||||
*hosts = newHosts
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsZero is perhaps a bit naive here.
|
||||
func (p ACLPolicy) IsZero() bool {
|
||||
if len(p.Groups) == 0 && len(p.Hosts) == 0 && len(p.ACLs) == 0 {
|
||||
func (policy ACLPolicy) IsZero() bool {
|
||||
if len(policy.Groups) == 0 && len(policy.Hosts) == 0 && len(policy.ACLs) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
|
|
229
api.go
229
api.go
|
@ -22,21 +22,25 @@ const RESERVED_RESPONSE_HEADER_SIZE = 4
|
|||
|
||||
// KeyHandler provides the Headscale pub key
|
||||
// Listens in /key.
|
||||
func (h *Headscale) KeyHandler(c *gin.Context) {
|
||||
c.Data(http.StatusOK, "text/plain; charset=utf-8", []byte(h.publicKey.HexString()))
|
||||
func (h *Headscale) KeyHandler(ctx *gin.Context) {
|
||||
ctx.Data(
|
||||
http.StatusOK,
|
||||
"text/plain; charset=utf-8",
|
||||
[]byte(h.publicKey.HexString()),
|
||||
)
|
||||
}
|
||||
|
||||
// RegisterWebAPI shows a simple message in the browser to point to the CLI
|
||||
// Listens in /register.
|
||||
func (h *Headscale) RegisterWebAPI(c *gin.Context) {
|
||||
mKeyStr := c.Query("key")
|
||||
if mKeyStr == "" {
|
||||
c.String(http.StatusBadRequest, "Wrong params")
|
||||
func (h *Headscale) RegisterWebAPI(ctx *gin.Context) {
|
||||
machineKeyStr := ctx.Query("key")
|
||||
if machineKeyStr == "" {
|
||||
ctx.String(http.StatusBadRequest, "Wrong params")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
|
||||
ctx.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
|
||||
<html>
|
||||
<body>
|
||||
<h1>headscale</h1>
|
||||
|
@ -53,45 +57,45 @@ func (h *Headscale) RegisterWebAPI(c *gin.Context) {
|
|||
</body>
|
||||
</html>
|
||||
|
||||
`, mKeyStr)))
|
||||
`, machineKeyStr)))
|
||||
}
|
||||
|
||||
// RegistrationHandler handles the actual registration process of a machine
|
||||
// Endpoint /machine/:id.
|
||||
func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
||||
body, _ := io.ReadAll(c.Request.Body)
|
||||
mKeyStr := c.Param("id")
|
||||
mKey, err := wgkey.ParseHex(mKeyStr)
|
||||
func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
|
||||
body, _ := io.ReadAll(ctx.Request.Body)
|
||||
machineKeyStr := ctx.Param("id")
|
||||
machineKey, err := wgkey.ParseHex(machineKeyStr)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "Registration").
|
||||
Err(err).
|
||||
Msg("Cannot parse machine key")
|
||||
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
|
||||
c.String(http.StatusInternalServerError, "Sad!")
|
||||
ctx.String(http.StatusInternalServerError, "Sad!")
|
||||
|
||||
return
|
||||
}
|
||||
req := tailcfg.RegisterRequest{}
|
||||
err = decode(body, &req, &mKey, h.privateKey)
|
||||
err = decode(body, &req, &machineKey, h.privateKey)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "Registration").
|
||||
Err(err).
|
||||
Msg("Cannot decode message")
|
||||
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
|
||||
c.String(http.StatusInternalServerError, "Very sad!")
|
||||
ctx.String(http.StatusInternalServerError, "Very sad!")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
m, err := h.GetMachineByMachineKey(mKey.HexString())
|
||||
machine, err := h.GetMachineByMachineKey(machineKey.HexString())
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine")
|
||||
newMachine := Machine{
|
||||
Expiry: &time.Time{},
|
||||
MachineKey: mKey.HexString(),
|
||||
MachineKey: machineKey.HexString(),
|
||||
Name: req.Hostinfo.Hostname,
|
||||
}
|
||||
if err := h.db.Create(&newMachine).Error; err != nil {
|
||||
|
@ -99,16 +103,16 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
|||
Str("handler", "Registration").
|
||||
Err(err).
|
||||
Msg("Could not create row")
|
||||
machineRegistrations.WithLabelValues("unknown", "web", "error", m.Namespace.Name).
|
||||
machineRegistrations.WithLabelValues("unknown", "web", "error", machine.Namespace.Name).
|
||||
Inc()
|
||||
|
||||
return
|
||||
}
|
||||
m = &newMachine
|
||||
machine = &newMachine
|
||||
}
|
||||
|
||||
if !m.Registered && req.Auth.AuthKey != "" {
|
||||
h.handleAuthKey(c, h.db, mKey, req, *m)
|
||||
if !machine.Registered && req.Auth.AuthKey != "" {
|
||||
h.handleAuthKey(ctx, h.db, machineKey, req, *machine)
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -116,63 +120,63 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
|||
resp := tailcfg.RegisterResponse{}
|
||||
|
||||
// We have the updated key!
|
||||
if m.NodeKey == wgkey.Key(req.NodeKey).HexString() {
|
||||
if machine.NodeKey == wgkey.Key(req.NodeKey).HexString() {
|
||||
// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout)
|
||||
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
|
||||
if !req.Expiry.IsZero() && req.Expiry.UTC().Before(now) {
|
||||
log.Info().
|
||||
Str("handler", "Registration").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("Client requested logout")
|
||||
|
||||
m.Expiry = &req.Expiry // save the expiry so that the machine is marked as expired
|
||||
h.db.Save(&m)
|
||||
machine.Expiry = &req.Expiry // save the expiry so that the machine is marked as expired
|
||||
h.db.Save(&machine)
|
||||
|
||||
resp.AuthURL = ""
|
||||
resp.MachineAuthorized = false
|
||||
resp.User = *m.Namespace.toUser()
|
||||
respBody, err := encode(resp, &mKey, h.privateKey)
|
||||
resp.User = *machine.Namespace.toUser()
|
||||
respBody, err := encode(resp, &machineKey, h.privateKey)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "Registration").
|
||||
Err(err).
|
||||
Msg("Cannot encode message")
|
||||
c.String(http.StatusInternalServerError, "")
|
||||
ctx.String(http.StatusInternalServerError, "")
|
||||
|
||||
return
|
||||
}
|
||||
c.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
|
||||
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if m.Registered && m.Expiry.UTC().After(now) {
|
||||
if machine.Registered && machine.Expiry.UTC().After(now) {
|
||||
// The machine registration is valid, respond with redirect to /map
|
||||
log.Debug().
|
||||
Str("handler", "Registration").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("Client is registered and we have the current NodeKey. All clear to /map")
|
||||
|
||||
resp.AuthURL = ""
|
||||
resp.MachineAuthorized = true
|
||||
resp.User = *m.Namespace.toUser()
|
||||
resp.Login = *m.Namespace.toLogin()
|
||||
resp.User = *machine.Namespace.toUser()
|
||||
resp.Login = *machine.Namespace.toLogin()
|
||||
|
||||
respBody, err := encode(resp, &mKey, h.privateKey)
|
||||
respBody, err := encode(resp, &machineKey, h.privateKey)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "Registration").
|
||||
Err(err).
|
||||
Msg("Cannot encode message")
|
||||
machineRegistrations.WithLabelValues("update", "web", "error", m.Namespace.Name).
|
||||
machineRegistrations.WithLabelValues("update", "web", "error", machine.Namespace.Name).
|
||||
Inc()
|
||||
c.String(http.StatusInternalServerError, "")
|
||||
ctx.String(http.StatusInternalServerError, "")
|
||||
|
||||
return
|
||||
}
|
||||
machineRegistrations.WithLabelValues("update", "web", "success", m.Namespace.Name).
|
||||
machineRegistrations.WithLabelValues("update", "web", "success", machine.Namespace.Name).
|
||||
Inc()
|
||||
c.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
|
||||
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -180,15 +184,15 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
|||
// The client has registered before, but has expired
|
||||
log.Debug().
|
||||
Str("handler", "Registration").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("Machine registration has expired. Sending a authurl to register")
|
||||
|
||||
if h.cfg.OIDC.Issuer != "" {
|
||||
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
|
||||
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
|
||||
strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString())
|
||||
} else {
|
||||
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
|
||||
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
|
||||
strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString())
|
||||
}
|
||||
|
||||
// When a client connects, it may request a specific expiry time in its
|
||||
|
@ -197,51 +201,52 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
|||
// into two steps (which cant pass arbitrary data between them easily) and needs to be
|
||||
// retrieved again after the user has authenticated. After the authentication flow
|
||||
// completes, RequestedExpiry is copied into Expiry.
|
||||
m.RequestedExpiry = &req.Expiry
|
||||
machine.RequestedExpiry = &req.Expiry
|
||||
|
||||
h.db.Save(&m)
|
||||
h.db.Save(&machine)
|
||||
|
||||
respBody, err := encode(resp, &mKey, h.privateKey)
|
||||
respBody, err := encode(resp, &machineKey, h.privateKey)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "Registration").
|
||||
Err(err).
|
||||
Msg("Cannot encode message")
|
||||
machineRegistrations.WithLabelValues("new", "web", "error", m.Namespace.Name).
|
||||
machineRegistrations.WithLabelValues("new", "web", "error", machine.Namespace.Name).
|
||||
Inc()
|
||||
c.String(http.StatusInternalServerError, "")
|
||||
ctx.String(http.StatusInternalServerError, "")
|
||||
|
||||
return
|
||||
}
|
||||
machineRegistrations.WithLabelValues("new", "web", "success", m.Namespace.Name).
|
||||
machineRegistrations.WithLabelValues("new", "web", "success", machine.Namespace.Name).
|
||||
Inc()
|
||||
c.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
|
||||
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
|
||||
if m.NodeKey == wgkey.Key(req.OldNodeKey).HexString() && m.Expiry.UTC().After(now) {
|
||||
if machine.NodeKey == wgkey.Key(req.OldNodeKey).HexString() &&
|
||||
machine.Expiry.UTC().After(now) {
|
||||
log.Debug().
|
||||
Str("handler", "Registration").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("We have the OldNodeKey in the database. This is a key refresh")
|
||||
m.NodeKey = wgkey.Key(req.NodeKey).HexString()
|
||||
h.db.Save(&m)
|
||||
machine.NodeKey = wgkey.Key(req.NodeKey).HexString()
|
||||
h.db.Save(&machine)
|
||||
|
||||
resp.AuthURL = ""
|
||||
resp.User = *m.Namespace.toUser()
|
||||
respBody, err := encode(resp, &mKey, h.privateKey)
|
||||
resp.User = *machine.Namespace.toUser()
|
||||
respBody, err := encode(resp, &machineKey, h.privateKey)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "Registration").
|
||||
Err(err).
|
||||
Msg("Cannot encode message")
|
||||
c.String(http.StatusInternalServerError, "Extremely sad!")
|
||||
ctx.String(http.StatusInternalServerError, "Extremely sad!")
|
||||
|
||||
return
|
||||
}
|
||||
c.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
|
||||
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -249,47 +254,47 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
|||
// The machine registration is new, redirect the client to the registration URL
|
||||
log.Debug().
|
||||
Str("handler", "Registration").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("The node is sending us a new NodeKey, sending auth url")
|
||||
if h.cfg.OIDC.Issuer != "" {
|
||||
resp.AuthURL = fmt.Sprintf(
|
||||
"%s/oidc/register/%s",
|
||||
strings.TrimSuffix(h.cfg.ServerURL, "/"),
|
||||
mKey.HexString(),
|
||||
machineKey.HexString(),
|
||||
)
|
||||
} else {
|
||||
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
|
||||
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
|
||||
strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString())
|
||||
}
|
||||
|
||||
// save the requested expiry time for retrieval later in the authentication flow
|
||||
m.RequestedExpiry = &req.Expiry
|
||||
m.NodeKey = wgkey.Key(req.NodeKey).HexString() // save the NodeKey
|
||||
h.db.Save(&m)
|
||||
machine.RequestedExpiry = &req.Expiry
|
||||
machine.NodeKey = wgkey.Key(req.NodeKey).HexString() // save the NodeKey
|
||||
h.db.Save(&machine)
|
||||
|
||||
respBody, err := encode(resp, &mKey, h.privateKey)
|
||||
respBody, err := encode(resp, &machineKey, h.privateKey)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "Registration").
|
||||
Err(err).
|
||||
Msg("Cannot encode message")
|
||||
c.String(http.StatusInternalServerError, "")
|
||||
ctx.String(http.StatusInternalServerError, "")
|
||||
|
||||
return
|
||||
}
|
||||
c.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
|
||||
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
|
||||
}
|
||||
|
||||
func (h *Headscale) getMapResponse(
|
||||
mKey wgkey.Key,
|
||||
machineKey wgkey.Key,
|
||||
req tailcfg.MapRequest,
|
||||
m *Machine,
|
||||
machine *Machine,
|
||||
) ([]byte, error) {
|
||||
log.Trace().
|
||||
Str("func", "getMapResponse").
|
||||
Str("machine", req.Hostinfo.Hostname).
|
||||
Msg("Creating Map response")
|
||||
node, err := m.toNode(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
|
||||
node, err := machine.toNode(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("func", "getMapResponse").
|
||||
|
@ -299,7 +304,7 @@ func (h *Headscale) getMapResponse(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
peers, err := h.getPeers(m)
|
||||
peers, err := h.getPeers(machine)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("func", "getMapResponse").
|
||||
|
@ -309,7 +314,7 @@ func (h *Headscale) getMapResponse(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
profiles := getMapResponseUserProfiles(*m, peers)
|
||||
profiles := getMapResponseUserProfiles(*machine, peers)
|
||||
|
||||
nodePeers, err := peers.toNodes(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
|
||||
if err != nil {
|
||||
|
@ -324,7 +329,7 @@ func (h *Headscale) getMapResponse(
|
|||
dnsConfig := getMapResponseDNSConfig(
|
||||
h.cfg.DNSConfig,
|
||||
h.cfg.BaseDomain,
|
||||
*m,
|
||||
*machine,
|
||||
peers,
|
||||
)
|
||||
|
||||
|
@ -351,12 +356,12 @@ func (h *Headscale) getMapResponse(
|
|||
|
||||
encoder, _ := zstd.NewWriter(nil)
|
||||
srcCompressed := encoder.EncodeAll(src, nil)
|
||||
respBody, err = encodeMsg(srcCompressed, &mKey, h.privateKey)
|
||||
respBody, err = encodeMsg(srcCompressed, &machineKey, h.privateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
respBody, err = encode(resp, &mKey, h.privateKey)
|
||||
respBody, err = encode(resp, &machineKey, h.privateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -370,24 +375,24 @@ func (h *Headscale) getMapResponse(
|
|||
}
|
||||
|
||||
func (h *Headscale) getMapKeepAliveResponse(
|
||||
mKey wgkey.Key,
|
||||
req tailcfg.MapRequest,
|
||||
machineKey wgkey.Key,
|
||||
mapRequest tailcfg.MapRequest,
|
||||
) ([]byte, error) {
|
||||
resp := tailcfg.MapResponse{
|
||||
mapResponse := tailcfg.MapResponse{
|
||||
KeepAlive: true,
|
||||
}
|
||||
var respBody []byte
|
||||
var err error
|
||||
if req.Compress == "zstd" {
|
||||
src, _ := json.Marshal(resp)
|
||||
if mapRequest.Compress == "zstd" {
|
||||
src, _ := json.Marshal(mapResponse)
|
||||
encoder, _ := zstd.NewWriter(nil)
|
||||
srcCompressed := encoder.EncodeAll(src, nil)
|
||||
respBody, err = encodeMsg(srcCompressed, &mKey, h.privateKey)
|
||||
respBody, err = encodeMsg(srcCompressed, &machineKey, h.privateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
respBody, err = encode(resp, &mKey, h.privateKey)
|
||||
respBody, err = encode(mapResponse, &machineKey, h.privateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -400,22 +405,22 @@ func (h *Headscale) getMapKeepAliveResponse(
|
|||
}
|
||||
|
||||
func (h *Headscale) handleAuthKey(
|
||||
c *gin.Context,
|
||||
ctx *gin.Context,
|
||||
db *gorm.DB,
|
||||
idKey wgkey.Key,
|
||||
req tailcfg.RegisterRequest,
|
||||
m Machine,
|
||||
reqisterRequest tailcfg.RegisterRequest,
|
||||
machine Machine,
|
||||
) {
|
||||
log.Debug().
|
||||
Str("func", "handleAuthKey").
|
||||
Str("machine", req.Hostinfo.Hostname).
|
||||
Msgf("Processing auth key for %s", req.Hostinfo.Hostname)
|
||||
Str("machine", reqisterRequest.Hostinfo.Hostname).
|
||||
Msgf("Processing auth key for %s", reqisterRequest.Hostinfo.Hostname)
|
||||
resp := tailcfg.RegisterResponse{}
|
||||
pak, err := h.checkKeyValidity(req.Auth.AuthKey)
|
||||
pak, err := h.checkKeyValidity(reqisterRequest.Auth.AuthKey)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("func", "handleAuthKey").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Err(err).
|
||||
Msg("Failed authentication via AuthKey")
|
||||
resp.MachineAuthorized = false
|
||||
|
@ -423,21 +428,21 @@ func (h *Headscale) handleAuthKey(
|
|||
if err != nil {
|
||||
log.Error().
|
||||
Str("func", "handleAuthKey").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Err(err).
|
||||
Msg("Cannot encode message")
|
||||
c.String(http.StatusInternalServerError, "")
|
||||
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).
|
||||
ctx.String(http.StatusInternalServerError, "")
|
||||
machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
|
||||
Inc()
|
||||
|
||||
return
|
||||
}
|
||||
c.Data(http.StatusUnauthorized, "application/json; charset=utf-8", respBody)
|
||||
ctx.Data(http.StatusUnauthorized, "application/json; charset=utf-8", respBody)
|
||||
log.Error().
|
||||
Str("func", "handleAuthKey").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("Failed authentication via AuthKey")
|
||||
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).
|
||||
machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
|
||||
Inc()
|
||||
|
||||
return
|
||||
|
@ -445,32 +450,34 @@ func (h *Headscale) handleAuthKey(
|
|||
|
||||
log.Debug().
|
||||
Str("func", "handleAuthKey").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("Authentication key was valid, proceeding to acquire an IP address")
|
||||
ip, err := h.getAvailableIP()
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("func", "handleAuthKey").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("Failed to find an available IP")
|
||||
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).
|
||||
machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
|
||||
Inc()
|
||||
|
||||
return
|
||||
}
|
||||
log.Info().
|
||||
Str("func", "handleAuthKey").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Str("ip", ip.String()).
|
||||
Msgf("Assigning %s to %s", ip, m.Name)
|
||||
Msgf("Assigning %s to %s", ip, machine.Name)
|
||||
|
||||
m.AuthKeyID = uint(pak.ID)
|
||||
m.IPAddress = ip.String()
|
||||
m.NamespaceID = pak.NamespaceID
|
||||
m.NodeKey = wgkey.Key(req.NodeKey).HexString() // we update it just in case
|
||||
m.Registered = true
|
||||
m.RegisterMethod = "authKey"
|
||||
db.Save(&m)
|
||||
machine.AuthKeyID = uint(pak.ID)
|
||||
machine.IPAddress = ip.String()
|
||||
machine.NamespaceID = pak.NamespaceID
|
||||
machine.NodeKey = wgkey.Key(reqisterRequest.NodeKey).
|
||||
HexString()
|
||||
// we update it just in case
|
||||
machine.Registered = true
|
||||
machine.RegisterMethod = "authKey"
|
||||
db.Save(&machine)
|
||||
|
||||
pak.Used = true
|
||||
db.Save(&pak)
|
||||
|
@ -481,21 +488,21 @@ func (h *Headscale) handleAuthKey(
|
|||
if err != nil {
|
||||
log.Error().
|
||||
Str("func", "handleAuthKey").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Err(err).
|
||||
Msg("Cannot encode message")
|
||||
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).
|
||||
machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
|
||||
Inc()
|
||||
c.String(http.StatusInternalServerError, "Extremely sad!")
|
||||
ctx.String(http.StatusInternalServerError, "Extremely sad!")
|
||||
|
||||
return
|
||||
}
|
||||
machineRegistrations.WithLabelValues("new", "authkey", "success", m.Namespace.Name).
|
||||
machineRegistrations.WithLabelValues("new", "authkey", "success", machine.Namespace.Name).
|
||||
Inc()
|
||||
c.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
|
||||
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
|
||||
log.Info().
|
||||
Str("func", "handleAuthKey").
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Str("ip", ip.String()).
|
||||
Msg("Successfully authenticated via AuthKey")
|
||||
}
|
||||
|
|
140
app.go
140
app.go
|
@ -169,7 +169,7 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
|
|||
return nil, errors.New("unsupported DB")
|
||||
}
|
||||
|
||||
h := Headscale{
|
||||
app := Headscale{
|
||||
cfg: cfg,
|
||||
dbType: cfg.DBtype,
|
||||
dbString: dbString,
|
||||
|
@ -178,32 +178,32 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
|
|||
aclRules: tailcfg.FilterAllowAll, // default allowall
|
||||
}
|
||||
|
||||
err = h.initDB()
|
||||
err = app.initDB()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if cfg.OIDC.Issuer != "" {
|
||||
err = h.initOIDC()
|
||||
err = app.initOIDC()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if h.cfg.DNSConfig != nil && h.cfg.DNSConfig.Proxied { // if MagicDNS
|
||||
if app.cfg.DNSConfig != nil && app.cfg.DNSConfig.Proxied { // if MagicDNS
|
||||
magicDNSDomains := generateMagicDNSRootDomains(
|
||||
h.cfg.IPPrefix,
|
||||
app.cfg.IPPrefix,
|
||||
)
|
||||
// we might have routes already from Split DNS
|
||||
if h.cfg.DNSConfig.Routes == nil {
|
||||
h.cfg.DNSConfig.Routes = make(map[string][]dnstype.Resolver)
|
||||
if app.cfg.DNSConfig.Routes == nil {
|
||||
app.cfg.DNSConfig.Routes = make(map[string][]dnstype.Resolver)
|
||||
}
|
||||
for _, d := range magicDNSDomains {
|
||||
h.cfg.DNSConfig.Routes[d.WithoutTrailingDot()] = nil
|
||||
app.cfg.DNSConfig.Routes[d.WithoutTrailingDot()] = nil
|
||||
}
|
||||
}
|
||||
|
||||
return &h, nil
|
||||
return &app, nil
|
||||
}
|
||||
|
||||
// Redirect to our TLS url.
|
||||
|
@ -229,35 +229,37 @@ func (h *Headscale) expireEphemeralNodesWorker() {
|
|||
return
|
||||
}
|
||||
|
||||
for _, ns := range namespaces {
|
||||
machines, err := h.ListMachinesInNamespace(ns.Name)
|
||||
for _, namespace := range namespaces {
|
||||
machines, err := h.ListMachinesInNamespace(namespace.Name)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("namespace", ns.Name).
|
||||
Str("namespace", namespace.Name).
|
||||
Msg("Error listing machines in namespace")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
for _, m := range machines {
|
||||
if m.AuthKey != nil && m.LastSeen != nil && m.AuthKey.Ephemeral &&
|
||||
time.Now().After(m.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) {
|
||||
for _, machine := range machines {
|
||||
if machine.AuthKey != nil && machine.LastSeen != nil &&
|
||||
machine.AuthKey.Ephemeral &&
|
||||
time.Now().
|
||||
After(machine.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) {
|
||||
log.Info().
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("Ephemeral client removed from database")
|
||||
|
||||
err = h.db.Unscoped().Delete(m).Error
|
||||
err = h.db.Unscoped().Delete(machine).Error
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("🤮 Cannot delete ephemeral machine from the database")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
h.setLastStateChangeToNow(ns.Name)
|
||||
h.setLastStateChangeToNow(namespace.Name)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -284,18 +286,18 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
|
|||
// with the "legacy" database-based client
|
||||
// It is also neede for grpc-gateway to be able to connect to
|
||||
// the server
|
||||
p, _ := peer.FromContext(ctx)
|
||||
client, _ := peer.FromContext(ctx)
|
||||
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("client_address", p.Addr.String()).
|
||||
Str("client_address", client.Addr.String()).
|
||||
Msg("Client is trying to authenticate")
|
||||
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
meta, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
log.Error().
|
||||
Caller().
|
||||
Str("client_address", p.Addr.String()).
|
||||
Str("client_address", client.Addr.String()).
|
||||
Msg("Retrieving metadata is failed")
|
||||
|
||||
return ctx, status.Errorf(
|
||||
|
@ -304,11 +306,11 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
|
|||
)
|
||||
}
|
||||
|
||||
authHeader, ok := md["authorization"]
|
||||
authHeader, ok := meta["authorization"]
|
||||
if !ok {
|
||||
log.Error().
|
||||
Caller().
|
||||
Str("client_address", p.Addr.String()).
|
||||
Str("client_address", client.Addr.String()).
|
||||
Msg("Authorization token is not supplied")
|
||||
|
||||
return ctx, status.Errorf(
|
||||
|
@ -322,7 +324,7 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
|
|||
if !strings.HasPrefix(token, AUTH_PREFIX) {
|
||||
log.Error().
|
||||
Caller().
|
||||
Str("client_address", p.Addr.String()).
|
||||
Str("client_address", client.Addr.String()).
|
||||
Msg(`missing "Bearer " prefix in "Authorization" header`)
|
||||
|
||||
return ctx, status.Error(
|
||||
|
@ -353,25 +355,25 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
|
|||
// return handler(ctx, req)
|
||||
}
|
||||
|
||||
func (h *Headscale) httpAuthenticationMiddleware(c *gin.Context) {
|
||||
func (h *Headscale) httpAuthenticationMiddleware(ctx *gin.Context) {
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("client_address", c.ClientIP()).
|
||||
Str("client_address", ctx.ClientIP()).
|
||||
Msg("HTTP authentication invoked")
|
||||
|
||||
authHeader := c.GetHeader("authorization")
|
||||
authHeader := ctx.GetHeader("authorization")
|
||||
|
||||
if !strings.HasPrefix(authHeader, AUTH_PREFIX) {
|
||||
log.Error().
|
||||
Caller().
|
||||
Str("client_address", c.ClientIP()).
|
||||
Str("client_address", ctx.ClientIP()).
|
||||
Msg(`missing "Bearer " prefix in "Authorization" header`)
|
||||
c.AbortWithStatus(http.StatusUnauthorized)
|
||||
ctx.AbortWithStatus(http.StatusUnauthorized)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
c.AbortWithStatus(http.StatusUnauthorized)
|
||||
ctx.AbortWithStatus(http.StatusUnauthorized)
|
||||
|
||||
// TODO(kradalby): Implement API key backend
|
||||
// Currently all traffic is unauthorized, this is intentional to allow
|
||||
|
@ -438,9 +440,9 @@ func (h *Headscale) Serve() error {
|
|||
|
||||
// Create the cmux object that will multiplex 2 protocols on the same port.
|
||||
// The two following listeners will be served on the same port below gracefully.
|
||||
m := cmux.New(networkListener)
|
||||
networkMutex := cmux.New(networkListener)
|
||||
// Match gRPC requests here
|
||||
grpcListener := m.MatchWithWriters(
|
||||
grpcListener := networkMutex.MatchWithWriters(
|
||||
cmux.HTTP2MatchHeaderFieldSendSettings("content-type", "application/grpc"),
|
||||
cmux.HTTP2MatchHeaderFieldSendSettings(
|
||||
"content-type",
|
||||
|
@ -448,7 +450,7 @@ func (h *Headscale) Serve() error {
|
|||
),
|
||||
)
|
||||
// Otherwise match regular http requests.
|
||||
httpListener := m.Match(cmux.Any())
|
||||
httpListener := networkMutex.Match(cmux.Any())
|
||||
|
||||
grpcGatewayMux := runtime.NewServeMux()
|
||||
|
||||
|
@ -471,33 +473,33 @@ func (h *Headscale) Serve() error {
|
|||
return err
|
||||
}
|
||||
|
||||
r := gin.Default()
|
||||
router := gin.Default()
|
||||
|
||||
p := ginprometheus.NewPrometheus("gin")
|
||||
p.Use(r)
|
||||
prometheus := ginprometheus.NewPrometheus("gin")
|
||||
prometheus.Use(router)
|
||||
|
||||
r.GET(
|
||||
router.GET(
|
||||
"/health",
|
||||
func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"healthy": "ok"}) },
|
||||
)
|
||||
r.GET("/key", h.KeyHandler)
|
||||
r.GET("/register", h.RegisterWebAPI)
|
||||
r.POST("/machine/:id/map", h.PollNetMapHandler)
|
||||
r.POST("/machine/:id", h.RegistrationHandler)
|
||||
r.GET("/oidc/register/:mkey", h.RegisterOIDC)
|
||||
r.GET("/oidc/callback", h.OIDCCallback)
|
||||
r.GET("/apple", h.AppleMobileConfig)
|
||||
r.GET("/apple/:platform", h.ApplePlatformConfig)
|
||||
r.GET("/swagger", SwaggerUI)
|
||||
r.GET("/swagger/v1/openapiv2.json", SwaggerAPIv1)
|
||||
router.GET("/key", h.KeyHandler)
|
||||
router.GET("/register", h.RegisterWebAPI)
|
||||
router.POST("/machine/:id/map", h.PollNetMapHandler)
|
||||
router.POST("/machine/:id", h.RegistrationHandler)
|
||||
router.GET("/oidc/register/:mkey", h.RegisterOIDC)
|
||||
router.GET("/oidc/callback", h.OIDCCallback)
|
||||
router.GET("/apple", h.AppleMobileConfig)
|
||||
router.GET("/apple/:platform", h.ApplePlatformConfig)
|
||||
router.GET("/swagger", SwaggerUI)
|
||||
router.GET("/swagger/v1/openapiv2.json", SwaggerAPIv1)
|
||||
|
||||
api := r.Group("/api")
|
||||
api := router.Group("/api")
|
||||
api.Use(h.httpAuthenticationMiddleware)
|
||||
{
|
||||
api.Any("/v1/*any", gin.WrapF(grpcGatewayMux.ServeHTTP))
|
||||
}
|
||||
|
||||
r.NoRoute(stdoutHandler)
|
||||
router.NoRoute(stdoutHandler)
|
||||
|
||||
// Fetch an initial DERP Map before we start serving
|
||||
h.DERPMap = GetDERPMap(h.cfg.DERP)
|
||||
|
@ -514,7 +516,7 @@ func (h *Headscale) Serve() error {
|
|||
|
||||
httpServer := &http.Server{
|
||||
Addr: h.cfg.Addr,
|
||||
Handler: r,
|
||||
Handler: router,
|
||||
ReadTimeout: HTTP_READ_TIMEOUT,
|
||||
// Go does not handle timeouts in HTTP very well, and there is
|
||||
// no good way to handle streaming timeouts, therefore we need to
|
||||
|
@ -561,29 +563,29 @@ func (h *Headscale) Serve() error {
|
|||
reflection.Register(grpcServer)
|
||||
reflection.Register(grpcSocket)
|
||||
|
||||
g := new(errgroup.Group)
|
||||
errorGroup := new(errgroup.Group)
|
||||
|
||||
g.Go(func() error { return grpcSocket.Serve(socketListener) })
|
||||
errorGroup.Go(func() error { return grpcSocket.Serve(socketListener) })
|
||||
|
||||
// TODO(kradalby): Verify if we need the same TLS setup for gRPC as HTTP
|
||||
g.Go(func() error { return grpcServer.Serve(grpcListener) })
|
||||
errorGroup.Go(func() error { return grpcServer.Serve(grpcListener) })
|
||||
|
||||
if tlsConfig != nil {
|
||||
g.Go(func() error {
|
||||
errorGroup.Go(func() error {
|
||||
tlsl := tls.NewListener(httpListener, tlsConfig)
|
||||
|
||||
return httpServer.Serve(tlsl)
|
||||
})
|
||||
} else {
|
||||
g.Go(func() error { return httpServer.Serve(httpListener) })
|
||||
errorGroup.Go(func() error { return httpServer.Serve(httpListener) })
|
||||
}
|
||||
|
||||
g.Go(func() error { return m.Serve() })
|
||||
errorGroup.Go(func() error { return networkMutex.Serve() })
|
||||
|
||||
log.Info().
|
||||
Msgf("listening and serving (multiplexed HTTP and gRPC) on: %s", h.cfg.Addr)
|
||||
|
||||
return g.Wait()
|
||||
return errorGroup.Wait()
|
||||
}
|
||||
|
||||
func (h *Headscale) getTLSSettings() (*tls.Config, error) {
|
||||
|
@ -594,7 +596,7 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
|
|||
Msg("Listening with TLS but ServerURL does not start with https://")
|
||||
}
|
||||
|
||||
m := autocert.Manager{
|
||||
certManager := autocert.Manager{
|
||||
Prompt: autocert.AcceptTOS,
|
||||
HostPolicy: autocert.HostWhitelist(h.cfg.TLSLetsEncryptHostname),
|
||||
Cache: autocert.DirCache(h.cfg.TLSLetsEncryptCacheDir),
|
||||
|
@ -609,7 +611,7 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
|
|||
// Configuration via autocert with TLS-ALPN-01 (https://tools.ietf.org/html/rfc8737)
|
||||
// The RFC requires that the validation is done on port 443; in other words, headscale
|
||||
// must be reachable on port 443.
|
||||
return m.TLSConfig(), nil
|
||||
return certManager.TLSConfig(), nil
|
||||
|
||||
case "HTTP-01":
|
||||
// Configuration via autocert with HTTP-01. This requires listening on
|
||||
|
@ -617,11 +619,11 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
|
|||
// service, which can be configured to run on any other port.
|
||||
go func() {
|
||||
log.Fatal().
|
||||
Err(http.ListenAndServe(h.cfg.TLSLetsEncryptListen, m.HTTPHandler(http.HandlerFunc(h.redirect)))).
|
||||
Err(http.ListenAndServe(h.cfg.TLSLetsEncryptListen, certManager.HTTPHandler(http.HandlerFunc(h.redirect)))).
|
||||
Msg("failed to set up a HTTP server")
|
||||
}()
|
||||
|
||||
return m.TLSConfig(), nil
|
||||
return certManager.TLSConfig(), nil
|
||||
|
||||
default:
|
||||
return nil, errors.New("unknown value for TLSLetsEncryptChallengeType")
|
||||
|
@ -676,13 +678,13 @@ func (h *Headscale) getLastStateChange(namespaces ...string) time.Time {
|
|||
}
|
||||
}
|
||||
|
||||
func stdoutHandler(c *gin.Context) {
|
||||
b, _ := io.ReadAll(c.Request.Body)
|
||||
func stdoutHandler(ctx *gin.Context) {
|
||||
body, _ := io.ReadAll(ctx.Request.Body)
|
||||
|
||||
log.Trace().
|
||||
Interface("header", c.Request.Header).
|
||||
Interface("proto", c.Request.Proto).
|
||||
Interface("url", c.Request.URL).
|
||||
Bytes("body", b).
|
||||
Interface("header", ctx.Request.Header).
|
||||
Interface("proto", ctx.Request.Proto).
|
||||
Interface("url", ctx.Request.URL).
|
||||
Bytes("body", body).
|
||||
Msg("Request did not match")
|
||||
}
|
||||
|
|
|
@ -12,8 +12,8 @@ import (
|
|||
|
||||
// AppleMobileConfig shows a simple message in the browser to point to the CLI
|
||||
// Listens in /register.
|
||||
func (h *Headscale) AppleMobileConfig(c *gin.Context) {
|
||||
t := template.Must(template.New("apple").Parse(`
|
||||
func (h *Headscale) AppleMobileConfig(ctx *gin.Context) {
|
||||
appleTemplate := template.Must(template.New("apple").Parse(`
|
||||
<html>
|
||||
<body>
|
||||
<h1>Apple configuration profiles</h1>
|
||||
|
@ -67,12 +67,12 @@ func (h *Headscale) AppleMobileConfig(c *gin.Context) {
|
|||
}
|
||||
|
||||
var payload bytes.Buffer
|
||||
if err := t.Execute(&payload, config); err != nil {
|
||||
if err := appleTemplate.Execute(&payload, config); err != nil {
|
||||
log.Error().
|
||||
Str("handler", "AppleMobileConfig").
|
||||
Err(err).
|
||||
Msg("Could not render Apple index template")
|
||||
c.Data(
|
||||
ctx.Data(
|
||||
http.StatusInternalServerError,
|
||||
"text/html; charset=utf-8",
|
||||
[]byte("Could not render Apple index template"),
|
||||
|
@ -81,11 +81,11 @@ func (h *Headscale) AppleMobileConfig(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
c.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes())
|
||||
ctx.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes())
|
||||
}
|
||||
|
||||
func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
|
||||
platform := c.Param("platform")
|
||||
func (h *Headscale) ApplePlatformConfig(ctx *gin.Context) {
|
||||
platform := ctx.Param("platform")
|
||||
|
||||
id, err := uuid.NewV4()
|
||||
if err != nil {
|
||||
|
@ -93,7 +93,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
|
|||
Str("handler", "ApplePlatformConfig").
|
||||
Err(err).
|
||||
Msg("Failed not create UUID")
|
||||
c.Data(
|
||||
ctx.Data(
|
||||
http.StatusInternalServerError,
|
||||
"text/html; charset=utf-8",
|
||||
[]byte("Failed to create UUID"),
|
||||
|
@ -108,7 +108,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
|
|||
Str("handler", "ApplePlatformConfig").
|
||||
Err(err).
|
||||
Msg("Failed not create UUID")
|
||||
c.Data(
|
||||
ctx.Data(
|
||||
http.StatusInternalServerError,
|
||||
"text/html; charset=utf-8",
|
||||
[]byte("Failed to create UUID"),
|
||||
|
@ -131,7 +131,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
|
|||
Str("handler", "ApplePlatformConfig").
|
||||
Err(err).
|
||||
Msg("Could not render Apple macOS template")
|
||||
c.Data(
|
||||
ctx.Data(
|
||||
http.StatusInternalServerError,
|
||||
"text/html; charset=utf-8",
|
||||
[]byte("Could not render Apple macOS template"),
|
||||
|
@ -145,7 +145,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
|
|||
Str("handler", "ApplePlatformConfig").
|
||||
Err(err).
|
||||
Msg("Could not render Apple iOS template")
|
||||
c.Data(
|
||||
ctx.Data(
|
||||
http.StatusInternalServerError,
|
||||
"text/html; charset=utf-8",
|
||||
[]byte("Could not render Apple iOS template"),
|
||||
|
@ -154,7 +154,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
default:
|
||||
c.Data(
|
||||
ctx.Data(
|
||||
http.StatusOK,
|
||||
"text/html; charset=utf-8",
|
||||
[]byte("Invalid platform, only ios and macos is supported"),
|
||||
|
@ -175,7 +175,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
|
|||
Str("handler", "ApplePlatformConfig").
|
||||
Err(err).
|
||||
Msg("Could not render Apple platform template")
|
||||
c.Data(
|
||||
ctx.Data(
|
||||
http.StatusInternalServerError,
|
||||
"text/html; charset=utf-8",
|
||||
[]byte("Could not render Apple platform template"),
|
||||
|
@ -184,7 +184,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
c.Data(
|
||||
ctx.Data(
|
||||
http.StatusOK,
|
||||
"application/x-apple-aspen-config; charset=utf-8",
|
||||
content.Bytes(),
|
||||
|
|
|
@ -167,10 +167,10 @@ var listNamespacesCmd = &cobra.Command{
|
|||
return
|
||||
}
|
||||
|
||||
d := pterm.TableData{{"ID", "Name", "Created"}}
|
||||
tableData := pterm.TableData{{"ID", "Name", "Created"}}
|
||||
for _, namespace := range response.GetNamespaces() {
|
||||
d = append(
|
||||
d,
|
||||
tableData = append(
|
||||
tableData,
|
||||
[]string{
|
||||
namespace.GetId(),
|
||||
namespace.GetName(),
|
||||
|
@ -178,7 +178,7 @@ var listNamespacesCmd = &cobra.Command{
|
|||
},
|
||||
)
|
||||
}
|
||||
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
|
||||
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
|
|
|
@ -157,14 +157,14 @@ var listNodesCmd = &cobra.Command{
|
|||
return
|
||||
}
|
||||
|
||||
d, err := nodesToPtables(namespace, response.Machines)
|
||||
tableData, err := nodesToPtables(namespace, response.Machines)
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
|
||||
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
|
@ -183,7 +183,7 @@ var deleteNodeCmd = &cobra.Command{
|
|||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
|
||||
id, err := cmd.Flags().GetInt("identifier")
|
||||
identifier, err := cmd.Flags().GetInt("identifier")
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
|
@ -199,7 +199,7 @@ var deleteNodeCmd = &cobra.Command{
|
|||
defer conn.Close()
|
||||
|
||||
getRequest := &v1.GetMachineRequest{
|
||||
MachineId: uint64(id),
|
||||
MachineId: uint64(identifier),
|
||||
}
|
||||
|
||||
getResponse, err := client.GetMachine(ctx, getRequest)
|
||||
|
@ -217,7 +217,7 @@ var deleteNodeCmd = &cobra.Command{
|
|||
}
|
||||
|
||||
deleteRequest := &v1.DeleteMachineRequest{
|
||||
MachineId: uint64(id),
|
||||
MachineId: uint64(identifier),
|
||||
}
|
||||
|
||||
confirm := false
|
||||
|
@ -280,7 +280,7 @@ func sharingWorker(
|
|||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
id, err := cmd.Flags().GetInt("identifier")
|
||||
identifier, err := cmd.Flags().GetInt("identifier")
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error converting ID to integer: %s", err), output)
|
||||
|
||||
|
@ -288,7 +288,7 @@ func sharingWorker(
|
|||
}
|
||||
|
||||
machineRequest := &v1.GetMachineRequest{
|
||||
MachineId: uint64(id),
|
||||
MachineId: uint64(identifier),
|
||||
}
|
||||
|
||||
machineResponse, err := client.GetMachine(ctx, machineRequest)
|
||||
|
@ -402,7 +402,7 @@ func nodesToPtables(
|
|||
currentNamespace string,
|
||||
machines []*v1.Machine,
|
||||
) (pterm.TableData, error) {
|
||||
d := pterm.TableData{
|
||||
tableData := pterm.TableData{
|
||||
{
|
||||
"ID",
|
||||
"Name",
|
||||
|
@ -448,8 +448,8 @@ func nodesToPtables(
|
|||
// Shared into this namespace
|
||||
namespace = pterm.LightYellow(machine.Namespace.Name)
|
||||
}
|
||||
d = append(
|
||||
d,
|
||||
tableData = append(
|
||||
tableData,
|
||||
[]string{
|
||||
strconv.FormatUint(machine.Id, headscale.BASE_10),
|
||||
machine.Name,
|
||||
|
@ -463,5 +463,5 @@ func nodesToPtables(
|
|||
)
|
||||
}
|
||||
|
||||
return d, nil
|
||||
return tableData, nil
|
||||
}
|
||||
|
|
|
@ -45,7 +45,7 @@ var listPreAuthKeys = &cobra.Command{
|
|||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
|
||||
n, err := cmd.Flags().GetString("namespace")
|
||||
namespace, err := cmd.Flags().GetString("namespace")
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
|
||||
|
||||
|
@ -57,7 +57,7 @@ var listPreAuthKeys = &cobra.Command{
|
|||
defer conn.Close()
|
||||
|
||||
request := &v1.ListPreAuthKeysRequest{
|
||||
Namespace: n,
|
||||
Namespace: namespace,
|
||||
}
|
||||
|
||||
response, err := client.ListPreAuthKeys(ctx, request)
|
||||
|
@ -77,34 +77,34 @@ var listPreAuthKeys = &cobra.Command{
|
|||
return
|
||||
}
|
||||
|
||||
d := pterm.TableData{
|
||||
tableData := pterm.TableData{
|
||||
{"ID", "Key", "Reusable", "Ephemeral", "Used", "Expiration", "Created"},
|
||||
}
|
||||
for _, k := range response.PreAuthKeys {
|
||||
for _, key := range response.PreAuthKeys {
|
||||
expiration := "-"
|
||||
if k.GetExpiration() != nil {
|
||||
expiration = k.Expiration.AsTime().Format("2006-01-02 15:04:05")
|
||||
if key.GetExpiration() != nil {
|
||||
expiration = key.Expiration.AsTime().Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
var reusable string
|
||||
if k.GetEphemeral() {
|
||||
if key.GetEphemeral() {
|
||||
reusable = "N/A"
|
||||
} else {
|
||||
reusable = fmt.Sprintf("%v", k.GetReusable())
|
||||
reusable = fmt.Sprintf("%v", key.GetReusable())
|
||||
}
|
||||
|
||||
d = append(d, []string{
|
||||
k.GetId(),
|
||||
k.GetKey(),
|
||||
tableData = append(tableData, []string{
|
||||
key.GetId(),
|
||||
key.GetKey(),
|
||||
reusable,
|
||||
strconv.FormatBool(k.GetEphemeral()),
|
||||
strconv.FormatBool(k.GetUsed()),
|
||||
strconv.FormatBool(key.GetEphemeral()),
|
||||
strconv.FormatBool(key.GetUsed()),
|
||||
expiration,
|
||||
k.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"),
|
||||
key.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"),
|
||||
})
|
||||
|
||||
}
|
||||
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
|
||||
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
|
|
|
@ -81,14 +81,14 @@ var listRoutesCmd = &cobra.Command{
|
|||
return
|
||||
}
|
||||
|
||||
d := routesToPtables(response.Routes)
|
||||
tableData := routesToPtables(response.Routes)
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
|
||||
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
|
@ -162,14 +162,14 @@ omit the route you do not want to enable.
|
|||
return
|
||||
}
|
||||
|
||||
d := routesToPtables(response.Routes)
|
||||
tableData := routesToPtables(response.Routes)
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
|
||||
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
|
@ -184,15 +184,15 @@ omit the route you do not want to enable.
|
|||
|
||||
// routesToPtables converts the list of routes to a nice table.
|
||||
func routesToPtables(routes *v1.Routes) pterm.TableData {
|
||||
d := pterm.TableData{{"Route", "Enabled"}}
|
||||
tableData := pterm.TableData{{"Route", "Enabled"}}
|
||||
|
||||
for _, route := range routes.GetAdvertisedRoutes() {
|
||||
enabled := isStringInSlice(routes.EnabledRoutes, route)
|
||||
|
||||
d = append(d, []string{route, strconv.FormatBool(enabled)})
|
||||
tableData = append(tableData, []string{route, strconv.FormatBool(enabled)})
|
||||
}
|
||||
|
||||
return d
|
||||
return tableData
|
||||
}
|
||||
|
||||
func isStringInSlice(strs []string, s string) bool {
|
||||
|
|
|
@ -318,7 +318,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
|
|||
|
||||
cfg.OIDC.MatchMap = loadOIDCMatchMap()
|
||||
|
||||
h, err := headscale.NewHeadscale(cfg)
|
||||
app, err := headscale.NewHeadscale(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -327,7 +327,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
|
|||
|
||||
if viper.GetString("acl_policy_path") != "" {
|
||||
aclPath := absPath(viper.GetString("acl_policy_path"))
|
||||
err = h.LoadACLPolicy(aclPath)
|
||||
err = app.LoadACLPolicy(aclPath)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("path", aclPath).
|
||||
|
@ -336,7 +336,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
|
|||
}
|
||||
}
|
||||
|
||||
return h, nil
|
||||
return app, nil
|
||||
}
|
||||
|
||||
func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.ClientConn, context.CancelFunc) {
|
||||
|
|
6
dns.go
6
dns.go
|
@ -79,7 +79,7 @@ func generateMagicDNSRootDomains(
|
|||
func getMapResponseDNSConfig(
|
||||
dnsConfigOrig *tailcfg.DNSConfig,
|
||||
baseDomain string,
|
||||
m Machine,
|
||||
machine Machine,
|
||||
peers Machines,
|
||||
) *tailcfg.DNSConfig {
|
||||
var dnsConfig *tailcfg.DNSConfig
|
||||
|
@ -88,11 +88,11 @@ func getMapResponseDNSConfig(
|
|||
dnsConfig = dnsConfigOrig.Clone()
|
||||
dnsConfig.Domains = append(
|
||||
dnsConfig.Domains,
|
||||
fmt.Sprintf("%s.%s", m.Namespace.Name, baseDomain),
|
||||
fmt.Sprintf("%s.%s", machine.Namespace.Name, baseDomain),
|
||||
)
|
||||
|
||||
namespaceSet := set.New(set.ThreadSafe)
|
||||
namespaceSet.Add(m.Namespace)
|
||||
namespaceSet.Add(machine.Namespace)
|
||||
for _, p := range peers {
|
||||
namespaceSet.Add(p.Namespace)
|
||||
}
|
||||
|
|
299
machine.go
299
machine.go
|
@ -56,21 +56,21 @@ type (
|
|||
)
|
||||
|
||||
// For the time being this method is rather naive.
|
||||
func (m Machine) isAlreadyRegistered() bool {
|
||||
return m.Registered
|
||||
func (machine Machine) isAlreadyRegistered() bool {
|
||||
return machine.Registered
|
||||
}
|
||||
|
||||
// isExpired returns whether the machine registration has expired.
|
||||
func (m Machine) isExpired() bool {
|
||||
return time.Now().UTC().After(*m.Expiry)
|
||||
func (machine Machine) isExpired() bool {
|
||||
return time.Now().UTC().After(*machine.Expiry)
|
||||
}
|
||||
|
||||
// If the Machine is expired, updateMachineExpiry updates the Machine Expiry time to the maximum allowed duration,
|
||||
// or the default duration if no Expiry time was requested by the client. The expiry time here does not (yet) cause
|
||||
// a client to be disconnected, however they will have to re-auth the machine if they attempt to reconnect after the
|
||||
// expiry time.
|
||||
func (h *Headscale) updateMachineExpiry(m *Machine) {
|
||||
if m.isExpired() {
|
||||
func (h *Headscale) updateMachineExpiry(machine *Machine) {
|
||||
if machine.isExpired() {
|
||||
now := time.Now().UTC()
|
||||
maxExpiry := now.Add(
|
||||
h.cfg.MaxMachineRegistrationDuration,
|
||||
|
@ -80,31 +80,31 @@ func (h *Headscale) updateMachineExpiry(m *Machine) {
|
|||
) // calculate the default expiry
|
||||
|
||||
// clamp the expiry time of the machine registration to the maximum allowed, or use the default if none supplied
|
||||
if maxExpiry.Before(*m.RequestedExpiry) {
|
||||
if maxExpiry.Before(*machine.RequestedExpiry) {
|
||||
log.Debug().
|
||||
Msgf("Clamping registration expiry time to maximum: %v (%v)", maxExpiry, h.cfg.MaxMachineRegistrationDuration)
|
||||
m.Expiry = &maxExpiry
|
||||
} else if m.RequestedExpiry.IsZero() {
|
||||
machine.Expiry = &maxExpiry
|
||||
} else if machine.RequestedExpiry.IsZero() {
|
||||
log.Debug().Msgf("Using default machine registration expiry time: %v (%v)", defaultExpiry, h.cfg.DefaultMachineRegistrationDuration)
|
||||
m.Expiry = &defaultExpiry
|
||||
machine.Expiry = &defaultExpiry
|
||||
} else {
|
||||
log.Debug().Msgf("Using requested machine registration expiry time: %v", m.RequestedExpiry)
|
||||
m.Expiry = m.RequestedExpiry
|
||||
log.Debug().Msgf("Using requested machine registration expiry time: %v", machine.RequestedExpiry)
|
||||
machine.Expiry = machine.RequestedExpiry
|
||||
}
|
||||
|
||||
h.db.Save(&m)
|
||||
h.db.Save(&machine)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) {
|
||||
func (h *Headscale) getDirectPeers(machine *Machine) (Machines, error) {
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("Finding direct peers")
|
||||
|
||||
machines := Machines{}
|
||||
if err := h.db.Preload("Namespace").Where("namespace_id = ? AND machine_key <> ? AND registered",
|
||||
m.NamespaceID, m.MachineKey).Find(&machines).Error; err != nil {
|
||||
machine.NamespaceID, machine.MachineKey).Find(&machines).Error; err != nil {
|
||||
log.Error().Err(err).Msg("Error accessing db")
|
||||
|
||||
return Machines{}, err
|
||||
|
@ -114,22 +114,22 @@ func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) {
|
|||
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msgf("Found direct machines: %s", machines.String())
|
||||
|
||||
return machines, nil
|
||||
}
|
||||
|
||||
// getShared fetches machines that are shared to the `Namespace` of the machine we are getting peers for.
|
||||
func (h *Headscale) getShared(m *Machine) (Machines, error) {
|
||||
func (h *Headscale) getShared(machine *Machine) (Machines, error) {
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("Finding shared peers")
|
||||
|
||||
sharedMachines := []SharedMachine{}
|
||||
if err := h.db.Preload("Namespace").Preload("Machine").Preload("Machine.Namespace").Where("namespace_id = ?",
|
||||
m.NamespaceID).Find(&sharedMachines).Error; err != nil {
|
||||
machine.NamespaceID).Find(&sharedMachines).Error; err != nil {
|
||||
return Machines{}, err
|
||||
}
|
||||
|
||||
|
@ -142,22 +142,22 @@ func (h *Headscale) getShared(m *Machine) (Machines, error) {
|
|||
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msgf("Found shared peers: %s", peers.String())
|
||||
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
// getSharedTo fetches the machines of the namespaces this machine is shared in.
|
||||
func (h *Headscale) getSharedTo(m *Machine) (Machines, error) {
|
||||
func (h *Headscale) getSharedTo(machine *Machine) (Machines, error) {
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("Finding peers in namespaces this machine is shared with")
|
||||
|
||||
sharedMachines := []SharedMachine{}
|
||||
if err := h.db.Preload("Namespace").Preload("Machine").Preload("Machine.Namespace").Where("machine_id = ?",
|
||||
m.ID).Find(&sharedMachines).Error; err != nil {
|
||||
machine.ID).Find(&sharedMachines).Error; err != nil {
|
||||
return Machines{}, err
|
||||
}
|
||||
|
||||
|
@ -176,14 +176,14 @@ func (h *Headscale) getSharedTo(m *Machine) (Machines, error) {
|
|||
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msgf("Found peers we are shared with: %s", peers.String())
|
||||
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
func (h *Headscale) getPeers(m *Machine) (Machines, error) {
|
||||
direct, err := h.getDirectPeers(m)
|
||||
func (h *Headscale) getPeers(machine *Machine) (Machines, error) {
|
||||
direct, err := h.getDirectPeers(machine)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
|
@ -193,7 +193,7 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) {
|
|||
return Machines{}, err
|
||||
}
|
||||
|
||||
shared, err := h.getShared(m)
|
||||
shared, err := h.getShared(machine)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
|
@ -203,7 +203,7 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) {
|
|||
return Machines{}, err
|
||||
}
|
||||
|
||||
sharedTo, err := h.getSharedTo(m)
|
||||
sharedTo, err := h.getSharedTo(machine)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
|
@ -220,7 +220,7 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) {
|
|||
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msgf("Found total peers: %s", peers.String())
|
||||
|
||||
return peers, nil
|
||||
|
@ -262,9 +262,9 @@ func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) {
|
|||
}
|
||||
|
||||
// GetMachineByMachineKey finds a Machine by ID and returns the Machine struct.
|
||||
func (h *Headscale) GetMachineByMachineKey(mKey string) (*Machine, error) {
|
||||
func (h *Headscale) GetMachineByMachineKey(machineKey string) (*Machine, error) {
|
||||
m := Machine{}
|
||||
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey); result.Error != nil {
|
||||
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", machineKey); result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
|
@ -273,8 +273,8 @@ func (h *Headscale) GetMachineByMachineKey(mKey string) (*Machine, error) {
|
|||
|
||||
// UpdateMachine takes a Machine struct pointer (typically already loaded from database
|
||||
// and updates it with the latest data from the database.
|
||||
func (h *Headscale) UpdateMachine(m *Machine) error {
|
||||
if result := h.db.Find(m).First(&m); result.Error != nil {
|
||||
func (h *Headscale) UpdateMachine(machine *Machine) error {
|
||||
if result := h.db.Find(machine).First(&machine); result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
|
@ -282,16 +282,16 @@ func (h *Headscale) UpdateMachine(m *Machine) error {
|
|||
}
|
||||
|
||||
// DeleteMachine softs deletes a Machine from the database.
|
||||
func (h *Headscale) DeleteMachine(m *Machine) error {
|
||||
err := h.RemoveSharedMachineFromAllNamespaces(m)
|
||||
func (h *Headscale) DeleteMachine(machine *Machine) error {
|
||||
err := h.RemoveSharedMachineFromAllNamespaces(machine)
|
||||
if err != nil && err != errorMachineNotShared {
|
||||
return err
|
||||
}
|
||||
|
||||
m.Registered = false
|
||||
namespaceID := m.NamespaceID
|
||||
h.db.Save(&m) // we mark it as unregistered, just in case
|
||||
if err := h.db.Delete(&m).Error; err != nil {
|
||||
machine.Registered = false
|
||||
namespaceID := machine.NamespaceID
|
||||
h.db.Save(&machine) // we mark it as unregistered, just in case
|
||||
if err := h.db.Delete(&machine).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -299,14 +299,14 @@ func (h *Headscale) DeleteMachine(m *Machine) error {
|
|||
}
|
||||
|
||||
// HardDeleteMachine hard deletes a Machine from the database.
|
||||
func (h *Headscale) HardDeleteMachine(m *Machine) error {
|
||||
err := h.RemoveSharedMachineFromAllNamespaces(m)
|
||||
func (h *Headscale) HardDeleteMachine(machine *Machine) error {
|
||||
err := h.RemoveSharedMachineFromAllNamespaces(machine)
|
||||
if err != nil && err != errorMachineNotShared {
|
||||
return err
|
||||
}
|
||||
|
||||
namespaceID := m.NamespaceID
|
||||
if err := h.db.Unscoped().Delete(&m).Error; err != nil {
|
||||
namespaceID := machine.NamespaceID
|
||||
if err := h.db.Unscoped().Delete(&machine).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -314,10 +314,10 @@ func (h *Headscale) HardDeleteMachine(m *Machine) error {
|
|||
}
|
||||
|
||||
// GetHostInfo returns a Hostinfo struct for the machine.
|
||||
func (m *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) {
|
||||
func (machine *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) {
|
||||
hostinfo := tailcfg.Hostinfo{}
|
||||
if len(m.HostInfo) != 0 {
|
||||
hi, err := m.HostInfo.MarshalJSON()
|
||||
if len(machine.HostInfo) != 0 {
|
||||
hi, err := machine.HostInfo.MarshalJSON()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -330,17 +330,17 @@ func (m *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) {
|
|||
return &hostinfo, nil
|
||||
}
|
||||
|
||||
func (h *Headscale) isOutdated(m *Machine) bool {
|
||||
if err := h.UpdateMachine(m); err != nil {
|
||||
func (h *Headscale) isOutdated(machine *Machine) bool {
|
||||
if err := h.UpdateMachine(machine); err != nil {
|
||||
// It does not seem meaningful to propagate this error as the end result
|
||||
// will have to be that the machine has to be considered outdated.
|
||||
return true
|
||||
}
|
||||
|
||||
sharedMachines, _ := h.getShared(m)
|
||||
sharedMachines, _ := h.getShared(machine)
|
||||
|
||||
namespaceSet := set.New(set.ThreadSafe)
|
||||
namespaceSet.Add(m.Namespace.Name)
|
||||
namespaceSet.Add(machine.Namespace.Name)
|
||||
|
||||
// Check if any of our shared namespaces has updates that we have
|
||||
// not propagated.
|
||||
|
@ -356,22 +356,22 @@ func (h *Headscale) isOutdated(m *Machine) bool {
|
|||
lastChange := h.getLastStateChange(namespaces...)
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", m.Name).
|
||||
Time("last_successful_update", *m.LastSuccessfulUpdate).
|
||||
Str("machine", machine.Name).
|
||||
Time("last_successful_update", *machine.LastSuccessfulUpdate).
|
||||
Time("last_state_change", lastChange).
|
||||
Msgf("Checking if %s is missing updates", m.Name)
|
||||
Msgf("Checking if %s is missing updates", machine.Name)
|
||||
|
||||
return m.LastSuccessfulUpdate.Before(lastChange)
|
||||
return machine.LastSuccessfulUpdate.Before(lastChange)
|
||||
}
|
||||
|
||||
func (m Machine) String() string {
|
||||
return m.Name
|
||||
func (machine Machine) String() string {
|
||||
return machine.Name
|
||||
}
|
||||
|
||||
func (ms Machines) String() string {
|
||||
temp := make([]string, len(ms))
|
||||
func (machines Machines) String() string {
|
||||
temp := make([]string, len(machines))
|
||||
|
||||
for index, machine := range ms {
|
||||
for index, machine := range machines {
|
||||
temp[index] = machine.Name
|
||||
}
|
||||
|
||||
|
@ -379,24 +379,24 @@ func (ms Machines) String() string {
|
|||
}
|
||||
|
||||
// TODO(kradalby): Remove when we have generics...
|
||||
func (ms MachinesP) String() string {
|
||||
temp := make([]string, len(ms))
|
||||
func (machines MachinesP) String() string {
|
||||
temp := make([]string, len(machines))
|
||||
|
||||
for index, machine := range ms {
|
||||
for index, machine := range machines {
|
||||
temp[index] = machine.Name
|
||||
}
|
||||
|
||||
return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
|
||||
}
|
||||
|
||||
func (ms Machines) toNodes(
|
||||
func (machines Machines) toNodes(
|
||||
baseDomain string,
|
||||
dnsConfig *tailcfg.DNSConfig,
|
||||
includeRoutes bool,
|
||||
) ([]*tailcfg.Node, error) {
|
||||
nodes := make([]*tailcfg.Node, len(ms))
|
||||
nodes := make([]*tailcfg.Node, len(machines))
|
||||
|
||||
for index, machine := range ms {
|
||||
for index, machine := range machines {
|
||||
node, err := machine.toNode(baseDomain, dnsConfig, includeRoutes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -410,23 +410,24 @@ func (ms Machines) toNodes(
|
|||
|
||||
// toNode converts a Machine into a Tailscale Node. includeRoutes is false for shared nodes
|
||||
// as per the expected behaviour in the official SaaS.
|
||||
func (m Machine) toNode(
|
||||
func (machine Machine) toNode(
|
||||
baseDomain string,
|
||||
dnsConfig *tailcfg.DNSConfig,
|
||||
includeRoutes bool,
|
||||
) (*tailcfg.Node, error) {
|
||||
nKey, err := wgkey.ParseHex(m.NodeKey)
|
||||
nodeKey, err := wgkey.ParseHex(machine.NodeKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mKey, err := wgkey.ParseHex(m.MachineKey)
|
||||
|
||||
machineKey, err := wgkey.ParseHex(machine.MachineKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var discoKey tailcfg.DiscoKey
|
||||
if m.DiscoKey != "" {
|
||||
dKey, err := wgkey.ParseHex(m.DiscoKey)
|
||||
if machine.DiscoKey != "" {
|
||||
dKey, err := wgkey.ParseHex(machine.DiscoKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -436,12 +437,12 @@ func (m Machine) toNode(
|
|||
}
|
||||
|
||||
addrs := []netaddr.IPPrefix{}
|
||||
ip, err := netaddr.ParseIPPrefix(fmt.Sprintf("%s/32", m.IPAddress))
|
||||
ip, err := netaddr.ParseIPPrefix(fmt.Sprintf("%s/32", machine.IPAddress))
|
||||
if err != nil {
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("ip", m.IPAddress).
|
||||
Msgf("Failed to parse IP Prefix from IP: %s", m.IPAddress)
|
||||
Str("ip", machine.IPAddress).
|
||||
Msgf("Failed to parse IP Prefix from IP: %s", machine.IPAddress)
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
@ -455,8 +456,8 @@ func (m Machine) toNode(
|
|||
|
||||
if includeRoutes {
|
||||
routesStr := []string{}
|
||||
if len(m.EnabledRoutes) != 0 {
|
||||
allwIps, err := m.EnabledRoutes.MarshalJSON()
|
||||
if len(machine.EnabledRoutes) != 0 {
|
||||
allwIps, err := machine.EnabledRoutes.MarshalJSON()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -476,8 +477,8 @@ func (m Machine) toNode(
|
|||
}
|
||||
|
||||
endpoints := []string{}
|
||||
if len(m.Endpoints) != 0 {
|
||||
be, err := m.Endpoints.MarshalJSON()
|
||||
if len(machine.Endpoints) != 0 {
|
||||
be, err := machine.Endpoints.MarshalJSON()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -488,8 +489,8 @@ func (m Machine) toNode(
|
|||
}
|
||||
|
||||
hostinfo := tailcfg.Hostinfo{}
|
||||
if len(m.HostInfo) != 0 {
|
||||
hi, err := m.HostInfo.MarshalJSON()
|
||||
if len(machine.HostInfo) != 0 {
|
||||
hi, err := machine.HostInfo.MarshalJSON()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -507,29 +508,34 @@ func (m Machine) toNode(
|
|||
}
|
||||
|
||||
var keyExpiry time.Time
|
||||
if m.Expiry != nil {
|
||||
keyExpiry = *m.Expiry
|
||||
if machine.Expiry != nil {
|
||||
keyExpiry = *machine.Expiry
|
||||
} else {
|
||||
keyExpiry = time.Time{}
|
||||
}
|
||||
|
||||
var hostname string
|
||||
if dnsConfig != nil && dnsConfig.Proxied { // MagicDNS
|
||||
hostname = fmt.Sprintf("%s.%s.%s", m.Name, m.Namespace.Name, baseDomain)
|
||||
hostname = fmt.Sprintf(
|
||||
"%s.%s.%s",
|
||||
machine.Name,
|
||||
machine.Namespace.Name,
|
||||
baseDomain,
|
||||
)
|
||||
} else {
|
||||
hostname = m.Name
|
||||
hostname = machine.Name
|
||||
}
|
||||
|
||||
n := tailcfg.Node{
|
||||
ID: tailcfg.NodeID(m.ID), // this is the actual ID
|
||||
ID: tailcfg.NodeID(machine.ID), // this is the actual ID
|
||||
StableID: tailcfg.StableNodeID(
|
||||
strconv.FormatUint(m.ID, BASE_10),
|
||||
strconv.FormatUint(machine.ID, BASE_10),
|
||||
), // in headscale, unlike tailcontrol server, IDs are permanent
|
||||
Name: hostname,
|
||||
User: tailcfg.UserID(m.NamespaceID),
|
||||
Key: tailcfg.NodeKey(nKey),
|
||||
User: tailcfg.UserID(machine.NamespaceID),
|
||||
Key: tailcfg.NodeKey(nodeKey),
|
||||
KeyExpiry: keyExpiry,
|
||||
Machine: tailcfg.MachineKey(mKey),
|
||||
Machine: tailcfg.MachineKey(machineKey),
|
||||
DiscoKey: discoKey,
|
||||
Addresses: addrs,
|
||||
AllowedIPs: allowedIPs,
|
||||
|
@ -537,68 +543,73 @@ func (m Machine) toNode(
|
|||
DERP: derp,
|
||||
|
||||
Hostinfo: hostinfo,
|
||||
Created: m.CreatedAt,
|
||||
LastSeen: m.LastSeen,
|
||||
Created: machine.CreatedAt,
|
||||
LastSeen: machine.LastSeen,
|
||||
|
||||
KeepAlive: true,
|
||||
MachineAuthorized: m.Registered,
|
||||
MachineAuthorized: machine.Registered,
|
||||
Capabilities: []string{tailcfg.CapabilityFileSharing},
|
||||
}
|
||||
|
||||
return &n, nil
|
||||
}
|
||||
|
||||
func (m *Machine) toProto() *v1.Machine {
|
||||
machine := &v1.Machine{
|
||||
Id: m.ID,
|
||||
MachineKey: m.MachineKey,
|
||||
func (machine *Machine) toProto() *v1.Machine {
|
||||
machineProto := &v1.Machine{
|
||||
Id: machine.ID,
|
||||
MachineKey: machine.MachineKey,
|
||||
|
||||
NodeKey: m.NodeKey,
|
||||
DiscoKey: m.DiscoKey,
|
||||
IpAddress: m.IPAddress,
|
||||
Name: m.Name,
|
||||
Namespace: m.Namespace.toProto(),
|
||||
NodeKey: machine.NodeKey,
|
||||
DiscoKey: machine.DiscoKey,
|
||||
IpAddress: machine.IPAddress,
|
||||
Name: machine.Name,
|
||||
Namespace: machine.Namespace.toProto(),
|
||||
|
||||
Registered: m.Registered,
|
||||
Registered: machine.Registered,
|
||||
|
||||
// TODO(kradalby): Implement register method enum converter
|
||||
// RegisterMethod: ,
|
||||
|
||||
CreatedAt: timestamppb.New(m.CreatedAt),
|
||||
CreatedAt: timestamppb.New(machine.CreatedAt),
|
||||
}
|
||||
|
||||
if m.AuthKey != nil {
|
||||
machine.PreAuthKey = m.AuthKey.toProto()
|
||||
if machine.AuthKey != nil {
|
||||
machineProto.PreAuthKey = machine.AuthKey.toProto()
|
||||
}
|
||||
|
||||
if m.LastSeen != nil {
|
||||
machine.LastSeen = timestamppb.New(*m.LastSeen)
|
||||
if machine.LastSeen != nil {
|
||||
machineProto.LastSeen = timestamppb.New(*machine.LastSeen)
|
||||
}
|
||||
|
||||
if m.LastSuccessfulUpdate != nil {
|
||||
machine.LastSuccessfulUpdate = timestamppb.New(*m.LastSuccessfulUpdate)
|
||||
if machine.LastSuccessfulUpdate != nil {
|
||||
machineProto.LastSuccessfulUpdate = timestamppb.New(
|
||||
*machine.LastSuccessfulUpdate,
|
||||
)
|
||||
}
|
||||
|
||||
if m.Expiry != nil {
|
||||
machine.Expiry = timestamppb.New(*m.Expiry)
|
||||
if machine.Expiry != nil {
|
||||
machineProto.Expiry = timestamppb.New(*machine.Expiry)
|
||||
}
|
||||
|
||||
return machine
|
||||
return machineProto
|
||||
}
|
||||
|
||||
// RegisterMachine is executed from the CLI to register a new Machine using its MachineKey.
|
||||
func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, error) {
|
||||
ns, err := h.GetNamespace(namespace)
|
||||
func (h *Headscale) RegisterMachine(
|
||||
key string,
|
||||
namespaceName string,
|
||||
) (*Machine, error) {
|
||||
namespace, err := h.GetNamespace(namespaceName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mKey, err := wgkey.ParseHex(key)
|
||||
machineKey, err := wgkey.ParseHex(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m := Machine{}
|
||||
if result := h.db.First(&m, "machine_key = ?", mKey.HexString()); errors.Is(
|
||||
machine := Machine{}
|
||||
if result := h.db.First(&machine, "machine_key = ?", machineKey.HexString()); errors.Is(
|
||||
result.Error,
|
||||
gorm.ErrRecordNotFound,
|
||||
) {
|
||||
|
@ -607,15 +618,15 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err
|
|||
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("Attempting to register machine")
|
||||
|
||||
if m.isAlreadyRegistered() {
|
||||
if machine.isAlreadyRegistered() {
|
||||
err := errors.New("Machine already registered")
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("Attempting to register machine")
|
||||
|
||||
return nil, err
|
||||
|
@ -626,7 +637,7 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err
|
|||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("Could not find IP for the new machine")
|
||||
|
||||
return nil, err
|
||||
|
@ -634,27 +645,27 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err
|
|||
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Str("ip", ip.String()).
|
||||
Msg("Found IP for host")
|
||||
|
||||
m.IPAddress = ip.String()
|
||||
m.NamespaceID = ns.ID
|
||||
m.Registered = true
|
||||
m.RegisterMethod = "cli"
|
||||
h.db.Save(&m)
|
||||
machine.IPAddress = ip.String()
|
||||
machine.NamespaceID = namespace.ID
|
||||
machine.Registered = true
|
||||
machine.RegisterMethod = "cli"
|
||||
h.db.Save(&machine)
|
||||
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Str("ip", ip.String()).
|
||||
Msg("Machine registered with the database")
|
||||
|
||||
return &m, nil
|
||||
return &machine, nil
|
||||
}
|
||||
|
||||
func (m *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) {
|
||||
hostInfo, err := m.GetHostInfo()
|
||||
func (machine *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) {
|
||||
hostInfo, err := machine.GetHostInfo()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -662,8 +673,8 @@ func (m *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) {
|
|||
return hostInfo.RoutableIPs, nil
|
||||
}
|
||||
|
||||
func (m *Machine) GetEnabledRoutes() ([]netaddr.IPPrefix, error) {
|
||||
data, err := m.EnabledRoutes.MarshalJSON()
|
||||
func (machine *Machine) GetEnabledRoutes() ([]netaddr.IPPrefix, error) {
|
||||
data, err := machine.EnabledRoutes.MarshalJSON()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -686,13 +697,13 @@ func (m *Machine) GetEnabledRoutes() ([]netaddr.IPPrefix, error) {
|
|||
return routes, nil
|
||||
}
|
||||
|
||||
func (m *Machine) IsRoutesEnabled(routeStr string) bool {
|
||||
func (machine *Machine) IsRoutesEnabled(routeStr string) bool {
|
||||
route, err := netaddr.ParseIPPrefix(routeStr)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
enabledRoutes, err := m.GetEnabledRoutes()
|
||||
enabledRoutes, err := machine.GetEnabledRoutes()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
@ -708,7 +719,7 @@ func (m *Machine) IsRoutesEnabled(routeStr string) bool {
|
|||
|
||||
// EnableNodeRoute enables new routes based on a list of new routes. It will _replace_ the
|
||||
// previous list of routes.
|
||||
func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
|
||||
func (h *Headscale) EnableRoutes(machine *Machine, routeStrs ...string) error {
|
||||
newRoutes := make([]netaddr.IPPrefix, len(routeStrs))
|
||||
for index, routeStr := range routeStrs {
|
||||
route, err := netaddr.ParseIPPrefix(routeStr)
|
||||
|
@ -719,7 +730,7 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
|
|||
newRoutes[index] = route
|
||||
}
|
||||
|
||||
availableRoutes, err := m.GetAdvertisedRoutes()
|
||||
availableRoutes, err := machine.GetAdvertisedRoutes()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -728,7 +739,7 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
|
|||
if !containsIpPrefix(availableRoutes, newRoute) {
|
||||
return fmt.Errorf(
|
||||
"route (%s) is not available on node %s",
|
||||
m.Name,
|
||||
machine.Name,
|
||||
newRoute,
|
||||
)
|
||||
}
|
||||
|
@ -739,10 +750,10 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
m.EnabledRoutes = datatypes.JSON(routes)
|
||||
h.db.Save(&m)
|
||||
machine.EnabledRoutes = datatypes.JSON(routes)
|
||||
h.db.Save(&machine)
|
||||
|
||||
err = h.RequestMapUpdates(m.NamespaceID)
|
||||
err = h.RequestMapUpdates(machine.NamespaceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -750,13 +761,13 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (m *Machine) RoutesToProto() (*v1.Routes, error) {
|
||||
availableRoutes, err := m.GetAdvertisedRoutes()
|
||||
func (machine *Machine) RoutesToProto() (*v1.Routes, error) {
|
||||
availableRoutes, err := machine.GetAdvertisedRoutes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
enabledRoutes, err := m.GetEnabledRoutes()
|
||||
enabledRoutes, err := machine.GetEnabledRoutes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -32,12 +32,12 @@ type Namespace struct {
|
|||
// CreateNamespace creates a new Namespace. Returns error if could not be created
|
||||
// or another namespace already exists.
|
||||
func (h *Headscale) CreateNamespace(name string) (*Namespace, error) {
|
||||
n := Namespace{}
|
||||
if err := h.db.Where("name = ?", name).First(&n).Error; err == nil {
|
||||
namespace := Namespace{}
|
||||
if err := h.db.Where("name = ?", name).First(&namespace).Error; err == nil {
|
||||
return nil, errorNamespaceExists
|
||||
}
|
||||
n.Name = name
|
||||
if err := h.db.Create(&n).Error; err != nil {
|
||||
namespace.Name = name
|
||||
if err := h.db.Create(&namespace).Error; err != nil {
|
||||
log.Error().
|
||||
Str("func", "CreateNamespace").
|
||||
Err(err).
|
||||
|
@ -46,22 +46,22 @@ func (h *Headscale) CreateNamespace(name string) (*Namespace, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
return &n, nil
|
||||
return &namespace, nil
|
||||
}
|
||||
|
||||
// DestroyNamespace destroys a Namespace. Returns error if the Namespace does
|
||||
// not exist or if there are machines associated with it.
|
||||
func (h *Headscale) DestroyNamespace(name string) error {
|
||||
n, err := h.GetNamespace(name)
|
||||
namespace, err := h.GetNamespace(name)
|
||||
if err != nil {
|
||||
return errorNamespaceNotFound
|
||||
}
|
||||
|
||||
m, err := h.ListMachinesInNamespace(name)
|
||||
machines, err := h.ListMachinesInNamespace(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(m) > 0 {
|
||||
if len(machines) > 0 {
|
||||
return errorNamespaceNotEmptyOfNodes
|
||||
}
|
||||
|
||||
|
@ -69,14 +69,14 @@ func (h *Headscale) DestroyNamespace(name string) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, p := range keys {
|
||||
err = h.DestroyPreAuthKey(&p)
|
||||
for _, key := range keys {
|
||||
err = h.DestroyPreAuthKey(&key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if result := h.db.Unscoped().Delete(&n); result.Error != nil {
|
||||
if result := h.db.Unscoped().Delete(&namespace); result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
|
@ -86,7 +86,7 @@ func (h *Headscale) DestroyNamespace(name string) error {
|
|||
// RenameNamespace renames a Namespace. Returns error if the Namespace does
|
||||
// not exist or if another Namespace exists with the new name.
|
||||
func (h *Headscale) RenameNamespace(oldName, newName string) error {
|
||||
n, err := h.GetNamespace(oldName)
|
||||
oldNamespace, err := h.GetNamespace(oldName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -98,13 +98,13 @@ func (h *Headscale) RenameNamespace(oldName, newName string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
n.Name = newName
|
||||
oldNamespace.Name = newName
|
||||
|
||||
if result := h.db.Save(&n); result.Error != nil {
|
||||
if result := h.db.Save(&oldNamespace); result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
err = h.RequestMapUpdates(n.ID)
|
||||
err = h.RequestMapUpdates(oldNamespace.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -114,15 +114,15 @@ func (h *Headscale) RenameNamespace(oldName, newName string) error {
|
|||
|
||||
// GetNamespace fetches a namespace by name.
|
||||
func (h *Headscale) GetNamespace(name string) (*Namespace, error) {
|
||||
n := Namespace{}
|
||||
if result := h.db.First(&n, "name = ?", name); errors.Is(
|
||||
namespace := Namespace{}
|
||||
if result := h.db.First(&namespace, "name = ?", name); errors.Is(
|
||||
result.Error,
|
||||
gorm.ErrRecordNotFound,
|
||||
) {
|
||||
return nil, errorNamespaceNotFound
|
||||
}
|
||||
|
||||
return &n, nil
|
||||
return &namespace, nil
|
||||
}
|
||||
|
||||
// ListNamespaces gets all the existing namespaces.
|
||||
|
@ -137,13 +137,13 @@ func (h *Headscale) ListNamespaces() ([]Namespace, error) {
|
|||
|
||||
// ListMachinesInNamespace gets all the nodes in a given namespace.
|
||||
func (h *Headscale) ListMachinesInNamespace(name string) ([]Machine, error) {
|
||||
n, err := h.GetNamespace(name)
|
||||
namespace, err := h.GetNamespace(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
machines := []Machine{}
|
||||
if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Where(&Machine{NamespaceID: n.ID}).Find(&machines).Error; err != nil {
|
||||
if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Where(&Machine{NamespaceID: namespace.ID}).Find(&machines).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -176,17 +176,18 @@ func (h *Headscale) ListSharedMachinesInNamespace(name string) ([]Machine, error
|
|||
}
|
||||
|
||||
// SetMachineNamespace assigns a Machine to a namespace.
|
||||
func (h *Headscale) SetMachineNamespace(m *Machine, namespaceName string) error {
|
||||
n, err := h.GetNamespace(namespaceName)
|
||||
func (h *Headscale) SetMachineNamespace(machine *Machine, namespaceName string) error {
|
||||
namespace, err := h.GetNamespace(namespaceName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.NamespaceID = n.ID
|
||||
h.db.Save(&m)
|
||||
machine.NamespaceID = namespace.ID
|
||||
h.db.Save(&machine)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO(kradalby): Remove the need for this.
|
||||
// RequestMapUpdates signals the KV worker to update the maps for this namespace.
|
||||
func (h *Headscale) RequestMapUpdates(namespaceID uint) error {
|
||||
namespace := Namespace{}
|
||||
|
@ -194,8 +195,8 @@ func (h *Headscale) RequestMapUpdates(namespaceID uint) error {
|
|||
return err
|
||||
}
|
||||
|
||||
v, err := h.getValue("namespaces_pending_updates")
|
||||
if err != nil || v == "" {
|
||||
namespacesPendingUpdates, err := h.getValue("namespaces_pending_updates")
|
||||
if err != nil || namespacesPendingUpdates == "" {
|
||||
err = h.setValue(
|
||||
"namespaces_pending_updates",
|
||||
fmt.Sprintf(`["%s"]`, namespace.Name),
|
||||
|
@ -207,7 +208,7 @@ func (h *Headscale) RequestMapUpdates(namespaceID uint) error {
|
|||
return nil
|
||||
}
|
||||
names := []string{}
|
||||
err = json.Unmarshal([]byte(v), &names)
|
||||
err = json.Unmarshal([]byte(namespacesPendingUpdates), &names)
|
||||
if err != nil {
|
||||
err = h.setValue(
|
||||
"namespaces_pending_updates",
|
||||
|
@ -235,16 +236,16 @@ func (h *Headscale) RequestMapUpdates(namespaceID uint) error {
|
|||
}
|
||||
|
||||
func (h *Headscale) checkForNamespacesPendingUpdates() {
|
||||
v, err := h.getValue("namespaces_pending_updates")
|
||||
namespacesPendingUpdates, err := h.getValue("namespaces_pending_updates")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if v == "" {
|
||||
if namespacesPendingUpdates == "" {
|
||||
return
|
||||
}
|
||||
|
||||
namespaces := []string{}
|
||||
err = json.Unmarshal([]byte(v), &namespaces)
|
||||
err = json.Unmarshal([]byte(namespacesPendingUpdates), &namespaces)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -255,11 +256,11 @@ func (h *Headscale) checkForNamespacesPendingUpdates() {
|
|||
Msg("Sending updates to nodes in namespacespace")
|
||||
h.setLastStateChangeToNow(namespace)
|
||||
}
|
||||
newV, err := h.getValue("namespaces_pending_updates")
|
||||
newPendingUpdateValue, err := h.getValue("namespaces_pending_updates")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if v == newV { // only clear when no changes, so we notified everybody
|
||||
if namespacesPendingUpdates == newPendingUpdateValue { // only clear when no changes, so we notified everybody
|
||||
err = h.setValue("namespaces_pending_updates", "")
|
||||
if err != nil {
|
||||
log.Error().
|
||||
|
@ -273,7 +274,7 @@ func (h *Headscale) checkForNamespacesPendingUpdates() {
|
|||
}
|
||||
|
||||
func (n *Namespace) toUser() *tailcfg.User {
|
||||
u := tailcfg.User{
|
||||
user := tailcfg.User{
|
||||
ID: tailcfg.UserID(n.ID),
|
||||
LoginName: n.Name,
|
||||
DisplayName: n.Name,
|
||||
|
@ -283,11 +284,11 @@ func (n *Namespace) toUser() *tailcfg.User {
|
|||
Created: time.Time{},
|
||||
}
|
||||
|
||||
return &u
|
||||
return &user
|
||||
}
|
||||
|
||||
func (n *Namespace) toLogin() *tailcfg.Login {
|
||||
l := tailcfg.Login{
|
||||
login := tailcfg.Login{
|
||||
ID: tailcfg.LoginID(n.ID),
|
||||
LoginName: n.Name,
|
||||
DisplayName: n.Name,
|
||||
|
@ -295,14 +296,14 @@ func (n *Namespace) toLogin() *tailcfg.Login {
|
|||
Domain: "headscale.net",
|
||||
}
|
||||
|
||||
return &l
|
||||
return &login
|
||||
}
|
||||
|
||||
func getMapResponseUserProfiles(m Machine, peers Machines) []tailcfg.UserProfile {
|
||||
func getMapResponseUserProfiles(machine Machine, peers Machines) []tailcfg.UserProfile {
|
||||
namespaceMap := make(map[string]Namespace)
|
||||
namespaceMap[m.Namespace.Name] = m.Namespace
|
||||
for _, p := range peers {
|
||||
namespaceMap[p.Namespace.Name] = p.Namespace // not worth checking if already is there
|
||||
namespaceMap[machine.Namespace.Name] = machine.Namespace
|
||||
for _, peer := range peers {
|
||||
namespaceMap[peer.Namespace.Name] = peer.Namespace // not worth checking if already is there
|
||||
}
|
||||
|
||||
profiles := []tailcfg.UserProfile{}
|
||||
|
|
69
oidc.go
69
oidc.go
|
@ -68,10 +68,10 @@ func (h *Headscale) initOIDC() error {
|
|||
// RegisterOIDC redirects to the OIDC provider for authentication
|
||||
// Puts machine key in cache so the callback can retrieve it using the oidc state param
|
||||
// Listens in /oidc/register/:mKey.
|
||||
func (h *Headscale) RegisterOIDC(c *gin.Context) {
|
||||
mKeyStr := c.Param("mkey")
|
||||
func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
|
||||
mKeyStr := ctx.Param("mkey")
|
||||
if mKeyStr == "" {
|
||||
c.String(http.StatusBadRequest, "Wrong params")
|
||||
ctx.String(http.StatusBadRequest, "Wrong params")
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -79,7 +79,7 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) {
|
|||
b := make([]byte, RANDOM_BYTE_SIZE)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
log.Error().Msg("could not read 16 bytes from rand")
|
||||
c.String(http.StatusInternalServerError, "could not read 16 bytes from rand")
|
||||
ctx.String(http.StatusInternalServerError, "could not read 16 bytes from rand")
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -92,7 +92,7 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) {
|
|||
authUrl := h.oauth2Config.AuthCodeURL(stateStr)
|
||||
log.Debug().Msgf("Redirecting to %s for authentication", authUrl)
|
||||
|
||||
c.Redirect(http.StatusFound, authUrl)
|
||||
ctx.Redirect(http.StatusFound, authUrl)
|
||||
}
|
||||
|
||||
// OIDCCallback handles the callback from the OIDC endpoint
|
||||
|
@ -100,19 +100,19 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) {
|
|||
// TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities
|
||||
// TODO: Add groups information from OIDC tokens into machine HostInfo
|
||||
// Listens in /oidc/callback.
|
||||
func (h *Headscale) OIDCCallback(c *gin.Context) {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
func (h *Headscale) OIDCCallback(ctx *gin.Context) {
|
||||
code := ctx.Query("code")
|
||||
state := ctx.Query("state")
|
||||
|
||||
if code == "" || state == "" {
|
||||
c.String(http.StatusBadRequest, "Wrong params")
|
||||
ctx.String(http.StatusBadRequest, "Wrong params")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
oauth2Token, err := h.oauth2Config.Exchange(context.Background(), code)
|
||||
if err != nil {
|
||||
c.String(http.StatusBadRequest, "Could not exchange code for token")
|
||||
ctx.String(http.StatusBadRequest, "Could not exchange code for token")
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -121,7 +121,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
|||
|
||||
rawIDToken, rawIDTokenOK := oauth2Token.Extra("id_token").(string)
|
||||
if !rawIDTokenOK {
|
||||
c.String(http.StatusBadRequest, "Could not extract ID Token")
|
||||
ctx.String(http.StatusBadRequest, "Could not extract ID Token")
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -130,7 +130,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
|||
|
||||
idToken, err := verifier.Verify(context.Background(), rawIDToken)
|
||||
if err != nil {
|
||||
c.String(http.StatusBadRequest, "Failed to verify id token: %s", err.Error())
|
||||
ctx.String(http.StatusBadRequest, "Failed to verify id token: %s", err.Error())
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -145,7 +145,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
|||
// Extract custom claims
|
||||
var claims IDTokenClaims
|
||||
if err = idToken.Claims(&claims); err != nil {
|
||||
c.String(
|
||||
ctx.String(
|
||||
http.StatusBadRequest,
|
||||
fmt.Sprintf("Failed to decode id token claims: %s", err),
|
||||
)
|
||||
|
@ -159,7 +159,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
|||
if !mKeyFound {
|
||||
log.Error().
|
||||
Msg("requested machine state key expired before authorisation completed")
|
||||
c.String(http.StatusBadRequest, "state has expired")
|
||||
ctx.String(http.StatusBadRequest, "state has expired")
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -167,16 +167,19 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
|||
|
||||
if !mKeyOK {
|
||||
log.Error().Msg("could not get machine key from cache")
|
||||
c.String(http.StatusInternalServerError, "could not get machine key from cache")
|
||||
ctx.String(
|
||||
http.StatusInternalServerError,
|
||||
"could not get machine key from cache",
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// retrieve machine information
|
||||
m, err := h.GetMachineByMachineKey(mKeyStr)
|
||||
machine, err := h.GetMachineByMachineKey(mKeyStr)
|
||||
if err != nil {
|
||||
log.Error().Msg("machine key not found in database")
|
||||
c.String(
|
||||
ctx.String(
|
||||
http.StatusInternalServerError,
|
||||
"could not get machine info from database",
|
||||
)
|
||||
|
@ -186,19 +189,19 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
|||
|
||||
now := time.Now().UTC()
|
||||
|
||||
if nsName, ok := h.getNamespaceFromEmail(claims.Email); ok {
|
||||
if namespaceName, ok := h.getNamespaceFromEmail(claims.Email); ok {
|
||||
// register the machine if it's new
|
||||
if !m.Registered {
|
||||
if !machine.Registered {
|
||||
log.Debug().Msg("Registering new machine after successful callback")
|
||||
|
||||
ns, err := h.GetNamespace(nsName)
|
||||
namespace, err := h.GetNamespace(namespaceName)
|
||||
if err != nil {
|
||||
ns, err = h.CreateNamespace(nsName)
|
||||
namespace, err = h.CreateNamespace(namespaceName)
|
||||
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Msgf("could not create new namespace '%s'", claims.Email)
|
||||
c.String(
|
||||
ctx.String(
|
||||
http.StatusInternalServerError,
|
||||
"could not create new namespace",
|
||||
)
|
||||
|
@ -209,7 +212,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
|||
|
||||
ip, err := h.getAvailableIP()
|
||||
if err != nil {
|
||||
c.String(
|
||||
ctx.String(
|
||||
http.StatusInternalServerError,
|
||||
"could not get an IP from the pool",
|
||||
)
|
||||
|
@ -217,17 +220,17 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
m.IPAddress = ip.String()
|
||||
m.NamespaceID = ns.ID
|
||||
m.Registered = true
|
||||
m.RegisterMethod = "oidc"
|
||||
m.LastSuccessfulUpdate = &now
|
||||
h.db.Save(&m)
|
||||
machine.IPAddress = ip.String()
|
||||
machine.NamespaceID = namespace.ID
|
||||
machine.Registered = true
|
||||
machine.RegisterMethod = "oidc"
|
||||
machine.LastSuccessfulUpdate = &now
|
||||
h.db.Save(&machine)
|
||||
}
|
||||
|
||||
h.updateMachineExpiry(m)
|
||||
h.updateMachineExpiry(machine)
|
||||
|
||||
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
|
||||
ctx.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
|
||||
<html>
|
||||
<body>
|
||||
<h1>headscale</h1>
|
||||
|
@ -243,9 +246,9 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
|||
log.Error().
|
||||
Str("email", claims.Email).
|
||||
Str("username", claims.Username).
|
||||
Str("machine", m.Name).
|
||||
Str("machine", machine.Name).
|
||||
Msg("Email could not be mapped to a namespace")
|
||||
c.String(
|
||||
ctx.String(
|
||||
http.StatusBadRequest,
|
||||
"email from claim could not be mapped to a namespace",
|
||||
)
|
||||
|
|
8
poll.go
8
poll.go
|
@ -233,7 +233,7 @@ func (h *Headscale) PollNetMapStream(
|
|||
) {
|
||||
go h.scheduledPollWorker(cancelKeepAlive, updateChan, keepAliveChan, mKey, req, m)
|
||||
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
c.Stream(func(writer io.Writer) bool {
|
||||
log.Trace().
|
||||
Str("handler", "PollNetMapStream").
|
||||
Str("machine", m.Name).
|
||||
|
@ -252,7 +252,7 @@ func (h *Headscale) PollNetMapStream(
|
|||
Str("channel", "pollData").
|
||||
Int("bytes", len(data)).
|
||||
Msg("Sending data received via pollData channel")
|
||||
_, err := w.Write(data)
|
||||
_, err := writer.Write(data)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "PollNetMapStream").
|
||||
|
@ -305,7 +305,7 @@ func (h *Headscale) PollNetMapStream(
|
|||
Str("channel", "keepAlive").
|
||||
Int("bytes", len(data)).
|
||||
Msg("Sending keep alive message")
|
||||
_, err := w.Write(data)
|
||||
_, err := writer.Write(data)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "PollNetMapStream").
|
||||
|
@ -370,7 +370,7 @@ func (h *Headscale) PollNetMapStream(
|
|||
Err(err).
|
||||
Msg("Could not get the map update")
|
||||
}
|
||||
_, err = w.Write(data)
|
||||
_, err = writer.Write(data)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "PollNetMapStream").
|
||||
|
|
32
sharing.go
32
sharing.go
|
@ -18,13 +18,16 @@ type SharedMachine struct {
|
|||
}
|
||||
|
||||
// AddSharedMachineToNamespace adds a machine as a shared node to a namespace.
|
||||
func (h *Headscale) AddSharedMachineToNamespace(m *Machine, ns *Namespace) error {
|
||||
if m.NamespaceID == ns.ID {
|
||||
func (h *Headscale) AddSharedMachineToNamespace(
|
||||
machine *Machine,
|
||||
namespace *Namespace,
|
||||
) error {
|
||||
if machine.NamespaceID == namespace.ID {
|
||||
return errorSameNamespace
|
||||
}
|
||||
|
||||
sharedMachines := []SharedMachine{}
|
||||
if err := h.db.Where("machine_id = ? AND namespace_id = ?", m.ID, ns.ID).Find(&sharedMachines).Error; err != nil {
|
||||
if err := h.db.Where("machine_id = ? AND namespace_id = ?", machine.ID, namespace.ID).Find(&sharedMachines).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if len(sharedMachines) > 0 {
|
||||
|
@ -32,10 +35,10 @@ func (h *Headscale) AddSharedMachineToNamespace(m *Machine, ns *Namespace) error
|
|||
}
|
||||
|
||||
sharedMachine := SharedMachine{
|
||||
MachineID: m.ID,
|
||||
Machine: *m,
|
||||
NamespaceID: ns.ID,
|
||||
Namespace: *ns,
|
||||
MachineID: machine.ID,
|
||||
Machine: *machine,
|
||||
NamespaceID: namespace.ID,
|
||||
Namespace: *namespace,
|
||||
}
|
||||
h.db.Save(&sharedMachine)
|
||||
|
||||
|
@ -43,14 +46,17 @@ func (h *Headscale) AddSharedMachineToNamespace(m *Machine, ns *Namespace) error
|
|||
}
|
||||
|
||||
// RemoveSharedMachineFromNamespace removes a shared machine from a namespace.
|
||||
func (h *Headscale) RemoveSharedMachineFromNamespace(m *Machine, ns *Namespace) error {
|
||||
if m.NamespaceID == ns.ID {
|
||||
func (h *Headscale) RemoveSharedMachineFromNamespace(
|
||||
machine *Machine,
|
||||
namespace *Namespace,
|
||||
) error {
|
||||
if machine.NamespaceID == namespace.ID {
|
||||
// Can't unshare from primary namespace
|
||||
return errorMachineNotShared
|
||||
}
|
||||
|
||||
sharedMachine := SharedMachine{}
|
||||
result := h.db.Where("machine_id = ? AND namespace_id = ?", m.ID, ns.ID).
|
||||
result := h.db.Where("machine_id = ? AND namespace_id = ?", machine.ID, namespace.ID).
|
||||
Unscoped().
|
||||
Delete(&sharedMachine)
|
||||
if result.Error != nil {
|
||||
|
@ -61,7 +67,7 @@ func (h *Headscale) RemoveSharedMachineFromNamespace(m *Machine, ns *Namespace)
|
|||
return errorMachineNotShared
|
||||
}
|
||||
|
||||
err := h.RequestMapUpdates(ns.ID)
|
||||
err := h.RequestMapUpdates(namespace.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -70,9 +76,9 @@ func (h *Headscale) RemoveSharedMachineFromNamespace(m *Machine, ns *Namespace)
|
|||
}
|
||||
|
||||
// RemoveSharedMachineFromAllNamespaces removes a machine as a shared node from all namespaces.
|
||||
func (h *Headscale) RemoveSharedMachineFromAllNamespaces(m *Machine) error {
|
||||
func (h *Headscale) RemoveSharedMachineFromAllNamespaces(machine *Machine) error {
|
||||
sharedMachine := SharedMachine{}
|
||||
if result := h.db.Where("machine_id = ?", m.ID).Unscoped().Delete(&sharedMachine); result.Error != nil {
|
||||
if result := h.db.Where("machine_id = ?", machine.ID).Unscoped().Delete(&sharedMachine); result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
|
|
14
swagger.go
14
swagger.go
|
@ -13,8 +13,8 @@ import (
|
|||
//go:embed gen/openapiv2/headscale/v1/headscale.swagger.json
|
||||
var apiV1JSON []byte
|
||||
|
||||
func SwaggerUI(c *gin.Context) {
|
||||
t := template.Must(template.New("swagger").Parse(`
|
||||
func SwaggerUI(ctx *gin.Context) {
|
||||
swaggerTemplate := template.Must(template.New("swagger").Parse(`
|
||||
<html>
|
||||
<head>
|
||||
<link rel="stylesheet" type="text/css" href="https://unpkg.com/swagger-ui-dist@3/swagger-ui.css">
|
||||
|
@ -47,12 +47,12 @@ func SwaggerUI(c *gin.Context) {
|
|||
</html>`))
|
||||
|
||||
var payload bytes.Buffer
|
||||
if err := t.Execute(&payload, struct{}{}); err != nil {
|
||||
if err := swaggerTemplate.Execute(&payload, struct{}{}); err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Could not render Swagger")
|
||||
c.Data(
|
||||
ctx.Data(
|
||||
http.StatusInternalServerError,
|
||||
"text/html; charset=utf-8",
|
||||
[]byte("Could not render Swagger"),
|
||||
|
@ -61,9 +61,9 @@ func SwaggerUI(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
c.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes())
|
||||
ctx.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes())
|
||||
}
|
||||
|
||||
func SwaggerAPIv1(c *gin.Context) {
|
||||
c.Data(http.StatusOK, "application/json; charset=utf-8", apiV1JSON)
|
||||
func SwaggerAPIv1(ctx *gin.Context) {
|
||||
ctx.Data(http.StatusOK, "application/json; charset=utf-8", apiV1JSON)
|
||||
}
|
||||
|
|
12
utils.go
12
utils.go
|
@ -36,7 +36,7 @@ func decode(
|
|||
|
||||
func decodeMsg(
|
||||
msg []byte,
|
||||
v interface{},
|
||||
output interface{},
|
||||
pubKey *wgkey.Key,
|
||||
privKey *wgkey.Private,
|
||||
) error {
|
||||
|
@ -45,7 +45,7 @@ func decodeMsg(
|
|||
return err
|
||||
}
|
||||
// fmt.Println(string(decrypted))
|
||||
if err := json.Unmarshal(decrypted, v); err != nil {
|
||||
if err := json.Unmarshal(decrypted, output); err != nil {
|
||||
return fmt.Errorf("response: %v", err)
|
||||
}
|
||||
|
||||
|
@ -78,13 +78,17 @@ func encode(v interface{}, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, e
|
|||
return encodeMsg(b, pubKey, privKey)
|
||||
}
|
||||
|
||||
func encodeMsg(b []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, error) {
|
||||
func encodeMsg(
|
||||
payload []byte,
|
||||
pubKey *wgkey.Key,
|
||||
privKey *wgkey.Private,
|
||||
) ([]byte, error) {
|
||||
var nonce [24]byte
|
||||
if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
pub, pri := (*[32]byte)(pubKey), (*[32]byte)(privKey)
|
||||
msg := box.Seal(nonce[:], b, &nonce, pub, pri)
|
||||
msg := box.Seal(nonce[:], payload, &nonce, pub, pri)
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue