Add proper requestID for STS errors (#7245)

This commit is contained in:
Harshavardhana 2019-02-14 17:54:33 -08:00 committed by kannappanr
parent 396d78352d
commit 8f62935448
2 changed files with 28 additions and 24 deletions

View File

@ -22,12 +22,11 @@ import (
) )
// writeSTSErrorRespone writes error headers // writeSTSErrorRespone writes error headers
func writeSTSErrorResponse(w http.ResponseWriter, errorCode STSErrorCode) { func writeSTSErrorResponse(w http.ResponseWriter, err STSError) {
stsError := getSTSError(errorCode)
// Generate error response. // Generate error response.
stsErrorResponse := getSTSErrorResponse(stsError) stsErrorResponse := getSTSErrorResponse(err, w.Header().Get(responseRequestIDKey))
encodedErrorResponse := encodeResponse(stsErrorResponse) encodedErrorResponse := encodeResponse(stsErrorResponse)
writeResponse(w, stsError.HTTPStatusCode, encodedErrorResponse, mimeXML) writeResponse(w, err.HTTPStatusCode, encodedErrorResponse, mimeXML)
} }
// STSError structure // STSError structure
@ -64,9 +63,19 @@ const (
ErrSTSInternalError ErrSTSInternalError
) )
type stsErrorCodeMap map[STSErrorCode]STSError
func (e stsErrorCodeMap) ToSTSErr(errCode STSErrorCode) STSError {
apiErr, ok := e[errCode]
if !ok {
return e[ErrSTSInternalError]
}
return apiErr
}
// error code to STSError structure, these fields carry respective // error code to STSError structure, these fields carry respective
// descriptions for all the error responses. // descriptions for all the error responses.
var stsErrCodeResponse = map[STSErrorCode]STSError{ var stsErrCodes = stsErrorCodeMap{
ErrSTSMissingParameter: { ErrSTSMissingParameter: {
Code: "MissingParameter", Code: "MissingParameter",
Description: "A required parameter for the specified action is not supplied.", Description: "A required parameter for the specified action is not supplied.",
@ -109,17 +118,12 @@ var stsErrCodeResponse = map[STSErrorCode]STSError{
}, },
} }
// getSTSError provides STS Error for input STS error code. // getSTSErrorResponse gets in standard error and
func getSTSError(code STSErrorCode) STSError {
return stsErrCodeResponse[code]
}
// getErrorResponse gets in standard error and resource value and
// provides a encodable populated response values // provides a encodable populated response values
func getSTSErrorResponse(err STSError) STSErrorResponse { func getSTSErrorResponse(err STSError, requestID string) STSErrorResponse {
errRsp := STSErrorResponse{} errRsp := STSErrorResponse{}
errRsp.Error.Code = err.Code errRsp.Error.Code = err.Code
errRsp.Error.Message = err.Description errRsp.Error.Message = err.Description
errRsp.RequestID = "3L137" errRsp.RequestID = requestID
return errRsp return errRsp
} }

View File

@ -70,13 +70,13 @@ func (sts *stsAPIHandlers) AssumeRoleWithJWT(w http.ResponseWriter, r *http.Requ
// Parse the incoming form data. // Parse the incoming form data.
if err := r.ParseForm(); err != nil { if err := r.ParseForm(); err != nil {
logger.LogIf(ctx, err) logger.LogIf(ctx, err)
writeSTSErrorResponse(w, ErrSTSInvalidParameterValue) writeSTSErrorResponse(w, stsErrCodes.ToSTSErr(ErrSTSInvalidParameterValue))
return return
} }
if r.Form.Get("Version") != stsAPIVersion { if r.Form.Get("Version") != stsAPIVersion {
logger.LogIf(ctx, fmt.Errorf("Invalid STS API version %s, expecting %s", r.Form.Get("Version"), stsAPIVersion)) logger.LogIf(ctx, fmt.Errorf("Invalid STS API version %s, expecting %s", r.Form.Get("Version"), stsAPIVersion))
writeSTSErrorResponse(w, ErrSTSMissingParameter) writeSTSErrorResponse(w, stsErrCodes.ToSTSErr(ErrSTSMissingParameter))
return return
} }
@ -85,7 +85,7 @@ func (sts *stsAPIHandlers) AssumeRoleWithJWT(w http.ResponseWriter, r *http.Requ
case clientGrants, webIdentity: case clientGrants, webIdentity:
default: default:
logger.LogIf(ctx, fmt.Errorf("Unsupported action %s", action)) logger.LogIf(ctx, fmt.Errorf("Unsupported action %s", action))
writeSTSErrorResponse(w, ErrSTSInvalidParameterValue) writeSTSErrorResponse(w, stsErrCodes.ToSTSErr(ErrSTSInvalidParameterValue))
return return
} }
@ -93,14 +93,14 @@ func (sts *stsAPIHandlers) AssumeRoleWithJWT(w http.ResponseWriter, r *http.Requ
defer logger.AuditLog(w, r, action, nil) defer logger.AuditLog(w, r, action, nil)
if globalIAMValidators == nil { if globalIAMValidators == nil {
writeSTSErrorResponse(w, ErrSTSNotInitialized) writeSTSErrorResponse(w, stsErrCodes.ToSTSErr(ErrSTSNotInitialized))
return return
} }
v, err := globalIAMValidators.Get("jwt") v, err := globalIAMValidators.Get("jwt")
if err != nil { if err != nil {
logger.LogIf(ctx, err) logger.LogIf(ctx, err)
writeSTSErrorResponse(w, ErrSTSInvalidParameterValue) writeSTSErrorResponse(w, stsErrCodes.ToSTSErr(ErrSTSInvalidParameterValue))
return return
} }
@ -115,17 +115,17 @@ func (sts *stsAPIHandlers) AssumeRoleWithJWT(w http.ResponseWriter, r *http.Requ
case validator.ErrTokenExpired: case validator.ErrTokenExpired:
switch action { switch action {
case clientGrants: case clientGrants:
writeSTSErrorResponse(w, ErrSTSClientGrantsExpiredToken) writeSTSErrorResponse(w, stsErrCodes.ToSTSErr(ErrSTSClientGrantsExpiredToken))
case webIdentity: case webIdentity:
writeSTSErrorResponse(w, ErrSTSWebIdentityExpiredToken) writeSTSErrorResponse(w, stsErrCodes.ToSTSErr(ErrSTSWebIdentityExpiredToken))
} }
return return
case validator.ErrInvalidDuration: case validator.ErrInvalidDuration:
writeSTSErrorResponse(w, ErrSTSInvalidParameterValue) writeSTSErrorResponse(w, stsErrCodes.ToSTSErr(ErrSTSInvalidParameterValue))
return return
} }
logger.LogIf(ctx, err) logger.LogIf(ctx, err)
writeSTSErrorResponse(w, ErrSTSInvalidParameterValue) writeSTSErrorResponse(w, stsErrCodes.ToSTSErr(ErrSTSInvalidParameterValue))
return return
} }
@ -133,7 +133,7 @@ func (sts *stsAPIHandlers) AssumeRoleWithJWT(w http.ResponseWriter, r *http.Requ
cred, err := auth.GetNewCredentialsWithMetadata(m, secret) cred, err := auth.GetNewCredentialsWithMetadata(m, secret)
if err != nil { if err != nil {
logger.LogIf(ctx, err) logger.LogIf(ctx, err)
writeSTSErrorResponse(w, ErrSTSInternalError) writeSTSErrorResponse(w, stsErrCodes.ToSTSErr(ErrSTSInternalError))
return return
} }
@ -154,7 +154,7 @@ func (sts *stsAPIHandlers) AssumeRoleWithJWT(w http.ResponseWriter, r *http.Requ
// Set the newly generated credentials. // Set the newly generated credentials.
if err = globalIAMSys.SetTempUser(cred.AccessKey, cred, policyName); err != nil { if err = globalIAMSys.SetTempUser(cred.AccessKey, cred, policyName); err != nil {
logger.LogIf(ctx, err) logger.LogIf(ctx, err)
writeSTSErrorResponse(w, ErrSTSInternalError) writeSTSErrorResponse(w, stsErrCodes.ToSTSErr(ErrSTSInternalError))
return return
} }