Clean up logging and error handling in oidc

We should never expose errors via web, it gives attackers a lot of info
(Insert OWASP guide).

Also handle error that didnt separate not found gorm issue and other
errors.
This commit is contained in:
Kristoffer Dalby 2021-11-21 21:51:39 +00:00
parent fac33e46e1
commit fcd4d94927
No known key found for this signature in database
GPG Key ID: 09F62DC067465735
1 changed files with 48 additions and 14 deletions

62
oidc.go
View File

@ -4,6 +4,7 @@ import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"net/http"
"regexp"
@ -15,6 +16,7 @@ import (
"github.com/patrickmn/go-cache"
"github.com/rs/zerolog/log"
"golang.org/x/oauth2"
"gorm.io/gorm"
)
const (
@ -37,7 +39,10 @@ func (h *Headscale) initOIDC() error {
h.oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDC.Issuer)
if err != nil {
log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error())
log.Error().
Err(err).
Caller().
Msgf("Could not retrieve OIDC Config: %s", err.Error())
return err
}
@ -69,8 +74,8 @@ func (h *Headscale) initOIDC() error {
// Puts machine key in cache so the callback can retrieve it using the oidc state param
// Listens in /oidc/register/:mKey.
func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
mKeyStr := ctx.Param("mkey")
if mKeyStr == "" {
machineKeyStr := ctx.Param("mkey")
if machineKeyStr == "" {
ctx.String(http.StatusBadRequest, "Wrong params")
return
@ -78,7 +83,9 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
randomBlob := make([]byte, randomByteSize)
if _, err := rand.Read(randomBlob); err != nil {
log.Error().Msg("could not read 16 bytes from rand")
log.Error().
Caller().
Msg("could not read 16 bytes from rand")
ctx.String(http.StatusInternalServerError, "could not read 16 bytes from rand")
return
@ -87,7 +94,7 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
stateStr := hex.EncodeToString(randomBlob)[:32]
// place the machine key into the state cache, so it can be retrieved later
h.oidcStateCache.Set(stateStr, mKeyStr, oidcStateCacheExpiration)
h.oidcStateCache.Set(stateStr, machineKeyStr, oidcStateCacheExpiration)
authURL := h.oauth2Config.AuthCodeURL(stateStr)
log.Debug().Msgf("Redirecting to %s for authentication", authURL)
@ -130,7 +137,11 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
idToken, err := verifier.Verify(context.Background(), rawIDToken)
if err != nil {
ctx.String(http.StatusBadRequest, "Failed to verify id token: %s", err.Error())
log.Error().
Err(err).
Caller().
Msg("failed to verify id token")
ctx.String(http.StatusBadRequest, "Failed to verify id token")
return
}
@ -145,27 +156,31 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
// Extract custom claims
var claims IDTokenClaims
if err = idToken.Claims(&claims); err != nil {
log.Error().
Err(err).
Caller().
Msg("Failed to decode id token claims")
ctx.String(
http.StatusBadRequest,
fmt.Sprintf("Failed to decode id token claims: %s", err),
fmt.Sprintf("Failed to decode id token claims"),
)
return
}
// retrieve machinekey from state cache
mKeyIf, mKeyFound := h.oidcStateCache.Get(state)
machineKeyIf, machineKeyFound := h.oidcStateCache.Get(state)
if !mKeyFound {
if !machineKeyFound {
log.Error().
Msg("requested machine state key expired before authorisation completed")
ctx.String(http.StatusBadRequest, "state has expired")
return
}
mKeyStr, mKeyOK := mKeyIf.(string)
machineKey, machineKeyOK := machineKeyIf.(string)
if !mKeyOK {
if !machineKeyOK {
log.Error().Msg("could not get machine key from cache")
ctx.String(
http.StatusInternalServerError,
@ -176,7 +191,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
}
// retrieve machine information
machine, err := h.GetMachineByMachineKey(mKeyStr)
machine, err := h.GetMachineByMachineKey(machineKey)
if err != nil {
log.Error().Msg("machine key not found in database")
ctx.String(
@ -195,12 +210,14 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
log.Debug().Msg("Registering new machine after successful callback")
namespace, err := h.GetNamespace(namespaceName)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
namespace, err = h.CreateNamespace(namespaceName)
if err != nil {
log.Error().
Msgf("could not create new namespace '%s'", claims.Email)
Err(err).
Caller().
Msgf("could not create new namespace '%s'", namespaceName)
ctx.String(
http.StatusInternalServerError,
"could not create new namespace",
@ -208,10 +225,26 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
return
}
} else if err != nil {
log.Error().
Caller().
Err(err).
Str("namespace", namespaceName).
Msg("could not find or create namespace")
ctx.String(
http.StatusInternalServerError,
"could not find or create namespace",
)
return
}
ip, err := h.getAvailableIP()
if err != nil {
log.Error().
Caller().
Err(err).
Msg("could not get an IP from the pool")
ctx.String(
http.StatusInternalServerError,
"could not get an IP from the pool",
@ -242,6 +275,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
}
log.Error().
Caller().
Str("email", claims.Email).
Str("username", claims.Username).
Str("machine", machine.Name).