Initial work eliminating one/two letter variables

This commit is contained in:
Kristoffer Dalby 2021-11-14 20:32:03 +01:00
parent 53ed749f45
commit 471c0b4993
No known key found for this signature in database
GPG Key ID: 09F62DC067465735
19 changed files with 568 additions and 532 deletions

View File

@ -28,6 +28,9 @@ linters:
# In progress
- gocritic
# TODO: approve: ok, db, id
- varnamelen
# We should strive to enable these:
- testpackage
- stylecheck
@ -39,7 +42,6 @@ linters:
- gosec
- forbidigo
- dupl
- varnamelen
- makezero
- paralleltest

74
acls.go
View File

@ -41,18 +41,18 @@ func (h *Headscale) LoadACLPolicy(path string) error {
defer policyFile.Close()
var policy ACLPolicy
b, err := io.ReadAll(policyFile)
policyBytes, err := io.ReadAll(policyFile)
if err != nil {
return err
}
ast, err := hujson.Parse(b)
ast, err := hujson.Parse(policyBytes)
if err != nil {
return err
}
ast.Standardize()
b = ast.Pack()
err = json.Unmarshal(b, &policy)
policyBytes = ast.Pack()
err = json.Unmarshal(policyBytes, &policy)
if err != nil {
return err
}
@ -73,32 +73,32 @@ func (h *Headscale) LoadACLPolicy(path string) error {
func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) {
rules := []tailcfg.FilterRule{}
for i, a := range h.aclPolicy.ACLs {
if a.Action != "accept" {
for index, acl := range h.aclPolicy.ACLs {
if acl.Action != "accept" {
return nil, errorInvalidAction
}
r := tailcfg.FilterRule{}
filterRule := tailcfg.FilterRule{}
srcIPs := []string{}
for j, u := range a.Users {
srcs, err := h.generateACLPolicySrcIP(u)
for innerIndex, user := range acl.Users {
srcs, err := h.generateACLPolicySrcIP(user)
if err != nil {
log.Error().
Msgf("Error parsing ACL %d, User %d", i, j)
Msgf("Error parsing ACL %d, User %d", index, innerIndex)
return nil, err
}
srcIPs = append(srcIPs, srcs...)
}
r.SrcIPs = srcIPs
filterRule.SrcIPs = srcIPs
destPorts := []tailcfg.NetPortRange{}
for j, d := range a.Ports {
dests, err := h.generateACLPolicyDestPorts(d)
for innerIndex, ports := range acl.Ports {
dests, err := h.generateACLPolicyDestPorts(ports)
if err != nil {
log.Error().
Msgf("Error parsing ACL %d, Port %d", i, j)
Msgf("Error parsing ACL %d, Port %d", index, innerIndex)
return nil, err
}
@ -162,17 +162,17 @@ func (h *Headscale) generateACLPolicyDestPorts(
return dests, nil
}
func (h *Headscale) expandAlias(s string) ([]string, error) {
if s == "*" {
func (h *Headscale) expandAlias(alias string) ([]string, error) {
if alias == "*" {
return []string{"*"}, nil
}
if strings.HasPrefix(s, "group:") {
if _, ok := h.aclPolicy.Groups[s]; !ok {
if strings.HasPrefix(alias, "group:") {
if _, ok := h.aclPolicy.Groups[alias]; !ok {
return nil, errorInvalidGroup
}
ips := []string{}
for _, n := range h.aclPolicy.Groups[s] {
for _, n := range h.aclPolicy.Groups[alias] {
nodes, err := h.ListMachinesInNamespace(n)
if err != nil {
return nil, errorInvalidNamespace
@ -185,8 +185,8 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
return ips, nil
}
if strings.HasPrefix(s, "tag:") {
if _, ok := h.aclPolicy.TagOwners[s]; !ok {
if strings.HasPrefix(alias, "tag:") {
if _, ok := h.aclPolicy.TagOwners[alias]; !ok {
return nil, errorInvalidTag
}
@ -197,10 +197,10 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
return nil, err
}
ips := []string{}
for _, m := range machines {
for _, machine := range machines {
hostinfo := tailcfg.Hostinfo{}
if len(m.HostInfo) != 0 {
hi, err := m.HostInfo.MarshalJSON()
if len(machine.HostInfo) != 0 {
hi, err := machine.HostInfo.MarshalJSON()
if err != nil {
return nil, err
}
@ -211,8 +211,8 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
// FIXME: Check TagOwners allows this
for _, t := range hostinfo.RequestTags {
if s[4:] == t {
ips = append(ips, m.IPAddress)
if alias[4:] == t {
ips = append(ips, machine.IPAddress)
break
}
@ -223,7 +223,7 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
return ips, nil
}
n, err := h.GetNamespace(s)
n, err := h.GetNamespace(alias)
if err == nil {
nodes, err := h.ListMachinesInNamespace(n.Name)
if err != nil {
@ -237,16 +237,16 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
return ips, nil
}
if h, ok := h.aclPolicy.Hosts[s]; ok {
if h, ok := h.aclPolicy.Hosts[alias]; ok {
return []string{h.String()}, nil
}
ip, err := netaddr.ParseIP(s)
ip, err := netaddr.ParseIP(alias)
if err == nil {
return []string{ip.String()}, nil
}
cidr, err := netaddr.ParseIPPrefix(s)
cidr, err := netaddr.ParseIPPrefix(alias)
if err == nil {
return []string{cidr.String()}, nil
}
@ -254,25 +254,25 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
return nil, errorInvalidUserSection
}
func (h *Headscale) expandPorts(s string) (*[]tailcfg.PortRange, error) {
if s == "*" {
func (h *Headscale) expandPorts(portsStr string) (*[]tailcfg.PortRange, error) {
if portsStr == "*" {
return &[]tailcfg.PortRange{
{First: PORT_RANGE_BEGIN, Last: PORT_RANGE_END},
}, nil
}
ports := []tailcfg.PortRange{}
for _, p := range strings.Split(s, ",") {
rang := strings.Split(p, "-")
for _, portStr := range strings.Split(portsStr, ",") {
rang := strings.Split(portStr, "-")
switch len(rang) {
case 1:
pi, err := strconv.ParseUint(rang[0], BASE_10, BIT_SIZE_16)
port, err := strconv.ParseUint(rang[0], BASE_10, BIT_SIZE_16)
if err != nil {
return nil, err
}
ports = append(ports, tailcfg.PortRange{
First: uint16(pi),
Last: uint16(pi),
First: uint16(port),
Last: uint16(port),
})
case EXPECTED_TOKEN_ITEMS:

View File

@ -41,37 +41,37 @@ type ACLTest struct {
}
// UnmarshalJSON allows to parse the Hosts directly into netaddr objects.
func (h *Hosts) UnmarshalJSON(data []byte) error {
hosts := Hosts{}
hs := make(map[string]string)
func (hosts *Hosts) UnmarshalJSON(data []byte) error {
newHosts := Hosts{}
hostIpPrefixMap := make(map[string]string)
ast, err := hujson.Parse(data)
if err != nil {
return err
}
ast.Standardize()
data = ast.Pack()
err = json.Unmarshal(data, &hs)
err = json.Unmarshal(data, &hostIpPrefixMap)
if err != nil {
return err
}
for k, v := range hs {
if !strings.Contains(v, "/") {
v += "/32"
for host, prefixStr := range hostIpPrefixMap {
if !strings.Contains(prefixStr, "/") {
prefixStr += "/32"
}
prefix, err := netaddr.ParseIPPrefix(v)
prefix, err := netaddr.ParseIPPrefix(prefixStr)
if err != nil {
return err
}
hosts[k] = prefix
newHosts[host] = prefix
}
*h = hosts
*hosts = newHosts
return nil
}
// IsZero is perhaps a bit naive here.
func (p ACLPolicy) IsZero() bool {
if len(p.Groups) == 0 && len(p.Hosts) == 0 && len(p.ACLs) == 0 {
func (policy ACLPolicy) IsZero() bool {
if len(policy.Groups) == 0 && len(policy.Hosts) == 0 && len(policy.ACLs) == 0 {
return true
}

229
api.go
View File

@ -22,21 +22,25 @@ const RESERVED_RESPONSE_HEADER_SIZE = 4
// KeyHandler provides the Headscale pub key
// Listens in /key.
func (h *Headscale) KeyHandler(c *gin.Context) {
c.Data(http.StatusOK, "text/plain; charset=utf-8", []byte(h.publicKey.HexString()))
func (h *Headscale) KeyHandler(ctx *gin.Context) {
ctx.Data(
http.StatusOK,
"text/plain; charset=utf-8",
[]byte(h.publicKey.HexString()),
)
}
// RegisterWebAPI shows a simple message in the browser to point to the CLI
// Listens in /register.
func (h *Headscale) RegisterWebAPI(c *gin.Context) {
mKeyStr := c.Query("key")
if mKeyStr == "" {
c.String(http.StatusBadRequest, "Wrong params")
func (h *Headscale) RegisterWebAPI(ctx *gin.Context) {
machineKeyStr := ctx.Query("key")
if machineKeyStr == "" {
ctx.String(http.StatusBadRequest, "Wrong params")
return
}
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
ctx.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
<html>
<body>
<h1>headscale</h1>
@ -53,45 +57,45 @@ func (h *Headscale) RegisterWebAPI(c *gin.Context) {
</body>
</html>
`, mKeyStr)))
`, machineKeyStr)))
}
// RegistrationHandler handles the actual registration process of a machine
// Endpoint /machine/:id.
func (h *Headscale) RegistrationHandler(c *gin.Context) {
body, _ := io.ReadAll(c.Request.Body)
mKeyStr := c.Param("id")
mKey, err := wgkey.ParseHex(mKeyStr)
func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
body, _ := io.ReadAll(ctx.Request.Body)
machineKeyStr := ctx.Param("id")
machineKey, err := wgkey.ParseHex(machineKeyStr)
if err != nil {
log.Error().
Str("handler", "Registration").
Err(err).
Msg("Cannot parse machine key")
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
c.String(http.StatusInternalServerError, "Sad!")
ctx.String(http.StatusInternalServerError, "Sad!")
return
}
req := tailcfg.RegisterRequest{}
err = decode(body, &req, &mKey, h.privateKey)
err = decode(body, &req, &machineKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "Registration").
Err(err).
Msg("Cannot decode message")
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
c.String(http.StatusInternalServerError, "Very sad!")
ctx.String(http.StatusInternalServerError, "Very sad!")
return
}
now := time.Now().UTC()
m, err := h.GetMachineByMachineKey(mKey.HexString())
machine, err := h.GetMachineByMachineKey(machineKey.HexString())
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine")
newMachine := Machine{
Expiry: &time.Time{},
MachineKey: mKey.HexString(),
MachineKey: machineKey.HexString(),
Name: req.Hostinfo.Hostname,
}
if err := h.db.Create(&newMachine).Error; err != nil {
@ -99,16 +103,16 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
Str("handler", "Registration").
Err(err).
Msg("Could not create row")
machineRegistrations.WithLabelValues("unknown", "web", "error", m.Namespace.Name).
machineRegistrations.WithLabelValues("unknown", "web", "error", machine.Namespace.Name).
Inc()
return
}
m = &newMachine
machine = &newMachine
}
if !m.Registered && req.Auth.AuthKey != "" {
h.handleAuthKey(c, h.db, mKey, req, *m)
if !machine.Registered && req.Auth.AuthKey != "" {
h.handleAuthKey(ctx, h.db, machineKey, req, *machine)
return
}
@ -116,63 +120,63 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
resp := tailcfg.RegisterResponse{}
// We have the updated key!
if m.NodeKey == wgkey.Key(req.NodeKey).HexString() {
if machine.NodeKey == wgkey.Key(req.NodeKey).HexString() {
// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout)
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
if !req.Expiry.IsZero() && req.Expiry.UTC().Before(now) {
log.Info().
Str("handler", "Registration").
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("Client requested logout")
m.Expiry = &req.Expiry // save the expiry so that the machine is marked as expired
h.db.Save(&m)
machine.Expiry = &req.Expiry // save the expiry so that the machine is marked as expired
h.db.Save(&machine)
resp.AuthURL = ""
resp.MachineAuthorized = false
resp.User = *m.Namespace.toUser()
respBody, err := encode(resp, &mKey, h.privateKey)
resp.User = *machine.Namespace.toUser()
respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "Registration").
Err(err).
Msg("Cannot encode message")
c.String(http.StatusInternalServerError, "")
ctx.String(http.StatusInternalServerError, "")
return
}
c.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
return
}
if m.Registered && m.Expiry.UTC().After(now) {
if machine.Registered && machine.Expiry.UTC().After(now) {
// The machine registration is valid, respond with redirect to /map
log.Debug().
Str("handler", "Registration").
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("Client is registered and we have the current NodeKey. All clear to /map")
resp.AuthURL = ""
resp.MachineAuthorized = true
resp.User = *m.Namespace.toUser()
resp.Login = *m.Namespace.toLogin()
resp.User = *machine.Namespace.toUser()
resp.Login = *machine.Namespace.toLogin()
respBody, err := encode(resp, &mKey, h.privateKey)
respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "Registration").
Err(err).
Msg("Cannot encode message")
machineRegistrations.WithLabelValues("update", "web", "error", m.Namespace.Name).
machineRegistrations.WithLabelValues("update", "web", "error", machine.Namespace.Name).
Inc()
c.String(http.StatusInternalServerError, "")
ctx.String(http.StatusInternalServerError, "")
return
}
machineRegistrations.WithLabelValues("update", "web", "success", m.Namespace.Name).
machineRegistrations.WithLabelValues("update", "web", "success", machine.Namespace.Name).
Inc()
c.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
return
}
@ -180,15 +184,15 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
// The client has registered before, but has expired
log.Debug().
Str("handler", "Registration").
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("Machine registration has expired. Sending a authurl to register")
if h.cfg.OIDC.Issuer != "" {
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString())
} else {
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString())
}
// When a client connects, it may request a specific expiry time in its
@ -197,51 +201,52 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
// into two steps (which cant pass arbitrary data between them easily) and needs to be
// retrieved again after the user has authenticated. After the authentication flow
// completes, RequestedExpiry is copied into Expiry.
m.RequestedExpiry = &req.Expiry
machine.RequestedExpiry = &req.Expiry
h.db.Save(&m)
h.db.Save(&machine)
respBody, err := encode(resp, &mKey, h.privateKey)
respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "Registration").
Err(err).
Msg("Cannot encode message")
machineRegistrations.WithLabelValues("new", "web", "error", m.Namespace.Name).
machineRegistrations.WithLabelValues("new", "web", "error", machine.Namespace.Name).
Inc()
c.String(http.StatusInternalServerError, "")
ctx.String(http.StatusInternalServerError, "")
return
}
machineRegistrations.WithLabelValues("new", "web", "success", m.Namespace.Name).
machineRegistrations.WithLabelValues("new", "web", "success", machine.Namespace.Name).
Inc()
c.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
return
}
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
if m.NodeKey == wgkey.Key(req.OldNodeKey).HexString() && m.Expiry.UTC().After(now) {
if machine.NodeKey == wgkey.Key(req.OldNodeKey).HexString() &&
machine.Expiry.UTC().After(now) {
log.Debug().
Str("handler", "Registration").
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("We have the OldNodeKey in the database. This is a key refresh")
m.NodeKey = wgkey.Key(req.NodeKey).HexString()
h.db.Save(&m)
machine.NodeKey = wgkey.Key(req.NodeKey).HexString()
h.db.Save(&machine)
resp.AuthURL = ""
resp.User = *m.Namespace.toUser()
respBody, err := encode(resp, &mKey, h.privateKey)
resp.User = *machine.Namespace.toUser()
respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "Registration").
Err(err).
Msg("Cannot encode message")
c.String(http.StatusInternalServerError, "Extremely sad!")
ctx.String(http.StatusInternalServerError, "Extremely sad!")
return
}
c.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
return
}
@ -249,47 +254,47 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
// The machine registration is new, redirect the client to the registration URL
log.Debug().
Str("handler", "Registration").
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("The node is sending us a new NodeKey, sending auth url")
if h.cfg.OIDC.Issuer != "" {
resp.AuthURL = fmt.Sprintf(
"%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"),
mKey.HexString(),
machineKey.HexString(),
)
} else {
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString())
}
// save the requested expiry time for retrieval later in the authentication flow
m.RequestedExpiry = &req.Expiry
m.NodeKey = wgkey.Key(req.NodeKey).HexString() // save the NodeKey
h.db.Save(&m)
machine.RequestedExpiry = &req.Expiry
machine.NodeKey = wgkey.Key(req.NodeKey).HexString() // save the NodeKey
h.db.Save(&machine)
respBody, err := encode(resp, &mKey, h.privateKey)
respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "Registration").
Err(err).
Msg("Cannot encode message")
c.String(http.StatusInternalServerError, "")
ctx.String(http.StatusInternalServerError, "")
return
}
c.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
}
func (h *Headscale) getMapResponse(
mKey wgkey.Key,
machineKey wgkey.Key,
req tailcfg.MapRequest,
m *Machine,
machine *Machine,
) ([]byte, error) {
log.Trace().
Str("func", "getMapResponse").
Str("machine", req.Hostinfo.Hostname).
Msg("Creating Map response")
node, err := m.toNode(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
node, err := machine.toNode(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
if err != nil {
log.Error().
Str("func", "getMapResponse").
@ -299,7 +304,7 @@ func (h *Headscale) getMapResponse(
return nil, err
}
peers, err := h.getPeers(m)
peers, err := h.getPeers(machine)
if err != nil {
log.Error().
Str("func", "getMapResponse").
@ -309,7 +314,7 @@ func (h *Headscale) getMapResponse(
return nil, err
}
profiles := getMapResponseUserProfiles(*m, peers)
profiles := getMapResponseUserProfiles(*machine, peers)
nodePeers, err := peers.toNodes(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
if err != nil {
@ -324,7 +329,7 @@ func (h *Headscale) getMapResponse(
dnsConfig := getMapResponseDNSConfig(
h.cfg.DNSConfig,
h.cfg.BaseDomain,
*m,
*machine,
peers,
)
@ -351,12 +356,12 @@ func (h *Headscale) getMapResponse(
encoder, _ := zstd.NewWriter(nil)
srcCompressed := encoder.EncodeAll(src, nil)
respBody, err = encodeMsg(srcCompressed, &mKey, h.privateKey)
respBody, err = encodeMsg(srcCompressed, &machineKey, h.privateKey)
if err != nil {
return nil, err
}
} else {
respBody, err = encode(resp, &mKey, h.privateKey)
respBody, err = encode(resp, &machineKey, h.privateKey)
if err != nil {
return nil, err
}
@ -370,24 +375,24 @@ func (h *Headscale) getMapResponse(
}
func (h *Headscale) getMapKeepAliveResponse(
mKey wgkey.Key,
req tailcfg.MapRequest,
machineKey wgkey.Key,
mapRequest tailcfg.MapRequest,
) ([]byte, error) {
resp := tailcfg.MapResponse{
mapResponse := tailcfg.MapResponse{
KeepAlive: true,
}
var respBody []byte
var err error
if req.Compress == "zstd" {
src, _ := json.Marshal(resp)
if mapRequest.Compress == "zstd" {
src, _ := json.Marshal(mapResponse)
encoder, _ := zstd.NewWriter(nil)
srcCompressed := encoder.EncodeAll(src, nil)
respBody, err = encodeMsg(srcCompressed, &mKey, h.privateKey)
respBody, err = encodeMsg(srcCompressed, &machineKey, h.privateKey)
if err != nil {
return nil, err
}
} else {
respBody, err = encode(resp, &mKey, h.privateKey)
respBody, err = encode(mapResponse, &machineKey, h.privateKey)
if err != nil {
return nil, err
}
@ -400,22 +405,22 @@ func (h *Headscale) getMapKeepAliveResponse(
}
func (h *Headscale) handleAuthKey(
c *gin.Context,
ctx *gin.Context,
db *gorm.DB,
idKey wgkey.Key,
req tailcfg.RegisterRequest,
m Machine,
reqisterRequest tailcfg.RegisterRequest,
machine Machine,
) {
log.Debug().
Str("func", "handleAuthKey").
Str("machine", req.Hostinfo.Hostname).
Msgf("Processing auth key for %s", req.Hostinfo.Hostname)
Str("machine", reqisterRequest.Hostinfo.Hostname).
Msgf("Processing auth key for %s", reqisterRequest.Hostinfo.Hostname)
resp := tailcfg.RegisterResponse{}
pak, err := h.checkKeyValidity(req.Auth.AuthKey)
pak, err := h.checkKeyValidity(reqisterRequest.Auth.AuthKey)
if err != nil {
log.Error().
Str("func", "handleAuthKey").
Str("machine", m.Name).
Str("machine", machine.Name).
Err(err).
Msg("Failed authentication via AuthKey")
resp.MachineAuthorized = false
@ -423,21 +428,21 @@ func (h *Headscale) handleAuthKey(
if err != nil {
log.Error().
Str("func", "handleAuthKey").
Str("machine", m.Name).
Str("machine", machine.Name).
Err(err).
Msg("Cannot encode message")
c.String(http.StatusInternalServerError, "")
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).
ctx.String(http.StatusInternalServerError, "")
machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
Inc()
return
}
c.Data(http.StatusUnauthorized, "application/json; charset=utf-8", respBody)
ctx.Data(http.StatusUnauthorized, "application/json; charset=utf-8", respBody)
log.Error().
Str("func", "handleAuthKey").
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("Failed authentication via AuthKey")
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).
machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
Inc()
return
@ -445,32 +450,34 @@ func (h *Headscale) handleAuthKey(
log.Debug().
Str("func", "handleAuthKey").
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("Authentication key was valid, proceeding to acquire an IP address")
ip, err := h.getAvailableIP()
if err != nil {
log.Error().
Str("func", "handleAuthKey").
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("Failed to find an available IP")
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).
machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
Inc()
return
}
log.Info().
Str("func", "handleAuthKey").
Str("machine", m.Name).
Str("machine", machine.Name).
Str("ip", ip.String()).
Msgf("Assigning %s to %s", ip, m.Name)
Msgf("Assigning %s to %s", ip, machine.Name)
m.AuthKeyID = uint(pak.ID)
m.IPAddress = ip.String()
m.NamespaceID = pak.NamespaceID
m.NodeKey = wgkey.Key(req.NodeKey).HexString() // we update it just in case
m.Registered = true
m.RegisterMethod = "authKey"
db.Save(&m)
machine.AuthKeyID = uint(pak.ID)
machine.IPAddress = ip.String()
machine.NamespaceID = pak.NamespaceID
machine.NodeKey = wgkey.Key(reqisterRequest.NodeKey).
HexString()
// we update it just in case
machine.Registered = true
machine.RegisterMethod = "authKey"
db.Save(&machine)
pak.Used = true
db.Save(&pak)
@ -481,21 +488,21 @@ func (h *Headscale) handleAuthKey(
if err != nil {
log.Error().
Str("func", "handleAuthKey").
Str("machine", m.Name).
Str("machine", machine.Name).
Err(err).
Msg("Cannot encode message")
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).
machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
Inc()
c.String(http.StatusInternalServerError, "Extremely sad!")
ctx.String(http.StatusInternalServerError, "Extremely sad!")
return
}
machineRegistrations.WithLabelValues("new", "authkey", "success", m.Namespace.Name).
machineRegistrations.WithLabelValues("new", "authkey", "success", machine.Namespace.Name).
Inc()
c.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
log.Info().
Str("func", "handleAuthKey").
Str("machine", m.Name).
Str("machine", machine.Name).
Str("ip", ip.String()).
Msg("Successfully authenticated via AuthKey")
}

140
app.go
View File

@ -169,7 +169,7 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
return nil, errors.New("unsupported DB")
}
h := Headscale{
app := Headscale{
cfg: cfg,
dbType: cfg.DBtype,
dbString: dbString,
@ -178,32 +178,32 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
aclRules: tailcfg.FilterAllowAll, // default allowall
}
err = h.initDB()
err = app.initDB()
if err != nil {
return nil, err
}
if cfg.OIDC.Issuer != "" {
err = h.initOIDC()
err = app.initOIDC()
if err != nil {
return nil, err
}
}
if h.cfg.DNSConfig != nil && h.cfg.DNSConfig.Proxied { // if MagicDNS
if app.cfg.DNSConfig != nil && app.cfg.DNSConfig.Proxied { // if MagicDNS
magicDNSDomains := generateMagicDNSRootDomains(
h.cfg.IPPrefix,
app.cfg.IPPrefix,
)
// we might have routes already from Split DNS
if h.cfg.DNSConfig.Routes == nil {
h.cfg.DNSConfig.Routes = make(map[string][]dnstype.Resolver)
if app.cfg.DNSConfig.Routes == nil {
app.cfg.DNSConfig.Routes = make(map[string][]dnstype.Resolver)
}
for _, d := range magicDNSDomains {
h.cfg.DNSConfig.Routes[d.WithoutTrailingDot()] = nil
app.cfg.DNSConfig.Routes[d.WithoutTrailingDot()] = nil
}
}
return &h, nil
return &app, nil
}
// Redirect to our TLS url.
@ -229,35 +229,37 @@ func (h *Headscale) expireEphemeralNodesWorker() {
return
}
for _, ns := range namespaces {
machines, err := h.ListMachinesInNamespace(ns.Name)
for _, namespace := range namespaces {
machines, err := h.ListMachinesInNamespace(namespace.Name)
if err != nil {
log.Error().
Err(err).
Str("namespace", ns.Name).
Str("namespace", namespace.Name).
Msg("Error listing machines in namespace")
return
}
for _, m := range machines {
if m.AuthKey != nil && m.LastSeen != nil && m.AuthKey.Ephemeral &&
time.Now().After(m.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) {
for _, machine := range machines {
if machine.AuthKey != nil && machine.LastSeen != nil &&
machine.AuthKey.Ephemeral &&
time.Now().
After(machine.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) {
log.Info().
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("Ephemeral client removed from database")
err = h.db.Unscoped().Delete(m).Error
err = h.db.Unscoped().Delete(machine).Error
if err != nil {
log.Error().
Err(err).
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("🤮 Cannot delete ephemeral machine from the database")
}
}
}
h.setLastStateChangeToNow(ns.Name)
h.setLastStateChangeToNow(namespace.Name)
}
}
@ -284,18 +286,18 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
// with the "legacy" database-based client
// It is also neede for grpc-gateway to be able to connect to
// the server
p, _ := peer.FromContext(ctx)
client, _ := peer.FromContext(ctx)
log.Trace().
Caller().
Str("client_address", p.Addr.String()).
Str("client_address", client.Addr.String()).
Msg("Client is trying to authenticate")
md, ok := metadata.FromIncomingContext(ctx)
meta, ok := metadata.FromIncomingContext(ctx)
if !ok {
log.Error().
Caller().
Str("client_address", p.Addr.String()).
Str("client_address", client.Addr.String()).
Msg("Retrieving metadata is failed")
return ctx, status.Errorf(
@ -304,11 +306,11 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
)
}
authHeader, ok := md["authorization"]
authHeader, ok := meta["authorization"]
if !ok {
log.Error().
Caller().
Str("client_address", p.Addr.String()).
Str("client_address", client.Addr.String()).
Msg("Authorization token is not supplied")
return ctx, status.Errorf(
@ -322,7 +324,7 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
if !strings.HasPrefix(token, AUTH_PREFIX) {
log.Error().
Caller().
Str("client_address", p.Addr.String()).
Str("client_address", client.Addr.String()).
Msg(`missing "Bearer " prefix in "Authorization" header`)
return ctx, status.Error(
@ -353,25 +355,25 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
// return handler(ctx, req)
}
func (h *Headscale) httpAuthenticationMiddleware(c *gin.Context) {
func (h *Headscale) httpAuthenticationMiddleware(ctx *gin.Context) {
log.Trace().
Caller().
Str("client_address", c.ClientIP()).
Str("client_address", ctx.ClientIP()).
Msg("HTTP authentication invoked")
authHeader := c.GetHeader("authorization")
authHeader := ctx.GetHeader("authorization")
if !strings.HasPrefix(authHeader, AUTH_PREFIX) {
log.Error().
Caller().
Str("client_address", c.ClientIP()).
Str("client_address", ctx.ClientIP()).
Msg(`missing "Bearer " prefix in "Authorization" header`)
c.AbortWithStatus(http.StatusUnauthorized)
ctx.AbortWithStatus(http.StatusUnauthorized)
return
}
c.AbortWithStatus(http.StatusUnauthorized)
ctx.AbortWithStatus(http.StatusUnauthorized)
// TODO(kradalby): Implement API key backend
// Currently all traffic is unauthorized, this is intentional to allow
@ -438,9 +440,9 @@ func (h *Headscale) Serve() error {
// Create the cmux object that will multiplex 2 protocols on the same port.
// The two following listeners will be served on the same port below gracefully.
m := cmux.New(networkListener)
networkMutex := cmux.New(networkListener)
// Match gRPC requests here
grpcListener := m.MatchWithWriters(
grpcListener := networkMutex.MatchWithWriters(
cmux.HTTP2MatchHeaderFieldSendSettings("content-type", "application/grpc"),
cmux.HTTP2MatchHeaderFieldSendSettings(
"content-type",
@ -448,7 +450,7 @@ func (h *Headscale) Serve() error {
),
)
// Otherwise match regular http requests.
httpListener := m.Match(cmux.Any())
httpListener := networkMutex.Match(cmux.Any())
grpcGatewayMux := runtime.NewServeMux()
@ -471,33 +473,33 @@ func (h *Headscale) Serve() error {
return err
}
r := gin.Default()
router := gin.Default()
p := ginprometheus.NewPrometheus("gin")
p.Use(r)
prometheus := ginprometheus.NewPrometheus("gin")
prometheus.Use(router)
r.GET(
router.GET(
"/health",
func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"healthy": "ok"}) },
)
r.GET("/key", h.KeyHandler)
r.GET("/register", h.RegisterWebAPI)
r.POST("/machine/:id/map", h.PollNetMapHandler)
r.POST("/machine/:id", h.RegistrationHandler)
r.GET("/oidc/register/:mkey", h.RegisterOIDC)
r.GET("/oidc/callback", h.OIDCCallback)
r.GET("/apple", h.AppleMobileConfig)
r.GET("/apple/:platform", h.ApplePlatformConfig)
r.GET("/swagger", SwaggerUI)
r.GET("/swagger/v1/openapiv2.json", SwaggerAPIv1)
router.GET("/key", h.KeyHandler)
router.GET("/register", h.RegisterWebAPI)
router.POST("/machine/:id/map", h.PollNetMapHandler)
router.POST("/machine/:id", h.RegistrationHandler)
router.GET("/oidc/register/:mkey", h.RegisterOIDC)
router.GET("/oidc/callback", h.OIDCCallback)
router.GET("/apple", h.AppleMobileConfig)
router.GET("/apple/:platform", h.ApplePlatformConfig)
router.GET("/swagger", SwaggerUI)
router.GET("/swagger/v1/openapiv2.json", SwaggerAPIv1)
api := r.Group("/api")
api := router.Group("/api")
api.Use(h.httpAuthenticationMiddleware)
{
api.Any("/v1/*any", gin.WrapF(grpcGatewayMux.ServeHTTP))
}
r.NoRoute(stdoutHandler)
router.NoRoute(stdoutHandler)
// Fetch an initial DERP Map before we start serving
h.DERPMap = GetDERPMap(h.cfg.DERP)
@ -514,7 +516,7 @@ func (h *Headscale) Serve() error {
httpServer := &http.Server{
Addr: h.cfg.Addr,
Handler: r,
Handler: router,
ReadTimeout: HTTP_READ_TIMEOUT,
// Go does not handle timeouts in HTTP very well, and there is
// no good way to handle streaming timeouts, therefore we need to
@ -561,29 +563,29 @@ func (h *Headscale) Serve() error {
reflection.Register(grpcServer)
reflection.Register(grpcSocket)
g := new(errgroup.Group)
errorGroup := new(errgroup.Group)
g.Go(func() error { return grpcSocket.Serve(socketListener) })
errorGroup.Go(func() error { return grpcSocket.Serve(socketListener) })
// TODO(kradalby): Verify if we need the same TLS setup for gRPC as HTTP
g.Go(func() error { return grpcServer.Serve(grpcListener) })
errorGroup.Go(func() error { return grpcServer.Serve(grpcListener) })
if tlsConfig != nil {
g.Go(func() error {
errorGroup.Go(func() error {
tlsl := tls.NewListener(httpListener, tlsConfig)
return httpServer.Serve(tlsl)
})
} else {
g.Go(func() error { return httpServer.Serve(httpListener) })
errorGroup.Go(func() error { return httpServer.Serve(httpListener) })
}
g.Go(func() error { return m.Serve() })
errorGroup.Go(func() error { return networkMutex.Serve() })
log.Info().
Msgf("listening and serving (multiplexed HTTP and gRPC) on: %s", h.cfg.Addr)
return g.Wait()
return errorGroup.Wait()
}
func (h *Headscale) getTLSSettings() (*tls.Config, error) {
@ -594,7 +596,7 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
Msg("Listening with TLS but ServerURL does not start with https://")
}
m := autocert.Manager{
certManager := autocert.Manager{
Prompt: autocert.AcceptTOS,
HostPolicy: autocert.HostWhitelist(h.cfg.TLSLetsEncryptHostname),
Cache: autocert.DirCache(h.cfg.TLSLetsEncryptCacheDir),
@ -609,7 +611,7 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
// Configuration via autocert with TLS-ALPN-01 (https://tools.ietf.org/html/rfc8737)
// The RFC requires that the validation is done on port 443; in other words, headscale
// must be reachable on port 443.
return m.TLSConfig(), nil
return certManager.TLSConfig(), nil
case "HTTP-01":
// Configuration via autocert with HTTP-01. This requires listening on
@ -617,11 +619,11 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
// service, which can be configured to run on any other port.
go func() {
log.Fatal().
Err(http.ListenAndServe(h.cfg.TLSLetsEncryptListen, m.HTTPHandler(http.HandlerFunc(h.redirect)))).
Err(http.ListenAndServe(h.cfg.TLSLetsEncryptListen, certManager.HTTPHandler(http.HandlerFunc(h.redirect)))).
Msg("failed to set up a HTTP server")
}()
return m.TLSConfig(), nil
return certManager.TLSConfig(), nil
default:
return nil, errors.New("unknown value for TLSLetsEncryptChallengeType")
@ -676,13 +678,13 @@ func (h *Headscale) getLastStateChange(namespaces ...string) time.Time {
}
}
func stdoutHandler(c *gin.Context) {
b, _ := io.ReadAll(c.Request.Body)
func stdoutHandler(ctx *gin.Context) {
body, _ := io.ReadAll(ctx.Request.Body)
log.Trace().
Interface("header", c.Request.Header).
Interface("proto", c.Request.Proto).
Interface("url", c.Request.URL).
Bytes("body", b).
Interface("header", ctx.Request.Header).
Interface("proto", ctx.Request.Proto).
Interface("url", ctx.Request.URL).
Bytes("body", body).
Msg("Request did not match")
}

View File

@ -12,8 +12,8 @@ import (
// AppleMobileConfig shows a simple message in the browser to point to the CLI
// Listens in /register.
func (h *Headscale) AppleMobileConfig(c *gin.Context) {
t := template.Must(template.New("apple").Parse(`
func (h *Headscale) AppleMobileConfig(ctx *gin.Context) {
appleTemplate := template.Must(template.New("apple").Parse(`
<html>
<body>
<h1>Apple configuration profiles</h1>
@ -67,12 +67,12 @@ func (h *Headscale) AppleMobileConfig(c *gin.Context) {
}
var payload bytes.Buffer
if err := t.Execute(&payload, config); err != nil {
if err := appleTemplate.Execute(&payload, config); err != nil {
log.Error().
Str("handler", "AppleMobileConfig").
Err(err).
Msg("Could not render Apple index template")
c.Data(
ctx.Data(
http.StatusInternalServerError,
"text/html; charset=utf-8",
[]byte("Could not render Apple index template"),
@ -81,11 +81,11 @@ func (h *Headscale) AppleMobileConfig(c *gin.Context) {
return
}
c.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes())
ctx.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes())
}
func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
platform := c.Param("platform")
func (h *Headscale) ApplePlatformConfig(ctx *gin.Context) {
platform := ctx.Param("platform")
id, err := uuid.NewV4()
if err != nil {
@ -93,7 +93,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
Str("handler", "ApplePlatformConfig").
Err(err).
Msg("Failed not create UUID")
c.Data(
ctx.Data(
http.StatusInternalServerError,
"text/html; charset=utf-8",
[]byte("Failed to create UUID"),
@ -108,7 +108,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
Str("handler", "ApplePlatformConfig").
Err(err).
Msg("Failed not create UUID")
c.Data(
ctx.Data(
http.StatusInternalServerError,
"text/html; charset=utf-8",
[]byte("Failed to create UUID"),
@ -131,7 +131,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
Str("handler", "ApplePlatformConfig").
Err(err).
Msg("Could not render Apple macOS template")
c.Data(
ctx.Data(
http.StatusInternalServerError,
"text/html; charset=utf-8",
[]byte("Could not render Apple macOS template"),
@ -145,7 +145,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
Str("handler", "ApplePlatformConfig").
Err(err).
Msg("Could not render Apple iOS template")
c.Data(
ctx.Data(
http.StatusInternalServerError,
"text/html; charset=utf-8",
[]byte("Could not render Apple iOS template"),
@ -154,7 +154,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
return
}
default:
c.Data(
ctx.Data(
http.StatusOK,
"text/html; charset=utf-8",
[]byte("Invalid platform, only ios and macos is supported"),
@ -175,7 +175,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
Str("handler", "ApplePlatformConfig").
Err(err).
Msg("Could not render Apple platform template")
c.Data(
ctx.Data(
http.StatusInternalServerError,
"text/html; charset=utf-8",
[]byte("Could not render Apple platform template"),
@ -184,7 +184,7 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
return
}
c.Data(
ctx.Data(
http.StatusOK,
"application/x-apple-aspen-config; charset=utf-8",
content.Bytes(),

View File

@ -167,10 +167,10 @@ var listNamespacesCmd = &cobra.Command{
return
}
d := pterm.TableData{{"ID", "Name", "Created"}}
tableData := pterm.TableData{{"ID", "Name", "Created"}}
for _, namespace := range response.GetNamespaces() {
d = append(
d,
tableData = append(
tableData,
[]string{
namespace.GetId(),
namespace.GetName(),
@ -178,7 +178,7 @@ var listNamespacesCmd = &cobra.Command{
},
)
}
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil {
ErrorOutput(
err,

View File

@ -157,14 +157,14 @@ var listNodesCmd = &cobra.Command{
return
}
d, err := nodesToPtables(namespace, response.Machines)
tableData, err := nodesToPtables(namespace, response.Machines)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
return
}
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil {
ErrorOutput(
err,
@ -183,7 +183,7 @@ var deleteNodeCmd = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
id, err := cmd.Flags().GetInt("identifier")
identifier, err := cmd.Flags().GetInt("identifier")
if err != nil {
ErrorOutput(
err,
@ -199,7 +199,7 @@ var deleteNodeCmd = &cobra.Command{
defer conn.Close()
getRequest := &v1.GetMachineRequest{
MachineId: uint64(id),
MachineId: uint64(identifier),
}
getResponse, err := client.GetMachine(ctx, getRequest)
@ -217,7 +217,7 @@ var deleteNodeCmd = &cobra.Command{
}
deleteRequest := &v1.DeleteMachineRequest{
MachineId: uint64(id),
MachineId: uint64(identifier),
}
confirm := false
@ -280,7 +280,7 @@ func sharingWorker(
defer cancel()
defer conn.Close()
id, err := cmd.Flags().GetInt("identifier")
identifier, err := cmd.Flags().GetInt("identifier")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting ID to integer: %s", err), output)
@ -288,7 +288,7 @@ func sharingWorker(
}
machineRequest := &v1.GetMachineRequest{
MachineId: uint64(id),
MachineId: uint64(identifier),
}
machineResponse, err := client.GetMachine(ctx, machineRequest)
@ -402,7 +402,7 @@ func nodesToPtables(
currentNamespace string,
machines []*v1.Machine,
) (pterm.TableData, error) {
d := pterm.TableData{
tableData := pterm.TableData{
{
"ID",
"Name",
@ -448,8 +448,8 @@ func nodesToPtables(
// Shared into this namespace
namespace = pterm.LightYellow(machine.Namespace.Name)
}
d = append(
d,
tableData = append(
tableData,
[]string{
strconv.FormatUint(machine.Id, headscale.BASE_10),
machine.Name,
@ -463,5 +463,5 @@ func nodesToPtables(
)
}
return d, nil
return tableData, nil
}

View File

@ -45,7 +45,7 @@ var listPreAuthKeys = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
n, err := cmd.Flags().GetString("namespace")
namespace, err := cmd.Flags().GetString("namespace")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
@ -57,7 +57,7 @@ var listPreAuthKeys = &cobra.Command{
defer conn.Close()
request := &v1.ListPreAuthKeysRequest{
Namespace: n,
Namespace: namespace,
}
response, err := client.ListPreAuthKeys(ctx, request)
@ -77,34 +77,34 @@ var listPreAuthKeys = &cobra.Command{
return
}
d := pterm.TableData{
tableData := pterm.TableData{
{"ID", "Key", "Reusable", "Ephemeral", "Used", "Expiration", "Created"},
}
for _, k := range response.PreAuthKeys {
for _, key := range response.PreAuthKeys {
expiration := "-"
if k.GetExpiration() != nil {
expiration = k.Expiration.AsTime().Format("2006-01-02 15:04:05")
if key.GetExpiration() != nil {
expiration = key.Expiration.AsTime().Format("2006-01-02 15:04:05")
}
var reusable string
if k.GetEphemeral() {
if key.GetEphemeral() {
reusable = "N/A"
} else {
reusable = fmt.Sprintf("%v", k.GetReusable())
reusable = fmt.Sprintf("%v", key.GetReusable())
}
d = append(d, []string{
k.GetId(),
k.GetKey(),
tableData = append(tableData, []string{
key.GetId(),
key.GetKey(),
reusable,
strconv.FormatBool(k.GetEphemeral()),
strconv.FormatBool(k.GetUsed()),
strconv.FormatBool(key.GetEphemeral()),
strconv.FormatBool(key.GetUsed()),
expiration,
k.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"),
key.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"),
})
}
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil {
ErrorOutput(
err,

View File

@ -81,14 +81,14 @@ var listRoutesCmd = &cobra.Command{
return
}
d := routesToPtables(response.Routes)
tableData := routesToPtables(response.Routes)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
return
}
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil {
ErrorOutput(
err,
@ -162,14 +162,14 @@ omit the route you do not want to enable.
return
}
d := routesToPtables(response.Routes)
tableData := routesToPtables(response.Routes)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
return
}
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil {
ErrorOutput(
err,
@ -184,15 +184,15 @@ omit the route you do not want to enable.
// routesToPtables converts the list of routes to a nice table.
func routesToPtables(routes *v1.Routes) pterm.TableData {
d := pterm.TableData{{"Route", "Enabled"}}
tableData := pterm.TableData{{"Route", "Enabled"}}
for _, route := range routes.GetAdvertisedRoutes() {
enabled := isStringInSlice(routes.EnabledRoutes, route)
d = append(d, []string{route, strconv.FormatBool(enabled)})
tableData = append(tableData, []string{route, strconv.FormatBool(enabled)})
}
return d
return tableData
}
func isStringInSlice(strs []string, s string) bool {

View File

@ -318,7 +318,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
cfg.OIDC.MatchMap = loadOIDCMatchMap()
h, err := headscale.NewHeadscale(cfg)
app, err := headscale.NewHeadscale(cfg)
if err != nil {
return nil, err
}
@ -327,7 +327,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
if viper.GetString("acl_policy_path") != "" {
aclPath := absPath(viper.GetString("acl_policy_path"))
err = h.LoadACLPolicy(aclPath)
err = app.LoadACLPolicy(aclPath)
if err != nil {
log.Error().
Str("path", aclPath).
@ -336,7 +336,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
}
}
return h, nil
return app, nil
}
func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.ClientConn, context.CancelFunc) {

6
dns.go
View File

@ -79,7 +79,7 @@ func generateMagicDNSRootDomains(
func getMapResponseDNSConfig(
dnsConfigOrig *tailcfg.DNSConfig,
baseDomain string,
m Machine,
machine Machine,
peers Machines,
) *tailcfg.DNSConfig {
var dnsConfig *tailcfg.DNSConfig
@ -88,11 +88,11 @@ func getMapResponseDNSConfig(
dnsConfig = dnsConfigOrig.Clone()
dnsConfig.Domains = append(
dnsConfig.Domains,
fmt.Sprintf("%s.%s", m.Namespace.Name, baseDomain),
fmt.Sprintf("%s.%s", machine.Namespace.Name, baseDomain),
)
namespaceSet := set.New(set.ThreadSafe)
namespaceSet.Add(m.Namespace)
namespaceSet.Add(machine.Namespace)
for _, p := range peers {
namespaceSet.Add(p.Namespace)
}

View File

@ -56,21 +56,21 @@ type (
)
// For the time being this method is rather naive.
func (m Machine) isAlreadyRegistered() bool {
return m.Registered
func (machine Machine) isAlreadyRegistered() bool {
return machine.Registered
}
// isExpired returns whether the machine registration has expired.
func (m Machine) isExpired() bool {
return time.Now().UTC().After(*m.Expiry)
func (machine Machine) isExpired() bool {
return time.Now().UTC().After(*machine.Expiry)
}
// If the Machine is expired, updateMachineExpiry updates the Machine Expiry time to the maximum allowed duration,
// or the default duration if no Expiry time was requested by the client. The expiry time here does not (yet) cause
// a client to be disconnected, however they will have to re-auth the machine if they attempt to reconnect after the
// expiry time.
func (h *Headscale) updateMachineExpiry(m *Machine) {
if m.isExpired() {
func (h *Headscale) updateMachineExpiry(machine *Machine) {
if machine.isExpired() {
now := time.Now().UTC()
maxExpiry := now.Add(
h.cfg.MaxMachineRegistrationDuration,
@ -80,31 +80,31 @@ func (h *Headscale) updateMachineExpiry(m *Machine) {
) // calculate the default expiry
// clamp the expiry time of the machine registration to the maximum allowed, or use the default if none supplied
if maxExpiry.Before(*m.RequestedExpiry) {
if maxExpiry.Before(*machine.RequestedExpiry) {
log.Debug().
Msgf("Clamping registration expiry time to maximum: %v (%v)", maxExpiry, h.cfg.MaxMachineRegistrationDuration)
m.Expiry = &maxExpiry
} else if m.RequestedExpiry.IsZero() {
machine.Expiry = &maxExpiry
} else if machine.RequestedExpiry.IsZero() {
log.Debug().Msgf("Using default machine registration expiry time: %v (%v)", defaultExpiry, h.cfg.DefaultMachineRegistrationDuration)
m.Expiry = &defaultExpiry
machine.Expiry = &defaultExpiry
} else {
log.Debug().Msgf("Using requested machine registration expiry time: %v", m.RequestedExpiry)
m.Expiry = m.RequestedExpiry
log.Debug().Msgf("Using requested machine registration expiry time: %v", machine.RequestedExpiry)
machine.Expiry = machine.RequestedExpiry
}
h.db.Save(&m)
h.db.Save(&machine)
}
}
func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) {
func (h *Headscale) getDirectPeers(machine *Machine) (Machines, error) {
log.Trace().
Caller().
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("Finding direct peers")
machines := Machines{}
if err := h.db.Preload("Namespace").Where("namespace_id = ? AND machine_key <> ? AND registered",
m.NamespaceID, m.MachineKey).Find(&machines).Error; err != nil {
machine.NamespaceID, machine.MachineKey).Find(&machines).Error; err != nil {
log.Error().Err(err).Msg("Error accessing db")
return Machines{}, err
@ -114,22 +114,22 @@ func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) {
log.Trace().
Caller().
Str("machine", m.Name).
Str("machine", machine.Name).
Msgf("Found direct machines: %s", machines.String())
return machines, nil
}
// getShared fetches machines that are shared to the `Namespace` of the machine we are getting peers for.
func (h *Headscale) getShared(m *Machine) (Machines, error) {
func (h *Headscale) getShared(machine *Machine) (Machines, error) {
log.Trace().
Caller().
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("Finding shared peers")
sharedMachines := []SharedMachine{}
if err := h.db.Preload("Namespace").Preload("Machine").Preload("Machine.Namespace").Where("namespace_id = ?",
m.NamespaceID).Find(&sharedMachines).Error; err != nil {
machine.NamespaceID).Find(&sharedMachines).Error; err != nil {
return Machines{}, err
}
@ -142,22 +142,22 @@ func (h *Headscale) getShared(m *Machine) (Machines, error) {
log.Trace().
Caller().
Str("machine", m.Name).
Str("machine", machine.Name).
Msgf("Found shared peers: %s", peers.String())
return peers, nil
}
// getSharedTo fetches the machines of the namespaces this machine is shared in.
func (h *Headscale) getSharedTo(m *Machine) (Machines, error) {
func (h *Headscale) getSharedTo(machine *Machine) (Machines, error) {
log.Trace().
Caller().
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("Finding peers in namespaces this machine is shared with")
sharedMachines := []SharedMachine{}
if err := h.db.Preload("Namespace").Preload("Machine").Preload("Machine.Namespace").Where("machine_id = ?",
m.ID).Find(&sharedMachines).Error; err != nil {
machine.ID).Find(&sharedMachines).Error; err != nil {
return Machines{}, err
}
@ -176,14 +176,14 @@ func (h *Headscale) getSharedTo(m *Machine) (Machines, error) {
log.Trace().
Caller().
Str("machine", m.Name).
Str("machine", machine.Name).
Msgf("Found peers we are shared with: %s", peers.String())
return peers, nil
}
func (h *Headscale) getPeers(m *Machine) (Machines, error) {
direct, err := h.getDirectPeers(m)
func (h *Headscale) getPeers(machine *Machine) (Machines, error) {
direct, err := h.getDirectPeers(machine)
if err != nil {
log.Error().
Caller().
@ -193,7 +193,7 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) {
return Machines{}, err
}
shared, err := h.getShared(m)
shared, err := h.getShared(machine)
if err != nil {
log.Error().
Caller().
@ -203,7 +203,7 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) {
return Machines{}, err
}
sharedTo, err := h.getSharedTo(m)
sharedTo, err := h.getSharedTo(machine)
if err != nil {
log.Error().
Caller().
@ -220,7 +220,7 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) {
log.Trace().
Caller().
Str("machine", m.Name).
Str("machine", machine.Name).
Msgf("Found total peers: %s", peers.String())
return peers, nil
@ -262,9 +262,9 @@ func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) {
}
// GetMachineByMachineKey finds a Machine by ID and returns the Machine struct.
func (h *Headscale) GetMachineByMachineKey(mKey string) (*Machine, error) {
func (h *Headscale) GetMachineByMachineKey(machineKey string) (*Machine, error) {
m := Machine{}
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey); result.Error != nil {
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", machineKey); result.Error != nil {
return nil, result.Error
}
@ -273,8 +273,8 @@ func (h *Headscale) GetMachineByMachineKey(mKey string) (*Machine, error) {
// UpdateMachine takes a Machine struct pointer (typically already loaded from database
// and updates it with the latest data from the database.
func (h *Headscale) UpdateMachine(m *Machine) error {
if result := h.db.Find(m).First(&m); result.Error != nil {
func (h *Headscale) UpdateMachine(machine *Machine) error {
if result := h.db.Find(machine).First(&machine); result.Error != nil {
return result.Error
}
@ -282,16 +282,16 @@ func (h *Headscale) UpdateMachine(m *Machine) error {
}
// DeleteMachine softs deletes a Machine from the database.
func (h *Headscale) DeleteMachine(m *Machine) error {
err := h.RemoveSharedMachineFromAllNamespaces(m)
func (h *Headscale) DeleteMachine(machine *Machine) error {
err := h.RemoveSharedMachineFromAllNamespaces(machine)
if err != nil && err != errorMachineNotShared {
return err
}
m.Registered = false
namespaceID := m.NamespaceID
h.db.Save(&m) // we mark it as unregistered, just in case
if err := h.db.Delete(&m).Error; err != nil {
machine.Registered = false
namespaceID := machine.NamespaceID
h.db.Save(&machine) // we mark it as unregistered, just in case
if err := h.db.Delete(&machine).Error; err != nil {
return err
}
@ -299,14 +299,14 @@ func (h *Headscale) DeleteMachine(m *Machine) error {
}
// HardDeleteMachine hard deletes a Machine from the database.
func (h *Headscale) HardDeleteMachine(m *Machine) error {
err := h.RemoveSharedMachineFromAllNamespaces(m)
func (h *Headscale) HardDeleteMachine(machine *Machine) error {
err := h.RemoveSharedMachineFromAllNamespaces(machine)
if err != nil && err != errorMachineNotShared {
return err
}
namespaceID := m.NamespaceID
if err := h.db.Unscoped().Delete(&m).Error; err != nil {
namespaceID := machine.NamespaceID
if err := h.db.Unscoped().Delete(&machine).Error; err != nil {
return err
}
@ -314,10 +314,10 @@ func (h *Headscale) HardDeleteMachine(m *Machine) error {
}
// GetHostInfo returns a Hostinfo struct for the machine.
func (m *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) {
func (machine *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) {
hostinfo := tailcfg.Hostinfo{}
if len(m.HostInfo) != 0 {
hi, err := m.HostInfo.MarshalJSON()
if len(machine.HostInfo) != 0 {
hi, err := machine.HostInfo.MarshalJSON()
if err != nil {
return nil, err
}
@ -330,17 +330,17 @@ func (m *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) {
return &hostinfo, nil
}
func (h *Headscale) isOutdated(m *Machine) bool {
if err := h.UpdateMachine(m); err != nil {
func (h *Headscale) isOutdated(machine *Machine) bool {
if err := h.UpdateMachine(machine); err != nil {
// It does not seem meaningful to propagate this error as the end result
// will have to be that the machine has to be considered outdated.
return true
}
sharedMachines, _ := h.getShared(m)
sharedMachines, _ := h.getShared(machine)
namespaceSet := set.New(set.ThreadSafe)
namespaceSet.Add(m.Namespace.Name)
namespaceSet.Add(machine.Namespace.Name)
// Check if any of our shared namespaces has updates that we have
// not propagated.
@ -356,22 +356,22 @@ func (h *Headscale) isOutdated(m *Machine) bool {
lastChange := h.getLastStateChange(namespaces...)
log.Trace().
Caller().
Str("machine", m.Name).
Time("last_successful_update", *m.LastSuccessfulUpdate).
Str("machine", machine.Name).
Time("last_successful_update", *machine.LastSuccessfulUpdate).
Time("last_state_change", lastChange).
Msgf("Checking if %s is missing updates", m.Name)
Msgf("Checking if %s is missing updates", machine.Name)
return m.LastSuccessfulUpdate.Before(lastChange)
return machine.LastSuccessfulUpdate.Before(lastChange)
}
func (m Machine) String() string {
return m.Name
func (machine Machine) String() string {
return machine.Name
}
func (ms Machines) String() string {
temp := make([]string, len(ms))
func (machines Machines) String() string {
temp := make([]string, len(machines))
for index, machine := range ms {
for index, machine := range machines {
temp[index] = machine.Name
}
@ -379,24 +379,24 @@ func (ms Machines) String() string {
}
// TODO(kradalby): Remove when we have generics...
func (ms MachinesP) String() string {
temp := make([]string, len(ms))
func (machines MachinesP) String() string {
temp := make([]string, len(machines))
for index, machine := range ms {
for index, machine := range machines {
temp[index] = machine.Name
}
return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
}
func (ms Machines) toNodes(
func (machines Machines) toNodes(
baseDomain string,
dnsConfig *tailcfg.DNSConfig,
includeRoutes bool,
) ([]*tailcfg.Node, error) {
nodes := make([]*tailcfg.Node, len(ms))
nodes := make([]*tailcfg.Node, len(machines))
for index, machine := range ms {
for index, machine := range machines {
node, err := machine.toNode(baseDomain, dnsConfig, includeRoutes)
if err != nil {
return nil, err
@ -410,23 +410,24 @@ func (ms Machines) toNodes(
// toNode converts a Machine into a Tailscale Node. includeRoutes is false for shared nodes
// as per the expected behaviour in the official SaaS.
func (m Machine) toNode(
func (machine Machine) toNode(
baseDomain string,
dnsConfig *tailcfg.DNSConfig,
includeRoutes bool,
) (*tailcfg.Node, error) {
nKey, err := wgkey.ParseHex(m.NodeKey)
nodeKey, err := wgkey.ParseHex(machine.NodeKey)
if err != nil {
return nil, err
}
mKey, err := wgkey.ParseHex(m.MachineKey)
machineKey, err := wgkey.ParseHex(machine.MachineKey)
if err != nil {
return nil, err
}
var discoKey tailcfg.DiscoKey
if m.DiscoKey != "" {
dKey, err := wgkey.ParseHex(m.DiscoKey)
if machine.DiscoKey != "" {
dKey, err := wgkey.ParseHex(machine.DiscoKey)
if err != nil {
return nil, err
}
@ -436,12 +437,12 @@ func (m Machine) toNode(
}
addrs := []netaddr.IPPrefix{}
ip, err := netaddr.ParseIPPrefix(fmt.Sprintf("%s/32", m.IPAddress))
ip, err := netaddr.ParseIPPrefix(fmt.Sprintf("%s/32", machine.IPAddress))
if err != nil {
log.Trace().
Caller().
Str("ip", m.IPAddress).
Msgf("Failed to parse IP Prefix from IP: %s", m.IPAddress)
Str("ip", machine.IPAddress).
Msgf("Failed to parse IP Prefix from IP: %s", machine.IPAddress)
return nil, err
}
@ -455,8 +456,8 @@ func (m Machine) toNode(
if includeRoutes {
routesStr := []string{}
if len(m.EnabledRoutes) != 0 {
allwIps, err := m.EnabledRoutes.MarshalJSON()
if len(machine.EnabledRoutes) != 0 {
allwIps, err := machine.EnabledRoutes.MarshalJSON()
if err != nil {
return nil, err
}
@ -476,8 +477,8 @@ func (m Machine) toNode(
}
endpoints := []string{}
if len(m.Endpoints) != 0 {
be, err := m.Endpoints.MarshalJSON()
if len(machine.Endpoints) != 0 {
be, err := machine.Endpoints.MarshalJSON()
if err != nil {
return nil, err
}
@ -488,8 +489,8 @@ func (m Machine) toNode(
}
hostinfo := tailcfg.Hostinfo{}
if len(m.HostInfo) != 0 {
hi, err := m.HostInfo.MarshalJSON()
if len(machine.HostInfo) != 0 {
hi, err := machine.HostInfo.MarshalJSON()
if err != nil {
return nil, err
}
@ -507,29 +508,34 @@ func (m Machine) toNode(
}
var keyExpiry time.Time
if m.Expiry != nil {
keyExpiry = *m.Expiry
if machine.Expiry != nil {
keyExpiry = *machine.Expiry
} else {
keyExpiry = time.Time{}
}
var hostname string
if dnsConfig != nil && dnsConfig.Proxied { // MagicDNS
hostname = fmt.Sprintf("%s.%s.%s", m.Name, m.Namespace.Name, baseDomain)
hostname = fmt.Sprintf(
"%s.%s.%s",
machine.Name,
machine.Namespace.Name,
baseDomain,
)
} else {
hostname = m.Name
hostname = machine.Name
}
n := tailcfg.Node{
ID: tailcfg.NodeID(m.ID), // this is the actual ID
ID: tailcfg.NodeID(machine.ID), // this is the actual ID
StableID: tailcfg.StableNodeID(
strconv.FormatUint(m.ID, BASE_10),
strconv.FormatUint(machine.ID, BASE_10),
), // in headscale, unlike tailcontrol server, IDs are permanent
Name: hostname,
User: tailcfg.UserID(m.NamespaceID),
Key: tailcfg.NodeKey(nKey),
User: tailcfg.UserID(machine.NamespaceID),
Key: tailcfg.NodeKey(nodeKey),
KeyExpiry: keyExpiry,
Machine: tailcfg.MachineKey(mKey),
Machine: tailcfg.MachineKey(machineKey),
DiscoKey: discoKey,
Addresses: addrs,
AllowedIPs: allowedIPs,
@ -537,68 +543,73 @@ func (m Machine) toNode(
DERP: derp,
Hostinfo: hostinfo,
Created: m.CreatedAt,
LastSeen: m.LastSeen,
Created: machine.CreatedAt,
LastSeen: machine.LastSeen,
KeepAlive: true,
MachineAuthorized: m.Registered,
MachineAuthorized: machine.Registered,
Capabilities: []string{tailcfg.CapabilityFileSharing},
}
return &n, nil
}
func (m *Machine) toProto() *v1.Machine {
machine := &v1.Machine{
Id: m.ID,
MachineKey: m.MachineKey,
func (machine *Machine) toProto() *v1.Machine {
machineProto := &v1.Machine{
Id: machine.ID,
MachineKey: machine.MachineKey,
NodeKey: m.NodeKey,
DiscoKey: m.DiscoKey,
IpAddress: m.IPAddress,
Name: m.Name,
Namespace: m.Namespace.toProto(),
NodeKey: machine.NodeKey,
DiscoKey: machine.DiscoKey,
IpAddress: machine.IPAddress,
Name: machine.Name,
Namespace: machine.Namespace.toProto(),
Registered: m.Registered,
Registered: machine.Registered,
// TODO(kradalby): Implement register method enum converter
// RegisterMethod: ,
CreatedAt: timestamppb.New(m.CreatedAt),
CreatedAt: timestamppb.New(machine.CreatedAt),
}
if m.AuthKey != nil {
machine.PreAuthKey = m.AuthKey.toProto()
if machine.AuthKey != nil {
machineProto.PreAuthKey = machine.AuthKey.toProto()
}
if m.LastSeen != nil {
machine.LastSeen = timestamppb.New(*m.LastSeen)
if machine.LastSeen != nil {
machineProto.LastSeen = timestamppb.New(*machine.LastSeen)
}
if m.LastSuccessfulUpdate != nil {
machine.LastSuccessfulUpdate = timestamppb.New(*m.LastSuccessfulUpdate)
if machine.LastSuccessfulUpdate != nil {
machineProto.LastSuccessfulUpdate = timestamppb.New(
*machine.LastSuccessfulUpdate,
)
}
if m.Expiry != nil {
machine.Expiry = timestamppb.New(*m.Expiry)
if machine.Expiry != nil {
machineProto.Expiry = timestamppb.New(*machine.Expiry)
}
return machine
return machineProto
}
// RegisterMachine is executed from the CLI to register a new Machine using its MachineKey.
func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, error) {
ns, err := h.GetNamespace(namespace)
func (h *Headscale) RegisterMachine(
key string,
namespaceName string,
) (*Machine, error) {
namespace, err := h.GetNamespace(namespaceName)
if err != nil {
return nil, err
}
mKey, err := wgkey.ParseHex(key)
machineKey, err := wgkey.ParseHex(key)
if err != nil {
return nil, err
}
m := Machine{}
if result := h.db.First(&m, "machine_key = ?", mKey.HexString()); errors.Is(
machine := Machine{}
if result := h.db.First(&machine, "machine_key = ?", machineKey.HexString()); errors.Is(
result.Error,
gorm.ErrRecordNotFound,
) {
@ -607,15 +618,15 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err
log.Trace().
Caller().
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("Attempting to register machine")
if m.isAlreadyRegistered() {
if machine.isAlreadyRegistered() {
err := errors.New("Machine already registered")
log.Error().
Caller().
Err(err).
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("Attempting to register machine")
return nil, err
@ -626,7 +637,7 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err
log.Error().
Caller().
Err(err).
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("Could not find IP for the new machine")
return nil, err
@ -634,27 +645,27 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err
log.Trace().
Caller().
Str("machine", m.Name).
Str("machine", machine.Name).
Str("ip", ip.String()).
Msg("Found IP for host")
m.IPAddress = ip.String()
m.NamespaceID = ns.ID
m.Registered = true
m.RegisterMethod = "cli"
h.db.Save(&m)
machine.IPAddress = ip.String()
machine.NamespaceID = namespace.ID
machine.Registered = true
machine.RegisterMethod = "cli"
h.db.Save(&machine)
log.Trace().
Caller().
Str("machine", m.Name).
Str("machine", machine.Name).
Str("ip", ip.String()).
Msg("Machine registered with the database")
return &m, nil
return &machine, nil
}
func (m *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) {
hostInfo, err := m.GetHostInfo()
func (machine *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) {
hostInfo, err := machine.GetHostInfo()
if err != nil {
return nil, err
}
@ -662,8 +673,8 @@ func (m *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) {
return hostInfo.RoutableIPs, nil
}
func (m *Machine) GetEnabledRoutes() ([]netaddr.IPPrefix, error) {
data, err := m.EnabledRoutes.MarshalJSON()
func (machine *Machine) GetEnabledRoutes() ([]netaddr.IPPrefix, error) {
data, err := machine.EnabledRoutes.MarshalJSON()
if err != nil {
return nil, err
}
@ -686,13 +697,13 @@ func (m *Machine) GetEnabledRoutes() ([]netaddr.IPPrefix, error) {
return routes, nil
}
func (m *Machine) IsRoutesEnabled(routeStr string) bool {
func (machine *Machine) IsRoutesEnabled(routeStr string) bool {
route, err := netaddr.ParseIPPrefix(routeStr)
if err != nil {
return false
}
enabledRoutes, err := m.GetEnabledRoutes()
enabledRoutes, err := machine.GetEnabledRoutes()
if err != nil {
return false
}
@ -708,7 +719,7 @@ func (m *Machine) IsRoutesEnabled(routeStr string) bool {
// EnableNodeRoute enables new routes based on a list of new routes. It will _replace_ the
// previous list of routes.
func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
func (h *Headscale) EnableRoutes(machine *Machine, routeStrs ...string) error {
newRoutes := make([]netaddr.IPPrefix, len(routeStrs))
for index, routeStr := range routeStrs {
route, err := netaddr.ParseIPPrefix(routeStr)
@ -719,7 +730,7 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
newRoutes[index] = route
}
availableRoutes, err := m.GetAdvertisedRoutes()
availableRoutes, err := machine.GetAdvertisedRoutes()
if err != nil {
return err
}
@ -728,7 +739,7 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
if !containsIpPrefix(availableRoutes, newRoute) {
return fmt.Errorf(
"route (%s) is not available on node %s",
m.Name,
machine.Name,
newRoute,
)
}
@ -739,10 +750,10 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
return err
}
m.EnabledRoutes = datatypes.JSON(routes)
h.db.Save(&m)
machine.EnabledRoutes = datatypes.JSON(routes)
h.db.Save(&machine)
err = h.RequestMapUpdates(m.NamespaceID)
err = h.RequestMapUpdates(machine.NamespaceID)
if err != nil {
return err
}
@ -750,13 +761,13 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
return nil
}
func (m *Machine) RoutesToProto() (*v1.Routes, error) {
availableRoutes, err := m.GetAdvertisedRoutes()
func (machine *Machine) RoutesToProto() (*v1.Routes, error) {
availableRoutes, err := machine.GetAdvertisedRoutes()
if err != nil {
return nil, err
}
enabledRoutes, err := m.GetEnabledRoutes()
enabledRoutes, err := machine.GetEnabledRoutes()
if err != nil {
return nil, err
}

View File

@ -32,12 +32,12 @@ type Namespace struct {
// CreateNamespace creates a new Namespace. Returns error if could not be created
// or another namespace already exists.
func (h *Headscale) CreateNamespace(name string) (*Namespace, error) {
n := Namespace{}
if err := h.db.Where("name = ?", name).First(&n).Error; err == nil {
namespace := Namespace{}
if err := h.db.Where("name = ?", name).First(&namespace).Error; err == nil {
return nil, errorNamespaceExists
}
n.Name = name
if err := h.db.Create(&n).Error; err != nil {
namespace.Name = name
if err := h.db.Create(&namespace).Error; err != nil {
log.Error().
Str("func", "CreateNamespace").
Err(err).
@ -46,22 +46,22 @@ func (h *Headscale) CreateNamespace(name string) (*Namespace, error) {
return nil, err
}
return &n, nil
return &namespace, nil
}
// DestroyNamespace destroys a Namespace. Returns error if the Namespace does
// not exist or if there are machines associated with it.
func (h *Headscale) DestroyNamespace(name string) error {
n, err := h.GetNamespace(name)
namespace, err := h.GetNamespace(name)
if err != nil {
return errorNamespaceNotFound
}
m, err := h.ListMachinesInNamespace(name)
machines, err := h.ListMachinesInNamespace(name)
if err != nil {
return err
}
if len(m) > 0 {
if len(machines) > 0 {
return errorNamespaceNotEmptyOfNodes
}
@ -69,14 +69,14 @@ func (h *Headscale) DestroyNamespace(name string) error {
if err != nil {
return err
}
for _, p := range keys {
err = h.DestroyPreAuthKey(&p)
for _, key := range keys {
err = h.DestroyPreAuthKey(&key)
if err != nil {
return err
}
}
if result := h.db.Unscoped().Delete(&n); result.Error != nil {
if result := h.db.Unscoped().Delete(&namespace); result.Error != nil {
return result.Error
}
@ -86,7 +86,7 @@ func (h *Headscale) DestroyNamespace(name string) error {
// RenameNamespace renames a Namespace. Returns error if the Namespace does
// not exist or if another Namespace exists with the new name.
func (h *Headscale) RenameNamespace(oldName, newName string) error {
n, err := h.GetNamespace(oldName)
oldNamespace, err := h.GetNamespace(oldName)
if err != nil {
return err
}
@ -98,13 +98,13 @@ func (h *Headscale) RenameNamespace(oldName, newName string) error {
return err
}
n.Name = newName
oldNamespace.Name = newName
if result := h.db.Save(&n); result.Error != nil {
if result := h.db.Save(&oldNamespace); result.Error != nil {
return result.Error
}
err = h.RequestMapUpdates(n.ID)
err = h.RequestMapUpdates(oldNamespace.ID)
if err != nil {
return err
}
@ -114,15 +114,15 @@ func (h *Headscale) RenameNamespace(oldName, newName string) error {
// GetNamespace fetches a namespace by name.
func (h *Headscale) GetNamespace(name string) (*Namespace, error) {
n := Namespace{}
if result := h.db.First(&n, "name = ?", name); errors.Is(
namespace := Namespace{}
if result := h.db.First(&namespace, "name = ?", name); errors.Is(
result.Error,
gorm.ErrRecordNotFound,
) {
return nil, errorNamespaceNotFound
}
return &n, nil
return &namespace, nil
}
// ListNamespaces gets all the existing namespaces.
@ -137,13 +137,13 @@ func (h *Headscale) ListNamespaces() ([]Namespace, error) {
// ListMachinesInNamespace gets all the nodes in a given namespace.
func (h *Headscale) ListMachinesInNamespace(name string) ([]Machine, error) {
n, err := h.GetNamespace(name)
namespace, err := h.GetNamespace(name)
if err != nil {
return nil, err
}
machines := []Machine{}
if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Where(&Machine{NamespaceID: n.ID}).Find(&machines).Error; err != nil {
if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Where(&Machine{NamespaceID: namespace.ID}).Find(&machines).Error; err != nil {
return nil, err
}
@ -176,17 +176,18 @@ func (h *Headscale) ListSharedMachinesInNamespace(name string) ([]Machine, error
}
// SetMachineNamespace assigns a Machine to a namespace.
func (h *Headscale) SetMachineNamespace(m *Machine, namespaceName string) error {
n, err := h.GetNamespace(namespaceName)
func (h *Headscale) SetMachineNamespace(machine *Machine, namespaceName string) error {
namespace, err := h.GetNamespace(namespaceName)
if err != nil {
return err
}
m.NamespaceID = n.ID
h.db.Save(&m)
machine.NamespaceID = namespace.ID
h.db.Save(&machine)
return nil
}
// TODO(kradalby): Remove the need for this.
// RequestMapUpdates signals the KV worker to update the maps for this namespace.
func (h *Headscale) RequestMapUpdates(namespaceID uint) error {
namespace := Namespace{}
@ -194,8 +195,8 @@ func (h *Headscale) RequestMapUpdates(namespaceID uint) error {
return err
}
v, err := h.getValue("namespaces_pending_updates")
if err != nil || v == "" {
namespacesPendingUpdates, err := h.getValue("namespaces_pending_updates")
if err != nil || namespacesPendingUpdates == "" {
err = h.setValue(
"namespaces_pending_updates",
fmt.Sprintf(`["%s"]`, namespace.Name),
@ -207,7 +208,7 @@ func (h *Headscale) RequestMapUpdates(namespaceID uint) error {
return nil
}
names := []string{}
err = json.Unmarshal([]byte(v), &names)
err = json.Unmarshal([]byte(namespacesPendingUpdates), &names)
if err != nil {
err = h.setValue(
"namespaces_pending_updates",
@ -235,16 +236,16 @@ func (h *Headscale) RequestMapUpdates(namespaceID uint) error {
}
func (h *Headscale) checkForNamespacesPendingUpdates() {
v, err := h.getValue("namespaces_pending_updates")
namespacesPendingUpdates, err := h.getValue("namespaces_pending_updates")
if err != nil {
return
}
if v == "" {
if namespacesPendingUpdates == "" {
return
}
namespaces := []string{}
err = json.Unmarshal([]byte(v), &namespaces)
err = json.Unmarshal([]byte(namespacesPendingUpdates), &namespaces)
if err != nil {
return
}
@ -255,11 +256,11 @@ func (h *Headscale) checkForNamespacesPendingUpdates() {
Msg("Sending updates to nodes in namespacespace")
h.setLastStateChangeToNow(namespace)
}
newV, err := h.getValue("namespaces_pending_updates")
newPendingUpdateValue, err := h.getValue("namespaces_pending_updates")
if err != nil {
return
}
if v == newV { // only clear when no changes, so we notified everybody
if namespacesPendingUpdates == newPendingUpdateValue { // only clear when no changes, so we notified everybody
err = h.setValue("namespaces_pending_updates", "")
if err != nil {
log.Error().
@ -273,7 +274,7 @@ func (h *Headscale) checkForNamespacesPendingUpdates() {
}
func (n *Namespace) toUser() *tailcfg.User {
u := tailcfg.User{
user := tailcfg.User{
ID: tailcfg.UserID(n.ID),
LoginName: n.Name,
DisplayName: n.Name,
@ -283,11 +284,11 @@ func (n *Namespace) toUser() *tailcfg.User {
Created: time.Time{},
}
return &u
return &user
}
func (n *Namespace) toLogin() *tailcfg.Login {
l := tailcfg.Login{
login := tailcfg.Login{
ID: tailcfg.LoginID(n.ID),
LoginName: n.Name,
DisplayName: n.Name,
@ -295,14 +296,14 @@ func (n *Namespace) toLogin() *tailcfg.Login {
Domain: "headscale.net",
}
return &l
return &login
}
func getMapResponseUserProfiles(m Machine, peers Machines) []tailcfg.UserProfile {
func getMapResponseUserProfiles(machine Machine, peers Machines) []tailcfg.UserProfile {
namespaceMap := make(map[string]Namespace)
namespaceMap[m.Namespace.Name] = m.Namespace
for _, p := range peers {
namespaceMap[p.Namespace.Name] = p.Namespace // not worth checking if already is there
namespaceMap[machine.Namespace.Name] = machine.Namespace
for _, peer := range peers {
namespaceMap[peer.Namespace.Name] = peer.Namespace // not worth checking if already is there
}
profiles := []tailcfg.UserProfile{}

69
oidc.go
View File

@ -68,10 +68,10 @@ func (h *Headscale) initOIDC() error {
// RegisterOIDC redirects to the OIDC provider for authentication
// Puts machine key in cache so the callback can retrieve it using the oidc state param
// Listens in /oidc/register/:mKey.
func (h *Headscale) RegisterOIDC(c *gin.Context) {
mKeyStr := c.Param("mkey")
func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
mKeyStr := ctx.Param("mkey")
if mKeyStr == "" {
c.String(http.StatusBadRequest, "Wrong params")
ctx.String(http.StatusBadRequest, "Wrong params")
return
}
@ -79,7 +79,7 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) {
b := make([]byte, RANDOM_BYTE_SIZE)
if _, err := rand.Read(b); err != nil {
log.Error().Msg("could not read 16 bytes from rand")
c.String(http.StatusInternalServerError, "could not read 16 bytes from rand")
ctx.String(http.StatusInternalServerError, "could not read 16 bytes from rand")
return
}
@ -92,7 +92,7 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) {
authUrl := h.oauth2Config.AuthCodeURL(stateStr)
log.Debug().Msgf("Redirecting to %s for authentication", authUrl)
c.Redirect(http.StatusFound, authUrl)
ctx.Redirect(http.StatusFound, authUrl)
}
// OIDCCallback handles the callback from the OIDC endpoint
@ -100,19 +100,19 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) {
// TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities
// TODO: Add groups information from OIDC tokens into machine HostInfo
// Listens in /oidc/callback.
func (h *Headscale) OIDCCallback(c *gin.Context) {
code := c.Query("code")
state := c.Query("state")
func (h *Headscale) OIDCCallback(ctx *gin.Context) {
code := ctx.Query("code")
state := ctx.Query("state")
if code == "" || state == "" {
c.String(http.StatusBadRequest, "Wrong params")
ctx.String(http.StatusBadRequest, "Wrong params")
return
}
oauth2Token, err := h.oauth2Config.Exchange(context.Background(), code)
if err != nil {
c.String(http.StatusBadRequest, "Could not exchange code for token")
ctx.String(http.StatusBadRequest, "Could not exchange code for token")
return
}
@ -121,7 +121,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
rawIDToken, rawIDTokenOK := oauth2Token.Extra("id_token").(string)
if !rawIDTokenOK {
c.String(http.StatusBadRequest, "Could not extract ID Token")
ctx.String(http.StatusBadRequest, "Could not extract ID Token")
return
}
@ -130,7 +130,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
idToken, err := verifier.Verify(context.Background(), rawIDToken)
if err != nil {
c.String(http.StatusBadRequest, "Failed to verify id token: %s", err.Error())
ctx.String(http.StatusBadRequest, "Failed to verify id token: %s", err.Error())
return
}
@ -145,7 +145,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
// Extract custom claims
var claims IDTokenClaims
if err = idToken.Claims(&claims); err != nil {
c.String(
ctx.String(
http.StatusBadRequest,
fmt.Sprintf("Failed to decode id token claims: %s", err),
)
@ -159,7 +159,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
if !mKeyFound {
log.Error().
Msg("requested machine state key expired before authorisation completed")
c.String(http.StatusBadRequest, "state has expired")
ctx.String(http.StatusBadRequest, "state has expired")
return
}
@ -167,16 +167,19 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
if !mKeyOK {
log.Error().Msg("could not get machine key from cache")
c.String(http.StatusInternalServerError, "could not get machine key from cache")
ctx.String(
http.StatusInternalServerError,
"could not get machine key from cache",
)
return
}
// retrieve machine information
m, err := h.GetMachineByMachineKey(mKeyStr)
machine, err := h.GetMachineByMachineKey(mKeyStr)
if err != nil {
log.Error().Msg("machine key not found in database")
c.String(
ctx.String(
http.StatusInternalServerError,
"could not get machine info from database",
)
@ -186,19 +189,19 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
now := time.Now().UTC()
if nsName, ok := h.getNamespaceFromEmail(claims.Email); ok {
if namespaceName, ok := h.getNamespaceFromEmail(claims.Email); ok {
// register the machine if it's new
if !m.Registered {
if !machine.Registered {
log.Debug().Msg("Registering new machine after successful callback")
ns, err := h.GetNamespace(nsName)
namespace, err := h.GetNamespace(namespaceName)
if err != nil {
ns, err = h.CreateNamespace(nsName)
namespace, err = h.CreateNamespace(namespaceName)
if err != nil {
log.Error().
Msgf("could not create new namespace '%s'", claims.Email)
c.String(
ctx.String(
http.StatusInternalServerError,
"could not create new namespace",
)
@ -209,7 +212,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
ip, err := h.getAvailableIP()
if err != nil {
c.String(
ctx.String(
http.StatusInternalServerError,
"could not get an IP from the pool",
)
@ -217,17 +220,17 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
return
}
m.IPAddress = ip.String()
m.NamespaceID = ns.ID
m.Registered = true
m.RegisterMethod = "oidc"
m.LastSuccessfulUpdate = &now
h.db.Save(&m)
machine.IPAddress = ip.String()
machine.NamespaceID = namespace.ID
machine.Registered = true
machine.RegisterMethod = "oidc"
machine.LastSuccessfulUpdate = &now
h.db.Save(&machine)
}
h.updateMachineExpiry(m)
h.updateMachineExpiry(machine)
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
ctx.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
<html>
<body>
<h1>headscale</h1>
@ -243,9 +246,9 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
log.Error().
Str("email", claims.Email).
Str("username", claims.Username).
Str("machine", m.Name).
Str("machine", machine.Name).
Msg("Email could not be mapped to a namespace")
c.String(
ctx.String(
http.StatusBadRequest,
"email from claim could not be mapped to a namespace",
)

View File

@ -233,7 +233,7 @@ func (h *Headscale) PollNetMapStream(
) {
go h.scheduledPollWorker(cancelKeepAlive, updateChan, keepAliveChan, mKey, req, m)
c.Stream(func(w io.Writer) bool {
c.Stream(func(writer io.Writer) bool {
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
@ -252,7 +252,7 @@ func (h *Headscale) PollNetMapStream(
Str("channel", "pollData").
Int("bytes", len(data)).
Msg("Sending data received via pollData channel")
_, err := w.Write(data)
_, err := writer.Write(data)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
@ -305,7 +305,7 @@ func (h *Headscale) PollNetMapStream(
Str("channel", "keepAlive").
Int("bytes", len(data)).
Msg("Sending keep alive message")
_, err := w.Write(data)
_, err := writer.Write(data)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
@ -370,7 +370,7 @@ func (h *Headscale) PollNetMapStream(
Err(err).
Msg("Could not get the map update")
}
_, err = w.Write(data)
_, err = writer.Write(data)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").

View File

@ -18,13 +18,16 @@ type SharedMachine struct {
}
// AddSharedMachineToNamespace adds a machine as a shared node to a namespace.
func (h *Headscale) AddSharedMachineToNamespace(m *Machine, ns *Namespace) error {
if m.NamespaceID == ns.ID {
func (h *Headscale) AddSharedMachineToNamespace(
machine *Machine,
namespace *Namespace,
) error {
if machine.NamespaceID == namespace.ID {
return errorSameNamespace
}
sharedMachines := []SharedMachine{}
if err := h.db.Where("machine_id = ? AND namespace_id = ?", m.ID, ns.ID).Find(&sharedMachines).Error; err != nil {
if err := h.db.Where("machine_id = ? AND namespace_id = ?", machine.ID, namespace.ID).Find(&sharedMachines).Error; err != nil {
return err
}
if len(sharedMachines) > 0 {
@ -32,10 +35,10 @@ func (h *Headscale) AddSharedMachineToNamespace(m *Machine, ns *Namespace) error
}
sharedMachine := SharedMachine{
MachineID: m.ID,
Machine: *m,
NamespaceID: ns.ID,
Namespace: *ns,
MachineID: machine.ID,
Machine: *machine,
NamespaceID: namespace.ID,
Namespace: *namespace,
}
h.db.Save(&sharedMachine)
@ -43,14 +46,17 @@ func (h *Headscale) AddSharedMachineToNamespace(m *Machine, ns *Namespace) error
}
// RemoveSharedMachineFromNamespace removes a shared machine from a namespace.
func (h *Headscale) RemoveSharedMachineFromNamespace(m *Machine, ns *Namespace) error {
if m.NamespaceID == ns.ID {
func (h *Headscale) RemoveSharedMachineFromNamespace(
machine *Machine,
namespace *Namespace,
) error {
if machine.NamespaceID == namespace.ID {
// Can't unshare from primary namespace
return errorMachineNotShared
}
sharedMachine := SharedMachine{}
result := h.db.Where("machine_id = ? AND namespace_id = ?", m.ID, ns.ID).
result := h.db.Where("machine_id = ? AND namespace_id = ?", machine.ID, namespace.ID).
Unscoped().
Delete(&sharedMachine)
if result.Error != nil {
@ -61,7 +67,7 @@ func (h *Headscale) RemoveSharedMachineFromNamespace(m *Machine, ns *Namespace)
return errorMachineNotShared
}
err := h.RequestMapUpdates(ns.ID)
err := h.RequestMapUpdates(namespace.ID)
if err != nil {
return err
}
@ -70,9 +76,9 @@ func (h *Headscale) RemoveSharedMachineFromNamespace(m *Machine, ns *Namespace)
}
// RemoveSharedMachineFromAllNamespaces removes a machine as a shared node from all namespaces.
func (h *Headscale) RemoveSharedMachineFromAllNamespaces(m *Machine) error {
func (h *Headscale) RemoveSharedMachineFromAllNamespaces(machine *Machine) error {
sharedMachine := SharedMachine{}
if result := h.db.Where("machine_id = ?", m.ID).Unscoped().Delete(&sharedMachine); result.Error != nil {
if result := h.db.Where("machine_id = ?", machine.ID).Unscoped().Delete(&sharedMachine); result.Error != nil {
return result.Error
}

View File

@ -13,8 +13,8 @@ import (
//go:embed gen/openapiv2/headscale/v1/headscale.swagger.json
var apiV1JSON []byte
func SwaggerUI(c *gin.Context) {
t := template.Must(template.New("swagger").Parse(`
func SwaggerUI(ctx *gin.Context) {
swaggerTemplate := template.Must(template.New("swagger").Parse(`
<html>
<head>
<link rel="stylesheet" type="text/css" href="https://unpkg.com/swagger-ui-dist@3/swagger-ui.css">
@ -47,12 +47,12 @@ func SwaggerUI(c *gin.Context) {
</html>`))
var payload bytes.Buffer
if err := t.Execute(&payload, struct{}{}); err != nil {
if err := swaggerTemplate.Execute(&payload, struct{}{}); err != nil {
log.Error().
Caller().
Err(err).
Msg("Could not render Swagger")
c.Data(
ctx.Data(
http.StatusInternalServerError,
"text/html; charset=utf-8",
[]byte("Could not render Swagger"),
@ -61,9 +61,9 @@ func SwaggerUI(c *gin.Context) {
return
}
c.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes())
ctx.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes())
}
func SwaggerAPIv1(c *gin.Context) {
c.Data(http.StatusOK, "application/json; charset=utf-8", apiV1JSON)
func SwaggerAPIv1(ctx *gin.Context) {
ctx.Data(http.StatusOK, "application/json; charset=utf-8", apiV1JSON)
}

View File

@ -36,7 +36,7 @@ func decode(
func decodeMsg(
msg []byte,
v interface{},
output interface{},
pubKey *wgkey.Key,
privKey *wgkey.Private,
) error {
@ -45,7 +45,7 @@ func decodeMsg(
return err
}
// fmt.Println(string(decrypted))
if err := json.Unmarshal(decrypted, v); err != nil {
if err := json.Unmarshal(decrypted, output); err != nil {
return fmt.Errorf("response: %v", err)
}
@ -78,13 +78,17 @@ func encode(v interface{}, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, e
return encodeMsg(b, pubKey, privKey)
}
func encodeMsg(b []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, error) {
func encodeMsg(
payload []byte,
pubKey *wgkey.Key,
privKey *wgkey.Private,
) ([]byte, error) {
var nonce [24]byte
if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil {
panic(err)
}
pub, pri := (*[32]byte)(pubKey), (*[32]byte)(privKey)
msg := box.Seal(nonce[:], b, &nonce, pub, pri)
msg := box.Seal(nonce[:], payload, &nonce, pub, pri)
return msg, nil
}