2020-11-09 10:05:29 -03:00

175 lines
4.1 KiB
Go

package keycrypt
import (
"bytes"
"crypto/rand"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
"strconv"
"strings"
"unicode/utf16"
"github.com/btcsuite/btcutil/hdkeychain"
"github.com/muun/libwallet/aescbc"
"golang.org/x/crypto/scrypt"
)
const (
ivLength = 16
saltLength = 8
scryptIterations = 512
scryptBlockSize = 8
scryptParallelizationFactor = 1
scryptOutputLength = 32
separator = ":"
)
// Encrypt encrypts an HD priv key using a user-provided secret into a string
// We use SCrypt256 for key derivation and AES-CBC-PKCS7 for encryption.
// The returned string has information about version, derivation path, scrypt and AES parameters.
func Encrypt(key *hdkeychain.ExtendedKey, keyPath, passphrase string) (string, error) {
iv := randomBytes(ivLength)
salt := randomBytes(saltLength)
inputSecret, err := scrypt.Key(
encodeUTF16(passphrase),
salt,
scryptIterations,
scryptBlockSize,
scryptParallelizationFactor,
scryptOutputLength,
)
if err != nil {
return "", fmt.Errorf("failed to compute scrypt key: %w", err)
}
privateKeyBytes := []byte(key.String())
encrypted, err := aescbc.EncryptPkcs7(inputSecret, iv, privateKeyBytes)
if err != nil {
return "", fmt.Errorf("failed to encrypt: %w", err)
}
derivationPathBytes := []byte(keyPath)
elements := []string{
"v1",
strconv.Itoa(scryptIterations),
strconv.Itoa(scryptParallelizationFactor),
strconv.Itoa(scryptBlockSize),
hex.EncodeToString(salt),
hex.EncodeToString(iv),
hex.EncodeToString(encrypted),
hex.EncodeToString(derivationPathBytes),
}
return strings.Join(elements, separator), nil
}
// Decrypt decrypts a key encrypted with Encrypt
func Decrypt(value, passphrase string) (*hdkeychain.ExtendedKey, string, error) {
elements := strings.Split(value, separator)
if len(elements) != 8 {
return nil, "", errors.New("invalid format")
}
version := elements[0]
iterations, err := strconv.Atoi(elements[1])
if err != nil {
return nil, "", fmt.Errorf("invalid iterations: %w", err)
}
parallelizationFactor, err := strconv.Atoi(elements[2])
if err != nil {
return nil, "", fmt.Errorf("invalid p: %w", err)
}
blockSize, err := strconv.Atoi(elements[3])
if err != nil {
return nil, "", fmt.Errorf("invalid blocksize: %w", err)
}
salt, err := hex.DecodeString(elements[4])
if err != nil {
return nil, "", fmt.Errorf("invalid salt: %w", err)
}
iv, err := hex.DecodeString(elements[5])
if err != nil {
return nil, "", fmt.Errorf("invalid iv: %w", err)
}
payload, err := hex.DecodeString(elements[6])
if err != nil {
return nil, "", fmt.Errorf("invalid payload: %w", err)
}
pathBytes, err := hex.DecodeString(elements[7])
if err != nil {
return nil, "", fmt.Errorf("invalid path: %w", err)
}
if version != "v1" {
return nil, "", fmt.Errorf("invalid version %s", version)
}
inputSecret, err := scrypt.Key(
encodeUTF16(passphrase),
salt,
iterations,
blockSize,
parallelizationFactor,
scryptOutputLength,
)
if err != nil {
return nil, "", fmt.Errorf("failed to compute scrypt key: %w", err)
}
decryptedBytes, err := aescbc.DecryptPkcs7(inputSecret, iv, payload)
if err != nil {
return nil, "", fmt.Errorf("failed to decrypt: %w", err)
}
key, err := hdkeychain.NewKeyFromString(string(decryptedBytes[:]))
if err != nil {
return nil, "", fmt.Errorf("could not decode private key: %w", err)
}
if !key.IsPrivate() {
return nil, "", errors.New("expected extended key to be private, not public")
}
path := string(pathBytes[:])
return key, path, nil
}
func randomBytes(count int) []byte {
buf := make([]byte, count)
_, err := rand.Read(buf)
if err != nil {
panic("couldn't read random bytes")
}
return buf
}
func encodeUTF16(s string) []byte {
// You might wonder why this code exists....
// Turns out that the scrypt implementation used in android is hardwired
// to use strings as UTF16 (which is Java's native format). So we need to
// use the same exact byte array encoding.
var buf bytes.Buffer
for _, r := range utf16.Encode([]rune(s)) {
binary.Write(&buf, binary.BigEndian, r)
}
return buf.Bytes()
}