simplify JWKS decoding in OpenID and more tests (#10119)

add tests for non-compliant Azure AD behavior
with "nonce" to fail properly and treat it as
expected behavior for non-standard JWT tokens.
This commit is contained in:
Harshavardhana
2020-07-25 08:42:41 -07:00
committed by GitHub
parent 416ec316bd
commit abbf6ce6cc
4 changed files with 94 additions and 52 deletions

View File

@@ -22,11 +22,9 @@ import (
"crypto/elliptic"
"crypto/rsa"
"encoding/base64"
"encoding/binary"
"errors"
"fmt"
"math/big"
"strings"
)
// JWKS - https://tools.ietf.org/html/rfc7517
@@ -47,15 +45,6 @@ type JWKS struct {
K string `json:"k,omitempty"`
}
func safeDecode(str string) ([]byte, error) {
lenMod4 := len(str) % 4
if lenMod4 > 0 {
str = str + strings.Repeat("=", 4-lenMod4)
}
return base64.URLEncoding.DecodeString(str)
}
var (
errMalformedJWKRSAKey = errors.New("malformed JWK RSA key")
errMalformedJWKECKey = errors.New("malformed JWK EC key")
@@ -70,29 +59,24 @@ func (key *JWKS) DecodePublicKey() (crypto.PublicKey, error) {
}
// decode exponent
data, err := safeDecode(key.E)
ebuf, err := base64.RawURLEncoding.DecodeString(key.E)
if err != nil {
return nil, errMalformedJWKRSAKey
}
if len(data) < 4 {
ndata := make([]byte, 4)
copy(ndata[4-len(data):], data)
data = ndata
}
pubKey := &rsa.PublicKey{
N: &big.Int{},
E: int(binary.BigEndian.Uint32(data[:])),
}
data, err = safeDecode(key.N)
nbuf, err := base64.RawURLEncoding.DecodeString(key.N)
if err != nil {
return nil, errMalformedJWKRSAKey
}
pubKey.N.SetBytes(data)
return pubKey, nil
var n, e big.Int
n.SetBytes(nbuf)
e.SetBytes(ebuf)
return &rsa.PublicKey{
E: int(e.Int64()),
N: &n,
}, nil
case "EC":
if key.Crv == "" || key.X == "" || key.Y == "" {
return nil, errMalformedJWKECKey
@@ -112,25 +96,25 @@ func (key *JWKS) DecodePublicKey() (crypto.PublicKey, error) {
return nil, fmt.Errorf("Unknown curve type: %s", key.Crv)
}
pubKey := &ecdsa.PublicKey{
xbuf, err := base64.RawURLEncoding.DecodeString(key.X)
if err != nil {
return nil, errMalformedJWKECKey
}
ybuf, err := base64.RawURLEncoding.DecodeString(key.Y)
if err != nil {
return nil, errMalformedJWKECKey
}
var x, y big.Int
x.SetBytes(xbuf)
y.SetBytes(ybuf)
return &ecdsa.PublicKey{
Curve: curve,
X: &big.Int{},
Y: &big.Int{},
}
data, err := safeDecode(key.X)
if err != nil {
return nil, errMalformedJWKECKey
}
pubKey.X.SetBytes(data)
data, err = safeDecode(key.Y)
if err != nil {
return nil, errMalformedJWKECKey
}
pubKey.Y.SetBytes(data)
return pubKey, nil
X: &x,
Y: &y,
}, nil
default:
return nil, fmt.Errorf("Unknown JWK key type %s", key.Kty)
}