// MinIO Cloud Storage, (C) 2019-2020 MinIO, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//    http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package crypto

import (
	"bytes"
	"context"
	"crypto/tls"
	"crypto/x509"
	"errors"
	"fmt"
	"io"
	"io/ioutil"
	"net/http"
	"net/url"
	"os"
	"path/filepath"
	"strings"
	"time"

	jsoniter "github.com/json-iterator/go"
	xhttp "github.com/minio/minio/cmd/http"
	xnet "github.com/minio/minio/pkg/net"
)

var json = jsoniter.ConfigCompatibleWithStandardLibrary

// ErrKESKeyExists is the error returned a KES server
// when a master key does exist.
var ErrKESKeyExists = NewKESError(http.StatusBadRequest, "key does already exist")

// KesConfig contains the configuration required
// to initialize and connect to a kes server.
type KesConfig struct {
	Enabled bool

	// The KES server endpoints.
	Endpoint []string

	// The path to the TLS private key used
	// by MinIO to authenticate to the kes
	// server during the TLS handshake (mTLS).
	KeyFile string

	// The path to the TLS certificate used
	// by MinIO to authenticate to the kes
	// server during the TLS handshake (mTLS).
	//
	// The kes server will also allow or deny
	// access based on this certificate.
	// In particular, the kes server will
	// lookup the policy that corresponds to
	// the identity in this certificate.
	CertFile string

	// Path to a file or directory containing
	// the CA certificate(s) that issued / will
	// issue certificates for the kes server.
	//
	// This is required if the TLS certificate
	// of the kes server has not been issued
	// (e.g. b/c it's self-signed) by a CA that
	// MinIO trusts.
	CAPath string

	// The default key ID returned by KMS.KeyID().
	DefaultKeyID string

	// The HTTP transport configuration for
	// the KES client.
	Transport *http.Transport
}

// Verify verifies if the kes configuration is correct
func (k KesConfig) Verify() (err error) {
	switch {
	case len(k.Endpoint) == 0:
		err = Errorf("crypto: missing kes endpoint")
	case k.CertFile == "":
		err = Errorf("crypto: missing cert file")
	case k.KeyFile == "":
		err = Errorf("crypto: missing key file")
	case k.DefaultKeyID == "":
		err = Errorf("crypto: missing default key id")
	}
	return err
}

type kesService struct {
	client *kesClient

	endpoints    []string
	defaultKeyID string
}

// NewKes returns a new kes KMS client. The returned KMS
// uses the X.509 certificate to authenticate itself to
// the kes server available at address.
//
// The defaultKeyID is the key ID returned when calling
// KMS.KeyID().
func NewKes(cfg KesConfig) (KMS, error) {
	cert, err := tls.LoadX509KeyPair(cfg.CertFile, cfg.KeyFile)
	if err != nil {
		return nil, err
	}
	if cfg.Transport.TLSClientConfig != nil {
		if err = loadCACertificates(cfg.CAPath,
			cfg.Transport.TLSClientConfig.RootCAs); err != nil {
			return nil, err
		}
	} else {
		rootCAs, _ := x509.SystemCertPool()
		if rootCAs == nil {
			// In some systems (like Windows) system cert pool is
			// not supported or no certificates are present on the
			// system - so we create a new cert pool.
			rootCAs = x509.NewCertPool()
		}
		if err = loadCACertificates(cfg.CAPath, rootCAs); err != nil {
			return nil, err
		}
		cfg.Transport.TLSClientConfig = &tls.Config{
			RootCAs: rootCAs,
		}
	}
	cfg.Transport.TLSClientConfig.Certificates = []tls.Certificate{cert}
	cfg.Transport.TLSClientConfig.NextProtos = []string{"h2"}

	return &kesService{
		client: &kesClient{
			endpoints: cfg.Endpoint,
			httpClient: http.Client{
				Transport: cfg.Transport,
			},
		},
		endpoints:    cfg.Endpoint,
		defaultKeyID: cfg.DefaultKeyID,
	}, nil
}

// DefaultKeyID returns the default key ID that should be
// used for SSE-S3 or SSE-KMS when the S3 client does not
// provide an explicit key ID.
func (kes *kesService) DefaultKeyID() string {
	return kes.defaultKeyID
}

// Info returns some information about the KES,
// configuration - like the endpoint or authentication
// method.
func (kes *kesService) Info() KMSInfo {
	return KMSInfo{
		Endpoints: kes.endpoints,
		Name:      kes.DefaultKeyID(),
		AuthType:  "TLS",
	}
}

// CreateKey tries to create a new master key with the given keyID.
func (kes *kesService) CreateKey(keyID string) error { return kes.client.CreateKey(keyID) }

