fix: comply with RFC6750 UserInfo endpoint requirements (#16592)

This commit is contained in:
Harshavardhana 2023-02-10 08:50:25 -08:00 committed by GitHub
parent 72daccd468
commit d65debb6bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 18 additions and 15 deletions

View File

@ -361,7 +361,7 @@ func (sts *stsAPIHandlers) AssumeRoleWithSSO(w http.ResponseWriter, r *http.Requ
}
// Validate JWT; check clientID in claims matches the one associated with the roleArn
if err := globalOpenIDConfig.Validate(roleArn, token, accessToken, r.Form.Get(stsDurationSeconds), claims); err != nil {
if err := globalOpenIDConfig.Validate(r.Context(), roleArn, token, accessToken, r.Form.Get(stsDurationSeconds), claims); err != nil {
switch err {
case openid.ErrTokenExpired:
switch action {

View File

@ -18,6 +18,7 @@
package openid
import (
"context"
"encoding/json"
"errors"
"fmt"
@ -140,7 +141,7 @@ const (
)
// Validate - validates the id_token.
func (r *Config) Validate(arn arn.ARN, token, accessToken, dsecs string, claims jwtgo.MapClaims) error {
func (r *Config) Validate(ctx context.Context, arn arn.ARN, token, accessToken, dsecs string, claims jwtgo.MapClaims) error {
jp := new(jwtgo.Parser)
jp.ValidMethods = []string{
"RS256", "RS384", "RS512",
@ -184,7 +185,7 @@ func (r *Config) Validate(arn arn.ARN, token, accessToken, dsecs string, claims
return err
}
if err = r.updateUserinfoClaims(arn, accessToken, claims); err != nil {
if err = r.updateUserinfoClaims(ctx, arn, accessToken, claims); err != nil {
return err
}
@ -223,7 +224,7 @@ func (r *Config) Validate(arn arn.ARN, token, accessToken, dsecs string, claims
return nil
}
func (r *Config) updateUserinfoClaims(arn arn.ARN, accessToken string, claims map[string]interface{}) error {
func (r *Config) updateUserinfoClaims(ctx context.Context, arn arn.ARN, accessToken string, claims map[string]interface{}) error {
pCfg, ok := r.arnProviderCfgsMap[arn]
// If claim user info is enabled, get claims from userInfo
// and overwrite them with the claims from JWT.
@ -231,7 +232,7 @@ func (r *Config) updateUserinfoClaims(arn arn.ARN, accessToken string, claims ma
if accessToken == "" {
return errors.New("access_token is mandatory if user_info claim is enabled")
}
uclaims, err := pCfg.UserInfo(accessToken, r.transport)
uclaims, err := pCfg.UserInfo(ctx, accessToken, r.transport)
if err != nil {
return err
}

View File

@ -19,6 +19,7 @@ package openid
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
@ -148,7 +149,7 @@ func TestJWTHMACType(t *testing.T) {
}
var claims jwtgo.MapClaims
if err = cfg.Validate(DummyRoleARN, token, "", "", claims); err != nil {
if err = cfg.Validate(context.Background(), DummyRoleARN, token, "", "", claims); err != nil {
t.Fatal(err)
}
}
@ -200,7 +201,7 @@ func TestJWT(t *testing.T) {
}
var claims jwtgo.MapClaims
if err = cfg.Validate(DummyRoleARN, u.Query().Get("Token"), "", "", claims); err == nil {
if err = cfg.Validate(context.Background(), DummyRoleARN, u.Query().Get("Token"), "", "", claims); err == nil {
t.Fatal(err)
}
}

View File

@ -18,6 +18,7 @@
package openid
import (
"context"
"encoding/json"
"errors"
"fmt"
@ -108,23 +109,25 @@ func (p *providerCfg) GetRoleArn() string {
// claims as part of the normal oauth2 flow, instead rely
// on service providers making calls to IDP to fetch additional
// claims available from the UserInfo endpoint
func (p *providerCfg) UserInfo(accessToken string, transport http.RoundTripper) (map[string]interface{}, error) {
func (p *providerCfg) UserInfo(ctx context.Context, accessToken string, transport http.RoundTripper) (map[string]interface{}, error) {
if p.JWKS.URL == nil || p.JWKS.URL.String() == "" {
return nil, errors.New("openid not configured")
}
client := &http.Client{
Transport: transport,
}
req, err := http.NewRequest(http.MethodPost, p.DiscoveryDoc.UserInfoEndpoint, nil)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.DiscoveryDoc.UserInfoEndpoint, nil)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
if accessToken != "" {
req.Header.Set("Authorization", "Bearer "+accessToken)
}
client := &http.Client{
Transport: transport,
}
resp, err := client.Do(req)
if err != nil {
return nil, err
@ -140,10 +143,8 @@ func (p *providerCfg) UserInfo(accessToken string, transport http.RoundTripper)
return nil, errors.New(resp.Status)
}
dec := json.NewDecoder(resp.Body)
claims := map[string]interface{}{}
if err = dec.Decode(&claims); err != nil {
if err = json.NewDecoder(resp.Body).Decode(&claims); err != nil {
// uncomment this for debugging when needed.
// reqBytes, _ := httputil.DumpRequest(req, false)
// fmt.Println(string(reqBytes))