jwt: Parse standard claims faster (#13821)

* Use structless/allocationless decoding for header (note "typ" isn't used)
* Create custom unmarshal code using jsonparser for StandardClaims.

Before/After:

```
BenchmarkParseJWTStandardClaims-32    	 4270724	       294.0 ns/op	     706 B/op	      16 allocs/op
BenchmarkParseJWTStandardClaims-32    	 5634847	       214.7 ns/op	     544 B/op	       9 allocs/op

BenchmarkParseJWTMapClaims-32    	 2763045	       428.6 ns/op	    1251 B/op	      29 allocs/op
BenchmarkParseJWTMapClaims-32    	 2839455	       410.9 ns/op	    1219 B/op	      26 allocs/op
```
This commit is contained in:
Klaus Post
2021-12-03 13:19:38 -08:00
committed by GitHub
parent 4f35054d29
commit f56cac6381
3 changed files with 87 additions and 17 deletions

View File

@@ -27,11 +27,13 @@ import (
"crypto"
"crypto/hmac"
"encoding/base64"
"errors"
"fmt"
"hash"
"sync"
"time"
"github.com/buger/jsonparser"
jwtgo "github.com/golang-jwt/jwt/v4"
jsoniter "github.com/json-iterator/go"
)
@@ -112,6 +114,74 @@ type StandardClaims struct {
jwtgo.StandardClaims
}
// UnmarshalJSON provides custom JSON unmarshal.
// This is mainly implemented for speed.
func (c *StandardClaims) UnmarshalJSON(b []byte) (err error) {
return jsonparser.ObjectEach(b, func(key []byte, value []byte, dataType jsonparser.ValueType, _ int) error {
if len(key) == 0 {
return nil
}
switch key[0] {
case 'a':
if string(key) == "accessKey" {
if dataType != jsonparser.String {
return errors.New("accessKey: Expected string")
}
c.AccessKey, err = jsonparser.ParseString(value)
return err
}
if string(key) == "aud" {
if dataType != jsonparser.String {
return errors.New("aud: Expected string")
}
c.Audience, err = jsonparser.ParseString(value)
return err
}
case 'e':
if string(key) == "exp" {
if dataType != jsonparser.Number {
return errors.New("exp: Expected number")
}
c.ExpiresAt, err = jsonparser.ParseInt(value)
return err
}
case 'i':
if string(key) == "iat" {
if dataType != jsonparser.Number {
return errors.New("exp: Expected number")
}
c.IssuedAt, err = jsonparser.ParseInt(value)
return err
}
if string(key) == "iss" {
if dataType != jsonparser.String {
return errors.New("iss: Expected string")
}
c.Issuer, err = jsonparser.ParseString(value)
return err
}
case 'n':
if string(key) == "nbf" {
if dataType != jsonparser.Number {
return errors.New("nbf: Expected number")
}
c.NotBefore, err = jsonparser.ParseInt(value)
return err
}
case 's':
if string(key) == "sub" {
if dataType != jsonparser.String {
return errors.New("sub: Expected string")
}
c.Subject, err = jsonparser.ParseString(value)
return err
}
}
// Ignore unknown fields
return nil
})
}
// MapClaims - implements custom unmarshaller
type MapClaims struct {
AccessKey string `json:"accessKey,omitempty"`
@@ -279,12 +349,6 @@ func ParseWithStandardClaims(tokenStr string, claims *StandardClaims, key []byte
return claims.Valid()
}
// https://tools.ietf.org/html/rfc7519#page-11
type jwtHeader struct {
Algorithm string `json:"alg"`
Type string `json:"typ"`
}
// ParseUnverifiedStandardClaims - WARNING: Don't use this method unless you know what you're doing
//
// This method parses the token but doesn't validate the signature. It's only
@@ -303,10 +367,11 @@ func ParseUnverifiedStandardClaims(token []byte, claims *StandardClaims, buf []b
if err != nil {
return nil, &jwtgo.ValidationError{Inner: err, Errors: jwtgo.ValidationErrorMalformed}
}
headerDec := buf[:n]
buf = buf[n:]
var header = jwtHeader{}
var json = jsoniter.ConfigCompatibleWithStandardLibrary
if err = json.Unmarshal(buf[:n], &header); err != nil {
alg, _, _, err := jsonparser.Get(headerDec, "alg")
if err != nil {
return nil, &jwtgo.ValidationError{Inner: err, Errors: jwtgo.ValidationErrorMalformed}
}
@@ -315,17 +380,17 @@ func ParseUnverifiedStandardClaims(token []byte, claims *StandardClaims, buf []b
return nil, &jwtgo.ValidationError{Inner: err, Errors: jwtgo.ValidationErrorMalformed}
}
if err = json.Unmarshal(buf[:n], claims); err != nil {
if err = claims.UnmarshalJSON(buf[:n]); err != nil {
return nil, &jwtgo.ValidationError{Inner: err, Errors: jwtgo.ValidationErrorMalformed}
}
for _, signer := range hmacSigners {
if header.Algorithm == signer.Name {
if string(alg) == signer.Name {
return signer, nil
}
}
return nil, jwtgo.NewValidationError(fmt.Sprintf("signing method (%s) is unavailable.", header.Algorithm),
return nil, jwtgo.NewValidationError(fmt.Sprintf("signing method (%s) is unavailable.", string(alg)),
jwtgo.ValidationErrorUnverifiable)
}
@@ -416,9 +481,10 @@ func ParseUnverifiedMapClaims(token []byte, claims *MapClaims, buf []byte) (*Sig
return nil, &jwtgo.ValidationError{Inner: err, Errors: jwtgo.ValidationErrorMalformed}
}
var header = jwtHeader{}
var json = jsoniter.ConfigCompatibleWithStandardLibrary
if err = json.Unmarshal(buf[:n], &header); err != nil {
headerDec := buf[:n]
buf = buf[n:]
alg, _, _, err := jsonparser.Get(headerDec, "alg")
if err != nil {
return nil, &jwtgo.ValidationError{Inner: err, Errors: jwtgo.ValidationErrorMalformed}
}
@@ -427,16 +493,17 @@ func ParseUnverifiedMapClaims(token []byte, claims *MapClaims, buf []byte) (*Sig
return nil, &jwtgo.ValidationError{Inner: err, Errors: jwtgo.ValidationErrorMalformed}
}
var json = jsoniter.ConfigCompatibleWithStandardLibrary
if err = json.Unmarshal(buf[:n], &claims.MapClaims); err != nil {
return nil, &jwtgo.ValidationError{Inner: err, Errors: jwtgo.ValidationErrorMalformed}
}
for _, signer := range hmacSigners {
if header.Algorithm == signer.Name {
if string(alg) == signer.Name {
return signer, nil
}
}
return nil, jwtgo.NewValidationError(fmt.Sprintf("signing method (%s) is unavailable.", header.Algorithm),
return nil, jwtgo.NewValidationError(fmt.Sprintf("signing method (%s) is unavailable.", string(alg)),
jwtgo.ValidationErrorUnverifiable)
}