// Copyright (c) 2015-2021 MinIO, Inc.
//
// This file is part of MinIO Object Storage stack
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program.  If not, see <http://www.gnu.org/licenses/>.

package crypto

import (
	"bytes"
	"context"
	"crypto/tls"
	"crypto/x509"
	"errors"
	"fmt"
	"io"
	"io/ioutil"
	"net/http"
	"net/url"
	"os"
	"path/filepath"
	"strings"
	"time"

	jsoniter "github.com/json-iterator/go"
	xhttp "github.com/minio/minio/cmd/http"
	"github.com/minio/minio/pkg/kms"
	xnet "github.com/minio/minio/pkg/net"
)

var json = jsoniter.ConfigCompatibleWithStandardLibrary

// ErrKESKeyExists is the error returned a KES server
// when a master key does exist.
var ErrKESKeyExists = NewKESError(http.StatusBadRequest, "key does already exist")

// KesConfig contains the configuration required
// to initialize and connect to a kes server.
type KesConfig struct {
	Enabled bool

	// The KES server endpoints.
	Endpoint []string

	// The path to the TLS private key used
	// by MinIO to authenticate to the kes
	// server during the TLS handshake (mTLS).
	KeyFile string

	// The path to the TLS certificate used
	// by MinIO to authenticate to the kes
	// server during the TLS handshake (mTLS).
	//
	// The kes server will also allow or deny
	// access based on this certificate.
	// In particular, the kes server will
	// lookup the policy that corresponds to
	// the identity in this certificate.
	CertFile string

	// Path to a file or directory containing
	// the CA certificate(s) that issued / will
	// issue certificates for the kes server.
	//
	// This is required if the TLS certificate
	// of the kes server has not been issued
	// (e.g. b/c it's self-signed) by a CA that
	// MinIO trusts.
	CAPath string

	// The default key ID returned by KMS.KeyID().
	DefaultKeyID string

	// The HTTP transport configuration for
	// the KES client.
	Transport *http.Transport
}

// Verify verifies if the kes configuration is correct
func (k KesConfig) Verify() (err error) {
	switch {
	case len(k.Endpoint) == 0:
		err = Errorf("crypto: missing kes endpoint")
	case k.CertFile == "":
		err = Errorf("crypto: missing cert file")
	case k.KeyFile == "":
		err = Errorf("crypto: missing key file")
	case k.DefaultKeyID == "":
		err = Errorf("crypto: missing default key id")
	}
	return err
}

type kesService struct {
	client *kesClient

	endpoints    []string
	defaultKeyID string
}

// NewKes returns a new kes KMS client. The returned KMS
// uses the X.509 certificate to authenticate itself to
// the kes server available at address.
//
// The defaultKeyID is the key ID returned when calling
// KMS.KeyID().
func NewKes(cfg KesConfig) (KMS, error) {
	cert, err := tls.LoadX509KeyPair(cfg.CertFile, cfg.KeyFile)
	if err != nil {
		return nil, err
	}

	if cfg.Transport.TLSClientConfig != nil && cfg.Transport.TLSClientConfig.RootCAs != nil {
		if err = loadCACertificates(cfg.CAPath, cfg.Transport.TLSClientConfig.RootCAs); err != nil {
			return nil, err
		}
	} else {
		rootCAs, _ := x509.SystemCertPool()
		if rootCAs == nil {
			// In some systems (like Windows) system cert pool is
			// not supported or no certificates are present on the
			// system - so we create a new cert pool.
			rootCAs = x509.NewCertPool()
		}
		if err = loadCACertificates(cfg.CAPath, rootCAs); err != nil {
			return nil, err
		}
		if cfg.Transport.TLSClientConfig == nil {
			cfg.Transport.TLSClientConfig = &tls.Config{
				RootCAs: rootCAs,
			}
		} else {
			cfg.Transport.TLSClientConfig.RootCAs = rootCAs
		}
	}
	cfg.Transport.TLSClientConfig.Certificates = []tls.Certificate{cert}
	cfg.Transport.TLSClientConfig.NextProtos = []string{"h2"}

	return &kesService{
		client: &kesClient{
			endpoints: cfg.Endpoint,
			httpClient: http.Client{
				Transport: cfg.Transport,
			},
		},
		endpoints:    cfg.Endpoint,
		defaultKeyID: cfg.DefaultKeyID,
	}, nil
}

func (kes *kesService) Stat() (kms.Status, error) {
	return kms.Status{
		Name:       "KES",
		Endpoints:  kes.endpoints,
		DefaultKey: kes.defaultKeyID,
	}, nil
}

// CreateKey tries to create a new master key with the given keyID.
func (kes *kesService) CreateKey(keyID string) error { return kes.client.CreateKey(keyID) }

