diff --git a/cmd/crypto/kes.go b/cmd/crypto/kes.go index 9ac01b3bf..0d6323173 100644 --- a/cmd/crypto/kes.go +++ b/cmd/crypto/kes.go @@ -16,6 +16,7 @@ package crypto import ( "bytes" + "context" "crypto/tls" "crypto/x509" "errors" @@ -115,15 +116,29 @@ func NewKes(cfg KesConfig) (KMS, error) { if err != nil { return nil, err } - certPool, err := loadCACertificates(cfg.CAPath) - if err != nil { - return nil, err + if cfg.Transport.TLSClientConfig != nil { + if err = loadCACertificates(cfg.CAPath, + cfg.Transport.TLSClientConfig.RootCAs); err != nil { + return nil, err + } + } else { + rootCAs, _ := x509.SystemCertPool() + if rootCAs == nil { + // In some systems (like Windows) system cert pool is + // not supported or no certificates are present on the + // system - so we create a new cert pool. + rootCAs = x509.NewCertPool() + } + if err = loadCACertificates(cfg.CAPath, rootCAs); err != nil { + return nil, err + } + cfg.Transport.TLSClientConfig = &tls.Config{ + RootCAs: rootCAs, + } } - cfg.Transport.TLSClientConfig = &tls.Config{ - Certificates: []tls.Certificate{cert}, - RootCAs: certPool, - } - cfg.Transport.ForceAttemptHTTP2 = true + cfg.Transport.TLSClientConfig.Certificates = []tls.Certificate{cert} + cfg.Transport.TLSClientConfig.NextProtos = []string{"h2"} + return &kesService{ client: &kesClient{ addr: cfg.Endpoint, @@ -359,7 +374,16 @@ func parseErrorResponse(resp *http.Response) error { } func (c *kesClient) post(url string, body io.Reader, limit int64) (io.Reader, error) { - resp, err := c.httpClient.Post(url, "application/json", body) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, body) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := c.httpClient.Do(req) if err != nil { return nil, err } @@ -388,7 +412,10 @@ func (c *kesClient) postRetry(url string, body io.ReadSeeker, limit int64) (io.R return response, nil } - if !xnet.IsNetworkOrHostDown(err) && !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) { + if !xnet.IsNetworkOrHostDown(err) && + !errors.Is(err, io.EOF) && + !errors.Is(err, io.ErrUnexpectedEOF) && + !errors.Is(err, context.DeadlineExceeded) { return nil, err } @@ -415,24 +442,17 @@ func (c *kesClient) postRetry(url string, body io.ReadSeeker, limit int64) (io.R // file as PEM-encoded certificate and add it to // the CertPool. If a file is not a PEM certificate // it will be ignored. -func loadCACertificates(path string) (*x509.CertPool, error) { - rootCAs, _ := x509.SystemCertPool() - if rootCAs == nil { - // In some systems (like Windows) system cert pool is - // not supported or no certificates are present on the - // system - so we create a new cert pool. - rootCAs = x509.NewCertPool() - } +func loadCACertificates(path string, rootCAs *x509.CertPool) error { if path == "" { - return rootCAs, nil + return nil } stat, err := os.Stat(path) if err != nil { if os.IsNotExist(err) || os.IsPermission(err) { - return rootCAs, nil + return nil } - return nil, Errorf("crypto: cannot open '%s': %v", path, err) + return Errorf("crypto: cannot open '%s': %v", path, err) } // If path is a file, parse as PEM-encoded certifcate @@ -441,12 +461,12 @@ func loadCACertificates(path string) (*x509.CertPool, error) { if !stat.IsDir() { cert, err := ioutil.ReadFile(path) if err != nil { - return nil, err + return err } if !rootCAs.AppendCertsFromPEM(cert) { - return nil, Errorf("crypto: '%s' is not a valid PEM-encoded certificate", path) + return Errorf("crypto: '%s' is not a valid PEM-encoded certificate", path) } - return rootCAs, nil + return nil } // If path is a directory then try @@ -456,7 +476,7 @@ func loadCACertificates(path string) (*x509.CertPool, error) { // we ignore it. files, err := ioutil.ReadDir(path) if err != nil { - return nil, err + return err } for _, file := range files { cert, err := ioutil.ReadFile(filepath.Join(path, file.Name())) @@ -465,6 +485,6 @@ func loadCACertificates(path string) (*x509.CertPool, error) { } rootCAs.AppendCertsFromPEM(cert) // ignore files which are not PEM certtificates } - return rootCAs, nil + return nil }