// GenerateKey returns a new plaintext key, generated by the KMS,
// and a sealed version of this plaintext key encrypted using the
// named key referenced by keyID. It also binds the generated key
// cryptographically to the provided context.
func (kes *kesService) GenerateKey(keyID string, ctx Context) (key [32]byte, sealedKey []byte, err error) {
	var context bytes.Buffer
	ctx.WriteTo(&context)

	var plainKey []byte
	plainKey, sealedKey, err = kes.client.GenerateDataKey(keyID, context.Bytes())
	if err != nil {
		return key, nil, err
	}
	if len(plainKey) != len(key) {
		return key, nil, Errorf("crypto: received invalid plaintext key size from KMS")
	}
	copy(key[:], plainKey)
	return key, sealedKey, nil
}

// UnsealKey returns the decrypted sealedKey as plaintext key.
// Therefore it sends the sealedKey to the KMS which decrypts
// it using the named key referenced by keyID and responses with
// the plaintext key.
//
// The context must be same context as the one provided while
// generating the plaintext key / sealedKey.
func (kes *kesService) UnsealKey(keyID string, sealedKey []byte, ctx Context) (key [32]byte, err error) {
	var context bytes.Buffer
	ctx.WriteTo(&context)

	var plainKey []byte
	plainKey, err = kes.client.DecryptDataKey(keyID, sealedKey, context.Bytes())
	if err != nil {
		return key, err
	}
	if len(plainKey) != len(key) {
		return key, Errorf("crypto: received invalid plaintext key size from KMS")
	}
	copy(key[:], plainKey)
	return key, nil
}

// kesClient implements the bare minimum functionality needed for
// MinIO to talk to a KES server. In particular, it implements
//   • CreateKey       (API: /v1/key/create/)
//   • GenerateDataKey (API: /v1/key/generate/)
//   • DecryptDataKey  (API: /v1/key/decrypt/)
type kesClient struct {
	endpoints  []string
	httpClient http.Client
}

// CreateKey tries to create a new cryptographic key with
// the specified name.
//
// The key will be generated by the server. The client
// application does not have the cryptographic key at
// any point in time.
func (c *kesClient) CreateKey(name string) error {
	path := fmt.Sprintf("/v1/key/create/%s", url.PathEscape(name))
	_, err := c.postRetry(path, nil, 0) // No request body and no response expected
	if err != nil {
		return err
	}
	return nil
}

// GenerateDataKey requests a new data key from the KES server.
// On success, the KES server will respond with the plaintext key
// and the ciphertext key as the plaintext key encrypted with
// the key specified by name.
//
// The optional context is crytpo. bound to the generated data key
// such that you have to provide the same context when decrypting
// the data key.
func (c *kesClient) GenerateDataKey(name string, context []byte) ([]byte, []byte, error) {
	type Request struct {
		Context []byte `json:"context"`
	}
	type Response struct {
		Plaintext  []byte `json:"plaintext"`
		Ciphertext []byte `json:"ciphertext"`
	}

	body, err := json.Marshal(Request{
		Context: context,
	})
	if err != nil {
		return nil, nil, err
	}

	const limit = 1 << 20 // A plaintext/ciphertext key pair will never be larger than 1 MB
	path := fmt.Sprintf("/v1/key/generate/%s", url.PathEscape(name))
	resp, err := c.postRetry(path, bytes.NewReader(body), limit)
	if err != nil {
		return nil, nil, err
	}

	var response Response
	if err = json.NewDecoder(resp).Decode(&response); err != nil {
		return nil, nil, err
	}
	return response.Plaintext, response.Ciphertext, nil
}

// GenerateDataKey decrypts an encrypted data key with the key
// specified by name by talking to the KES server.
// On success, the KES server will respond with the plaintext key.
//
// The optional context must match the value you provided when
// generating the data key.
func (c *kesClient) DecryptDataKey(name string, ciphertext, context []byte) ([]byte, error) {
	type Request struct {
		Ciphertext []byte `json:"ciphertext"`
		Context    []byte `json:"context,omitempty"`
	}
	type Response struct {
		Plaintext []byte `json:"plaintext"`
	}

	body, err := json.Marshal(Request{
		Ciphertext: ciphertext,
		Context:    context,
	})
	if err != nil {
		return nil, err
	}

	const limit = 1 << 20 // A data key will never be larger than 1 MiB
	path := fmt.Sprintf("/v1/key/decrypt/%s", url.PathEscape(name))
	resp, err := c.postRetry(path, bytes.NewReader(body), limit)
	if err != nil {
		return nil, err
	}

	var response Response
	if err = json.NewDecoder(resp).Decode(&response); err != nil {
		return nil, err
	}
	return response.Plaintext, nil
}

// NewKESError returns a new KES API error with the given
// HTTP status code and error message.
//
// Two errors with the same status code and
// error message are equal:
//   e1 == e2 // true.
func NewKESError(code int, text string) error {
	return kesError{
		code:    code,
		message: text,
	}
}

