From a1e7e771cecfbcd988c340e1e8fd1d6cd5a5d467 Mon Sep 17 00:00:00 2001
From: Grigoriy Mikhalkin <grigoriymikhalkin@gmail.com>
Date: Sun, 7 Aug 2022 13:57:07 +0200
Subject: [PATCH] refactor OIDC callback aux functions

---
 oidc.go | 185 +++++++++++++++++++++++++++++---------------------------
 1 file changed, 96 insertions(+), 89 deletions(-)

diff --git a/oidc.go b/oidc.go
index 5509bd47..a385a921 100644
--- a/oidc.go
+++ b/oidc.go
@@ -21,6 +21,13 @@ import (
 
 const (
 	randomByteSize = 16
+
+	errEmptyOIDCCallbackParams = Error("empty OIDC callback params")
+	errNoOIDCIDToken           = Error("could not extract ID Token for OIDC callback")
+	errOIDCAllowedDomains      = Error("authenticated principal does not match any allowed domain")
+	errOIDCAllowedUsers        = Error("authenticated principal does not match any allowed user")
+	errOIDCInvalidMachineState = Error("requested machine state key expired before authorisation completed")
+	errOIDCMachineKeyMissing   = Error("could not get machine key from cache")
 )
 
 type IDTokenClaims struct {
@@ -136,18 +143,18 @@ func (h *Headscale) OIDCCallback(
 	writer http.ResponseWriter,
 	req *http.Request,
 ) {
-	code, state, ok := validateOIDCCallbackParams(writer, req)
-	if !ok {
+	code, state, err := validateOIDCCallbackParams(writer, req)
+	if err != nil {
 		return
 	}
 
-	rawIDToken, ok := h.getIDTokenForOIDCCallback(writer, code, state)
-	if !ok {
+	rawIDToken, err := h.getIDTokenForOIDCCallback(writer, code, state)
+	if err != nil {
 		return
 	}
 
-	idToken, ok := h.verifyIDTokenForOIDCCallback(writer, rawIDToken)
-	if !ok {
+	idToken, err := h.verifyIDTokenForOIDCCallback(writer, rawIDToken)
+	if err != nil {
 		return
 	}
 
@@ -158,43 +165,43 @@ func (h *Headscale) OIDCCallback(
 	// 	return
 	// }
 
-	claims, ok := extractIDTokenClaims(writer, idToken)
-	if !ok {
+	claims, err := extractIDTokenClaims(writer, idToken)
+	if err != nil {
 		return
 	}
 
-	if ok := validateOIDCAllowedDomains(writer, h.cfg.OIDC.AllowedDomains, claims); !ok {
+	if err := validateOIDCAllowedDomains(writer, h.cfg.OIDC.AllowedDomains, claims); err != nil {
 		return
 	}
 
-	if ok := validateOIDCAllowedUsers(writer, h.cfg.OIDC.AllowedUsers, claims); !ok {
+	if err := validateOIDCAllowedUsers(writer, h.cfg.OIDC.AllowedUsers, claims); err != nil {
 		return
 	}
 
-	machineKey, ok := h.validateMachineForOIDCCallback(writer, state, claims)
-	if !ok {
+	machineKey, machineExists, err := h.validateMachineForOIDCCallback(writer, state, claims)
+	if err != nil || machineExists {
 		return
 	}
 
-	namespaceName, ok := getNamespaceName(writer, claims, h.cfg.OIDC.StripEmaildomain)
-	if !ok {
+	namespaceName, err := getNamespaceName(writer, claims, h.cfg.OIDC.StripEmaildomain)
+	if err != nil {
 		return
 	}
 
 	// register the machine if it's new
 	log.Debug().Msg("Registering new machine after successful callback")
 
-	namespace, ok := h.findOrCreateNewNamespaceForOIDCCallback(writer, namespaceName)
-	if !ok {
+	namespace, err := h.findOrCreateNewNamespaceForOIDCCallback(writer, namespaceName)
+	if err != nil {
 		return
 	}
 
-	if ok := h.registerMachineForOIDCCallback(writer, namespace, machineKey); !ok {
+	if err := h.registerMachineForOIDCCallback(writer, namespace, machineKey); err != nil {
 		return
 	}
 
-	content, ok := renderOIDCCallbackTemplate(writer, claims)
-	if !ok {
+	content, err := renderOIDCCallbackTemplate(writer, claims)
+	if err != nil {
 		return
 	}
 
@@ -211,7 +218,7 @@ func (h *Headscale) OIDCCallback(
 func validateOIDCCallbackParams(
 	writer http.ResponseWriter,
 	req *http.Request,
-) (string, string, bool) {
+) (string, string, error) {
 	code := req.URL.Query().Get("code")
 	state := req.URL.Query().Get("state")
 
@@ -226,16 +233,16 @@ func validateOIDCCallbackParams(
 				Msg("Failed to write response")
 		}
 
-		return "", "", false
+		return "", "", errEmptyOIDCCallbackParams
 	}
 
-	return code, state, true
+	return code, state, nil
 }
 
 func (h *Headscale) getIDTokenForOIDCCallback(
 	writer http.ResponseWriter,
 	code, state string,
-) (string, bool) {
+) (string, error) {
 	oauth2Token, err := h.oauth2Config.Exchange(context.Background(), code)
 	if err != nil {
 		log.Error().
@@ -244,15 +251,15 @@ func (h *Headscale) getIDTokenForOIDCCallback(
 			Msg("Could not exchange code for token")
 		writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
 		writer.WriteHeader(http.StatusBadRequest)
-		_, err := writer.Write([]byte("Could not exchange code for token"))
-		if err != nil {
+		_, werr := writer.Write([]byte("Could not exchange code for token"))
+		if werr != nil {
 			log.Error().
 				Caller().
-				Err(err).
+				Err(werr).
 				Msg("Failed to write response")
 		}
 
-		return "", false
+		return "", err
 	}
 
 	log.Trace().
@@ -273,16 +280,16 @@ func (h *Headscale) getIDTokenForOIDCCallback(
 				Msg("Failed to write response")
 		}
 
-		return "", false
+		return "", errNoOIDCIDToken
 	}
 
-	return rawIDToken, true
+	return rawIDToken, nil
 }
 
 func (h *Headscale) verifyIDTokenForOIDCCallback(
 	writer http.ResponseWriter,
 	rawIDToken string,
-) (*oidc.IDToken, bool) {
+) (*oidc.IDToken, error) {
 	verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDC.ClientID})
 	idToken, err := verifier.Verify(context.Background(), rawIDToken)
 	if err != nil {
@@ -292,24 +299,24 @@ func (h *Headscale) verifyIDTokenForOIDCCallback(
 			Msg("failed to verify id token")
 		writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
 		writer.WriteHeader(http.StatusBadRequest)
-		_, err := writer.Write([]byte("Failed to verify id token"))
-		if err != nil {
+		_, werr := writer.Write([]byte("Failed to verify id token"))
+		if werr != nil {
 			log.Error().
 				Caller().
-				Err(err).
+				Err(werr).
 				Msg("Failed to write response")
 		}
 
-		return nil, false
+		return nil, err
 	}
 
-	return idToken, true
+	return idToken, nil
 }
 
 func extractIDTokenClaims(
 	writer http.ResponseWriter,
 	idToken *oidc.IDToken,
-) (*IDTokenClaims, bool) {
+) (*IDTokenClaims, error) {
 	var claims IDTokenClaims
 	if err := idToken.Claims(claims); err != nil {
 		log.Error().
@@ -318,18 +325,18 @@ func extractIDTokenClaims(
 			Msg("Failed to decode id token claims")
 		writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
 		writer.WriteHeader(http.StatusBadRequest)
-		_, err := writer.Write([]byte("Failed to decode id token claims"))
-		if err != nil {
+		_, werr := writer.Write([]byte("Failed to decode id token claims"))
+		if werr != nil {
 			log.Error().
 				Caller().
-				Err(err).
+				Err(werr).
 				Msg("Failed to write response")
 		}
 
-		return nil, false
+		return nil, err
 	}
 
-	return &claims, true
+	return &claims, nil
 }
 
 // validateOIDCAllowedDomains checks that if AllowedDomains is provided,
@@ -338,7 +345,7 @@ func validateOIDCAllowedDomains(
 	writer http.ResponseWriter,
 	allowedDomains []string,
 	claims *IDTokenClaims,
-) bool {
+) error {
 	if len(allowedDomains) > 0 {
 		if at := strings.LastIndex(claims.Email, "@"); at < 0 ||
 			!IsStringInSlice(allowedDomains, claims.Email[at+1:]) {
@@ -353,11 +360,11 @@ func validateOIDCAllowedDomains(
 					Msg("Failed to write response")
 			}
 
-			return false
+			return errOIDCAllowedDomains
 		}
 	}
 
-	return true
+	return nil
 }
 
 // validateOIDCAllowedUsers checks that if AllowedUsers is provided,
@@ -366,7 +373,7 @@ func validateOIDCAllowedUsers(
 	writer http.ResponseWriter,
 	allowedUsers []string,
 	claims *IDTokenClaims,
-) bool {
+) error {
 	if len(allowedUsers) > 0 &&
 		!IsStringInSlice(allowedUsers, claims.Email) {
 		log.Error().Msg("authenticated principal does not match any allowed user")
@@ -380,10 +387,10 @@ func validateOIDCAllowedUsers(
 				Msg("Failed to write response")
 		}
 
-		return false
+		return errOIDCAllowedUsers
 	}
 
-	return true
+	return nil
 }
 
 // validateMachine retrieves machine information if it exist
@@ -394,7 +401,7 @@ func (h *Headscale) validateMachineForOIDCCallback(
 	writer http.ResponseWriter,
 	state string,
 	claims *IDTokenClaims,
-) (*key.MachinePublic, bool) {
+) (*key.MachinePublic, bool, error) {
 	// retrieve machinekey from state cache
 	machineKeyIf, machineKeyFound := h.registrationCache.Get(state)
 	if !machineKeyFound {
@@ -410,7 +417,7 @@ func (h *Headscale) validateMachineForOIDCCallback(
 				Msg("Failed to write response")
 		}
 
-		return nil, false
+		return nil, false, errOIDCInvalidMachineState
 	}
 
 	var machineKey key.MachinePublic
@@ -423,15 +430,15 @@ func (h *Headscale) validateMachineForOIDCCallback(
 			Msg("could not parse machine public key")
 		writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
 		writer.WriteHeader(http.StatusBadRequest)
-		_, err := writer.Write([]byte("could not parse public key"))
-		if err != nil {
+		_, werr := writer.Write([]byte("could not parse public key"))
+		if werr != nil {
 			log.Error().
 				Caller().
-				Err(err).
+				Err(werr).
 				Msg("Failed to write response")
 		}
 
-		return nil, false
+		return nil, false, err
 	}
 
 	if !machineKeyOK {
@@ -446,7 +453,7 @@ func (h *Headscale) validateMachineForOIDCCallback(
 				Msg("Failed to write response")
 		}
 
-		return nil, false
+		return nil, false, errOIDCMachineKeyMissing
 	}
 
 	// retrieve machine information if it exist
@@ -469,7 +476,7 @@ func (h *Headscale) validateMachineForOIDCCallback(
 				Msg("Failed to refresh machine")
 			http.Error(writer, "Failed to refresh machine", http.StatusInternalServerError)
 
-			return nil, false
+			return nil, true, err
 		}
 
 		var content bytes.Buffer
@@ -485,15 +492,15 @@ func (h *Headscale) validateMachineForOIDCCallback(
 
 			writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
 			writer.WriteHeader(http.StatusInternalServerError)
-			_, err := writer.Write([]byte("Could not render OIDC callback template"))
-			if err != nil {
+			_, werr := writer.Write([]byte("Could not render OIDC callback template"))
+			if werr != nil {
 				log.Error().
 					Caller().
-					Err(err).
+					Err(werr).
 					Msg("Failed to write response")
 			}
 
-			return nil, false
+			return nil, true, err
 		}
 
 		writer.Header().Set("Content-Type", "text/html; charset=utf-8")
@@ -506,17 +513,17 @@ func (h *Headscale) validateMachineForOIDCCallback(
 				Msg("Failed to write response")
 		}
 
-		return nil, false
+		return nil, true, nil
 	}
 
-	return &machineKey, true
+	return &machineKey, false, nil
 }
 
 func getNamespaceName(
 	writer http.ResponseWriter,
 	claims *IDTokenClaims,
 	stripEmaildomain bool,
-) (string, bool) {
+) (string, error) {
 	namespaceName, err := NormalizeToFQDNRules(
 		claims.Email,
 		stripEmaildomain,
@@ -525,24 +532,24 @@ func getNamespaceName(
 		log.Error().Err(err).Caller().Msgf("couldn't normalize email")
 		writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
 		writer.WriteHeader(http.StatusInternalServerError)
-		_, err := writer.Write([]byte("couldn't normalize email"))
-		if err != nil {
+		_, werr := writer.Write([]byte("couldn't normalize email"))
+		if werr != nil {
 			log.Error().
 				Caller().
-				Err(err).
+				Err(werr).
 				Msg("Failed to write response")
 		}
 
-		return "", false
+		return "", err
 	}
 
-	return namespaceName, true
+	return namespaceName, nil
 }
 
 func (h *Headscale) findOrCreateNewNamespaceForOIDCCallback(
 	writer http.ResponseWriter,
 	namespaceName string,
-) (*Namespace, bool) {
+) (*Namespace, error) {
 	namespace, err := h.GetNamespace(namespaceName)
 	if errors.Is(err, errNamespaceNotFound) {
 		namespace, err = h.CreateNamespace(namespaceName)
@@ -554,15 +561,15 @@ func (h *Headscale) findOrCreateNewNamespaceForOIDCCallback(
 				Msgf("could not create new namespace '%s'", namespaceName)
 			writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
 			writer.WriteHeader(http.StatusInternalServerError)
-			_, err := writer.Write([]byte("could not create namespace"))
-			if err != nil {
+			_, werr := writer.Write([]byte("could not create namespace"))
+			if werr != nil {
 				log.Error().
 					Caller().
-					Err(err).
+					Err(werr).
 					Msg("Failed to write response")
 			}
 
-			return nil, false
+			return nil, err
 		}
 	} else if err != nil {
 		log.Error().
@@ -572,25 +579,25 @@ func (h *Headscale) findOrCreateNewNamespaceForOIDCCallback(
 			Msg("could not find or create namespace")
 		writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
 		writer.WriteHeader(http.StatusInternalServerError)
-		_, err := writer.Write([]byte("could not find or create namespace"))
-		if err != nil {
+		_, werr := writer.Write([]byte("could not find or create namespace"))
+		if werr != nil {
 			log.Error().
 				Caller().
-				Err(err).
+				Err(werr).
 				Msg("Failed to write response")
 		}
 
-		return nil, false
+		return nil, err
 	}
 
-	return namespace, true
+	return namespace, nil
 }
 
 func (h *Headscale) registerMachineForOIDCCallback(
 	writer http.ResponseWriter,
 	namespace *Namespace,
 	machineKey *key.MachinePublic,
-) bool {
+) error {
 	machineKeyStr := MachinePublicKeyStripPrefix(*machineKey)
 
 	if _, err := h.RegisterMachineFromAuthCallback(
@@ -604,24 +611,24 @@ func (h *Headscale) registerMachineForOIDCCallback(
 			Msg("could not register machine")
 		writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
 		writer.WriteHeader(http.StatusInternalServerError)
-		_, err := writer.Write([]byte("could not register machine"))
-		if err != nil {
+		_, werr := writer.Write([]byte("could not register machine"))
+		if werr != nil {
 			log.Error().
 				Caller().
-				Err(err).
+				Err(werr).
 				Msg("Failed to write response")
 		}
 
-		return false
+		return err
 	}
 
-	return true
+	return nil
 }
 
 func renderOIDCCallbackTemplate(
 	writer http.ResponseWriter,
 	claims *IDTokenClaims,
-) (*bytes.Buffer, bool) {
+) (*bytes.Buffer, error) {
 	var content bytes.Buffer
 	if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
 		User: claims.Email,
@@ -635,16 +642,16 @@ func renderOIDCCallbackTemplate(
 
 		writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
 		writer.WriteHeader(http.StatusInternalServerError)
-		_, err := writer.Write([]byte("Could not render OIDC callback template"))
-		if err != nil {
+		_, werr := writer.Write([]byte("Could not render OIDC callback template"))
+		if werr != nil {
 			log.Error().
 				Caller().
-				Err(err).
+				Err(werr).
 				Msg("Failed to write response")
 		}
 
-		return nil, false
+		return nil, err
 	}
 
-	return &content, true
+	return &content, nil
 }