diff --git a/docs/sts/web-identity.go b/docs/sts/web-identity.go index efb921517..23c16e800 100644 --- a/docs/sts/web-identity.go +++ b/docs/sts/web-identity.go @@ -20,6 +20,7 @@ package main import ( + "bytes" "context" "crypto/rand" "encoding/base64" @@ -108,9 +109,34 @@ func init() { flag.IntVar(&port, "port", 8080, "Port") } +func implicitFlowURL(c oauth2.Config, state string) string { + var buf bytes.Buffer + buf.WriteString(c.Endpoint.AuthURL) + v := url.Values{ + "response_type": {"id_token"}, + "response_mode": {"form_post"}, + "client_id": {c.ClientID}, + } + if c.RedirectURL != "" { + v.Set("redirect_uri", c.RedirectURL) + } + if len(c.Scopes) > 0 { + v.Set("scope", strings.Join(c.Scopes, " ")) + } + v.Set("state", state) + v.Set("nonce", state) + if strings.Contains(c.Endpoint.AuthURL, "?") { + buf.WriteByte('&') + } else { + buf.WriteByte('?') + } + buf.WriteString(v.Encode()) + return buf.String() +} + func main() { flag.Parse() - if clientID == "" || clientSec == "" { + if clientID == "" { flag.PrintDefaults() return } @@ -148,29 +174,47 @@ func main() { http.NotFound(w, r) return } - http.Redirect(w, r, config.AuthCodeURL(state), http.StatusFound) + if clientSec != "" { + http.Redirect(w, r, config.AuthCodeURL(state), http.StatusFound) + } else { + http.Redirect(w, r, implicitFlowURL(config, state), http.StatusFound) + } }) http.HandleFunc("/oauth2/callback", func(w http.ResponseWriter, r *http.Request) { log.Printf("%s %s", r.Method, r.RequestURI) - if r.URL.Query().Get("state") != state { + + if err := r.ParseForm(); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if r.Form.Get("state") != state { http.Error(w, "state did not match", http.StatusBadRequest) return } - getWebTokenExpiry := func() (*credentials.WebIdentityToken, error) { - oauth2Token, err := config.Exchange(ctx, r.URL.Query().Get("code")) - if err != nil { - return nil, err - } - if !oauth2Token.Valid() { - return nil, errors.New("invalid token") + var getWebTokenExpiry func() (*credentials.WebIdentityToken, error) + if clientSec == "" { + getWebTokenExpiry = func() (*credentials.WebIdentityToken, error) { + return &credentials.WebIdentityToken{ + Token: r.Form.Get("id_token"), + }, nil } + } else { + getWebTokenExpiry = func() (*credentials.WebIdentityToken, error) { + oauth2Token, err := config.Exchange(ctx, r.URL.Query().Get("code")) + if err != nil { + return nil, err + } + if !oauth2Token.Valid() { + return nil, errors.New("invalid token") + } - return &credentials.WebIdentityToken{ - Token: oauth2Token.Extra("id_token").(string), - Expiry: int(oauth2Token.Expiry.Sub(time.Now().UTC()).Seconds()), - }, nil + return &credentials.WebIdentityToken{ + Token: oauth2Token.Extra("id_token").(string), + Expiry: int(oauth2Token.Expiry.Sub(time.Now().UTC()).Seconds()), + }, nil + } } sts, err := credentials.NewSTSWebIdentity(stsEndpoint, getWebTokenExpiry)