// GenerateKey returns a new plaintext key, generated by the KMS,
// and a sealed version of this plaintext key encrypted using the
// named key referenced by keyID. It also binds the generated key
// cryptographically to the provided context.
func (kes *kesService) GenerateKey(keyID string, ctx Context) (kms.DEK, error) {
	if keyID == "" {
		keyID = kes.defaultKeyID
	}
	context, err := ctx.MarshalText()
	if err != nil {
		return kms.DEK{}, err
	}

	plaintext, ciphertext, err := kes.client.GenerateDataKey(keyID, context)
	if err != nil {
		return kms.DEK{}, err
	}
	return kms.DEK{
		KeyID:      keyID,
		Plaintext:  plaintext,
		Ciphertext: ciphertext,
	}, nil
}

// UnsealKey returns the decrypted sealedKey as plaintext key.
// Therefore it sends the sealedKey to the KMS which decrypts
// it using the named key referenced by keyID and responses with
// the plaintext key.
//
// The context must be same context as the one provided while
// generating the plaintext key / sealedKey.
func (kes *kesService) DecryptKey(keyID string, ciphertext []byte, ctx Context) ([]byte, error) {
	context, err := ctx.MarshalText()
	if err != nil {
		return nil, err
	}
	return kes.client.DecryptDataKey(keyID, ciphertext, context)
}

// kesClient implements the bare minimum functionality needed for
// MinIO to talk to a KES server. In particular, it implements
//   • CreateKey       (API: /v1/key/create/)
//   • GenerateDataKey (API: /v1/key/generate/)
//   • DecryptDataKey  (API: /v1/key/decrypt/)
type kesClient struct {
	endpoints  []string
	httpClient http.Client
}

// CreateKey tries to create a new cryptographic key with
// the specified name.
//
// The key will be generated by the server. The client
// application does not have the cryptographic key at
// any point in time.
func (c *kesClient) CreateKey(name string) error {
	path := fmt.Sprintf("/v1/key/create/%s", url.PathEscape(name))
	_, err := c.postRetry(path, nil, 0) // No request body and no response expected
	if err != nil {
		return err
	}
	return nil
}

// GenerateDataKey requests a new data key from the KES server.
// On success, the KES server will respond with the plaintext key
// and the ciphertext key as the plaintext key encrypted with
// the key specified by name.
//
// The optional context is crytpo. bound to the generated data key
// such that you have to provide the same context when decrypting
// the data key.
func (c *kesClient) GenerateDataKey(name string, context []byte) ([]byte, []byte, error) {
	type Request struct {
		Context []byte `json:"context"`
	}
	type Response struct {
		Plaintext  []byte `json:"plaintext"`
		Ciphertext []byte `json:"ciphertext"`
	}

	body, err := json.Marshal(Request{
		Context: context,
	})
	if err != nil {
		return nil, nil, err
	}

	const limit = 1 << 20 // A plaintext/ciphertext key pair will never be larger than 1 MB
	path := fmt.Sprintf("/v1/key/generate/%s", url.PathEscape(name))
	resp, err := c.postRetry(path, bytes.NewReader(body), limit)
	if err != nil {
		return nil, nil, err
	}

	var response Response
	if err = json.NewDecoder(resp).Decode(&response); err != nil {
		return nil, nil, err
	}
	return response.Plaintext, response.Ciphertext, nil
}

// GenerateDataKey decrypts an encrypted data key with the key
// specified by name by talking to the KES server.
// On success, the KES server will respond with the plaintext key.
//
// The optional context must match the value you provided when
// generating the data key.
func (c *kesClient) DecryptDataKey(name string, ciphertext, context []byte) ([]byte, error) {
	type Request struct {
		Ciphertext []byte `json:"ciphertext"`
		Context    []byte `json:"context,omitempty"`
	}
	type Response struct {
		Plaintext []byte `json:"plaintext"`
	}

	body, err := json.Marshal(Request{
		Ciphertext: ciphertext,
		Context:    context,
	})
	if err != nil {
		return nil, err
	}

	const limit = 1 << 20 // A data key will never be larger than 1 MiB
	path := fmt.Sprintf("/v1/key/decrypt/%s", url.PathEscape(name))
	resp, err := c.postRetry(path, bytes.NewReader(body), limit)
	if err != nil {
		return nil, err
	}

	var response Response
	if err = json.NewDecoder(resp).Decode(&response); err != nil {
		return nil, err
	}
	return response.Plaintext, nil
}

// NewKESError returns a new KES API error with the given
// HTTP status code and error message.
//
// Two errors with the same status code and
// error message are equal:
//   e1 == e2 // true.
func NewKESError(code int, text string) error {
	return kesError{
		code:    code,
		message: text,
	}
}

type kesError struct {
	code    int
	message string
}

