jwt: Parse standard claims faster (#13821)

* Use structless/allocationless decoding for header (note "typ" isn't used)
* Create custom unmarshal code using jsonparser for StandardClaims.

Before/After:

```
BenchmarkParseJWTStandardClaims-32    	 4270724	       294.0 ns/op	     706 B/op	      16 allocs/op
BenchmarkParseJWTStandardClaims-32    	 5634847	       214.7 ns/op	     544 B/op	       9 allocs/op

BenchmarkParseJWTMapClaims-32    	 2763045	       428.6 ns/op	    1251 B/op	      29 allocs/op
BenchmarkParseJWTMapClaims-32    	 2839455	       410.9 ns/op	    1219 B/op	      26 allocs/op
```
This commit is contained in:
Klaus Post 2021-12-03 13:19:38 -08:00 committed by GitHub
parent 4f35054d29
commit f56cac6381
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 87 additions and 17 deletions

1
go.mod
View File

@ -104,6 +104,7 @@ require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/bits-and-blooms/bitset v1.2.0 // indirect
github.com/briandowns/spinner v1.16.0 // indirect
github.com/buger/jsonparser v1.1.1 // indirect
github.com/coreos/go-semver v0.3.0 // indirect
github.com/coreos/go-systemd/v22 v22.3.2 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect

2
go.sum
View File

@ -221,6 +221,8 @@ github.com/bombsimon/wsl/v3 v3.0.0/go.mod h1:st10JtZYLE4D5sC7b8xV4zTKZwAQjCH/Hy2
github.com/bombsimon/wsl/v3 v3.1.0/go.mod h1:st10JtZYLE4D5sC7b8xV4zTKZwAQjCH/Hy2Pm1FNZIc=
github.com/briandowns/spinner v1.16.0 h1:DFmp6hEaIx2QXXuqSJmtfSBSAjRmpGiKG6ip2Wm/yOs=
github.com/briandowns/spinner v1.16.0/go.mod h1:QOuQk7x+EaDASo80FEXwlwiA+j/PPIcX3FScO+3/ZPQ=
github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs=
github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
github.com/caarlos0/ctrlc v1.0.0/go.mod h1:CdXpj4rmq0q/1Eb44M9zi2nKB0QraNKuRGYGrrHhcQw=
github.com/campoy/unique v0.0.0-20180121183637-88950e537e7e/go.mod h1:9IOqJGCPMSc6E5ydlp5NIonxObaeu/Iub/X03EKPVYo=
github.com/casbin/casbin/v2 v2.1.2/go.mod h1:YcPU1XXisHhLzuxH9coDNf2FbKpjGlbCg3n9yuLkIJQ=

View File

@ -27,11 +27,13 @@ import (
"crypto"
"crypto/hmac"
"encoding/base64"
"errors"
"fmt"
"hash"
"sync"
"time"
"github.com/buger/jsonparser"
jwtgo "github.com/golang-jwt/jwt/v4"
jsoniter "github.com/json-iterator/go"
)
@ -112,6 +114,74 @@ type StandardClaims struct {
jwtgo.StandardClaims
}
// UnmarshalJSON provides custom JSON unmarshal.
// This is mainly implemented for speed.
func (c *StandardClaims) UnmarshalJSON(b []byte) (err error) {
return jsonparser.ObjectEach(b, func(key []byte, value []byte, dataType jsonparser.ValueType, _ int) error {
if len(key) == 0 {
return nil
}
switch key[0] {
case 'a':
if string(key) == "accessKey" {
if dataType != jsonparser.String {
return errors.New("accessKey: Expected string")
}
c.AccessKey, err = jsonparser.ParseString(value)
return err
}
if string(key) == "aud" {
if dataType != jsonparser.String {
return errors.New("aud: Expected string")
}
c.Audience, err = jsonparser.ParseString(value)
return err
}
case 'e':
if string(key) == "exp" {
if dataType != jsonparser.Number {
return errors.New("exp: Expected number")
}
c.ExpiresAt, err = jsonparser.ParseInt(value)
return err
}
case 'i':
if string(key) == "iat" {
if dataType != jsonparser.Number {
return errors.New("exp: Expected number")
}
c.IssuedAt, err = jsonparser.ParseInt(value)
return err
}
if string(key) == "iss" {
if dataType != jsonparser.String {
return errors.New("iss: Expected string")
}
c.Issuer, err = jsonparser.ParseString(value)
return err
}
case 'n':
if string(key) == "nbf" {
if dataType != jsonparser.Number {
return errors.New("nbf: Expected number")
}
c.NotBefore, err = jsonparser.ParseInt(value)
return err
}
case 's':
if string(key) == "sub" {
if dataType != jsonparser.String {
return errors.New("sub: Expected string")
}
c.Subject, err = jsonparser.ParseString(value)
return err
}
}
// Ignore unknown fields
return nil
})
}
// MapClaims - implements custom unmarshaller
type MapClaims struct {
AccessKey string `json:"accessKey,omitempty"`
@ -279,12 +349,6 @@ func ParseWithStandardClaims(tokenStr string, claims *StandardClaims, key []byte
return claims.Valid()
}
// https://tools.ietf.org/html/rfc7519#page-11
type jwtHeader struct {
Algorithm string `json:"alg"`
Type string `json:"typ"`
}
// ParseUnverifiedStandardClaims - WARNING: Don't use this method unless you know what you're doing
//
// This method parses the token but doesn't validate the signature. It's only
@ -303,10 +367,11 @@ func ParseUnverifiedStandardClaims(token []byte, claims *StandardClaims, buf []b
if err != nil {
return nil, &jwtgo.ValidationError{Inner: err, Errors: jwtgo.ValidationErrorMalformed}
}
headerDec := buf[:n]
buf = buf[n:]
var header = jwtHeader{}
var json = jsoniter.ConfigCompatibleWithStandardLibrary
if err = json.Unmarshal(buf[:n], &header); err != nil {
alg, _, _, err := jsonparser.Get(headerDec, "alg")
if err != nil {
return nil, &jwtgo.ValidationError{Inner: err, Errors: jwtgo.ValidationErrorMalformed}
}
@ -315,17 +380,17 @@ func ParseUnverifiedStandardClaims(token []byte, claims *StandardClaims, buf []b
return nil, &jwtgo.ValidationError{Inner: err, Errors: jwtgo.ValidationErrorMalformed}
}
if err = json.Unmarshal(buf[:n], claims); err != nil {
if err = claims.UnmarshalJSON(buf[:n]); err != nil {
return nil, &jwtgo.ValidationError{Inner: err, Errors: jwtgo.ValidationErrorMalformed}
}
for _, signer := range hmacSigners {
if header.Algorithm == signer.Name {
if string(alg) == signer.Name {
return signer, nil
}
}
return nil, jwtgo.NewValidationError(fmt.Sprintf("signing method (%s) is unavailable.", header.Algorithm),
return nil, jwtgo.NewValidationError(fmt.Sprintf("signing method (%s) is unavailable.", string(alg)),
jwtgo.ValidationErrorUnverifiable)
}
@ -416,9 +481,10 @@ func ParseUnverifiedMapClaims(token []byte, claims *MapClaims, buf []byte) (*Sig
return nil, &jwtgo.ValidationError{Inner: err, Errors: jwtgo.ValidationErrorMalformed}
}
var header = jwtHeader{}
var json = jsoniter.ConfigCompatibleWithStandardLibrary
if err = json.Unmarshal(buf[:n], &header); err != nil {
headerDec := buf[:n]
buf = buf[n:]
alg, _, _, err := jsonparser.Get(headerDec, "alg")
if err != nil {
return nil, &jwtgo.ValidationError{Inner: err, Errors: jwtgo.ValidationErrorMalformed}
}
@ -427,16 +493,17 @@ func ParseUnverifiedMapClaims(token []byte, claims *MapClaims, buf []byte) (*Sig
return nil, &jwtgo.ValidationError{Inner: err, Errors: jwtgo.ValidationErrorMalformed}
}
var json = jsoniter.ConfigCompatibleWithStandardLibrary
if err = json.Unmarshal(buf[:n], &claims.MapClaims); err != nil {
return nil, &jwtgo.ValidationError{Inner: err, Errors: jwtgo.ValidationErrorMalformed}
}
for _, signer := range hmacSigners {
if header.Algorithm == signer.Name {
if string(alg) == signer.Name {
return signer, nil
}
}
return nil, jwtgo.NewValidationError(fmt.Sprintf("signing method (%s) is unavailable.", header.Algorithm),
return nil, jwtgo.NewValidationError(fmt.Sprintf("signing method (%s) is unavailable.", string(alg)),
jwtgo.ValidationErrorUnverifiable)
}