diff --git a/cmd/sts-handlers.go b/cmd/sts-handlers.go index fbda9f08a..b7eee8d46 100644 --- a/cmd/sts-handlers.go +++ b/cmd/sts-handlers.go @@ -361,7 +361,7 @@ func (sts *stsAPIHandlers) AssumeRoleWithSSO(w http.ResponseWriter, r *http.Requ } // Validate JWT; check clientID in claims matches the one associated with the roleArn - if err := globalOpenIDConfig.Validate(roleArn, token, accessToken, r.Form.Get(stsDurationSeconds), claims); err != nil { + if err := globalOpenIDConfig.Validate(r.Context(), roleArn, token, accessToken, r.Form.Get(stsDurationSeconds), claims); err != nil { switch err { case openid.ErrTokenExpired: switch action { diff --git a/internal/config/identity/openid/jwt.go b/internal/config/identity/openid/jwt.go index 0bebed12a..97444f215 100644 --- a/internal/config/identity/openid/jwt.go +++ b/internal/config/identity/openid/jwt.go @@ -18,6 +18,7 @@ package openid import ( + "context" "encoding/json" "errors" "fmt" @@ -140,7 +141,7 @@ const ( ) // Validate - validates the id_token. -func (r *Config) Validate(arn arn.ARN, token, accessToken, dsecs string, claims jwtgo.MapClaims) error { +func (r *Config) Validate(ctx context.Context, arn arn.ARN, token, accessToken, dsecs string, claims jwtgo.MapClaims) error { jp := new(jwtgo.Parser) jp.ValidMethods = []string{ "RS256", "RS384", "RS512", @@ -184,7 +185,7 @@ func (r *Config) Validate(arn arn.ARN, token, accessToken, dsecs string, claims return err } - if err = r.updateUserinfoClaims(arn, accessToken, claims); err != nil { + if err = r.updateUserinfoClaims(ctx, arn, accessToken, claims); err != nil { return err } @@ -223,7 +224,7 @@ func (r *Config) Validate(arn arn.ARN, token, accessToken, dsecs string, claims return nil } -func (r *Config) updateUserinfoClaims(arn arn.ARN, accessToken string, claims map[string]interface{}) error { +func (r *Config) updateUserinfoClaims(ctx context.Context, arn arn.ARN, accessToken string, claims map[string]interface{}) error { pCfg, ok := r.arnProviderCfgsMap[arn] // If claim user info is enabled, get claims from userInfo // and overwrite them with the claims from JWT. @@ -231,7 +232,7 @@ func (r *Config) updateUserinfoClaims(arn arn.ARN, accessToken string, claims ma if accessToken == "" { return errors.New("access_token is mandatory if user_info claim is enabled") } - uclaims, err := pCfg.UserInfo(accessToken, r.transport) + uclaims, err := pCfg.UserInfo(ctx, accessToken, r.transport) if err != nil { return err } diff --git a/internal/config/identity/openid/jwt_test.go b/internal/config/identity/openid/jwt_test.go index 752cf1ea7..e7d41a625 100644 --- a/internal/config/identity/openid/jwt_test.go +++ b/internal/config/identity/openid/jwt_test.go @@ -19,6 +19,7 @@ package openid import ( "bytes" + "context" "encoding/base64" "encoding/json" "fmt" @@ -148,7 +149,7 @@ func TestJWTHMACType(t *testing.T) { } var claims jwtgo.MapClaims - if err = cfg.Validate(DummyRoleARN, token, "", "", claims); err != nil { + if err = cfg.Validate(context.Background(), DummyRoleARN, token, "", "", claims); err != nil { t.Fatal(err) } } @@ -200,7 +201,7 @@ func TestJWT(t *testing.T) { } var claims jwtgo.MapClaims - if err = cfg.Validate(DummyRoleARN, u.Query().Get("Token"), "", "", claims); err == nil { + if err = cfg.Validate(context.Background(), DummyRoleARN, u.Query().Get("Token"), "", "", claims); err == nil { t.Fatal(err) } } diff --git a/internal/config/identity/openid/providercfg.go b/internal/config/identity/openid/providercfg.go index 91296e484..968af6ed2 100644 --- a/internal/config/identity/openid/providercfg.go +++ b/internal/config/identity/openid/providercfg.go @@ -18,6 +18,7 @@ package openid import ( + "context" "encoding/json" "errors" "fmt" @@ -108,23 +109,25 @@ func (p *providerCfg) GetRoleArn() string { // claims as part of the normal oauth2 flow, instead rely // on service providers making calls to IDP to fetch additional // claims available from the UserInfo endpoint -func (p *providerCfg) UserInfo(accessToken string, transport http.RoundTripper) (map[string]interface{}, error) { +func (p *providerCfg) UserInfo(ctx context.Context, accessToken string, transport http.RoundTripper) (map[string]interface{}, error) { if p.JWKS.URL == nil || p.JWKS.URL.String() == "" { return nil, errors.New("openid not configured") } - client := &http.Client{ - Transport: transport, - } - req, err := http.NewRequest(http.MethodPost, p.DiscoveryDoc.UserInfoEndpoint, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.DiscoveryDoc.UserInfoEndpoint, nil) if err != nil { return nil, err } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") if accessToken != "" { req.Header.Set("Authorization", "Bearer "+accessToken) } + client := &http.Client{ + Transport: transport, + } + resp, err := client.Do(req) if err != nil { return nil, err @@ -140,10 +143,8 @@ func (p *providerCfg) UserInfo(accessToken string, transport http.RoundTripper) return nil, errors.New(resp.Status) } - dec := json.NewDecoder(resp.Body) claims := map[string]interface{}{} - - if err = dec.Decode(&claims); err != nil { + if err = json.NewDecoder(resp.Body).Decode(&claims); err != nil { // uncomment this for debugging when needed. // reqBytes, _ := httputil.DumpRequest(req, false) // fmt.Println(string(reqBytes))