Add proper requestID for STS errors (#7245)

This commit is contained in:
Harshavardhana 2019-02-14 17:54:33 -08:00 committed by kannappanr
parent 396d78352d
commit 8f62935448
2 changed files with 28 additions and 24 deletions

View File

@ -22,12 +22,11 @@ import (
)
// writeSTSErrorRespone writes error headers
func writeSTSErrorResponse(w http.ResponseWriter, errorCode STSErrorCode) {
stsError := getSTSError(errorCode)
func writeSTSErrorResponse(w http.ResponseWriter, err STSError) {
// Generate error response.
stsErrorResponse := getSTSErrorResponse(stsError)
stsErrorResponse := getSTSErrorResponse(err, w.Header().Get(responseRequestIDKey))
encodedErrorResponse := encodeResponse(stsErrorResponse)
writeResponse(w, stsError.HTTPStatusCode, encodedErrorResponse, mimeXML)
writeResponse(w, err.HTTPStatusCode, encodedErrorResponse, mimeXML)
}
// STSError structure
@ -64,9 +63,19 @@ const (
ErrSTSInternalError
)
type stsErrorCodeMap map[STSErrorCode]STSError
func (e stsErrorCodeMap) ToSTSErr(errCode STSErrorCode) STSError {
apiErr, ok := e[errCode]
if !ok {
return e[ErrSTSInternalError]
}
return apiErr
}
// error code to STSError structure, these fields carry respective
// descriptions for all the error responses.
var stsErrCodeResponse = map[STSErrorCode]STSError{
var stsErrCodes = stsErrorCodeMap{
ErrSTSMissingParameter: {
Code: "MissingParameter",
Description: "A required parameter for the specified action is not supplied.",
@ -109,17 +118,12 @@ var stsErrCodeResponse = map[STSErrorCode]STSError{
},
}
// getSTSError provides STS Error for input STS error code.
func getSTSError(code STSErrorCode) STSError {
return stsErrCodeResponse[code]
}
// getErrorResponse gets in standard error and resource value and
// getSTSErrorResponse gets in standard error and
// provides a encodable populated response values
func getSTSErrorResponse(err STSError) STSErrorResponse {
func getSTSErrorResponse(err STSError, requestID string) STSErrorResponse {
errRsp := STSErrorResponse{}
errRsp.Error.Code = err.Code
errRsp.Error.Message = err.Description
errRsp.RequestID = "3L137"
errRsp.RequestID = requestID
return errRsp
}

View File

@ -70,13 +70,13 @@ func (sts *stsAPIHandlers) AssumeRoleWithJWT(w http.ResponseWriter, r *http.Requ
// Parse the incoming form data.
if err := r.ParseForm(); err != nil {
logger.LogIf(ctx, err)
writeSTSErrorResponse(w, ErrSTSInvalidParameterValue)
writeSTSErrorResponse(w, stsErrCodes.ToSTSErr(ErrSTSInvalidParameterValue))
return
}
if r.Form.Get("Version") != stsAPIVersion {
logger.LogIf(ctx, fmt.Errorf("Invalid STS API version %s, expecting %s", r.Form.Get("Version"), stsAPIVersion))
writeSTSErrorResponse(w, ErrSTSMissingParameter)
writeSTSErrorResponse(w, stsErrCodes.ToSTSErr(ErrSTSMissingParameter))
return
}
@ -85,7 +85,7 @@ func (sts *stsAPIHandlers) AssumeRoleWithJWT(w http.ResponseWriter, r *http.Requ
case clientGrants, webIdentity:
default:
logger.LogIf(ctx, fmt.Errorf("Unsupported action %s", action))
writeSTSErrorResponse(w, ErrSTSInvalidParameterValue)
writeSTSErrorResponse(w, stsErrCodes.ToSTSErr(ErrSTSInvalidParameterValue))
return
}
@ -93,14 +93,14 @@ func (sts *stsAPIHandlers) AssumeRoleWithJWT(w http.ResponseWriter, r *http.Requ
defer logger.AuditLog(w, r, action, nil)
if globalIAMValidators == nil {
writeSTSErrorResponse(w, ErrSTSNotInitialized)
writeSTSErrorResponse(w, stsErrCodes.ToSTSErr(ErrSTSNotInitialized))
return
}
v, err := globalIAMValidators.Get("jwt")
if err != nil {
logger.LogIf(ctx, err)
writeSTSErrorResponse(w, ErrSTSInvalidParameterValue)
writeSTSErrorResponse(w, stsErrCodes.ToSTSErr(ErrSTSInvalidParameterValue))
return
}
@ -115,17 +115,17 @@ func (sts *stsAPIHandlers) AssumeRoleWithJWT(w http.ResponseWriter, r *http.Requ
case validator.ErrTokenExpired:
switch action {
case clientGrants:
writeSTSErrorResponse(w, ErrSTSClientGrantsExpiredToken)
writeSTSErrorResponse(w, stsErrCodes.ToSTSErr(ErrSTSClientGrantsExpiredToken))
case webIdentity:
writeSTSErrorResponse(w, ErrSTSWebIdentityExpiredToken)
writeSTSErrorResponse(w, stsErrCodes.ToSTSErr(ErrSTSWebIdentityExpiredToken))
}
return
case validator.ErrInvalidDuration:
writeSTSErrorResponse(w, ErrSTSInvalidParameterValue)
writeSTSErrorResponse(w, stsErrCodes.ToSTSErr(ErrSTSInvalidParameterValue))
return
}
logger.LogIf(ctx, err)
writeSTSErrorResponse(w, ErrSTSInvalidParameterValue)
writeSTSErrorResponse(w, stsErrCodes.ToSTSErr(ErrSTSInvalidParameterValue))
return
}
@ -133,7 +133,7 @@ func (sts *stsAPIHandlers) AssumeRoleWithJWT(w http.ResponseWriter, r *http.Requ
cred, err := auth.GetNewCredentialsWithMetadata(m, secret)
if err != nil {
logger.LogIf(ctx, err)
writeSTSErrorResponse(w, ErrSTSInternalError)
writeSTSErrorResponse(w, stsErrCodes.ToSTSErr(ErrSTSInternalError))
return
}
@ -154,7 +154,7 @@ func (sts *stsAPIHandlers) AssumeRoleWithJWT(w http.ResponseWriter, r *http.Requ
// Set the newly generated credentials.
if err = globalIAMSys.SetTempUser(cred.AccessKey, cred, policyName); err != nil {
logger.LogIf(ctx, err)
writeSTSErrorResponse(w, ErrSTSInternalError)
writeSTSErrorResponse(w, stsErrCodes.ToSTSErr(ErrSTSInternalError))
return
}