// MinIO Cloud Storage, (C) 2019 MinIO, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//    http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package crypto

import (
	"bytes"
	"crypto/tls"
	"crypto/x509"
	"encoding/json"
	"fmt"
	"io"
	"io/ioutil"
	"net/http"
	"net/url"
	"os"
	"path/filepath"
	"strings"
)

// KesConfig contains the configuration required
// to initialize and connect to a kes server.
type KesConfig struct {
	Enabled bool

	// The kes server endpoint.
	Endpoint string

	// The path to the TLS private key used
	// by MinIO to authenticate to the kes
	// server during the TLS handshake (mTLS).
	KeyFile string

	// The path to the TLS certificate used
	// by MinIO to authenticate to the kes
	// server during the TLS handshake (mTLS).
	//
	// The kes server will also allow or deny
	// access based on this certificate.
	// In particular, the kes server will
	// lookup the policy that corresponds to
	// the identity in this certificate.
	CertFile string

	// Path to a file or directory containing
	// the CA certificate(s) that issued / will
	// issue certificates for the kes server.
	//
	// This is required if the TLS certificate
	// of the kes server has not been issued
	// (e.g. b/c it's self-signed) by a CA that
	// MinIO trusts.
	CAPath string

	// The default key ID returned by KMS.KeyID().
	DefaultKeyID string

	// The HTTP transport configuration for
	// the KES client.
	Transport *http.Transport
}

// Verify verifies if the kes configuration is correct
func (k KesConfig) Verify() (err error) {
	switch {
	case k.Endpoint == "":
		err = Errorf("crypto: missing kes endpoint")
	case k.CertFile == "":
		err = Errorf("crypto: missing cert file")
	case k.KeyFile == "":
		err = Errorf("crypto: missing key file")
	case k.DefaultKeyID == "":
		err = Errorf("crypto: missing default key id")
	}
	return err
}

type kesService struct {
	client *kesClient

	endpoint     string
	defaultKeyID string
}

// NewKes returns a new kes KMS client. The returned KMS
// uses the X.509 certificate to authenticate itself to
// the kes server available at address.
//
// The defaultKeyID is the key ID returned when calling
// KMS.KeyID().
func NewKes(cfg KesConfig) (KMS, error) {
	cert, err := tls.LoadX509KeyPair(cfg.CertFile, cfg.KeyFile)
	if err != nil {
		return nil, err
	}
	certPool, err := loadCACertificates(cfg.CAPath)
	if err != nil {
		return nil, err
	}
	cfg.Transport.TLSClientConfig = &tls.Config{
		Certificates: []tls.Certificate{cert},
		RootCAs:      certPool,
	}
	return &kesService{
		client: &kesClient{
			addr: cfg.Endpoint,
			httpClient: http.Client{
				Transport: cfg.Transport,
			},
		},
		endpoint:     cfg.Endpoint,
		defaultKeyID: cfg.DefaultKeyID,
	}, nil
}

// KeyID returns the default key ID.
func (kes *kesService) KeyID() string {
	return kes.defaultKeyID
}

// Info returns some status information about the KMS.
func (kes *kesService) Info() KMSInfo {
	return KMSInfo{
		Endpoint: kes.endpoint,
		Name:     kes.KeyID(),
		AuthType: "TLS",
	}
}

// GenerateKey returns a new plaintext key, generated by the KMS,
// and a sealed version of this plaintext key encrypted using the
// named key referenced by keyID. It also binds the generated key
// cryptographically to the provided context.
func (kes *kesService) GenerateKey(keyID string, ctx Context) (key [32]byte, sealedKey []byte, err error) {
	var context bytes.Buffer
	ctx.WriteTo(&context)

	var plainKey []byte
	plainKey, sealedKey, err = kes.client.GenerateDataKey(keyID, context.Bytes())
	if err != nil {
		return key, nil, err
	}
	if len(plainKey) != len(key) {
		return key, nil, Errorf("crypto: received invalid plaintext key size from KMS")
	}
	copy(key[:], plainKey)
	return key, sealedKey, nil
}

// UnsealKey returns the decrypted sealedKey as plaintext key.
// Therefore it sends the sealedKey to the KMS which decrypts
// it using the named key referenced by keyID and responses with
// the plaintext key.
//
// The context must be same context as the one provided while
// generating the plaintext key / sealedKey.
func (kes *kesService) UnsealKey(keyID string, sealedKey []byte, ctx Context) (key [32]byte, err error) {
	var context bytes.Buffer
	ctx.WriteTo(&context)

	var plainKey []byte
	plainKey, err = kes.client.DecryptDataKey(keyID, sealedKey, context.Bytes())
	if err != nil {
		return key, err
	}
	if len(plainKey) != len(key) {
		return key, Errorf("crypto: received invalid plaintext key size from KMS")
	}
	copy(key[:], plainKey)
	return key, nil
}

// UpdateKey re-wraps the sealedKey if the master key referenced by the keyID
// has been changed by the KMS operator - i.e. the master key has been rotated.
// If the master key hasn't changed since the sealedKey has been created / updated
// it may return the same sealedKey as rotatedKey.
//
// The context must be same context as the one provided while
// generating the plaintext key / sealedKey.
func (kes *kesService) UpdateKey(keyID string, sealedKey []byte, ctx Context) ([]byte, error) {
	_, err := kes.UnsealKey(keyID, sealedKey, ctx)
	if err != nil {
		return nil, err
	}

	// Currently a kes server does not support key rotation (of the same key)
	// Therefore, we simply return the same sealedKey.
	return sealedKey, nil
}

