jwt: Improve allocations (#13532)

Avoid string -> byte allocations.

```
BenchmarkParseJWTStandardClaims-32       3527152           343.2 ns/op      1489 B/op         21 allocs/op
BenchmarkParseJWTStandardClaims-32       4713199           259.2 ns/op       706 B/op         16 allocs/op

BenchmarkParseJWTMapClaims-32        2666668           448.7 ns/op      1883 B/op         32 allocs/op
BenchmarkParseJWTMapClaims-32        3120709           377.1 ns/op      1227 B/op         28 allocs/op
```
This commit is contained in:
Klaus Post 2021-10-28 17:04:48 -07:00 committed by GitHub
parent db84bb9bd3
commit 9424dca9e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -23,11 +23,12 @@ package jwt
// borrowed under MIT License https://github.com/golang-jwt/jwt/blob/main/LICENSE // borrowed under MIT License https://github.com/golang-jwt/jwt/blob/main/LICENSE
import ( import (
"bytes"
"crypto" "crypto"
"crypto/hmac" "crypto/hmac"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"strings" "hash"
"sync" "sync"
"time" "time"
@ -38,8 +39,9 @@ import (
// SigningMethodHMAC - Implements the HMAC-SHA family of signing methods signing methods // SigningMethodHMAC - Implements the HMAC-SHA family of signing methods signing methods
// Expects key type of []byte for both signing and validation // Expects key type of []byte for both signing and validation
type SigningMethodHMAC struct { type SigningMethodHMAC struct {
Name string Name string
Hash crypto.Hash Hash crypto.Hash
HasherPool sync.Pool
} }
// Specific instances for HS256, HS384, HS512 // Specific instances for HS256, HS384, HS512
@ -49,6 +51,8 @@ var (
SigningMethodHS512 *SigningMethodHMAC SigningMethodHS512 *SigningMethodHMAC
) )
const base64BufferSize = 8192
var ( var (
base64BufPool sync.Pool base64BufPool sync.Pool
hmacSigners []*SigningMethodHMAC hmacSigners []*SigningMethodHMAC
@ -57,16 +61,49 @@ var (
func init() { func init() {
base64BufPool = sync.Pool{ base64BufPool = sync.Pool{
New: func() interface{} { New: func() interface{} {
buf := make([]byte, 8192) buf := make([]byte, base64BufferSize)
return &buf return &buf
}, },
} }
hmacSigners = []*SigningMethodHMAC{ hmacSigners = []*SigningMethodHMAC{
{"HS256", crypto.SHA256}, {Name: "HS256", Hash: crypto.SHA256},
{"HS384", crypto.SHA384}, {Name: "HS384", Hash: crypto.SHA384},
{"HS512", crypto.SHA512}, {Name: "HS512", Hash: crypto.SHA512},
} }
for i := range hmacSigners {
h := hmacSigners[i].Hash
hmacSigners[i].HasherPool.New = func() interface{} {
return h.New()
}
}
}
// HashBorrower allows borrowing hashes and will keep track of them.
func (s *SigningMethodHMAC) HashBorrower() HashBorrower {
return HashBorrower{pool: &s.HasherPool, borrowed: make([]hash.Hash, 0, 2)}
}
// HashBorrower keeps track of borrowed hashers and allows to return them all.
type HashBorrower struct {
pool *sync.Pool
borrowed []hash.Hash
}
// Borrow a single hasher.
func (h *HashBorrower) Borrow() hash.Hash {
hasher := h.pool.Get().(hash.Hash)
h.borrowed = append(h.borrowed, hasher)
hasher.Reset()
return hasher
}
// ReturnAll will return all borrowed hashes.
func (h *HashBorrower) ReturnAll() {
for _, hasher := range h.borrowed {
h.pool.Put(hasher)
}
h.borrowed = nil
} }
// StandardClaims are basically standard claims with "accessKey" // StandardClaims are basically standard claims with "accessKey"
@ -202,26 +239,35 @@ func ParseWithStandardClaims(tokenStr string, claims *StandardClaims, key []byte
bufp := base64BufPool.Get().(*[]byte) bufp := base64BufPool.Get().(*[]byte)
defer base64BufPool.Put(bufp) defer base64BufPool.Put(bufp)
signer, err := ParseUnverifiedStandardClaims(tokenStr, claims, *bufp) tokenBuf := base64BufPool.Get().(*[]byte)
defer base64BufPool.Put(tokenBuf)
token := *tokenBuf
// Copy token to buffer, truncate to length.
token = token[:copy(token[:base64BufferSize], tokenStr)]
signer, err := ParseUnverifiedStandardClaims(token, claims, *bufp)
if err != nil { if err != nil {
return err return err
} }
i := strings.LastIndex(tokenStr, ".") i := bytes.LastIndexByte(token, '.')
if i < 0 { if i < 0 {
return jwtgo.ErrSignatureInvalid return jwtgo.ErrSignatureInvalid
} }
n, err := base64Decode(tokenStr[i+1:], *bufp) n, err := base64DecodeBytes(token[i+1:], *bufp)
if err != nil { if err != nil {
return err return err
} }
borrow := signer.HashBorrower()
hasher := hmac.New(signer.Hash.New, key) hasher := hmac.New(borrow.Borrow, key)
hasher.Write([]byte(tokenStr[:i])) hasher.Write(token[:i])
if !hmac.Equal((*bufp)[:n], hasher.Sum(nil)) { if !hmac.Equal((*bufp)[:n], hasher.Sum(nil)) {
borrow.ReturnAll()
return jwtgo.ErrSignatureInvalid return jwtgo.ErrSignatureInvalid
} }
borrow.ReturnAll()
if claims.AccessKey == "" && claims.Subject == "" { if claims.AccessKey == "" && claims.Subject == "" {
return jwtgo.NewValidationError("accessKey/sub missing", return jwtgo.NewValidationError("accessKey/sub missing",
@ -245,15 +291,15 @@ type jwtHeader struct {
// ever useful in cases where you know the signature is valid (because it has // ever useful in cases where you know the signature is valid (because it has
// been checked previously in the stack) and you want to extract values from // been checked previously in the stack) and you want to extract values from
// it. // it.
func ParseUnverifiedStandardClaims(tokenString string, claims *StandardClaims, buf []byte) (*SigningMethodHMAC, error) { func ParseUnverifiedStandardClaims(token []byte, claims *StandardClaims, buf []byte) (*SigningMethodHMAC, error) {
if strings.Count(tokenString, ".") != 2 { if bytes.Count(token, []byte(".")) != 2 {
return nil, jwtgo.ErrSignatureInvalid return nil, jwtgo.ErrSignatureInvalid
} }
i := strings.Index(tokenString, ".") i := bytes.IndexByte(token, '.')
j := strings.LastIndex(tokenString, ".") j := bytes.LastIndexByte(token, '.')
n, err := base64Decode(tokenString[:i], buf) n, err := base64DecodeBytes(token[:i], buf)
if err != nil { if err != nil {
return nil, &jwtgo.ValidationError{Inner: err, Errors: jwtgo.ValidationErrorMalformed} return nil, &jwtgo.ValidationError{Inner: err, Errors: jwtgo.ValidationErrorMalformed}
} }
@ -264,7 +310,7 @@ func ParseUnverifiedStandardClaims(tokenString string, claims *StandardClaims, b
return nil, &jwtgo.ValidationError{Inner: err, Errors: jwtgo.ValidationErrorMalformed} return nil, &jwtgo.ValidationError{Inner: err, Errors: jwtgo.ValidationErrorMalformed}
} }
n, err = base64Decode(tokenString[i+1:j], buf) n, err = base64DecodeBytes(token[i+1:j], buf)
if err != nil { if err != nil {
return nil, &jwtgo.ValidationError{Inner: err, Errors: jwtgo.ValidationErrorMalformed} return nil, &jwtgo.ValidationError{Inner: err, Errors: jwtgo.ValidationErrorMalformed}
} }
@ -294,17 +340,24 @@ func ParseWithClaims(tokenStr string, claims *MapClaims, fn func(*MapClaims) ([]
bufp := base64BufPool.Get().(*[]byte) bufp := base64BufPool.Get().(*[]byte)
defer base64BufPool.Put(bufp) defer base64BufPool.Put(bufp)
signer, err := ParseUnverifiedMapClaims(tokenStr, claims, *bufp) tokenBuf := base64BufPool.Get().(*[]byte)
defer base64BufPool.Put(tokenBuf)
token := *tokenBuf
// Copy token to buffer, truncate to length.
token = token[:copy(token[:base64BufferSize], tokenStr)]
signer, err := ParseUnverifiedMapClaims(token, claims, *bufp)
if err != nil { if err != nil {
return err return err
} }
i := strings.LastIndex(tokenStr, ".") i := bytes.LastIndexByte(token, '.')
if i < 0 { if i < 0 {
return jwtgo.ErrSignatureInvalid return jwtgo.ErrSignatureInvalid
} }
n, err := base64Decode(tokenStr[i+1:], *bufp) n, err := base64DecodeBytes(token[i+1:], *bufp)
if err != nil { if err != nil {
return err return err
} }
@ -325,21 +378,23 @@ func ParseWithClaims(tokenStr string, claims *MapClaims, fn func(*MapClaims) ([]
if err != nil { if err != nil {
return err return err
} }
borrow := signer.HashBorrower()
hasher := hmac.New(signer.Hash.New, key) hasher := hmac.New(borrow.Borrow, key)
hasher.Write([]byte(tokenStr[:i])) hasher.Write([]byte(tokenStr[:i]))
if !hmac.Equal((*bufp)[:n], hasher.Sum(nil)) { if !hmac.Equal((*bufp)[:n], hasher.Sum(nil)) {
borrow.ReturnAll()
return jwtgo.ErrSignatureInvalid return jwtgo.ErrSignatureInvalid
} }
borrow.ReturnAll()
// Signature is valid, lets validate the claims for // Signature is valid, lets validate the claims for
// other fields such as expiry etc. // other fields such as expiry etc.
return claims.Valid() return claims.Valid()
} }
// base64Decode returns the bytes represented by the base64 string s. // base64DecodeBytes returns the bytes represented by the base64 string s.
func base64Decode(s string, buf []byte) (int, error) { func base64DecodeBytes(b []byte, buf []byte) (int, error) {
return base64.RawURLEncoding.Decode(buf, []byte(s)) return base64.RawURLEncoding.Decode(buf, b)
} }
// ParseUnverifiedMapClaims - WARNING: Don't use this method unless you know what you're doing // ParseUnverifiedMapClaims - WARNING: Don't use this method unless you know what you're doing
@ -348,15 +403,15 @@ func base64Decode(s string, buf []byte) (int, error) {
// ever useful in cases where you know the signature is valid (because it has // ever useful in cases where you know the signature is valid (because it has
// been checked previously in the stack) and you want to extract values from // been checked previously in the stack) and you want to extract values from
// it. // it.
func ParseUnverifiedMapClaims(tokenString string, claims *MapClaims, buf []byte) (*SigningMethodHMAC, error) { func ParseUnverifiedMapClaims(token []byte, claims *MapClaims, buf []byte) (*SigningMethodHMAC, error) {
if strings.Count(tokenString, ".") != 2 { if bytes.Count(token, []byte(".")) != 2 {
return nil, jwtgo.ErrSignatureInvalid return nil, jwtgo.ErrSignatureInvalid
} }
i := strings.Index(tokenString, ".") i := bytes.IndexByte(token, '.')
j := strings.LastIndex(tokenString, ".") j := bytes.LastIndexByte(token, '.')
n, err := base64Decode(tokenString[:i], buf) n, err := base64DecodeBytes(token[:i], buf)
if err != nil { if err != nil {
return nil, &jwtgo.ValidationError{Inner: err, Errors: jwtgo.ValidationErrorMalformed} return nil, &jwtgo.ValidationError{Inner: err, Errors: jwtgo.ValidationErrorMalformed}
} }
@ -367,7 +422,7 @@ func ParseUnverifiedMapClaims(tokenString string, claims *MapClaims, buf []byte)
return nil, &jwtgo.ValidationError{Inner: err, Errors: jwtgo.ValidationErrorMalformed} return nil, &jwtgo.ValidationError{Inner: err, Errors: jwtgo.ValidationErrorMalformed}
} }
n, err = base64Decode(tokenString[i+1:j], buf) n, err = base64DecodeBytes(token[i+1:j], buf)
if err != nil { if err != nil {
return nil, &jwtgo.ValidationError{Inner: err, Errors: jwtgo.ValidationErrorMalformed} return nil, &jwtgo.ValidationError{Inner: err, Errors: jwtgo.ValidationErrorMalformed}
} }