change service account embedded policy size limit (#19840)

Bonus: trim-off all the unnecessary spaces to allow
for real 2048 characters in policies for STS handlers
and re-use the code in all STS handlers.
This commit is contained in:
Harshavardhana 2024-05-30 11:10:41 -07:00 committed by GitHub
parent 4af31e654b
commit 8f93e81afb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 67 additions and 82 deletions

View File

@ -2371,7 +2371,7 @@ func (store *IAMStoreSys) UpdateServiceAccount(ctx context.Context, accessKey st
return updatedAt, err return updatedAt, err
} }
if len(policyBuf) > 2048 { if len(policyBuf) > maxSVCSessionPolicySize {
return updatedAt, errSessionPolicyTooLarge return updatedAt, errSessionPolicyTooLarge
} }

View File

@ -78,6 +78,10 @@ const (
inheritedPolicyType = "inherited-policy" inheritedPolicyType = "inherited-policy"
) )
const (
maxSVCSessionPolicySize = 4096
)
// IAMSys - config system. // IAMSys - config system.
type IAMSys struct { type IAMSys struct {
// Need to keep them here to keep alignment - ref: https://golang.org/pkg/sync/atomic/#pkg-note-BUG // Need to keep them here to keep alignment - ref: https://golang.org/pkg/sync/atomic/#pkg-note-BUG
@ -977,7 +981,7 @@ func (sys *IAMSys) NewServiceAccount(ctx context.Context, parentUser string, gro
if err != nil { if err != nil {
return auth.Credentials{}, time.Time{}, err return auth.Credentials{}, time.Time{}, err
} }
if len(policyBuf) > 2048 { if len(policyBuf) > maxSVCSessionPolicySize {
return auth.Credentials{}, time.Time{}, errSessionPolicyTooLarge return auth.Credentials{}, time.Time{}, errSessionPolicyTooLarge
} }
} }

View File

@ -22,9 +22,11 @@ import (
"context" "context"
"crypto/x509" "crypto/x509"
"encoding/base64" "encoding/base64"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"net/url"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -82,8 +84,50 @@ const (
// Role Claim key // Role Claim key
roleArnClaim = "roleArn" roleArnClaim = "roleArn"
// maximum supported STS session policy size
maxSTSSessionPolicySize = 2048
) )
type stsClaims map[string]interface{}
func (c stsClaims) populateSessionPolicy(form url.Values) error {
if len(form) == 0 {
return nil
}
sessionPolicyStr := form.Get(stsPolicy)
if len(sessionPolicyStr) == 0 {
return nil
}
sessionPolicy, err := policy.ParseConfig(bytes.NewReader([]byte(sessionPolicyStr)))
if err != nil {
return err
}
// Version in policy must not be empty
if sessionPolicy.Version == "" {
return errors.New("Version cannot be empty expecting '2012-10-17'")
}
policyBuf, err := json.Marshal(sessionPolicy)
if err != nil {
return err
}
// https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
// https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRole.html
// The plain text that you use for both inline and managed session
// policies shouldn't exceed maxSTSSessionPolicySize characters.
if len(policyBuf) > maxSTSSessionPolicySize {
return errSessionPolicyTooLarge
}
c[policy.SessionPolicyName] = base64.StdEncoding.EncodeToString(policyBuf)
return nil
}
// stsAPIHandlers implements and provides http handlers for AWS STS API. // stsAPIHandlers implements and provides http handlers for AWS STS API.
type stsAPIHandlers struct{} type stsAPIHandlers struct{}
@ -212,7 +256,7 @@ func getTokenSigningKey() (string, error) {
func (sts *stsAPIHandlers) AssumeRole(w http.ResponseWriter, r *http.Request) { func (sts *stsAPIHandlers) AssumeRole(w http.ResponseWriter, r *http.Request) {
ctx := newContext(r, w, "AssumeRole") ctx := newContext(r, w, "AssumeRole")
claims := make(map[string]interface{}) claims := stsClaims{}
defer logger.AuditLog(ctx, w, r, claims) defer logger.AuditLog(ctx, w, r, claims)
// Check auth here (otherwise r.Form will have unexpected values from // Check auth here (otherwise r.Form will have unexpected values from
@ -249,29 +293,11 @@ func (sts *stsAPIHandlers) AssumeRole(w http.ResponseWriter, r *http.Request) {
return return
} }
sessionPolicyStr := r.Form.Get(stsPolicy) if err := claims.populateSessionPolicy(r.Form); err != nil {
// https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRole.html writeSTSErrorResponse(ctx, w, ErrSTSInvalidParameterValue, err)
// The plain text that you use for both inline and managed session
// policies shouldn't exceed 2048 characters.
if len(sessionPolicyStr) > 2048 {
writeSTSErrorResponse(ctx, w, ErrSTSInvalidParameterValue, errSessionPolicyTooLarge)
return return
} }
if len(sessionPolicyStr) > 0 {
sessionPolicy, err := policy.ParseConfig(bytes.NewReader([]byte(sessionPolicyStr)))
if err != nil {
writeSTSErrorResponse(ctx, w, ErrSTSInvalidParameterValue, err)
return
}
// Version in policy must not be empty
if sessionPolicy.Version == "" {
writeSTSErrorResponse(ctx, w, ErrSTSInvalidParameterValue, fmt.Errorf("Version cannot be empty expecting '2012-10-17'"))
return
}
}
duration, err := openid.GetDefaultExpiration(r.Form.Get(stsDurationSeconds)) duration, err := openid.GetDefaultExpiration(r.Form.Get(stsDurationSeconds))
if err != nil { if err != nil {
writeSTSErrorResponse(ctx, w, ErrSTSInvalidParameterValue, err) writeSTSErrorResponse(ctx, w, ErrSTSInvalidParameterValue, err)
@ -288,10 +314,6 @@ func (sts *stsAPIHandlers) AssumeRole(w http.ResponseWriter, r *http.Request) {
return return
} }
if len(sessionPolicyStr) > 0 {
claims[policy.SessionPolicyName] = base64.StdEncoding.EncodeToString([]byte(sessionPolicyStr))
}
secret, err := getTokenSigningKey() secret, err := getTokenSigningKey()
if err != nil { if err != nil {
writeSTSErrorResponse(ctx, w, ErrSTSInternalError, err) writeSTSErrorResponse(ctx, w, ErrSTSInternalError, err)
@ -342,7 +364,7 @@ func (sts *stsAPIHandlers) AssumeRole(w http.ResponseWriter, r *http.Request) {
func (sts *stsAPIHandlers) AssumeRoleWithSSO(w http.ResponseWriter, r *http.Request) { func (sts *stsAPIHandlers) AssumeRoleWithSSO(w http.ResponseWriter, r *http.Request) {
ctx := newContext(r, w, "AssumeRoleSSOCommon") ctx := newContext(r, w, "AssumeRoleSSOCommon")
claims := make(map[string]interface{}) claims := stsClaims{}
defer logger.AuditLog(ctx, w, r, claims) defer logger.AuditLog(ctx, w, r, claims)
// Parse the incoming form data. // Parse the incoming form data.
@ -449,31 +471,11 @@ func (sts *stsAPIHandlers) AssumeRoleWithSSO(w http.ResponseWriter, r *http.Requ
claims[iamPolicyClaimNameOpenID()] = policyName claims[iamPolicyClaimNameOpenID()] = policyName
} }
sessionPolicyStr := r.Form.Get(stsPolicy) if err := claims.populateSessionPolicy(r.Form); err != nil {
// https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html writeSTSErrorResponse(ctx, w, ErrSTSInvalidParameterValue, err)
// The plain text that you use for both inline and managed session
// policies shouldn't exceed 2048 characters.
if len(sessionPolicyStr) > 2048 {
writeSTSErrorResponse(ctx, w, ErrSTSInvalidParameterValue, fmt.Errorf("Session policy should not exceed 2048 characters"))
return return
} }
if len(sessionPolicyStr) > 0 {
sessionPolicy, err := policy.ParseConfig(bytes.NewReader([]byte(sessionPolicyStr)))
if err != nil {
writeSTSErrorResponse(ctx, w, ErrSTSInvalidParameterValue, err)
return
}
// Version in policy must not be empty
if sessionPolicy.Version == "" {
writeSTSErrorResponse(ctx, w, ErrSTSInvalidParameterValue, fmt.Errorf("Invalid session policy version"))
return
}
claims[policy.SessionPolicyName] = base64.StdEncoding.EncodeToString([]byte(sessionPolicyStr))
}
secret, err := getTokenSigningKey() secret, err := getTokenSigningKey()
if err != nil { if err != nil {
writeSTSErrorResponse(ctx, w, ErrSTSInternalError, err) writeSTSErrorResponse(ctx, w, ErrSTSInternalError, err)
@ -612,7 +614,7 @@ func (sts *stsAPIHandlers) AssumeRoleWithClientGrants(w http.ResponseWriter, r *
func (sts *stsAPIHandlers) AssumeRoleWithLDAPIdentity(w http.ResponseWriter, r *http.Request) { func (sts *stsAPIHandlers) AssumeRoleWithLDAPIdentity(w http.ResponseWriter, r *http.Request) {
ctx := newContext(r, w, "AssumeRoleWithLDAPIdentity") ctx := newContext(r, w, "AssumeRoleWithLDAPIdentity")
claims := make(map[string]interface{}) claims := stsClaims{}
defer logger.AuditLog(ctx, w, r, claims, stsLDAPPassword) defer logger.AuditLog(ctx, w, r, claims, stsLDAPPassword)
// Parse the incoming form data. // Parse the incoming form data.
@ -643,29 +645,11 @@ func (sts *stsAPIHandlers) AssumeRoleWithLDAPIdentity(w http.ResponseWriter, r *
return return
} }
sessionPolicyStr := r.Form.Get(stsPolicy) if err := claims.populateSessionPolicy(r.Form); err != nil {
// https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRole.html writeSTSErrorResponse(ctx, w, ErrSTSInvalidParameterValue, err)
// The plain text that you use for both inline and managed session
// policies shouldn't exceed 2048 characters.
if len(sessionPolicyStr) > 2048 {
writeSTSErrorResponse(ctx, w, ErrSTSInvalidParameterValue, fmt.Errorf("Session policy should not exceed 2048 characters"))
return return
} }
if len(sessionPolicyStr) > 0 {
sessionPolicy, err := policy.ParseConfig(bytes.NewReader([]byte(sessionPolicyStr)))
if err != nil {
writeSTSErrorResponse(ctx, w, ErrSTSInvalidParameterValue, err)
return
}
// Version in policy must not be empty
if sessionPolicy.Version == "" {
writeSTSErrorResponse(ctx, w, ErrSTSInvalidParameterValue, fmt.Errorf("Version needs to be specified in session policy"))
return
}
}
if !globalIAMSys.Initialized() { if !globalIAMSys.Initialized() {
writeSTSErrorResponse(ctx, w, ErrSTSIAMNotInitialized, errIAMNotInitialized) writeSTSErrorResponse(ctx, w, ErrSTSIAMNotInitialized, errIAMNotInitialized)
return return
@ -708,10 +692,6 @@ func (sts *stsAPIHandlers) AssumeRoleWithLDAPIdentity(w http.ResponseWriter, r *
claims[ldapAttribPrefix+attrib] = value claims[ldapAttribPrefix+attrib] = value
} }
if len(sessionPolicyStr) > 0 {
claims[policy.SessionPolicyName] = base64.StdEncoding.EncodeToString([]byte(sessionPolicyStr))
}
secret, err := getTokenSigningKey() secret, err := getTokenSigningKey()
if err != nil { if err != nil {
writeSTSErrorResponse(ctx, w, ErrSTSInternalError, err) writeSTSErrorResponse(ctx, w, ErrSTSInternalError, err)

View File

@ -133,7 +133,7 @@ const (
) )
// Validate - validates the id_token. // Validate - validates the id_token.
func (r *Config) Validate(ctx context.Context, arn arn.ARN, token, accessToken, dsecs string, claims jwtgo.MapClaims) error { func (r *Config) Validate(ctx context.Context, arn arn.ARN, token, accessToken, dsecs string, claims map[string]interface{}) error {
jp := new(jwtgo.Parser) jp := new(jwtgo.Parser)
jp.ValidMethods = []string{ jp.ValidMethods = []string{
"RS256", "RS384", "RS512", "RS256", "RS384", "RS512",
@ -156,14 +156,15 @@ func (r *Config) Validate(ctx context.Context, arn arn.ARN, token, accessToken,
return fmt.Errorf("Role %s does not exist", arn) return fmt.Errorf("Role %s does not exist", arn)
} }
jwtToken, err := jp.ParseWithClaims(token, &claims, keyFuncCallback) mclaims := jwtgo.MapClaims(claims)
jwtToken, err := jp.ParseWithClaims(token, &mclaims, keyFuncCallback)
if err != nil { if err != nil {
// Re-populate the public key in-case the JWKS // Re-populate the public key in-case the JWKS
// pubkeys are refreshed // pubkeys are refreshed
if err = r.PopulatePublicKey(arn); err != nil { if err = r.PopulatePublicKey(arn); err != nil {
return err return err
} }
jwtToken, err = jwtgo.ParseWithClaims(token, &claims, keyFuncCallback) jwtToken, err = jwtgo.ParseWithClaims(token, &mclaims, keyFuncCallback)
if err != nil { if err != nil {
return err return err
} }
@ -173,11 +174,11 @@ func (r *Config) Validate(ctx context.Context, arn arn.ARN, token, accessToken,
return ErrTokenExpired return ErrTokenExpired
} }
if err = updateClaimsExpiry(dsecs, claims); err != nil { if err = updateClaimsExpiry(dsecs, mclaims); err != nil {
return err return err
} }
if err = r.updateUserinfoClaims(ctx, arn, accessToken, claims); err != nil { if err = r.updateUserinfoClaims(ctx, arn, accessToken, mclaims); err != nil {
return err return err
} }
@ -190,7 +191,7 @@ func (r *Config) Validate(ctx context.Context, arn arn.ARN, token, accessToken,
// array of case sensitive strings. In the common special case // array of case sensitive strings. In the common special case
// when there is one audience, the aud value MAY be a single // when there is one audience, the aud value MAY be a single
// case sensitive // case sensitive
audValues, ok := policy.GetValuesFromClaims(claims, audClaim) audValues, ok := policy.GetValuesFromClaims(mclaims, audClaim)
if !ok { if !ok {
return errors.New("STS JWT Token has `aud` claim invalid, `aud` must match configured OpenID Client ID") return errors.New("STS JWT Token has `aud` claim invalid, `aud` must match configured OpenID Client ID")
} }
@ -204,7 +205,7 @@ func (r *Config) Validate(ctx context.Context, arn arn.ARN, token, accessToken,
// be included even when the authorized party is the same // be included even when the authorized party is the same
// as the sole audience. The azp value is a case sensitive // as the sole audience. The azp value is a case sensitive
// string containing a StringOrURI value // string containing a StringOrURI value
azpValues, ok := policy.GetValuesFromClaims(claims, azpClaim) azpValues, ok := policy.GetValuesFromClaims(mclaims, azpClaim)
if !ok { if !ok {
return errors.New("STS JWT Token has `azp` claim invalid, `azp` must match configured OpenID Client ID") return errors.New("STS JWT Token has `azp` claim invalid, `azp` must match configured OpenID Client ID")
} }