mirror of
https://github.com/juanfont/headscale.git
synced 2025-01-11 20:23:18 -05:00
feat: add tampered request test for pkce feature
This commit is contained in:
parent
f356d08ec9
commit
360d1afe19
@ -13,6 +13,7 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -34,9 +35,13 @@ const (
|
|||||||
dockerContextPath = "../."
|
dockerContextPath = "../."
|
||||||
hsicOIDCMockHashLength = 6
|
hsicOIDCMockHashLength = 6
|
||||||
defaultAccessTTL = 10 * time.Minute
|
defaultAccessTTL = 10 * time.Minute
|
||||||
|
nodeStateRunning = "Running"
|
||||||
)
|
)
|
||||||
|
|
||||||
var errStatusCodeNotOK = errors.New("status code not OK")
|
var (
|
||||||
|
errStatusCodeNotOK = errors.New("status code not OK")
|
||||||
|
ErrOIDCClientCount = errors.New("client count must be 1 for OIDC scenario")
|
||||||
|
)
|
||||||
|
|
||||||
type AuthOIDCScenario struct {
|
type AuthOIDCScenario struct {
|
||||||
*Scenario
|
*Scenario
|
||||||
@ -617,12 +622,128 @@ func TestOIDCAuthenticationWithPKCE(t *testing.T) {
|
|||||||
for _, client := range allClients {
|
for _, client := range allClients {
|
||||||
status, err := client.Status()
|
status, err := client.Status()
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
if status.BackendState != "Running" {
|
if status.BackendState != nodeStateRunning {
|
||||||
t.Errorf("client %s is not running: %s", client.Hostname(), status.BackendState)
|
t.Errorf("client %s is not running: %s", client.Hostname(), status.BackendState)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type tamperVerifierTransport struct {
|
||||||
|
base http.RoundTripper
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tamperVerifierTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
log.Printf("RoundTrip: %s %s", req.Method, req.URL.String())
|
||||||
|
|
||||||
|
// For POST requests, tamper with form data
|
||||||
|
if req.Method == http.MethodPost {
|
||||||
|
log.Printf("Processing POST request")
|
||||||
|
err := req.ParseForm()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error parsing form: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if verifier := req.Form.Get("code_challenge"); verifier != "" {
|
||||||
|
log.Printf("Found POST verifier: %s", verifier)
|
||||||
|
// Tamper with the verifier
|
||||||
|
req.Form.Set("code_challenge", verifier+"_tampered")
|
||||||
|
log.Printf("Modified POST verifier to: %s", req.Form.Get("code_challenge"))
|
||||||
|
// Update request body with modified form
|
||||||
|
req.Body = io.NopCloser(strings.NewReader(req.Form.Encode()))
|
||||||
|
req.ContentLength = int64(len(req.Form.Encode()))
|
||||||
|
} else {
|
||||||
|
log.Printf("No code_challenge found in POST form data")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// For GET requests, tamper with URL query parameters
|
||||||
|
if req.Method == http.MethodGet {
|
||||||
|
log.Printf("Processing GET request")
|
||||||
|
q := req.URL.Query()
|
||||||
|
if verifier := q.Get("code_challenge"); verifier != "" {
|
||||||
|
log.Printf("Found GET verifier: %s", verifier)
|
||||||
|
q.Set("code_challenge", verifier+"_tampered")
|
||||||
|
req.URL.RawQuery = q.Encode()
|
||||||
|
log.Printf("Modified URL to: %s", req.URL.String())
|
||||||
|
} else {
|
||||||
|
log.Printf("No code_challenge found in GET query params")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward the request with the tampered verifier
|
||||||
|
resp, err := t.base.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("RoundTrip error: %v", err)
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
log.Printf("Response status: %s", resp.Status)
|
||||||
|
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOIDCAuthenticationWithPKCEVerifierTampering(t *testing.T) {
|
||||||
|
IntegrationSkip(t)
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
baseScenario, err := NewScenario(dockertestMaxWait())
|
||||||
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
scenario := AuthOIDCScenario{
|
||||||
|
Scenario: baseScenario,
|
||||||
|
}
|
||||||
|
defer scenario.ShutdownAssertNoPanics(t)
|
||||||
|
|
||||||
|
// Single user with one node for testing PKCE flow
|
||||||
|
spec := map[string]int{
|
||||||
|
"user1": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
mockusers := []mockoidc.MockUser{
|
||||||
|
oidcMockUser("user1", true),
|
||||||
|
}
|
||||||
|
|
||||||
|
oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers)
|
||||||
|
assertNoErrf(t, "failed to run mock OIDC server: %s", err)
|
||||||
|
defer scenario.mockOIDC.Close()
|
||||||
|
|
||||||
|
oidcMap := map[string]string{
|
||||||
|
"HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer,
|
||||||
|
"HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID,
|
||||||
|
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
|
||||||
|
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
|
||||||
|
"HEADSCALE_OIDC_PKCE_ENABLED": "1", // Enable PKCE
|
||||||
|
"HEADSCALE_OIDC_MAP_LEGACY_USERS": "0",
|
||||||
|
"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "0",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a transport that modifies the PKCE verifier in transit
|
||||||
|
baseTransport := &http.Transport{
|
||||||
|
// #nosec G402 -- This is a test-only code using mock OIDC server with self-signed certificates
|
||||||
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||||
|
}
|
||||||
|
tamperTransport := &tamperVerifierTransport{
|
||||||
|
base: baseTransport,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = scenario.CreateHeadscaleEnvWithHTTPModifier(
|
||||||
|
spec,
|
||||||
|
func(cli *http.Client) {
|
||||||
|
cli.Transport = tamperTransport
|
||||||
|
},
|
||||||
|
hsic.WithTestName("oidcauthpkce"),
|
||||||
|
hsic.WithConfigEnv(oidcMap),
|
||||||
|
hsic.WithTLS(),
|
||||||
|
hsic.WithHostnameAsServerURL(),
|
||||||
|
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(oidcConfig.ClientSecret)),
|
||||||
|
)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected authentication to fail due to PKCE verifier tampering, but it succeeded")
|
||||||
|
} else {
|
||||||
|
log.Printf("auth got error: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *AuthOIDCScenario) CreateHeadscaleEnv(
|
func (s *AuthOIDCScenario) CreateHeadscaleEnv(
|
||||||
users map[string]int,
|
users map[string]int,
|
||||||
opts ...hsic.Option,
|
opts ...hsic.Option,
|
||||||
@ -643,7 +764,7 @@ func (s *AuthOIDCScenario) CreateHeadscaleEnv(
|
|||||||
// This is because the MockOIDC server can only serve login
|
// This is because the MockOIDC server can only serve login
|
||||||
// requests based on a queue it has been given on startup.
|
// requests based on a queue it has been given on startup.
|
||||||
// We currently only populates it with one login request per user.
|
// We currently only populates it with one login request per user.
|
||||||
return fmt.Errorf("client count must be 1 for OIDC scenario.")
|
return ErrOIDCClientCount
|
||||||
}
|
}
|
||||||
log.Printf("creating user %s with %d clients", userName, clientCount)
|
log.Printf("creating user %s with %d clients", userName, clientCount)
|
||||||
err = s.CreateUser(userName)
|
err = s.CreateUser(userName)
|
||||||
@ -665,6 +786,49 @@ func (s *AuthOIDCScenario) CreateHeadscaleEnv(
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *AuthOIDCScenario) CreateHeadscaleEnvWithHTTPModifier(
|
||||||
|
users map[string]int,
|
||||||
|
httpModifier func(*http.Client),
|
||||||
|
opts ...hsic.Option,
|
||||||
|
) error {
|
||||||
|
headscale, err := s.Headscale(opts...)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = headscale.WaitForRunning()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for userName, clientCount := range users {
|
||||||
|
if clientCount != 1 {
|
||||||
|
// OIDC scenario only supports one client per user.
|
||||||
|
// This is because the MockOIDC server can only serve login
|
||||||
|
// requests based on a queue it has been given on startup.
|
||||||
|
// We currently only populates it with one login request per user.
|
||||||
|
return ErrOIDCClientCount
|
||||||
|
}
|
||||||
|
log.Printf("creating user %s with %d clients", userName, clientCount)
|
||||||
|
err = s.CreateUser(userName)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.CreateTailscaleNodesInUser(userName, "all", clientCount)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.runTailscaleUpWithModifier(userName, headscale.GetEndpoint(), httpModifier)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUser) (*types.OIDCConfig, error) {
|
func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUser) (*types.OIDCConfig, error) {
|
||||||
port, err := dockertestutil.RandomFreeHostPort()
|
port, err := dockertestutil.RandomFreeHostPort()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -774,7 +938,7 @@ func (s *AuthOIDCScenario) runTailscaleUp(
|
|||||||
log.Printf("%s failed to run tailscale up: %s", c.Hostname(), err)
|
log.Printf("%s failed to run tailscale up: %s", c.Hostname(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
loginURL.Host = fmt.Sprintf("%s:8080", headscale.GetIP())
|
loginURL.Host = headscale.GetIP() + ":8080"
|
||||||
loginURL.Scheme = "http"
|
loginURL.Scheme = "http"
|
||||||
|
|
||||||
if len(headscale.GetCert()) > 0 {
|
if len(headscale.GetCert()) > 0 {
|
||||||
@ -782,6 +946,7 @@ func (s *AuthOIDCScenario) runTailscaleUp(
|
|||||||
}
|
}
|
||||||
|
|
||||||
insecureTransport := &http.Transport{
|
insecureTransport := &http.Transport{
|
||||||
|
// #nosec G402 -- This is a test-only code using mock OIDC server with self-signed certificates
|
||||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -848,6 +1013,98 @@ func (s *AuthOIDCScenario) runTailscaleUp(
|
|||||||
return fmt.Errorf("failed to up tailscale node: %w", errNoUserAvailable)
|
return fmt.Errorf("failed to up tailscale node: %w", errNoUserAvailable)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *AuthOIDCScenario) runTailscaleUpWithModifier(
|
||||||
|
userStr string,
|
||||||
|
loginServer string,
|
||||||
|
httpClientModifier func(*http.Client),
|
||||||
|
) error {
|
||||||
|
headscale, err := s.Headscale()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("running tailscale up for user %s", userStr)
|
||||||
|
if user, ok := s.users[userStr]; ok {
|
||||||
|
for _, client := range user.Clients {
|
||||||
|
c := client
|
||||||
|
err := func() error {
|
||||||
|
status, err := c.Status()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("%s failed to get status: %s", c.Hostname(), err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if status.BackendState == nodeStateRunning {
|
||||||
|
log.Printf("%s is already running", c.Hostname())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("%s running tailscale up", c.Hostname())
|
||||||
|
|
||||||
|
loginURL, err := c.LoginWithURL(loginServer)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("%s failed to run tailscale up: %s", c.Hostname(), err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
loginURL.Host = headscale.GetIP() + ":8080"
|
||||||
|
loginURL.Scheme = "http"
|
||||||
|
|
||||||
|
if len(headscale.GetCert()) > 0 {
|
||||||
|
loginURL.Scheme = "https"
|
||||||
|
}
|
||||||
|
|
||||||
|
insecureTransport := &http.Transport{
|
||||||
|
// #nosec G402 -- This is a test-only code using mock OIDC server with self-signed certificates
|
||||||
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("%s login url: %s\n", c.Hostname(), loginURL.String())
|
||||||
|
|
||||||
|
log.Printf("%s logging in with url", c.Hostname())
|
||||||
|
httpClient := &http.Client{Transport: insecureTransport}
|
||||||
|
|
||||||
|
// Allow the test to modify the HTTP client
|
||||||
|
if httpClientModifier != nil {
|
||||||
|
httpClientModifier(httpClient)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil)
|
||||||
|
resp, err := httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf(
|
||||||
|
"%s failed to login using url %s: %s",
|
||||||
|
c.Hostname(),
|
||||||
|
loginURL,
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
log.Printf("%s response code of oidc login request was %s", c.Hostname(), resp.Status)
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
log.Printf("body: %s", body)
|
||||||
|
|
||||||
|
return errStatusCodeNotOK
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("failed to up tailscale node: %w", errNoUserAvailable)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *AuthOIDCScenario) Shutdown() {
|
func (s *AuthOIDCScenario) Shutdown() {
|
||||||
err := s.pool.Purge(s.mockOIDC)
|
err := s.pool.Purge(s.mockOIDC)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
Loading…
Reference in New Issue
Block a user