refactor: extractSignedHeaders() handles headers removed by Go http server (#4054)

* refactor: extractSignedHeaders() handles headers removed by Go http server.
* Cleanup extractSignedHeaders() TestExtractSignedHeaders()
This commit is contained in:
Krishna Srinivas 2017-04-05 17:00:24 -07:00 committed by Harshavardhana
parent af82d27018
commit 1d99a560e3
6 changed files with 59 additions and 73 deletions

View File

@ -21,6 +21,7 @@ import (
"encoding/hex" "encoding/hex"
"net/http" "net/http"
"regexp" "regexp"
"strconv"
"strings" "strings"
"unicode/utf8" "unicode/utf8"
@ -119,7 +120,14 @@ func extractSignedHeaders(signedHeaders []string, r *http.Request) (http.Header,
// `host` will not be found in the headers, can be found in r.Host. // `host` will not be found in the headers, can be found in r.Host.
// but its alway necessary that the list of signed headers containing host in it. // but its alway necessary that the list of signed headers containing host in it.
val, ok := reqHeaders[http.CanonicalHeaderKey(header)] val, ok := reqHeaders[http.CanonicalHeaderKey(header)]
if !ok { if ok {
for _, enc := range val {
extractedSignedHeaders.Add(header, enc)
}
continue
}
switch header {
case "expect":
// Golang http server strips off 'Expect' header, if the // Golang http server strips off 'Expect' header, if the
// client sent this as part of signed headers we need to // client sent this as part of signed headers we need to
// handle otherwise we would see a signature mismatch. // handle otherwise we would see a signature mismatch.
@ -136,24 +144,23 @@ func extractSignedHeaders(signedHeaders []string, r *http.Request) (http.Header,
// be sent, for the time being keep this work around. // be sent, for the time being keep this work around.
// Adding a *TODO* to remove this later when Golang server // Adding a *TODO* to remove this later when Golang server
// doesn't filter out the 'Expect' header. // doesn't filter out the 'Expect' header.
if header == "expect" { extractedSignedHeaders.Set(header, "100-continue")
extractedSignedHeaders[header] = []string{"100-continue"} case "host":
continue // Go http server removes "host" from Request.Header
extractedSignedHeaders.Set(header, r.Host)
case "transfer-encoding":
// Go http server removes "host" from Request.Header
for _, enc := range r.TransferEncoding {
extractedSignedHeaders.Add(header, enc)
} }
// the "host" field will not be found in the header map, it can be found in req.Host. case "content-length":
// but its necessary to make sure that the "host" field exists in the list of signed parameters, // Signature-V4 spec excludes Content-Length from signed headers list for signature calculation.
// the check is done above. // But some clients deviate from this rule. Hence we consider Content-Length for signature
if header == "host" { // calculation to be compatible with such clients.
continue extractedSignedHeaders.Set(header, strconv.FormatInt(r.ContentLength, 10))
} default:
if header == "transfer-encoding" {
extractedSignedHeaders[header] = r.TransferEncoding
continue
}
// If not found continue, we will stop here.
return nil, ErrUnsignedHeaders return nil, ErrUnsignedHeaders
} }
extractedSignedHeaders[header] = val
} }
return extractedSignedHeaders, ErrNone return extractedSignedHeaders, ErrNone
} }

View File

