mirror of
https://github.com/muun/recovery.git
synced 2025-11-11 14:30:19 -05:00
Update project structure and build process
This commit is contained in:
180
libwallet/walletdb/walletdb.go
Normal file
180
libwallet/walletdb/walletdb.go
Normal file
@@ -0,0 +1,180 @@
|
||||
package walletdb
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
_ "github.com/jinzhu/gorm/dialects/sqlite"
|
||||
gormigrate "gopkg.in/gormigrate.v1"
|
||||
)
|
||||
|
||||
type InvoiceState string
|
||||
|
||||
const (
|
||||
InvoiceStateRegistered InvoiceState = "registered"
|
||||
InvoiceStateUsed InvoiceState = "used"
|
||||
)
|
||||
|
||||
// TODO: probably rename to InvoiceSecrets or similar
|
||||
type Invoice struct {
|
||||
gorm.Model
|
||||
Preimage []byte
|
||||
PaymentHash []byte
|
||||
PaymentSecret []byte
|
||||
KeyPath string
|
||||
ShortChanId uint64
|
||||
AmountSat int64
|
||||
State InvoiceState
|
||||
Metadata string
|
||||
UsedAt *time.Time
|
||||
}
|
||||
|
||||
type DB struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func Open(path string) (*DB, error) {
|
||||
db, err := gorm.Open("sqlite3", path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = migrate(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &DB{db}, nil
|
||||
}
|
||||
|
||||
func migrate(db *gorm.DB) error {
|
||||
opts := gormigrate.Options{
|
||||
UseTransaction: true,
|
||||
}
|
||||
m := gormigrate.New(db, &opts, []*gormigrate.Migration{
|
||||
{
|
||||
ID: "initial",
|
||||
Migrate: func(tx *gorm.DB) error {
|
||||
type Invoice struct {
|
||||
gorm.Model
|
||||
Preimage []byte
|
||||
PaymentHash []byte
|
||||
PaymentSecret []byte
|
||||
KeyPath string
|
||||
ShortChanId uint64
|
||||
State string
|
||||
UsedAt *time.Time
|
||||
}
|
||||
// This guard exists because at some point migrations were run outside a
|
||||
// transactional context and a user experimented problems with an invoices
|
||||
// table that was already created but whose migration had not been properly
|
||||
// recorded.
|
||||
if !tx.HasTable(&Invoice{}) {
|
||||
return tx.CreateTable(&Invoice{}).Error
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Rollback: func(tx *gorm.DB) error {
|
||||
return tx.DropTable("invoices").Error
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "add amount to invoices table",
|
||||
Migrate: func(tx *gorm.DB) error {
|
||||
type Invoice struct {
|
||||
gorm.Model
|
||||
Preimage []byte
|
||||
PaymentHash []byte
|
||||
PaymentSecret []byte
|
||||
KeyPath string
|
||||
ShortChanId uint64
|
||||
AmountSat int64
|
||||
State string
|
||||
UsedAt *time.Time
|
||||
}
|
||||
return tx.AutoMigrate(&Invoice{}).Error
|
||||
},
|
||||
Rollback: func(tx *gorm.DB) error {
|
||||
return tx.Table("invoices").DropColumn(gorm.ToColumnName("AmountSat")).Error
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "add metadata to invoices table",
|
||||
Migrate: func(tx *gorm.DB) error {
|
||||
type Invoice struct {
|
||||
gorm.Model
|
||||
Preimage []byte
|
||||
PaymentHash []byte
|
||||
PaymentSecret []byte
|
||||
KeyPath string
|
||||
ShortChanId uint64
|
||||
AmountSat int64
|
||||
State InvoiceState
|
||||
Metadata string
|
||||
UsedAt *time.Time
|
||||
}
|
||||
return tx.AutoMigrate(&Invoice{}).Error
|
||||
},
|
||||
Rollback: func(tx *gorm.DB) error {
|
||||
return tx.Table("invoices").DropColumn(gorm.ToColumnName("Metadata")).Error
|
||||
},
|
||||
},
|
||||
})
|
||||
return m.Migrate()
|
||||
}
|
||||
|
||||
func (d *DB) CreateInvoice(invoice *Invoice) error {
|
||||
// uint64 values with high bit set are not supported, we will
|
||||
// have to convert back and forth
|
||||
invoice.ShortChanId = invoice.ShortChanId & 0x7FFFFFFFFFFFFFFF
|
||||
res := d.db.Create(invoice)
|
||||
invoice.ShortChanId = invoice.ShortChanId | (1 << 63)
|
||||
return res.Error
|
||||
}
|
||||
|
||||
func (d *DB) SaveInvoice(invoice *Invoice) error {
|
||||
// uint64 values with high bit set are not supported, we will
|
||||
// have to convert back and forth
|
||||
invoice.ShortChanId = invoice.ShortChanId & 0x7FFFFFFFFFFFFFFF
|
||||
res := d.db.Save(invoice)
|
||||
invoice.ShortChanId = invoice.ShortChanId | (1 << 63)
|
||||
return res.Error
|
||||
}
|
||||
|
||||
func (d *DB) FindFirstUnusedInvoice() (*Invoice, error) {
|
||||
var invoice Invoice
|
||||
if res := d.db.Where(&Invoice{State: InvoiceStateRegistered}).First(&invoice); res.Error != nil {
|
||||
|
||||
if errors.Is(res.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return nil, res.Error
|
||||
}
|
||||
invoice.ShortChanId = invoice.ShortChanId | (1 << 63)
|
||||
return &invoice, nil
|
||||
}
|
||||
|
||||
func (d *DB) CountUnusedInvoices() (int, error) {
|
||||
var count int
|
||||
if res := d.db.Model(&Invoice{}).Where(&Invoice{State: InvoiceStateRegistered}).Count(&count); res.Error != nil {
|
||||
return 0, res.Error
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (d *DB) FindByPaymentHash(hash []byte) (*Invoice, error) {
|
||||
var invoice Invoice
|
||||
if res := d.db.Where(&Invoice{PaymentHash: hash}).First(&invoice); res.Error != nil {
|
||||
return nil, res.Error
|
||||
}
|
||||
invoice.ShortChanId = invoice.ShortChanId | (1 << 63)
|
||||
return &invoice, nil
|
||||
}
|
||||
|
||||
func (d *DB) Close() {
|
||||
err := d.db.Close()
|
||||
if err != nil {
|
||||
log.Printf("error closing the db: %v", err)
|
||||
}
|
||||
}
|
||||
95
libwallet/walletdb/walletdb_test.go
Normal file
95
libwallet/walletdb/walletdb_test.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package walletdb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"io/ioutil"
|
||||
"math"
|
||||
"path"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestOpen(t *testing.T) {
|
||||
dir, err := ioutil.TempDir("", "libwallet")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
db, err := Open(path.Join(dir, "test.db"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
}
|
||||
|
||||
func TestInvoices(t *testing.T) {
|
||||
dir, err := ioutil.TempDir("", "libwallet")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
db, err := Open(path.Join(dir, "test.db"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
shortChanId := uint64((math.MaxInt64 - 5) | (1 << 63))
|
||||
paymentHash := randomBytes(32)
|
||||
|
||||
err = db.CreateInvoice(&Invoice{
|
||||
Preimage: randomBytes(32),
|
||||
PaymentHash: paymentHash,
|
||||
PaymentSecret: randomBytes(32),
|
||||
KeyPath: "34/56",
|
||||
ShortChanId: shortChanId,
|
||||
State: InvoiceStateRegistered,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
count, err := db.CountUnusedInvoices()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Fatalf("expected to find 1 unused invoice, got %d", count)
|
||||
}
|
||||
|
||||
inv, err := db.FindByPaymentHash(paymentHash)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(inv.PaymentHash, paymentHash) {
|
||||
t.Fatal("expected invoice payment hash does not match")
|
||||
}
|
||||
|
||||
inv, err = db.FindFirstUnusedInvoice()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(inv.PaymentHash, paymentHash) {
|
||||
t.Fatal("expected invoice payment hash does not match")
|
||||
}
|
||||
if inv.ShortChanId != shortChanId {
|
||||
t.Fatal("expected invoice short channel id does not match")
|
||||
}
|
||||
|
||||
err = db.SaveInvoice(inv)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if inv.ShortChanId != shortChanId {
|
||||
t.Fatal("expected invoice short channel id does not match")
|
||||
}
|
||||
}
|
||||
|
||||
func randomBytes(count int) []byte {
|
||||
buf := make([]byte, count)
|
||||
_, err := rand.Read(buf)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return buf
|
||||
}
|
||||
Reference in New Issue
Block a user