diff --git a/go.mod b/go.mod index 24f3917b4..9f4c75350 100644 --- a/go.mod +++ b/go.mod @@ -104,6 +104,7 @@ require ( github.com/beorn7/perks v1.0.1 // indirect github.com/bits-and-blooms/bitset v1.2.0 // indirect github.com/briandowns/spinner v1.16.0 // indirect + github.com/buger/jsonparser v1.1.1 // indirect github.com/coreos/go-semver v0.3.0 // indirect github.com/coreos/go-systemd/v22 v22.3.2 // indirect github.com/davecgh/go-spew v1.1.1 // indirect diff --git a/go.sum b/go.sum index 4f727d6d2..9f0c09350 100644 --- a/go.sum +++ b/go.sum @@ -221,6 +221,8 @@ github.com/bombsimon/wsl/v3 v3.0.0/go.mod h1:st10JtZYLE4D5sC7b8xV4zTKZwAQjCH/Hy2 github.com/bombsimon/wsl/v3 v3.1.0/go.mod h1:st10JtZYLE4D5sC7b8xV4zTKZwAQjCH/Hy2Pm1FNZIc= github.com/briandowns/spinner v1.16.0 h1:DFmp6hEaIx2QXXuqSJmtfSBSAjRmpGiKG6ip2Wm/yOs= github.com/briandowns/spinner v1.16.0/go.mod h1:QOuQk7x+EaDASo80FEXwlwiA+j/PPIcX3FScO+3/ZPQ= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/caarlos0/ctrlc v1.0.0/go.mod h1:CdXpj4rmq0q/1Eb44M9zi2nKB0QraNKuRGYGrrHhcQw= github.com/campoy/unique v0.0.0-20180121183637-88950e537e7e/go.mod h1:9IOqJGCPMSc6E5ydlp5NIonxObaeu/Iub/X03EKPVYo= github.com/casbin/casbin/v2 v2.1.2/go.mod h1:YcPU1XXisHhLzuxH9coDNf2FbKpjGlbCg3n9yuLkIJQ= diff --git a/internal/jwt/parser.go b/internal/jwt/parser.go index 17248eef7..30b416755 100644 --- a/internal/jwt/parser.go +++ b/internal/jwt/parser.go @@ -27,11 +27,13 @@ import ( "crypto" "crypto/hmac" "encoding/base64" + "errors" "fmt" "hash" "sync" "time" + "github.com/buger/jsonparser" jwtgo "github.com/golang-jwt/jwt/v4" jsoniter "github.com/json-iterator/go" ) @@ -112,6 +114,74 @@ type StandardClaims struct { jwtgo.StandardClaims } +// UnmarshalJSON provides custom JSON unmarshal. +// This is mainly implemented for speed. +func (c *StandardClaims) UnmarshalJSON(b []byte) (err error) { + return jsonparser.ObjectEach(b, func(key []byte, value []byte, dataType jsonparser.ValueType, _ int) error { + if len(key) == 0 { + return nil + } + switch key[0] { + case 'a': + if string(key) == "accessKey" { + if dataType != jsonparser.String { + return errors.New("accessKey: Expected string") + } + c.AccessKey, err = jsonparser.ParseString(value) + return err + } + if string(key) == "aud" { + if dataType != jsonparser.String { + return errors.New("aud: Expected string") + } + c.Audience, err = jsonparser.ParseString(value) + return err + } + case 'e': + if string(key) == "exp" { + if dataType != jsonparser.Number { + return errors.New("exp: Expected number") + } + c.ExpiresAt, err = jsonparser.ParseInt(value) + return err + } + case 'i': + if string(key) == "iat" { + if dataType != jsonparser.Number { + return errors.New("exp: Expected number") + } + c.IssuedAt, err = jsonparser.ParseInt(value) + return err + } + if string(key) == "iss" { + if dataType != jsonparser.String { + return errors.New("iss: Expected string") + } + c.Issuer, err = jsonparser.ParseString(value) + return err + } + case 'n': + if string(key) == "nbf" { + if dataType != jsonparser.Number { + return errors.New("nbf: Expected number") + } + c.NotBefore, err = jsonparser.ParseInt(value) + return err + } + case 's': + if string(key) == "sub" { + if dataType != jsonparser.String { + return errors.New("sub: Expected string") + } + c.Subject, err = jsonparser.ParseString(value) + return err + } + } + // Ignore unknown fields + return nil + }) +} + // MapClaims - implements custom unmarshaller type MapClaims struct { AccessKey string `json:"accessKey,omitempty"` @@ -279,12 +349,6 @@ func ParseWithStandardClaims(tokenStr string, claims *StandardClaims, key []byte return claims.Valid() } -// https://tools.ietf.org/html/rfc7519#page-11 -type jwtHeader struct { - Algorithm string `json:"alg"` - Type string `json:"typ"` -} - // ParseUnverifiedStandardClaims - WARNING: Don't use this method unless you know what you're doing // // This method parses the token but doesn't validate the signature. It's only @@ -303,10 +367,11 @@ func ParseUnverifiedStandardClaims(token []byte, claims *StandardClaims, buf []b if err != nil { return nil, &jwtgo.ValidationError{Inner: err, Errors: jwtgo.ValidationErrorMalformed} } + headerDec := buf[:n] + buf = buf[n:] - var header = jwtHeader{} - var json = jsoniter.ConfigCompatibleWithStandardLibrary - if err = json.Unmarshal(buf[:n], &header); err != nil { + alg, _, _, err := jsonparser.Get(headerDec, "alg") + if err != nil { return nil, &jwtgo.ValidationError{Inner: err, Errors: jwtgo.ValidationErrorMalformed} } @@ -315,17 +380,17 @@ func ParseUnverifiedStandardClaims(token []byte, claims *StandardClaims, buf []b return nil, &jwtgo.ValidationError{Inner: err, Errors: jwtgo.ValidationErrorMalformed} } - if err = json.Unmarshal(buf[:n], claims); err != nil { + if err = claims.UnmarshalJSON(buf[:n]); err != nil { return nil, &jwtgo.ValidationError{Inner: err, Errors: jwtgo.ValidationErrorMalformed} } for _, signer := range hmacSigners { - if header.Algorithm == signer.Name { + if string(alg) == signer.Name { return signer, nil } } - return nil, jwtgo.NewValidationError(fmt.Sprintf("signing method (%s) is unavailable.", header.Algorithm), + return nil, jwtgo.NewValidationError(fmt.Sprintf("signing method (%s) is unavailable.", string(alg)), jwtgo.ValidationErrorUnverifiable) } @@ -416,9 +481,10 @@ func ParseUnverifiedMapClaims(token []byte, claims *MapClaims, buf []byte) (*Sig return nil, &jwtgo.ValidationError{Inner: err, Errors: jwtgo.ValidationErrorMalformed} } - var header = jwtHeader{} - var json = jsoniter.ConfigCompatibleWithStandardLibrary - if err = json.Unmarshal(buf[:n], &header); err != nil { + headerDec := buf[:n] + buf = buf[n:] + alg, _, _, err := jsonparser.Get(headerDec, "alg") + if err != nil { return nil, &jwtgo.ValidationError{Inner: err, Errors: jwtgo.ValidationErrorMalformed} } @@ -427,16 +493,17 @@ func ParseUnverifiedMapClaims(token []byte, claims *MapClaims, buf []byte) (*Sig return nil, &jwtgo.ValidationError{Inner: err, Errors: jwtgo.ValidationErrorMalformed} } + var json = jsoniter.ConfigCompatibleWithStandardLibrary if err = json.Unmarshal(buf[:n], &claims.MapClaims); err != nil { return nil, &jwtgo.ValidationError{Inner: err, Errors: jwtgo.ValidationErrorMalformed} } for _, signer := range hmacSigners { - if header.Algorithm == signer.Name { + if string(alg) == signer.Name { return signer, nil } } - return nil, jwtgo.NewValidationError(fmt.Sprintf("signing method (%s) is unavailable.", header.Algorithm), + return nil, jwtgo.NewValidationError(fmt.Sprintf("signing method (%s) is unavailable.", string(alg)), jwtgo.ValidationErrorUnverifiable) }