support implicit flow in web-identity.go example (#12600)

when a client secret is not provided,
automatically assume implicit flow
for authentication and invoke
relevant code accordingly.
This commit is contained in:
Harshavardhana 2021-06-30 07:43:04 -07:00 committed by GitHub
parent 4575291f8a
commit 3137dc2eb3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -20,6 +20,7 @@
package main package main
import ( import (
"bytes"
"context" "context"
"crypto/rand" "crypto/rand"
"encoding/base64" "encoding/base64"
@ -108,9 +109,34 @@ func init() {
flag.IntVar(&port, "port", 8080, "Port") flag.IntVar(&port, "port", 8080, "Port")
} }
func implicitFlowURL(c oauth2.Config, state string) string {
var buf bytes.Buffer
buf.WriteString(c.Endpoint.AuthURL)
v := url.Values{
"response_type": {"id_token"},
"response_mode": {"form_post"},
"client_id": {c.ClientID},
}
if c.RedirectURL != "" {
v.Set("redirect_uri", c.RedirectURL)
}
if len(c.Scopes) > 0 {
v.Set("scope", strings.Join(c.Scopes, " "))
}
v.Set("state", state)
v.Set("nonce", state)
if strings.Contains(c.Endpoint.AuthURL, "?") {
buf.WriteByte('&')
} else {
buf.WriteByte('?')
}
buf.WriteString(v.Encode())
return buf.String()
}
func main() { func main() {
flag.Parse() flag.Parse()
if clientID == "" || clientSec == "" { if clientID == "" {
flag.PrintDefaults() flag.PrintDefaults()
return return
} }
@ -148,29 +174,47 @@ func main() {
http.NotFound(w, r) http.NotFound(w, r)
return return
} }
http.Redirect(w, r, config.AuthCodeURL(state), http.StatusFound) if clientSec != "" {
http.Redirect(w, r, config.AuthCodeURL(state), http.StatusFound)
} else {
http.Redirect(w, r, implicitFlowURL(config, state), http.StatusFound)
}
}) })
http.HandleFunc("/oauth2/callback", func(w http.ResponseWriter, r *http.Request) { http.HandleFunc("/oauth2/callback", func(w http.ResponseWriter, r *http.Request) {
log.Printf("%s %s", r.Method, r.RequestURI) log.Printf("%s %s", r.Method, r.RequestURI)
if r.URL.Query().Get("state") != state {
if err := r.ParseForm(); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if r.Form.Get("state") != state {
http.Error(w, "state did not match", http.StatusBadRequest) http.Error(w, "state did not match", http.StatusBadRequest)
return return
} }
getWebTokenExpiry := func() (*credentials.WebIdentityToken, error) { var getWebTokenExpiry func() (*credentials.WebIdentityToken, error)
oauth2Token, err := config.Exchange(ctx, r.URL.Query().Get("code")) if clientSec == "" {
if err != nil { getWebTokenExpiry = func() (*credentials.WebIdentityToken, error) {
return nil, err return &credentials.WebIdentityToken{
} Token: r.Form.Get("id_token"),
if !oauth2Token.Valid() { }, nil
return nil, errors.New("invalid token")
} }
} else {
getWebTokenExpiry = func() (*credentials.WebIdentityToken, error) {
oauth2Token, err := config.Exchange(ctx, r.URL.Query().Get("code"))
if err != nil {
return nil, err
}
if !oauth2Token.Valid() {
return nil, errors.New("invalid token")
}
return &credentials.WebIdentityToken{ return &credentials.WebIdentityToken{
Token: oauth2Token.Extra("id_token").(string), Token: oauth2Token.Extra("id_token").(string),
Expiry: int(oauth2Token.Expiry.Sub(time.Now().UTC()).Seconds()), Expiry: int(oauth2Token.Expiry.Sub(time.Now().UTC()).Seconds()),
}, nil }, nil
}
} }
sts, err := credentials.NewSTSWebIdentity(stsEndpoint, getWebTokenExpiry) sts, err := credentials.NewSTSWebIdentity(stsEndpoint, getWebTokenExpiry)