From 9424dca9e47d9bf3f7689d0bd4fb7944bcfdee1c Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Thu, 28 Oct 2021 17:04:48 -0700 Subject: [PATCH] jwt: Improve allocations (#13532) Avoid string -> byte allocations. ``` BenchmarkParseJWTStandardClaims-32 3527152 343.2 ns/op 1489 B/op 21 allocs/op BenchmarkParseJWTStandardClaims-32 4713199 259.2 ns/op 706 B/op 16 allocs/op BenchmarkParseJWTMapClaims-32 2666668 448.7 ns/op 1883 B/op 32 allocs/op BenchmarkParseJWTMapClaims-32 3120709 377.1 ns/op 1227 B/op 28 allocs/op ``` --- internal/jwt/parser.go | 121 ++++++++++++++++++++++++++++++----------- 1 file changed, 88 insertions(+), 33 deletions(-) diff --git a/internal/jwt/parser.go b/internal/jwt/parser.go index f2807eb7a..7e49cc956 100644 --- a/internal/jwt/parser.go +++ b/internal/jwt/parser.go @@ -23,11 +23,12 @@ package jwt // borrowed under MIT License https://github.com/golang-jwt/jwt/blob/main/LICENSE import ( + "bytes" "crypto" "crypto/hmac" "encoding/base64" "fmt" - "strings" + "hash" "sync" "time" @@ -38,8 +39,9 @@ import ( // SigningMethodHMAC - Implements the HMAC-SHA family of signing methods signing methods // Expects key type of []byte for both signing and validation type SigningMethodHMAC struct { - Name string - Hash crypto.Hash + Name string + Hash crypto.Hash + HasherPool sync.Pool } // Specific instances for HS256, HS384, HS512 @@ -49,6 +51,8 @@ var ( SigningMethodHS512 *SigningMethodHMAC ) +const base64BufferSize = 8192 + var ( base64BufPool sync.Pool hmacSigners []*SigningMethodHMAC @@ -57,16 +61,49 @@ var ( func init() { base64BufPool = sync.Pool{ New: func() interface{} { - buf := make([]byte, 8192) + buf := make([]byte, base64BufferSize) return &buf }, } hmacSigners = []*SigningMethodHMAC{ - {"HS256", crypto.SHA256}, - {"HS384", crypto.SHA384}, - {"HS512", crypto.SHA512}, + {Name: "HS256", Hash: crypto.SHA256}, + {Name: "HS384", Hash: crypto.SHA384}, + {Name: "HS512", Hash: crypto.SHA512}, } + for i := range hmacSigners { + h := hmacSigners[i].Hash + hmacSigners[i].HasherPool.New = func() interface{} { + return h.New() + } + } +} + +// HashBorrower allows borrowing hashes and will keep track of them. +func (s *SigningMethodHMAC) HashBorrower() HashBorrower { + return HashBorrower{pool: &s.HasherPool, borrowed: make([]hash.Hash, 0, 2)} +} + +// HashBorrower keeps track of borrowed hashers and allows to return them all. +type HashBorrower struct { + pool *sync.Pool + borrowed []hash.Hash +} + +// Borrow a single hasher. +func (h *HashBorrower) Borrow() hash.Hash { + hasher := h.pool.Get().(hash.Hash) + h.borrowed = append(h.borrowed, hasher) + hasher.Reset() + return hasher +} + +// ReturnAll will return all borrowed hashes. +func (h *HashBorrower) ReturnAll() { + for _, hasher := range h.borrowed { + h.pool.Put(hasher) + } + h.borrowed = nil } // StandardClaims are basically standard claims with "accessKey" @@ -202,26 +239,35 @@ func ParseWithStandardClaims(tokenStr string, claims *StandardClaims, key []byte bufp := base64BufPool.Get().(*[]byte) defer base64BufPool.Put(bufp) - signer, err := ParseUnverifiedStandardClaims(tokenStr, claims, *bufp) + tokenBuf := base64BufPool.Get().(*[]byte) + defer base64BufPool.Put(tokenBuf) + + token := *tokenBuf + // Copy token to buffer, truncate to length. + token = token[:copy(token[:base64BufferSize], tokenStr)] + + signer, err := ParseUnverifiedStandardClaims(token, claims, *bufp) if err != nil { return err } - i := strings.LastIndex(tokenStr, ".") + i := bytes.LastIndexByte(token, '.') if i < 0 { return jwtgo.ErrSignatureInvalid } - n, err := base64Decode(tokenStr[i+1:], *bufp) + n, err := base64DecodeBytes(token[i+1:], *bufp) if err != nil { return err } - - hasher := hmac.New(signer.Hash.New, key) - hasher.Write([]byte(tokenStr[:i])) + borrow := signer.HashBorrower() + hasher := hmac.New(borrow.Borrow, key) + hasher.Write(token[:i]) if !hmac.Equal((*bufp)[:n], hasher.Sum(nil)) { + borrow.ReturnAll() return jwtgo.ErrSignatureInvalid } + borrow.ReturnAll() if claims.AccessKey == "" && claims.Subject == "" { return jwtgo.NewValidationError("accessKey/sub missing", @@ -245,15 +291,15 @@ type jwtHeader struct { // ever useful in cases where you know the signature is valid (because it has // been checked previously in the stack) and you want to extract values from // it. -func ParseUnverifiedStandardClaims(tokenString string, claims *StandardClaims, buf []byte) (*SigningMethodHMAC, error) { - if strings.Count(tokenString, ".") != 2 { +func ParseUnverifiedStandardClaims(token []byte, claims *StandardClaims, buf []byte) (*SigningMethodHMAC, error) { + if bytes.Count(token, []byte(".")) != 2 { return nil, jwtgo.ErrSignatureInvalid } - i := strings.Index(tokenString, ".") - j := strings.LastIndex(tokenString, ".") + i := bytes.IndexByte(token, '.') + j := bytes.LastIndexByte(token, '.') - n, err := base64Decode(tokenString[:i], buf) + n, err := base64DecodeBytes(token[:i], buf) if err != nil { return nil, &jwtgo.ValidationError{Inner: err, Errors: jwtgo.ValidationErrorMalformed} } @@ -264,7 +310,7 @@ func ParseUnverifiedStandardClaims(tokenString string, claims *StandardClaims, b return nil, &jwtgo.ValidationError{Inner: err, Errors: jwtgo.ValidationErrorMalformed} } - n, err = base64Decode(tokenString[i+1:j], buf) + n, err = base64DecodeBytes(token[i+1:j], buf) if err != nil { return nil, &jwtgo.ValidationError{Inner: err, Errors: jwtgo.ValidationErrorMalformed} } @@ -294,17 +340,24 @@ func ParseWithClaims(tokenStr string, claims *MapClaims, fn func(*MapClaims) ([] bufp := base64BufPool.Get().(*[]byte) defer base64BufPool.Put(bufp) - signer, err := ParseUnverifiedMapClaims(tokenStr, claims, *bufp) + tokenBuf := base64BufPool.Get().(*[]byte) + defer base64BufPool.Put(tokenBuf) + + token := *tokenBuf + // Copy token to buffer, truncate to length. + token = token[:copy(token[:base64BufferSize], tokenStr)] + + signer, err := ParseUnverifiedMapClaims(token, claims, *bufp) if err != nil { return err } - i := strings.LastIndex(tokenStr, ".") + i := bytes.LastIndexByte(token, '.') if i < 0 { return jwtgo.ErrSignatureInvalid } - n, err := base64Decode(tokenStr[i+1:], *bufp) + n, err := base64DecodeBytes(token[i+1:], *bufp) if err != nil { return err } @@ -325,21 +378,23 @@ func ParseWithClaims(tokenStr string, claims *MapClaims, fn func(*MapClaims) ([] if err != nil { return err } - - hasher := hmac.New(signer.Hash.New, key) + borrow := signer.HashBorrower() + hasher := hmac.New(borrow.Borrow, key) hasher.Write([]byte(tokenStr[:i])) if !hmac.Equal((*bufp)[:n], hasher.Sum(nil)) { + borrow.ReturnAll() return jwtgo.ErrSignatureInvalid } + borrow.ReturnAll() // Signature is valid, lets validate the claims for // other fields such as expiry etc. return claims.Valid() } -// base64Decode returns the bytes represented by the base64 string s. -func base64Decode(s string, buf []byte) (int, error) { - return base64.RawURLEncoding.Decode(buf, []byte(s)) +// base64DecodeBytes returns the bytes represented by the base64 string s. +func base64DecodeBytes(b []byte, buf []byte) (int, error) { + return base64.RawURLEncoding.Decode(buf, b) } // ParseUnverifiedMapClaims - WARNING: Don't use this method unless you know what you're doing @@ -348,15 +403,15 @@ func base64Decode(s string, buf []byte) (int, error) { // ever useful in cases where you know the signature is valid (because it has // been checked previously in the stack) and you want to extract values from // it. -func ParseUnverifiedMapClaims(tokenString string, claims *MapClaims, buf []byte) (*SigningMethodHMAC, error) { - if strings.Count(tokenString, ".") != 2 { +func ParseUnverifiedMapClaims(token []byte, claims *MapClaims, buf []byte) (*SigningMethodHMAC, error) { + if bytes.Count(token, []byte(".")) != 2 { return nil, jwtgo.ErrSignatureInvalid } - i := strings.Index(tokenString, ".") - j := strings.LastIndex(tokenString, ".") + i := bytes.IndexByte(token, '.') + j := bytes.LastIndexByte(token, '.') - n, err := base64Decode(tokenString[:i], buf) + n, err := base64DecodeBytes(token[:i], buf) if err != nil { return nil, &jwtgo.ValidationError{Inner: err, Errors: jwtgo.ValidationErrorMalformed} } @@ -367,7 +422,7 @@ func ParseUnverifiedMapClaims(tokenString string, claims *MapClaims, buf []byte) return nil, &jwtgo.ValidationError{Inner: err, Errors: jwtgo.ValidationErrorMalformed} } - n, err = base64Decode(tokenString[i+1:j], buf) + n, err = base64DecodeBytes(token[i+1:j], buf) if err != nil { return nil, &jwtgo.ValidationError{Inner: err, Errors: jwtgo.ValidationErrorMalformed} }