// kesClient implements the bare minimum functionality needed for
// MinIO to talk to a KES server. In particular, it implements
// GenerateDataKey (API: /v1/key/generate/) and
// DecryptDataKey  (API: /v1/key/decrypt/).
type kesClient struct {
	addr       string
	httpClient http.Client
}

// GenerateDataKey requests a new data key from the KES server.
// On success, the KES server will respond with the plaintext key
// and the ciphertext key as the plaintext key encrypted with
// the key specified by name.
//
// The optional context is crytpo. bound to the generated data key
// such that you have to provide the same context when decrypting
// the data key.
func (c *kesClient) GenerateDataKey(name string, context []byte) ([]byte, []byte, error) {
	type Request struct {
		Context []byte `json:"context"`
	}
	body, err := json.Marshal(Request{
		Context: context,
	})
	if err != nil {
		return nil, nil, err
	}

	url := fmt.Sprintf("%s/v1/key/generate/%s", c.addr, url.PathEscape(name))
	resp, err := c.httpClient.Post(url, "application/json", bytes.NewReader(body))
	if err != nil {
		return nil, nil, err
	}
	if resp.StatusCode != http.StatusOK {
		return nil, nil, c.parseErrorResponse(resp)
	}
	defer resp.Body.Close()

	type Response struct {
		Plaintext  []byte `json:"plaintext"`
		Ciphertext []byte `json:"ciphertext"`
	}
	const limit = 1 << 20 // A plaintext/ciphertext key pair will never be larger than 1 MB
	var response Response
	if err = json.NewDecoder(io.LimitReader(resp.Body, limit)).Decode(&response); err != nil {
		return nil, nil, err
	}
	return response.Plaintext, response.Ciphertext, nil
}

// GenerateDataKey decrypts an encrypted data key with the key
// specified by name by talking to the KES server.
// On success, the KES server will respond with the plaintext key.
//
// The optional context must match the value you provided when
// generating the data key.
func (c *kesClient) DecryptDataKey(name string, ciphertext, context []byte) ([]byte, error) {
	type Request struct {
		Ciphertext []byte `json:"ciphertext"`
		Context    []byte `json:"context"`
	}
	body, err := json.Marshal(Request{
		Ciphertext: ciphertext,
		Context:    context,
	})
	if err != nil {
		return nil, err
	}

	url := fmt.Sprintf("%s/v1/key/decrypt/%s", c.addr, url.PathEscape(name))
	resp, err := c.httpClient.Post(url, "application/json", bytes.NewReader(body))
	if err != nil {
		return nil, err
	}
	if resp.StatusCode != http.StatusOK {
		return nil, c.parseErrorResponse(resp)
	}
	defer resp.Body.Close()

	type Response struct {
		Plaintext []byte `json:"plaintext"`
	}
	const limit = 32 * 1024 // A data key will never be larger than 32 KB
	var response Response
	if err = json.NewDecoder(io.LimitReader(resp.Body, limit)).Decode(&response); err != nil {
		return nil, err
	}
	return response.Plaintext, nil
}

func (c *kesClient) parseErrorResponse(resp *http.Response) error {
	if resp.Body == nil {
		return nil
	}
	defer resp.Body.Close()

	const limit = 32 * 1024 // A (valid) error response will not be greater than 32 KB
	var errMsg strings.Builder
	if _, err := io.Copy(&errMsg, io.LimitReader(resp.Body, limit)); err != nil {
		return err
	}
	return Errorf("%s: %s", http.StatusText(resp.StatusCode), errMsg.String())
}

// loadCACertificates returns a new CertPool
// that contains all system root CA certificates
// and any PEM-encoded certificate(s) found at
// path.
//
// If path is a file, loadCACertificates will
// try to parse it as PEM-encoded certificate.
// If this fails, it returns an error.
//
// If path is a directory it tries to parse each
// file as PEM-encoded certificate and add it to
// the CertPool. If a file is not a PEM certificate
// it will be ignored.
func loadCACertificates(path string) (*x509.CertPool, error) {
	rootCAs, _ := x509.SystemCertPool()
	if rootCAs == nil {
		// In some systems (like Windows) system cert pool is
		// not supported or no certificates are present on the
		// system - so we create a new cert pool.
		rootCAs = x509.NewCertPool()
	}
	if path == "" {
		return rootCAs, nil
	}

	stat, err := os.Stat(path)
	if err != nil {
		if os.IsNotExist(err) || os.IsPermission(err) {
			return rootCAs, nil
		}
		return nil, Errorf("crypto: cannot open '%s': %v", path, err)
	}

	// If path is a file, parse as PEM-encoded certifcate
	// and try to add it to the CertPool. If this fails
	// return an error.
	if !stat.IsDir() {
		cert, err := ioutil.ReadFile(path)
		if err != nil {
			return nil, err
		}
		if !rootCAs.AppendCertsFromPEM(cert) {
			return nil, Errorf("crypto: '%s' is not a valid PEM-encoded certificate", path)
		}
		return rootCAs, nil
	}

	// If path is a directory then try
	// to parse each file as PEM-encoded
	// certificate and add it to the CertPool.
	// If a file is not a PEM-encoded certificate
	// we ignore it.
	files, err := ioutil.ReadDir(path)
	if err != nil {
		return nil, err
	}
	for _, file := range files {
		cert, err := ioutil.ReadFile(filepath.Join(path, file.Name()))
		if err != nil {
			continue // ignore files which are not readable
		}
		rootCAs.AppendCertsFromPEM(cert) // ignore files which are not PEM certtificates
	}
	return rootCAs, nil

}