type kesError struct {
	code    int
	message string
}

// Status returns the HTTP status code of the error.
func (e kesError) Status() int { return e.code }

// Status returns the error message of the error.
func (e kesError) Error() string { return e.message }

func parseErrorResponse(resp *http.Response) error {
	if resp == nil || resp.StatusCode < 400 {
		return nil
	}
	if resp.Body == nil {
		return NewKESError(resp.StatusCode, "")
	}
	defer resp.Body.Close()

	const MaxBodySize = 1 << 20
	var size = resp.ContentLength
	if size < 0 || size > MaxBodySize {
		size = MaxBodySize
	}

	contentType := strings.TrimSpace(resp.Header.Get("Content-Type"))
	if strings.HasPrefix(contentType, "application/json") {
		type Response struct {
			Message string `json:"message"`
		}
		var response Response
		if err := json.NewDecoder(io.LimitReader(resp.Body, size)).Decode(&response); err != nil {
			return err
		}
		return NewKESError(resp.StatusCode, response.Message)
	}

	var sb strings.Builder
	if _, err := io.Copy(&sb, io.LimitReader(resp.Body, size)); err != nil {
		return err
	}
	return NewKESError(resp.StatusCode, sb.String())
}

func (c *kesClient) post(url string, body io.Reader, limit int64) (io.Reader, error) {
	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
	defer cancel()

	req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, body)
	if err != nil {
		return nil, err
	}
	req.Header.Set("Content-Type", "application/json")

	resp, err := c.httpClient.Do(req)
	if err != nil {
		return nil, err
	}
	// Drain the entire body to make sure we have re-use connections
	defer xhttp.DrainBody(resp.Body)

	if resp.StatusCode != http.StatusOK {
		return nil, parseErrorResponse(resp)
	}

	// We have to copy the response body due to draining.
	var respBody bytes.Buffer
	if _, err = io.Copy(&respBody, io.LimitReader(resp.Body, limit)); err != nil {
		return nil, err
	}
	return &respBody, nil
}

func (c *kesClient) postRetry(path string, body io.ReadSeeker, limit int64) (io.Reader, error) {
	retryMax := 1 + len(c.endpoints)
	for i := 0; ; i++ {
		if body != nil {
			body.Seek(0, io.SeekStart) // seek to the beginning of the body.
		}

		response, err := c.post(c.endpoints[i%len(c.endpoints)]+path, body, limit)
		if err == nil {
			return response, nil
		}

		// If the error is not temp. / retryable => fail the request immediately.
		if !xnet.IsNetworkOrHostDown(err) &&
			!errors.Is(err, io.EOF) &&
			!errors.Is(err, io.ErrUnexpectedEOF) &&
			!errors.Is(err, context.DeadlineExceeded) {
			return nil, err
		}
		if remain := retryMax - i; remain <= 0 { // Fail if we exceeded our retry limit.
			return response, err
		}

		// If there are more KES instances then skip waiting and
		// try the next endpoint directly.
		if i < len(c.endpoints) {
			continue
		}
		<-time.After(LinearJitterBackoff(retryWaitMin, retryWaitMax, i))
	}
}

// loadCACertificates returns a new CertPool
// that contains all system root CA certificates
// and any PEM-encoded certificate(s) found at
// path.
//
// If path is a file, loadCACertificates will
// try to parse it as PEM-encoded certificate.
// If this fails, it returns an error.
//
// If path is a directory it tries to parse each
// file as PEM-encoded certificate and add it to
// the CertPool. If a file is not a PEM certificate
// it will be ignored.
func loadCACertificates(path string, rootCAs *x509.CertPool) error {
	if path == "" {
		return nil
	}

	stat, err := os.Stat(path)
	if err != nil {
		if os.IsNotExist(err) || os.IsPermission(err) {
			return nil
		}
		return Errorf("crypto: cannot open '%s': %v", path, err)
	}

	// If path is a file, parse as PEM-encoded certifcate
	// and try to add it to the CertPool. If this fails
	// return an error.
	if !stat.IsDir() {
		cert, err := ioutil.ReadFile(path)
		if err != nil {
			return err
		}
		if !rootCAs.AppendCertsFromPEM(cert) {
			return Errorf("crypto: '%s' is not a valid PEM-encoded certificate", path)
		}
		return nil
	}

	// If path is a directory then try
	// to parse each file as PEM-encoded
	// certificate and add it to the CertPool.
	// If a file is not a PEM-encoded certificate
	// we ignore it.
	files, err := ioutil.ReadDir(path)
	if err != nil {
		return err
	}
	for _, file := range files {
		cert, err := ioutil.ReadFile(filepath.Join(path, file.Name()))
		if err != nil {
			continue // ignore files which are not readable
		}
		rootCAs.AppendCertsFromPEM(cert) // ignore files which are not PEM certtificates
	}
	return nil

}