package db

import (
	"context"
	"encoding"
	"fmt"
	"reflect"

	"gorm.io/gorm/schema"
)

// Got from https://github.com/xdg-go/strum/blob/main/types.go
var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()

func isTextUnmarshaler(rv reflect.Value) bool {
	return rv.Type().Implements(textUnmarshalerType)
}

func maybeInstantiatePtr(rv reflect.Value) {
	if rv.Kind() == reflect.Ptr && rv.IsNil() {
		np := reflect.New(rv.Type().Elem())
		rv.Set(np)
	}
}

func decodingError(name string, err error) error {
	return fmt.Errorf("error decoding to %s: %w", name, err)
}

// TextSerialiser implements the Serialiser interface for fields that
// have a type that implements encoding.TextUnmarshaler.
type TextSerialiser struct{}

func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect.Value, dbValue interface{}) (err error) {
	fieldValue := reflect.New(field.FieldType)

	// If the field is a pointer, we need to dereference it to get the actual type
	// so we do not end with a second pointer.
	if fieldValue.Elem().Kind() == reflect.Ptr {
		fieldValue = fieldValue.Elem()
	}

	if dbValue != nil {
		var bytes []byte
		switch v := dbValue.(type) {
		case []byte:
			bytes = v
		case string:
			bytes = []byte(v)
		default:
			return fmt.Errorf("failed to unmarshal text value: %#v", dbValue)
		}

		if isTextUnmarshaler(fieldValue) {
			maybeInstantiatePtr(fieldValue)
			f := fieldValue.MethodByName("UnmarshalText")
			args := []reflect.Value{reflect.ValueOf(bytes)}
			ret := f.Call(args)
			if !ret[0].IsNil() {
				return decodingError(field.Name, ret[0].Interface().(error))
			}

			// If the underlying field is to a pointer type, we need to
			// assign the value as a pointer to it.
			// If it is not a pointer, we need to assign the value to the
			// field.
			dstField := field.ReflectValueOf(ctx, dst)
			if dstField.Kind() == reflect.Ptr {
				dstField.Set(fieldValue)
			} else {
				dstField.Set(fieldValue.Elem())
			}
			return nil
		} else {
			return fmt.Errorf("unsupported type: %T", fieldValue.Interface())
		}
	}

	return
}

func (TextSerialiser) Value(ctx context.Context, field *schema.Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
	switch v := fieldValue.(type) {
	case encoding.TextMarshaler:
		// If the value is nil, we return nil, however, go nil values are not
		// always comparable, particularly when reflection is involved:
		// https://dev.to/arxeiss/in-go-nil-is-not-equal-to-nil-sometimes-jn8
		if v == nil || (reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil()) {
			return nil, nil
		}
		b, err := v.MarshalText()
		if err != nil {
			return nil, err
		}
		return string(b), nil
	default:
		return nil, fmt.Errorf("only encoding.TextMarshaler is supported, got %t", v)
	}
}