// Status returns the HTTP status code of the error.
func (e kesError) Status() int { return e.code }

// Status returns the error message of the error.
func (e kesError) Error() string { return e.message }

func parseErrorResponse(resp *http.Response) error {
	if resp == nil || resp.StatusCode < 400 {
		return nil
	}
	if resp.Body == nil {
		return NewKESError(resp.StatusCode, "")
	}
	defer resp.Body.Close()

	const MaxBodySize = 1 << 20
	var size = resp.ContentLength
	if size < 0 || size > MaxBodySize {
		size = MaxBodySize
	}

	contentType := strings.TrimSpace(resp.Header.Get("Content-Type"))
	if strings.HasPrefix(contentType, "application/json") {
		type Response struct {
			Message string `json:"message"`
		}
		var response Response
		if err := json.NewDecoder(io.LimitReader(resp.Body, size)).Decode(&response); err != nil {
			return err
		}
		return NewKESError(resp.StatusCode, response.Message)
	}

	var sb strings.Builder
	if _, err := io.Copy(&sb, io.LimitReader(resp.Body, size)); err != nil {
		return err
	}
	return NewKESError(resp.StatusCode, sb.String())
}

func (c *kesClient) post(url string, body io.Reader, limit int64) (io.Reader, error) {
	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
	defer cancel()

	req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, body)
	if err != nil {
		return nil, err
	}
	req.Header.Set("Content-Type", "application/json")

	resp, err := c.httpClient.Do(req)
	if err != nil {
		return nil, err
	}
	// Drain the entire body to make sure we have re-use connections
	defer xhttp.DrainBody(resp.Body)

	if resp.StatusCode != http.StatusOK {
		return nil, parseErrorResponse(resp)
	}

	// We have to copy the response body due to draining.
	var respBody bytes.Buffer
	if _, err = io.Copy(&respBody, io.LimitReader(resp.Body, limit)); err != nil {
		return nil, err
	}
	return &respBody, nil
}

func (c *kesClient) postRetry(path string, body io.ReadSeeker, limit int64) (io.Reader, error) {
	retryMax := 1 + len(c.endpoints)
	for i := 0; ; i++ {
		if body != nil {
			body.Seek(0, io.SeekStart) // seek to the beginning of the body.
		}

		response, err := c.post(c.endpoints[i%len(c.endpoints)]+path, body, limit)
		if err == nil {
			return response, nil
		}

		// If the error is not temp. / retryable => fail the request immediately.
		if !xnet.IsNetworkOrHostDown(err, false) &&
			!errors.Is(err, io.EOF) &&
			!errors.Is(err, io.ErrUnexpectedEOF) &&
			!errors.Is(err, context.DeadlineExceeded) {
			return nil, err
		}
		if remain := retryMax - i; remain <= 0 { // Fail if we exceeded our retry limit.
			return response, err
		}

		// If there are more KES instances then skip waiting and
		// try the next endpoint directly.
		if i < len(c.endpoints) {
			continue
		}
		<-time.After(LinearJitterBackoff(retryWaitMin, retryWaitMax, i))
	}
}

// loadCACertificates returns a new CertPool
// that contains all system root CA certificates
// and any PEM-encoded certificate(s) found at
// path.
//
// If path is a file, loadCACertificates will
// try to parse it as PEM-encoded certificate.
// If this fails, it returns an error.
//
// If path is a directory it tries to parse each
// file as PEM-encoded certificate and add it to
// the CertPool. If a file is not a PEM certificate
// it will be ignored.
func loadCACertificates(path string, rootCAs *x509.CertPool) error {
	if path == "" {
		return nil
	}

	stat, err := os.Stat(path)
	if err != nil {
		if os.IsNotExist(err) || os.IsPermission(err) {
			return nil
		}
		return Errorf("crypto: cannot open '%s': %v", path, err)
	}

	// If path is a file, parse as PEM-encoded certifcate
	// and try to add it to the CertPool. If this fails
	// return an error.
	if !stat.IsDir() {
		cert, err := ioutil.ReadFile(path)
		if err != nil {
			return err
		}
		if !rootCAs.AppendCertsFromPEM(cert) {
			return Errorf("crypto: '%s' is not a valid PEM-encoded certificate", path)
		}
		return nil
	}

	// If path is a directory then try
	// to parse each file as PEM-encoded
	// certificate and add it to the CertPool.
	// If a file is not a PEM-encoded certificate
	// we ignore it.
	files, err := ioutil.ReadDir(path)
	if err != nil {
		return err
	}
	for _, file := range files {
		cert, err := ioutil.ReadFile(filepath.Join(path, file.Name()))
		if err != nil {
			continue // ignore files which are not readable
		}
		rootCAs.AppendCertsFromPEM(cert) // ignore files which are not PEM certtificates
	}
	return nil

}