@ -141,8 +141,9 @@ func TestExtractSignedHeaders(t *testing.T) {
expectedContentSha256 := "1234abcd" expectedContentSha256 := "1234abcd"
expectedTime := UTCNow().Format(iso8601Format) expectedTime := UTCNow().Format(iso8601Format)
expectedTransferEncoding := "gzip" expectedTransferEncoding := "gzip"
expectedExpect := "100-continue"
r, err := http.NewRequest("GET", "http://localhost", nil) r, err := http.NewRequest("GET", "http://play.minio.io:9000", nil)
if err != nil { if err != nil {
t.Fatal("Unable to create http.Request :", err) t.Fatal("Unable to create http.Request :", err)
} }
@ -150,9 +151,8 @@ func TestExtractSignedHeaders(t *testing.T) {
// Creating input http header. // Creating input http header.
inputHeader := r.Header inputHeader := r.Header
inputHeader.Set(signedHeaders[0], expectedHost) inputHeader.Set("x-amz-content-sha256", expectedContentSha256)
inputHeader.Set(signedHeaders[1], expectedContentSha256) inputHeader.Set("x-amz-date", expectedTime)
inputHeader.Set(signedHeaders[2], expectedTime)
// calling the function being tested. // calling the function being tested.
extractedSignedHeaders, errCode := extractSignedHeaders(signedHeaders, r) extractedSignedHeaders, errCode := extractSignedHeaders(signedHeaders, r)
if errCode != ErrNone { if errCode != ErrNone {
@ -160,24 +160,24 @@ func TestExtractSignedHeaders(t *testing.T) {
} }
// "x-amz-content-sha256" header value from the extracted result. // "x-amz-content-sha256" header value from the extracted result.
extractedContentSha256 := extractedSignedHeaders[signedHeaders[1]] extractedContentSha256 := extractedSignedHeaders.Get("x-amz-content-sha256")
// "host" header value from the extracted result. // "host" header value from the extracted result.
extractedHost := extractedSignedHeaders[signedHeaders[0]] extractedHost := extractedSignedHeaders.Get("host")
// "x-amz-date" header from the extracted result. // "x-amz-date" header from the extracted result.
extractedDate := extractedSignedHeaders[signedHeaders[2]] extractedDate := extractedSignedHeaders.Get("x-amz-date")
// extracted `expect` header. // extracted `expect` header.
extractedExpect := extractedSignedHeaders["expect"][0] extractedExpect := extractedSignedHeaders.Get("expect")
extractedTransferEncoding := extractedSignedHeaders["transfer-encoding"][0] extractedTransferEncoding := extractedSignedHeaders.Get("transfer-encoding")
if expectedHost != extractedHost[0] { if expectedHost != extractedHost {
t.Errorf("host header mismatch: expected `%s`, got `%s`", expectedHost, extractedHost) t.Errorf("host header mismatch: expected `%s`, got `%s`", expectedHost, extractedHost)
} }
// assert the result with the expected value. // assert the result with the expected value.
if expectedContentSha256 != extractedContentSha256[0] { if expectedContentSha256 != extractedContentSha256 {
t.Errorf("x-amz-content-sha256 header mismatch: expected `%s`, got `%s`", expectedContentSha256, extractedContentSha256) t.Errorf("x-amz-content-sha256 header mismatch: expected `%s`, got `%s`", expectedContentSha256, extractedContentSha256)
} }
if expectedTime != extractedDate[0] { if expectedTime != extractedDate {
t.Errorf("x-amz-date header mismatch: expected `%s`, got `%s`", expectedTime, extractedDate) t.Errorf("x-amz-date header mismatch: expected `%s`, got `%s`", expectedTime, extractedDate)
} }
if extractedTransferEncoding != expectedTransferEncoding { if extractedTransferEncoding != expectedTransferEncoding {
@ -185,12 +185,12 @@ func TestExtractSignedHeaders(t *testing.T) {
} }
// Since the list of signed headers value contained `expect`, the default value of `100-continue` will be added to extracted signed headers. // Since the list of signed headers value contained `expect`, the default value of `100-continue` will be added to extracted signed headers.
if extractedExpect != "100-continue" { if extractedExpect != expectedExpect {
t.Errorf("expect header incorrect value: expected `%s`, got `%s`", "100-continue", extractedExpect) t.Errorf("expect header incorrect value: expected `%s`, got `%s`", expectedExpect, extractedExpect)
} }
// case where the headers doesn't contain the one of the signed header in the signed headers list. // case where the headers don't contain the one of the signed header in the signed headers list.
signedHeaders = append(signedHeaders, " X-Amz-Credential") signedHeaders = append(signedHeaders, "X-Amz-Credential")
// expected to fail with `ErrUnsignedHeaders`. // expected to fail with `ErrUnsignedHeaders`.
_, errCode = extractSignedHeaders(signedHeaders, r) _, errCode = extractSignedHeaders(signedHeaders, r)
if errCode != ErrUnsignedHeaders { if errCode != ErrUnsignedHeaders {
@ -198,7 +198,7 @@ func TestExtractSignedHeaders(t *testing.T) {
} }
// case where the list of signed headers doesn't contain the host field. // case where the list of signed headers doesn't contain the host field.
signedHeaders = signedHeaders[1:] signedHeaders = signedHeaders[2:5]
// expected to fail with `ErrUnsignedHeaders`. // expected to fail with `ErrUnsignedHeaders`.
_, errCode = extractSignedHeaders(signedHeaders, r) _, errCode = extractSignedHeaders(signedHeaders, r)
if errCode != ErrUnsignedHeaders { if errCode != ErrUnsignedHeaders {

View File

@ -46,25 +46,19 @@ const (
) )
// getCanonicalHeaders generate a list of request headers with their values // getCanonicalHeaders generate a list of request headers with their values
func getCanonicalHeaders(signedHeaders http.Header, host string) string { func getCanonicalHeaders(signedHeaders http.Header) string {
var headers []string var headers []string
vals := make(http.Header) vals := make(http.Header)
for k, vv := range signedHeaders { for k, vv := range signedHeaders {
headers = append(headers, strings.ToLower(k)) headers = append(headers, strings.ToLower(k))
vals[strings.ToLower(k)] = vv vals[strings.ToLower(k)] = vv
} }
headers = append(headers, presignedHostHeader)
sort.Strings(headers) sort.Strings(headers)
var buf bytes.Buffer var buf bytes.Buffer
for _, k := range headers { for _, k := range headers {
buf.WriteString(k) buf.WriteString(k)
buf.WriteByte(':') buf.WriteByte(':')
switch {
case k == presignedHostHeader:
buf.WriteString(host)
fallthrough
default:
for idx, v := range vals[k] { for idx, v := range vals[k] {
if idx > 0 { if idx > 0 {
buf.WriteByte(',') buf.WriteByte(',')
@ -73,7 +67,6 @@ func getCanonicalHeaders(signedHeaders http.Header, host string) string {
} }
buf.WriteByte('\n') buf.WriteByte('\n')
} }
}
return buf.String() return buf.String()
} }
@ -83,7 +76,6 @@ func getSignedHeaders(signedHeaders http.Header) string {
for k := range signedHeaders { for k := range signedHeaders {
headers = append(headers, strings.ToLower(k)) headers = append(headers, strings.ToLower(k))
} }
headers = append(headers, presignedHostHeader)
sort.Strings(headers) sort.Strings(headers)
return strings.Join(headers, ";") return strings.Join(headers, ";")
} }
@ -98,14 +90,14 @@ func getSignedHeaders(signedHeaders http.Header) string {
// <SignedHeaders>\n // <SignedHeaders>\n
// <HashedPayload> // <HashedPayload>
// //
func getCanonicalRequest(extractedSignedHeaders http.Header, payload, queryStr, urlPath, method, host string) string { func getCanonicalRequest(extractedSignedHeaders http.Header, payload, queryStr, urlPath, method string) string {
rawQuery := strings.Replace(queryStr, "+", "%20", -1) rawQuery := strings.Replace(queryStr, "+", "%20", -1)
encodedPath := getURLEncodedName(urlPath) encodedPath := getURLEncodedName(urlPath)
canonicalRequest := strings.Join([]string{ canonicalRequest := strings.Join([]string{
method, method,
encodedPath, encodedPath,
rawQuery, rawQuery,
getCanonicalHeaders(extractedSignedHeaders, host), getCanonicalHeaders(extractedSignedHeaders),
getSignedHeaders(extractedSignedHeaders), getSignedHeaders(extractedSignedHeaders),
payload, payload,
}, "\n") }, "\n")
@ -304,7 +296,7 @@ func doesPresignedSignatureMatch(hashedPayload string, r *http.Request, region s
/// Verify finally if signature is same. /// Verify finally if signature is same.
// Get canonical request. // Get canonical request.
presignedCanonicalReq := getCanonicalRequest(extractedSignedHeaders, hashedPayload, encodedQuery, req.URL.Path, req.Method, req.Host) presignedCanonicalReq := getCanonicalRequest(extractedSignedHeaders, hashedPayload, encodedQuery, req.URL.Path, req.Method)
// Get string to sign from canonical request. // Get string to sign from canonical request.
presignedStringToSign := getStringToSign(presignedCanonicalReq, t, pSignValues.Credential.getScope()) presignedStringToSign := getStringToSign(presignedCanonicalReq, t, pSignValues.Credential.getScope())
@ -346,19 +338,6 @@ func doesSignatureMatch(hashedPayload string, r *http.Request, region string) AP
return ErrContentSHA256Mismatch return ErrContentSHA256Mismatch
} }
header := req.Header
// Signature-V4 spec excludes Content-Length from signed headers list for signature calculation.
// But some clients deviate from this rule. Hence we consider Content-Length for signature
// calculation to be compatible with such clients.
for _, h := range signV4Values.SignedHeaders {
if h == "content-length" {
header = cloneHeader(req.Header)
header.Set("content-length", strconv.FormatInt(r.ContentLength, 10))
break
}
}
// Extract all the signed headers along with its values. // Extract all the signed headers along with its values.
extractedSignedHeaders, errCode := extractSignedHeaders(signV4Values.SignedHeaders, r) extractedSignedHeaders, errCode := extractSignedHeaders(signV4Values.SignedHeaders, r)
if errCode != ErrNone { if errCode != ErrNone {
@ -401,8 +380,7 @@ func doesSignatureMatch(hashedPayload string, r *http.Request, region string) AP
queryStr := req.URL.Query().Encode() queryStr := req.URL.Query().Encode()
// Get canonical request. // Get canonical request.
canonicalRequest := getCanonicalRequest(extractedSignedHeaders, hashedPayload, queryStr, req.URL.Path, req.Method, req.Host) canonicalRequest := getCanonicalRequest(extractedSignedHeaders, hashedPayload, queryStr, req.URL.Path, req.Method)
// Get string to sign from canonical request. // Get string to sign from canonical request.
stringToSign := getStringToSign(canonicalRequest, t, signV4Values.Credential.getScope()) stringToSign := getStringToSign(canonicalRequest, t, signV4Values.Credential.getScope())

View File

@ -133,7 +133,7 @@ func calculateSeedSignature(r *http.Request) (signature string, date time.Time,
queryStr := req.URL.Query().Encode() queryStr := req.URL.Query().Encode()
// Get canonical request. // Get canonical request.
canonicalRequest := getCanonicalRequest(extractedSignedHeaders, payload, queryStr, req.URL.Path, req.Method, req.Host) canonicalRequest := getCanonicalRequest(extractedSignedHeaders, payload, queryStr, req.URL.Path, req.Method)
// Get string to sign from canonical request. // Get string to sign from canonical request.
stringToSign := getStringToSign(canonicalRequest, date, signV4Values.Credential.getScope()) stringToSign := getStringToSign(canonicalRequest, date, signV4Values.Credential.getScope())

View File

@ -889,11 +889,12 @@ func preSignV4(req *http.Request, accessKeyID, secretAccessKey string, expires i
query.Set("X-Amz-Credential", credential) query.Set("X-Amz-Credential", credential)
query.Set("X-Amz-Content-Sha256", unsignedPayload) query.Set("X-Amz-Content-Sha256", unsignedPayload)
// Headers are empty, since "host" is the only header required to be signed for Presigned URLs. // "host" is the only header required to be signed for Presigned URLs.
var extractedSignedHeaders http.Header extractedSignedHeaders := make(http.Header)
extractedSignedHeaders.Set("host", req.Host)
queryStr := strings.Replace(query.Encode(), "+", "%20", -1) queryStr := strings.Replace(query.Encode(), "+", "%20", -1)
canonicalRequest := getCanonicalRequest(extractedSignedHeaders, unsignedPayload, queryStr, req.URL.Path, req.Method, req.Host) canonicalRequest := getCanonicalRequest(extractedSignedHeaders, unsignedPayload, queryStr, req.URL.Path, req.Method)
stringToSign := getStringToSign(canonicalRequest, date, scope) stringToSign := getStringToSign(canonicalRequest, date, scope)
signingKey := getSigningKey(secretAccessKey, date, region) signingKey := getSigningKey(secretAccessKey, date, region)
signature := getSignature(signingKey, stringToSign) signature := getSignature(signingKey, stringToSign)

View File

@ -868,10 +868,10 @@ func presignedGet(host, bucket, object string, expiry int64) string {
path := "/" + path.Join(bucket, object) path := "/" + path.Join(bucket, object)
// Headers are empty, since "host" is the only header required to be signed for Presigned URLs. // "host" is the only header required to be signed for Presigned URLs.
var extractedSignedHeaders http.Header extractedSignedHeaders := make(http.Header)
extractedSignedHeaders.Set("host", host)
canonicalRequest := getCanonicalRequest(extractedSignedHeaders, unsignedPayload, query, path, "GET", host) canonicalRequest := getCanonicalRequest(extractedSignedHeaders, unsignedPayload, query, path, "GET")
stringToSign := getStringToSign(canonicalRequest, date, getScope(date, region)) stringToSign := getStringToSign(canonicalRequest, date, getScope(date, region))
signingKey := getSigningKey(secretKey, date, region) signingKey := getSigningKey(secretKey, date, region)
signature := getSignature(signingKey, stringToSign) signature := getSignature(signingKey, stringToSign)