headscale/utils.go

253 lines
5.5 KiB
Go
Raw Normal View History

2020-06-21 06:32:08 -04:00
// Codehere is mostly taken from github.com/tailscale/tailscale
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package headscale
import (
2021-10-30 10:29:03 -04:00
"context"
2020-06-21 06:32:08 -04:00
"encoding/json"
"fmt"
2021-10-30 10:29:03 -04:00
"net"
2021-08-13 05:33:19 -04:00
"strings"
2020-06-21 06:32:08 -04:00
"github.com/rs/zerolog/log"
"inet.af/netaddr"
2021-08-13 05:33:19 -04:00
"tailscale.com/tailcfg"
"tailscale.com/types/key"
2020-06-21 06:32:08 -04:00
)
2021-11-15 14:18:14 -05:00
const (
errCannotDecryptReponse = Error("cannot decrypt response")
errCouldNotAllocateIP = Error("could not find any suitable IP")
// These constants are copied from the upstream tailscale.com/types/key
// library, because they are not exported.
// https://github.com/tailscale/tailscale/tree/main/types/key
// nodePublicHexPrefix is the prefix used to identify a
// hex-encoded node public key.
//
// This prefix is used in the control protocol, so cannot be
// changed.
nodePublicHexPrefix = "nodekey:"
// machinePublicHexPrefix is the prefix used to identify a
// hex-encoded machine public key.
//
// This prefix is used in the control protocol, so cannot be
// changed.
machinePublicHexPrefix = "mkey:"
// discoPublicHexPrefix is the prefix used to identify a
// hex-encoded disco public key.
//
// This prefix is used in the control protocol, so cannot be
// changed.
discoPublicHexPrefix = "discokey:"
2021-11-15 14:18:14 -05:00
)
func MachinePublicKeyStripPrefix(machineKey key.MachinePublic) string {
return strings.TrimPrefix(machineKey.String(), machinePublicHexPrefix)
}
func NodePublicKeyStripPrefix(nodeKey key.NodePublic) string {
return strings.TrimPrefix(nodeKey.String(), nodePublicHexPrefix)
}
func DiscoPublicKeyStripPrefix(discoKey key.DiscoPublic) string {
return strings.TrimPrefix(discoKey.String(), discoPublicHexPrefix)
}
func MachinePublicKeyEnsurePrefix(machineKey string) string {
if !strings.HasPrefix(machineKey, machinePublicHexPrefix) {
return machinePublicHexPrefix + machineKey
}
return machineKey
}
func NodePublicKeyEnsurePrefix(nodeKey string) string {
if !strings.HasPrefix(nodeKey, nodePublicHexPrefix) {
return nodePublicHexPrefix + nodeKey
}
return nodeKey
}
func DiscoPublicKeyEnsurePrefix(discoKey string) string {
if !strings.HasPrefix(discoKey, discoPublicHexPrefix) {
return discoPublicHexPrefix + discoKey
}
return discoKey
}
2021-05-05 19:01:45 -04:00
// Error is used to compare errors as per https://dave.cheney.net/2016/04/07/constant-errors
2021-05-05 17:00:04 -04:00
type Error string
func (e Error) Error() string { return string(e) }
2021-11-13 03:36:45 -05:00
func decode(
msg []byte,
output interface{},
pubKey *key.MachinePublic,
privKey *key.MachinePrivate,
2021-11-13 03:36:45 -05:00
) error {
log.Trace().Int("length", len(msg)).Msg("Trying to decrypt")
decrypted, ok := privKey.OpenFrom(*pubKey, msg)
if !ok {
return errCannotDecryptReponse
2020-06-21 06:32:08 -04:00
}
if err := json.Unmarshal(decrypted, output); err != nil {
2021-11-15 14:18:14 -05:00
return err
2020-06-21 06:32:08 -04:00
}
2021-11-14 10:46:09 -05:00
2020-06-21 06:32:08 -04:00
return nil
}
func encode(
v interface{},
pubKey *key.MachinePublic,
privKey *key.MachinePrivate,
) ([]byte, error) {
2020-06-21 06:32:08 -04:00
b, err := json.Marshal(v)
if err != nil {
return nil, err
}
2021-08-13 05:33:19 -04:00
return privKey.SealTo(*pubKey, b), nil
2020-06-21 06:32:08 -04:00
}
func (h *Headscale) getAvailableIP() (*netaddr.IP, error) {
ipPrefix := h.cfg.IPPrefix
usedIps, err := h.getUsedIPs()
if err != nil {
return nil, err
}
// Get the first IP in our prefix
ip := ipPrefix.IP()
2020-06-21 06:32:08 -04:00
for {
if !ipPrefix.Contains(ip) {
2021-11-15 14:18:14 -05:00
return nil, errCouldNotAllocateIP
2020-06-21 06:32:08 -04:00
}
// Some OS (including Linux) does not like when IPs ends with 0 or 255, which
// is typically called network or broadcast. Lets avoid them and continue
// to look when we get one of those traditionally reserved IPs.
ipRaw := ip.As4()
if ipRaw[3] == 0 || ipRaw[3] == 255 {
ip = ip.Next()
2021-11-14 10:46:09 -05:00
continue
}
if ip.IsZero() &&
ip.IsLoopback() {
ip = ip.Next()
2021-11-14 10:46:09 -05:00
continue
2020-06-21 06:32:08 -04:00
}
if !containsIPs(usedIps, ip) {
return &ip, nil
2020-06-21 06:32:08 -04:00
}
ip = ip.Next()
2020-06-21 06:32:08 -04:00
}
}
func (h *Headscale) getUsedIPs() ([]netaddr.IP, error) {
var addresses []string
h.db.Model(&Machine{}).Pluck("ip_address", &addresses)
ips := make([]netaddr.IP, len(addresses))
for index, addr := range addresses {
if addr != "" {
ip, err := netaddr.ParseIP(addr)
if err != nil {
2021-11-15 14:18:14 -05:00
return nil, fmt.Errorf("failed to parse ip from database: %w", err)
}
ips[index] = ip
}
}
return ips, nil
}
func containsIPs(ips []netaddr.IP, ip netaddr.IP) bool {
for _, v := range ips {
if v == ip {
return true
2020-06-21 06:32:08 -04:00
}
}
2021-02-21 16:11:27 -05:00
return false
2020-06-21 06:32:08 -04:00
}
2021-08-13 05:33:19 -04:00
func tailNodesToString(nodes []*tailcfg.Node) string {
temp := make([]string, len(nodes))
for index, node := range nodes {
temp[index] = node.Name
}
return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
}
func tailMapResponseToString(resp tailcfg.MapResponse) string {
2021-11-13 03:36:45 -05:00
return fmt.Sprintf(
"{ Node: %s, Peers: %s }",
resp.Node.Name,
tailNodesToString(resp.Peers),
)
2021-08-13 05:33:19 -04:00
}
2021-10-30 10:29:03 -04:00
func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) {
var d net.Dialer
2021-11-14 10:46:09 -05:00
2021-10-30 10:29:03 -04:00
return d.DialContext(ctx, "unix", addr)
}
2021-11-04 18:17:44 -04:00
func ipPrefixToString(prefixes []netaddr.IPPrefix) []string {
result := make([]string, len(prefixes))
for index, prefix := range prefixes {
result[index] = prefix.String()
}
return result
}
func stringToIPPrefix(prefixes []string) ([]netaddr.IPPrefix, error) {
2021-11-04 18:17:44 -04:00
result := make([]netaddr.IPPrefix, len(prefixes))
for index, prefixStr := range prefixes {
prefix, err := netaddr.ParseIPPrefix(prefixStr)
if err != nil {
return []netaddr.IPPrefix{}, err
}
result[index] = prefix
}
return result, nil
}
func containsIPPrefix(prefixes []netaddr.IPPrefix, prefix netaddr.IPPrefix) bool {
2021-11-04 18:17:44 -04:00
for _, p := range prefixes {
if prefix == p {
return true
}
}
return false
}