S3 Select: optimize output (#8238)

Queue output items and reuse them.
Remove the unneeded type system in sql and just use the Go type system.

In best case this is more than an order of magnitude speedup:

```
BenchmarkSelectAll_1M-12    	       1	1841049400 ns/op	274299728 B/op	 4198522 allocs/op
BenchmarkSelectAll_1M-12    	      14	  84833400 ns/op	169228346 B/op	 3146541 allocs/op
```
This commit is contained in:
Klaus Post 2019-09-16 17:26:27 -07:00 committed by kannappanr
parent 017456df63
commit c9b8bd8de2
13 changed files with 556 additions and 231 deletions

View File

@ -61,6 +61,33 @@ func (r *Record) Set(name string, value *sql.Value) error {
return nil return nil
} }
// Reset data in record.
func (r *Record) Reset() {
if len(r.columnNames) > 0 {
r.columnNames = r.columnNames[:0]
}
if len(r.csvRecord) > 0 {
r.csvRecord = r.csvRecord[:0]
}
for k := range r.nameIndexMap {
delete(r.nameIndexMap, k)
}
}
// CopyFrom will copy all records from the incoming and append them to the existing records.
// The source record must be of a similar type.
// Note that the lookup index is not copied.
func (r *Record) CopyFrom(record sql.Record) error {
other, ok := record.(*Record)
if !ok {
return fmt.Errorf("unexpected record type, expected %T, got %T", r, record)
}
//before := len(r.csvRecord)
r.columnNames = append(r.columnNames, other.columnNames...)
r.csvRecord = append(r.csvRecord, other.csvRecord...)
return nil
}
// WriteCSV - encodes to CSV data. // WriteCSV - encodes to CSV data.
func (r *Record) WriteCSV(writer io.Writer, fieldDelimiter rune) error { func (r *Record) WriteCSV(writer io.Writer, fieldDelimiter rune) error {
w := csv.NewWriter(writer) w := csv.NewWriter(writer)

View File

@ -51,6 +51,24 @@ func (r *Record) Get(name string) (*sql.Value, error) {
return nil, errors.New("not implemented here") return nil, errors.New("not implemented here")
} }
// Reset the record.
func (r *Record) Reset() {
if len(r.KVS) > 0 {
r.KVS = r.KVS[:0]
}
}
// CopyFrom will copy all records from the incoming and append them to the existing records.
// The source record must be of a similar type.
func (r *Record) CopyFrom(record sql.Record) error {
other, ok := record.(*Record)
if !ok {
return fmt.Errorf("unexpected record type, expected %T, got %T", r, record)
}
r.KVS = append(r.KVS, other.KVS...)
return nil
}
// Set - sets the value for a column name. // Set - sets the value for a column name.
func (r *Record) Set(name string, value *sql.Value) error { func (r *Record) Set(name string, value *sql.Value) error {
var v interface{} var v interface{}

View File

@ -417,7 +417,7 @@ func newMessageWriter(w http.ResponseWriter, getProgressFunc func() (bytesScanne
getProgressFunc: getProgressFunc, getProgressFunc: getProgressFunc,
payloadBuffer: make([]byte, bufLength), payloadBuffer: make([]byte, bufLength),
payloadCh: make(chan *bytes.Buffer), payloadCh: make(chan *bytes.Buffer, 1),
errCh: make(chan []byte), errCh: make(chan []byte),
doneCh: make(chan struct{}), doneCh: make(chan struct{}),

View File

@ -61,7 +61,8 @@ const (
var bufPool = sync.Pool{ var bufPool = sync.Pool{
New: func() interface{} { New: func() interface{} {
return new(bytes.Buffer) // make a buffer with a reasonable capacity.
return bytes.NewBuffer(make([]byte, 0, maxRecordSize))
}, },
} }
@ -341,7 +342,10 @@ func (s3Select *S3Select) marshal(buf *bytes.Buffer, record sql.Record) error {
if err != nil { if err != nil {
return err return err
} }
err = bufioWriter.Flush()
if err != nil {
return err
}
buf.Truncate(buf.Len() - 1) buf.Truncate(buf.Len() - 1)
buf.WriteString(s3Select.Output.CSVArgs.RecordDelimiter) buf.WriteString(s3Select.Output.CSVArgs.RecordDelimiter)
@ -370,25 +374,33 @@ func (s3Select *S3Select) Evaluate(w http.ResponseWriter) {
writer := newMessageWriter(w, getProgressFunc) writer := newMessageWriter(w, getProgressFunc)
var inputRecord sql.Record var inputRecord sql.Record
var outputRecord sql.Record var outputQueue []sql.Record
// Create queue based on the type.
if s3Select.statement.IsAggregated() {
outputQueue = make([]sql.Record, 0, 1)
} else {
outputQueue = make([]sql.Record, 0, 100)
}
var err error var err error
sendRecord := func() bool { sendRecord := func() bool {
if outputRecord == nil {
return true
}
buf := bufPool.Get().(*bytes.Buffer) buf := bufPool.Get().(*bytes.Buffer)
buf.Reset() buf.Reset()
if err = s3Select.marshal(buf, outputRecord); err != nil { for _, outputRecord := range outputQueue {
bufPool.Put(buf) if outputRecord == nil {
return false continue
} }
before := buf.Len()
if buf.Len() > maxRecordSize { if err = s3Select.marshal(buf, outputRecord); err != nil {
writer.FinishWithError("OverMaxRecordSize", "The length of a record in the input or result is greater than maxCharsPerRecord of 1 MB.") bufPool.Put(buf)
bufPool.Put(buf) return false
return false }
if buf.Len()-before > maxRecordSize {
writer.FinishWithError("OverMaxRecordSize", "The length of a record in the input or result is greater than maxCharsPerRecord of 1 MB.")
bufPool.Put(buf)
return false
}
} }
if err = writer.SendRecord(buf); err != nil { if err = writer.SendRecord(buf); err != nil {
@ -397,7 +409,7 @@ func (s3Select *S3Select) Evaluate(w http.ResponseWriter) {
bufPool.Put(buf) bufPool.Put(buf)
return false return false
} }
outputQueue = outputQueue[:0]
return true return true
} }
@ -417,14 +429,15 @@ func (s3Select *S3Select) Evaluate(w http.ResponseWriter) {
} }
if s3Select.statement.IsAggregated() { if s3Select.statement.IsAggregated() {
outputRecord = s3Select.outputRecord() outputRecord := s3Select.outputRecord()
if err = s3Select.statement.AggregateResult(outputRecord); err != nil { if err = s3Select.statement.AggregateResult(outputRecord); err != nil {
break break
} }
outputQueue = append(outputQueue, outputRecord)
}
if !sendRecord() { if !sendRecord() {
break break
}
} }
if err = writer.Finish(s3Select.getProgress()); err != nil { if err = writer.Finish(s3Select.getProgress()); err != nil {
@ -443,10 +456,33 @@ func (s3Select *S3Select) Evaluate(w http.ResponseWriter) {
break break
} }
} else { } else {
outputRecord = s3Select.outputRecord() var outputRecord sql.Record
if outputRecord, err = s3Select.statement.Eval(inputRecord, outputRecord); err != nil { // We will attempt to reuse the records in the table.
// The type of these should not change.
// The queue should always have at least one entry left for this to work.
outputQueue = outputQueue[:len(outputQueue)+1]
if t := outputQueue[len(outputQueue)-1]; t != nil {
// If the output record is already set, we reuse it.
outputRecord = t
outputRecord.Reset()
} else {
// Create new one
outputRecord = s3Select.outputRecord()
outputQueue[len(outputQueue)-1] = outputRecord
}
if err = s3Select.statement.Eval(inputRecord, outputRecord); err != nil {
break break
} }
if outputRecord == nil {
// This should not be written.
// Remove it from the queue.
outputQueue = outputQueue[:len(outputQueue)-1]
continue
}
if len(outputQueue) < cap(outputQueue) {
continue
}
if !sendRecord() { if !sendRecord() {
break break

View File

@ -103,6 +103,7 @@ func benchmarkSelect(b *testing.B, count int, query string) {
b.ResetTimer() b.ResetTimer()
b.ReportAllocs() b.ReportAllocs()
b.SetBytes(int64(count))
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
for pb.Next() { for pb.Next() {
@ -147,6 +148,30 @@ func BenchmarkSelectAll_10M(b *testing.B) {
benchmarkSelectAll(b, 10*humanize.MiByte) benchmarkSelectAll(b, 10*humanize.MiByte)
} }
func benchmarkSingleCol(b *testing.B, count int) {
benchmarkSelect(b, count, "select id from S3Object")
}
// BenchmarkSingleRow_100K - benchmark SELECT column function with 100k records.
func BenchmarkSingleCol_100K(b *testing.B) {
benchmarkSingleCol(b, 1e5)
}
// BenchmarkSelectAll_1M - benchmark * function with 1m records.
func BenchmarkSingleCol_1M(b *testing.B) {
benchmarkSingleCol(b, 1e6)
}
// BenchmarkSelectAll_2M - benchmark * function with 2m records.
func BenchmarkSingleCol_2M(b *testing.B) {
benchmarkSingleCol(b, 2e6)
}
// BenchmarkSelectAll_10M - benchmark * function with 10m records.
func BenchmarkSingleCol_10M(b *testing.B) {
benchmarkSingleCol(b, 1e7)
}
func benchmarkAggregateCount(b *testing.B, count int) { func benchmarkAggregateCount(b *testing.B, count int) {
benchmarkSelect(b, count, "select count(*) from S3Object") benchmarkSelect(b, count, "select count(*) from S3Object")
} }

View File

@ -238,12 +238,11 @@ func (e *In) evalInNode(r Record, arg *Value) (*Value, error) {
// FIXME: type inference? // FIXME: type inference?
// Types must match. // Types must match.
if arg.vType != eltVal.vType { if !arg.SameTypeAs(*eltVal) {
// match failed. // match failed.
continue continue
} }
if arg.Equals(*eltVal) {
if arg.value == eltVal.value {
result = true result = true
break break
} }
@ -318,9 +317,11 @@ func (e *UnaryTerm) evalNode(r Record) (*Value, error) {
func (e *JSONPath) evalNode(r Record) (*Value, error) { func (e *JSONPath) evalNode(r Record) (*Value, error) {
// Strip the table name from the keypath. // Strip the table name from the keypath.
keypath := e.String() keypath := e.String()
ps := strings.SplitN(keypath, ".", 2) if strings.Contains(keypath, ".") {
if len(ps) == 2 { ps := strings.SplitN(keypath, ".", 2)
keypath = ps[1] if len(ps) == 2 {
keypath = ps[1]
}
} }
objFmt, rawVal := r.Raw() objFmt, rawVal := r.Raw()
switch objFmt { switch objFmt {

View File

@ -173,7 +173,7 @@ func nullif(v1, v2 *Value) (res *Value, err error) {
return v1, nil return v1, nil
} }
if v1.vType != v2.vType { if v1.SameTypeAs(*v2) {
return v1, nil return v1, nil
} }
@ -456,27 +456,24 @@ func intCast(v *Value) (int64, error) {
return 0, false return 0, false
} }
switch v.vType { switch x := v.value.(type) {
case typeFloat: case float64:
// Truncate fractional part // Truncate fractional part
return int64(v.value.(float64)), nil return int64(x), nil
case typeInt: case int64:
return v.value.(int64), nil return x, nil
case typeString: case string:
// Parse as number, truncate floating point if // Parse as number, truncate floating point if
// needed. // needed.
s, _ := v.ToString() res, ok := strToInt(x)
res, ok := strToInt(s)
if !ok { if !ok {
return 0, errCastFailure("could not parse as int") return 0, errCastFailure("could not parse as int")
} }
return res, nil return res, nil
case typeBytes: case []byte:
// Parse as number, truncate floating point if // Parse as number, truncate floating point if
// needed. // needed.
b, _ := v.ToBytes() res, ok := strToInt(string(x))
s := string(b)
res, ok := strToInt(s)
if !ok { if !ok {
return 0, errCastFailure("could not parse as int") return 0, errCastFailure("could not parse as int")
} }
@ -488,20 +485,19 @@ func intCast(v *Value) (int64, error) {
} }
func floatCast(v *Value) (float64, error) { func floatCast(v *Value) (float64, error) {
switch v.vType { switch x := v.value.(type) {
case typeFloat: case float64:
return v.value.(float64), nil return x, nil
case typeInt: case int:
return float64(v.value.(int64)), nil return float64(x), nil
case typeString: case string:
f, err := strconv.ParseFloat(v.value.(string), 64) f, err := strconv.ParseFloat(x, 64)
if err != nil { if err != nil {
return 0, errCastFailure("could not parse as float") return 0, errCastFailure("could not parse as float")
} }
return f, nil return f, nil
case typeBytes: case []byte:
b, _ := v.ToBytes() f, err := strconv.ParseFloat(string(x), 64)
f, err := strconv.ParseFloat(string(b), 64)
if err != nil { if err != nil {
return 0, errCastFailure("could not parse as float") return 0, errCastFailure("could not parse as float")
} }
@ -512,41 +508,33 @@ func floatCast(v *Value) (float64, error) {
} }
func stringCast(v *Value) (string, error) { func stringCast(v *Value) (string, error) {
switch v.vType { switch x := v.value.(type) {
case typeFloat: case float64:
f, _ := v.ToFloat() return fmt.Sprintf("%v", x), nil
return fmt.Sprintf("%v", f), nil case int64:
case typeInt: return fmt.Sprintf("%v", x), nil
i, _ := v.ToInt() case string:
return fmt.Sprintf("%v", i), nil return x, nil
case typeString: case []byte:
s, _ := v.ToString() return string(x), nil
return s, nil case bool:
case typeBytes: return fmt.Sprintf("%v", x), nil
b, _ := v.ToBytes() case nil:
return string(b), nil
case typeBool:
b, _ := v.ToBool()
return fmt.Sprintf("%v", b), nil
case typeNull:
// FIXME: verify this case is correct // FIXME: verify this case is correct
return fmt.Sprintf("NULL"), nil return "NULL", nil
} }
// This does not happen // This does not happen
return "", nil return "", errCastFailure(fmt.Sprintf("cannot cast %v to string type", v.GetTypeString()))
} }
func timestampCast(v *Value) (t time.Time, _ error) { func timestampCast(v *Value) (t time.Time, _ error) {
switch v.vType { switch x := v.value.(type) {
case typeString: case string:
s, _ := v.ToString() return parseSQLTimestamp(x)
return parseSQLTimestamp(s) case []byte:
case typeBytes: return parseSQLTimestamp(string(x))
b, _ := v.ToBytes() case time.Time:
return parseSQLTimestamp(string(b)) return x, nil
case typeTimestamp:
t, _ = v.ToTimestamp()
return t, nil
default: default:
return t, errCastFailure(fmt.Sprintf("cannot cast %v to Timestamp type", v.GetTypeString())) return t, errCastFailure(fmt.Sprintf("cannot cast %v to Timestamp type", v.GetTypeString()))
} }
@ -563,16 +551,13 @@ func boolCast(v *Value) (b bool, _ error) {
return false, errCastFailure("cannot cast to Bool") return false, errCastFailure("cannot cast to Bool")
} }
} }
switch v.vType { switch x := v.value.(type) {
case typeBool: case bool:
b, _ := v.ToBool() return x, nil
return b, nil case string:
case typeString: return sToB(strings.ToLower(x))
s, _ := v.ToString() case []byte:
return sToB(strings.ToLower(s)) return sToB(strings.ToLower(string(x)))
case typeBytes:
b, _ := v.ToBytes()
return sToB(strings.ToLower(string(b)))
default: default:
return false, errCastFailure("cannot cast %v to Bool") return false, errCastFailure("cannot cast %v to Bool")
} }

View File

@ -99,10 +99,14 @@ type JSONPathElement struct {
ArrayWildcard bool `parser:"| @\"[*]\""` // [*] form ArrayWildcard bool `parser:"| @\"[*]\""` // [*] form
} }
// JSONPath represents a keypath // JSONPath represents a keypath.
// Instances should be treated idempotent and not change once created.
type JSONPath struct { type JSONPath struct {
BaseKey *Identifier `parser:" @@"` BaseKey *Identifier `parser:" @@"`
PathExpr []*JSONPathElement `parser:"(@@)*"` PathExpr []*JSONPathElement `parser:"(@@)*"`
// Cached values:
pathString string
} }
// AliasedExpression is an expression that can be optionally named // AliasedExpression is an expression that can be optionally named

View File

@ -43,6 +43,11 @@ type Record interface {
WriteCSV(writer io.Writer, fieldDelimiter rune) error WriteCSV(writer io.Writer, fieldDelimiter rune) error
WriteJSON(writer io.Writer) error WriteJSON(writer io.Writer) error
// CopyFrom will copy all records from the incoming and append them to the existing records.
// The source record must be of a similar type as destination.
CopyFrom(src Record) error
Reset()
// Returns underlying representation // Returns underlying representation
Raw() (SelectObjectFormat, interface{}) Raw() (SelectObjectFormat, interface{})

View File

@ -219,11 +219,11 @@ func (e *SelectStatement) AggregateRow(input Record) error {
// Eval - evaluates the Select statement for the given record. It // Eval - evaluates the Select statement for the given record. It
// applies only to non-aggregation queries. // applies only to non-aggregation queries.
func (e *SelectStatement) Eval(input, output Record) (Record, error) { func (e *SelectStatement) Eval(input, output Record) error {
ok, err := e.isPassingWhereClause(input) ok, err := e.isPassingWhereClause(input)
if err != nil || !ok { if err != nil || !ok {
// Either error or row did not pass where clause // Either error or row did not pass where clause
return nil, err return err
} }
if e.selectAST.Expression.All { if e.selectAST.Expression.All {
@ -234,14 +234,13 @@ func (e *SelectStatement) Eval(input, output Record) (Record, error) {
if e.limitValue > -1 { if e.limitValue > -1 {
e.outputCount++ e.outputCount++
} }
return output.CopyFrom(input)
return input, nil
} }
for i, expr := range e.selectAST.Expression.Expressions { for i, expr := range e.selectAST.Expression.Expressions {
v, err := expr.evalNode(input) v, err := expr.evalNode(input)
if err != nil { if err != nil {
return nil, err return err
} }
// Pick output column names // Pick output column names
@ -259,7 +258,7 @@ func (e *SelectStatement) Eval(input, output Record) (Record, error) {
e.outputCount++ e.outputCount++
} }
return output, nil return nil
} }
// LimitReached - returns true if the number of records output has // LimitReached - returns true if the number of records output has

View File

@ -25,12 +25,15 @@ import (
// String - returns the JSONPath representation // String - returns the JSONPath representation
func (e *JSONPath) String() string { func (e *JSONPath) String() string {
parts := make([]string, len(e.PathExpr)+1) if len(e.pathString) == 0 {
parts[0] = e.BaseKey.String() parts := make([]string, len(e.PathExpr)+1)
for i, pe := range e.PathExpr { parts[0] = e.BaseKey.String()
parts[i+1] = pe.String() for i, pe := range e.PathExpr {
parts[i+1] = pe.String()
}
e.pathString = strings.Join(parts, "")
} }
return strings.Join(parts, "") return e.pathString
} }
func (e *JSONPathElement) String() string { func (e *JSONPathElement) String() string {
@ -94,9 +97,12 @@ func getLastKeypathComponent(e *Expression) (string, bool) {
if n > 0 && jpath.PathExpr[n-1].Key == nil { if n > 0 && jpath.PathExpr[n-1].Key == nil {
return "", false return "", false
} }
ps := jpath.String()
ps := strings.Split(jpath.String(), ".") if idx := strings.LastIndex(ps, "."); idx >= 0 {
return ps[len(ps)-1], true // Get last part of path string.
ps = ps[idx+1:]
}
return ps, true
} }
// HasKeypath returns if the from clause has a key path - // HasKeypath returns if the from clause has a key path -

View File

@ -20,6 +20,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"math" "math"
"reflect"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -34,28 +35,6 @@ var (
errCmpInvalidBoolOperator = errors.New("invalid comparison operator for boolean arguments") errCmpInvalidBoolOperator = errors.New("invalid comparison operator for boolean arguments")
) )
// vType represents the concrete type of a `Value`
type vType int
// Valid values for Type
const (
typeNull vType = iota + 1
typeBool
typeString
// 64-bit signed integer
typeInt
// 64-bit floating point
typeFloat
// timestamp type
typeTimestamp
// This type refers to untyped values, e.g. as read from CSV
typeBytes
)
// Value represents a value of restricted type reduced from an // Value represents a value of restricted type reduced from an
// expression represented by an ASTNode. Only one of the fields is // expression represented by an ASTNode. Only one of the fields is
// non-nil. // non-nil.
@ -65,43 +44,42 @@ const (
// used. // used.
type Value struct { type Value struct {
value interface{} value interface{}
vType vType
} }
// GetTypeString returns a string representation for vType // GetTypeString returns a string representation for vType
func (v *Value) GetTypeString() string { func (v Value) GetTypeString() string {
switch v.vType { switch v.value.(type) {
case typeNull: case nil:
return "NULL" return "NULL"
case typeBool: case bool:
return "BOOL" return "BOOL"
case typeString: case string:
return "STRING" return "STRING"
case typeInt: case int64:
return "INT" return "INT"
case typeFloat: case float64:
return "FLOAT" return "FLOAT"
case typeTimestamp: case time.Time:
return "TIMESTAMP" return "TIMESTAMP"
case typeBytes: case []byte:
return "BYTES" return "BYTES"
} }
return "--" return "--"
} }
// Repr returns a string representation of value. // Repr returns a string representation of value.
func (v *Value) Repr() string { func (v Value) Repr() string {
switch v.vType { switch x := v.value.(type) {
case typeNull: case nil:
return ":NULL" return ":NULL"
case typeBool, typeInt, typeFloat: case bool, int64, float64:
return fmt.Sprintf("%v:%s", v.value, v.GetTypeString()) return fmt.Sprintf("%v:%s", v.value, v.GetTypeString())
case typeTimestamp: case time.Time:
return fmt.Sprintf("%s:TIMESTAMP", v.value.(time.Time)) return fmt.Sprintf("%s:TIMESTAMP", x)
case typeString: case string:
return fmt.Sprintf("\"%s\":%s", v.value.(string), v.GetTypeString()) return fmt.Sprintf("\"%s\":%s", x, v.GetTypeString())
case typeBytes: case []byte:
return fmt.Sprintf("\"%s\":BYTES", string(v.value.([]byte))) return fmt.Sprintf("\"%s\":BYTES", string(x))
default: default:
return fmt.Sprintf("%v:INVALID", v.value) return fmt.Sprintf("%v:INVALID", v.value)
} }
@ -109,154 +87,174 @@ func (v *Value) Repr() string {
// FromFloat creates a Value from a number // FromFloat creates a Value from a number
func FromFloat(f float64) *Value { func FromFloat(f float64) *Value {
return &Value{value: f, vType: typeFloat} return &Value{value: f}
} }
// FromInt creates a Value from an int // FromInt creates a Value from an int
func FromInt(f int64) *Value { func FromInt(f int64) *Value {
return &Value{value: f, vType: typeInt} return &Value{value: f}
} }
// FromString creates a Value from a string // FromString creates a Value from a string
func FromString(str string) *Value { func FromString(str string) *Value {
return &Value{value: str, vType: typeString} return &Value{value: str}
} }
// FromBool creates a Value from a bool // FromBool creates a Value from a bool
func FromBool(b bool) *Value { func FromBool(b bool) *Value {
return &Value{value: b, vType: typeBool} return &Value{value: b}
} }
// FromTimestamp creates a Value from a timestamp // FromTimestamp creates a Value from a timestamp
func FromTimestamp(t time.Time) *Value { func FromTimestamp(t time.Time) *Value {
return &Value{value: t, vType: typeTimestamp} return &Value{value: t}
} }
// FromNull creates a Value with Null value // FromNull creates a Value with Null value
func FromNull() *Value { func FromNull() *Value {
return &Value{vType: typeNull} return &Value{value: nil}
} }
// FromBytes creates a Value from a []byte // FromBytes creates a Value from a []byte
func FromBytes(b []byte) *Value { func FromBytes(b []byte) *Value {
return &Value{value: b, vType: typeBytes} return &Value{value: b}
} }
// ToFloat works for int and float values // ToFloat works for int and float values
func (v *Value) ToFloat() (val float64, ok bool) { func (v Value) ToFloat() (val float64, ok bool) {
switch v.vType { switch x := v.value.(type) {
case typeFloat: case float64:
val, ok = v.value.(float64) return x, true
case typeInt: case int64:
var i int64 return float64(x), true
i, ok = v.value.(int64)
val = float64(i)
default:
} }
return return 0, false
} }
// ToInt converts value to int. // ToInt converts value to int.
func (v *Value) ToInt() (val int64, ok bool) { func (v Value) ToInt() (val int64, ok bool) {
switch v.vType { val, ok = v.value.(int64)
case typeInt:
val, ok = v.value.(int64)
default:
}
return return
} }
// ToString converts value to string. // ToString converts value to string.
func (v *Value) ToString() (val string, ok bool) { func (v Value) ToString() (val string, ok bool) {
switch v.vType { val, ok = v.value.(string)
case typeString:
val, ok = v.value.(string)
default:
}
return return
} }
// Equals returns whether the values strictly match.
// Both type and value must match.
func (v Value) Equals(b Value) (ok bool) {
if !v.SameTypeAs(b) {
return false
}
return reflect.DeepEqual(v.value, b.value)
}
// SameTypeAs return whether the two types are strictly the same.
func (v Value) SameTypeAs(b Value) (ok bool) {
switch v.value.(type) {
case bool:
_, ok = b.value.(bool)
case string:
_, ok = b.value.(string)
case int64:
_, ok = b.value.(int64)
case float64:
_, ok = b.value.(float64)
case time.Time:
_, ok = b.value.(time.Time)
case []byte:
_, ok = b.value.([]byte)
default:
ok = reflect.TypeOf(v.value) == reflect.TypeOf(b.value)
}
return ok
}
// ToBool returns the bool value; second return value refers to if the bool // ToBool returns the bool value; second return value refers to if the bool
// conversion succeeded. // conversion succeeded.
func (v *Value) ToBool() (val bool, ok bool) { func (v Value) ToBool() (val bool, ok bool) {
switch v.vType { val, ok = v.value.(bool)
case typeBool: return
return v.value.(bool), true
}
return false, false
} }
// ToTimestamp returns the timestamp value if present. // ToTimestamp returns the timestamp value if present.
func (v *Value) ToTimestamp() (t time.Time, ok bool) { func (v Value) ToTimestamp() (t time.Time, ok bool) {
switch v.vType { t, ok = v.value.(time.Time)
case typeTimestamp: return
return v.value.(time.Time), true
}
return t, false
} }
// ToBytes converts Value to byte-slice. // ToBytes converts Value to byte-slice.
func (v *Value) ToBytes() ([]byte, bool) { func (v Value) ToBytes() (val []byte, ok bool) {
switch v.vType { val, ok = v.value.([]byte)
case typeBytes: return
return v.value.([]byte), true
}
return nil, false
} }
// IsNull - checks if value is missing. // IsNull - checks if value is missing.
func (v *Value) IsNull() bool { func (v Value) IsNull() bool {
return v.vType == typeNull switch v.value.(type) {
case nil:
return true
}
return false
} }
func (v *Value) isNumeric() bool { func (v Value) isNumeric() bool {
return v.vType == typeInt || v.vType == typeFloat switch v.value.(type) {
case int64, float64:
return true
}
return false
} }
// setters used internally to mutate values // setters used internally to mutate values
func (v *Value) setInt(i int64) { func (v *Value) setInt(i int64) {
v.vType = typeInt
v.value = i v.value = i
} }
func (v *Value) setFloat(f float64) { func (v *Value) setFloat(f float64) {
v.vType = typeFloat
v.value = f v.value = f
} }
func (v *Value) setString(s string) { func (v *Value) setString(s string) {
v.vType = typeString
v.value = s v.value = s
} }
func (v *Value) setBool(b bool) { func (v *Value) setBool(b bool) {
v.vType = typeBool
v.value = b v.value = b
} }
func (v *Value) setTimestamp(t time.Time) { func (v *Value) setTimestamp(t time.Time) {
v.vType = typeTimestamp
v.value = t v.value = t
} }
func (v Value) String() string {
return fmt.Sprintf("%#v", v.value)
}
// CSVString - convert to string for CSV serialization // CSVString - convert to string for CSV serialization
func (v *Value) CSVString() string { func (v Value) CSVString() string {
switch v.vType { switch x := v.value.(type) {
case typeNull: case nil:
return "" return ""
case typeBool: case bool:
return fmt.Sprintf("%v", v.value.(bool)) if x {
case typeString: return "true"
return v.value.(string) }
case typeInt: return "false"
return fmt.Sprintf("%v", v.value.(int64)) case string:
case typeFloat: return x
return fmt.Sprintf("%v", v.value.(float64)) case int64:
case typeTimestamp: return strconv.FormatInt(x, 10)
return FormatSQLTimestamp(v.value.(time.Time)) case float64:
case typeBytes: return strconv.FormatFloat(x, 'g', -1, 64)
return fmt.Sprintf("%v", string(v.value.([]byte))) case time.Time:
return FormatSQLTimestamp(x)
case []byte:
return string(x)
default: default:
return "CSV serialization not implemented for this type" return "CSV serialization not implemented for this type"
} }
@ -273,11 +271,11 @@ func floatToValue(f float64) *Value {
// negate negates a numeric value // negate negates a numeric value
func (v *Value) negate() { func (v *Value) negate() {
switch v.vType { switch x := v.value.(type) {
case typeFloat: case float64:
v.value = -(v.value.(float64)) v.value = -x
case typeInt: case int64:
v.value = -(v.value.(int64)) v.value = -x
} }
} }
@ -411,25 +409,25 @@ func inferTypesForCmp(a *Value, b *Value) error {
case okA && !okB: case okA && !okB:
// Here a has `a` is untyped, but `b` has a fixed // Here a has `a` is untyped, but `b` has a fixed
// type. // type.
switch b.vType { switch b.value.(type) {
case typeString: case string:
s := a.bytesToString() s := a.bytesToString()
a.setString(s) a.setString(s)
case typeInt, typeFloat: case int64, float64:
if iA, ok := a.bytesToInt(); ok { if iA, ok := a.bytesToInt(); ok {
a.setInt(iA) a.setInt(iA)
} else if fA, ok := a.bytesToFloat(); ok { } else if fA, ok := a.bytesToFloat(); ok {
a.setFloat(fA) a.setFloat(fA)
} else { } else {
return fmt.Errorf("Could not convert %s to a number", string(a.value.([]byte))) return fmt.Errorf("Could not convert %s to a number", a.String())
} }
case typeBool: case bool:
if bA, ok := a.bytesToBool(); ok { if bA, ok := a.bytesToBool(); ok {
a.setBool(bA) a.setBool(bA)
} else { } else {
return fmt.Errorf("Could not convert %s to a boolean", string(a.value.([]byte))) return fmt.Errorf("Could not convert %s to a boolean", a.String())
} }
default: default:

View File

@ -0,0 +1,221 @@
/*
* MinIO Cloud Storage, (C) 2019 MinIO, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import (
"fmt"
"math"
"testing"
"time"
)
// valueBuilders contains one constructor for each value type.
// Values should match if type is the same.
var valueBuilders = []func() *Value{
func() *Value {
return FromNull()
},
func() *Value {
return FromBool(true)
},
func() *Value {
return FromBytes([]byte("byte contents"))
},
func() *Value {
return FromFloat(math.Pi)
},
func() *Value {
return FromInt(0x1337)
},
func() *Value {
t, err := time.Parse(time.RFC3339, "2006-01-02T15:04:05Z")
if err != nil {
panic(err)
}
return FromTimestamp(t)
},
func() *Value {
return FromString("string contents")
},
}
// altValueBuilders contains one constructor for each value type.
// Values are zero values and should NOT match the values in valueBuilders, except Null type.
var altValueBuilders = []func() *Value{
func() *Value {
return FromNull()
},
func() *Value {
return FromBool(false)
},
func() *Value {
return FromBytes(nil)
},
func() *Value {
return FromFloat(0)
},
func() *Value {
return FromInt(0)
},
func() *Value {
return FromTimestamp(time.Time{})
},
func() *Value {
return FromString("")
},
}
func TestValue_SameTypeAs(t *testing.T) {
type fields struct {
a, b Value
}
type test struct {
name string
fields fields
wantOk bool
}
var tests []test
for i := range valueBuilders {
a := valueBuilders[i]()
for j := range valueBuilders {
b := valueBuilders[j]()
tests = append(tests, test{
name: fmt.Sprint(a.GetTypeString(), "==", b.GetTypeString()),
fields: fields{
a: *a, b: *b,
},
wantOk: i == j,
})
}
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if gotOk := tt.fields.a.SameTypeAs(tt.fields.b); gotOk != tt.wantOk {
t.Errorf("SameTypeAs() = %v, want %v", gotOk, tt.wantOk)
}
})
}
}
func TestValue_Equals(t *testing.T) {
type fields struct {
a, b Value
}
type test struct {
name string
fields fields
wantOk bool
}
var tests []test
for i := range valueBuilders {
a := valueBuilders[i]()
for j := range valueBuilders {
b := valueBuilders[j]()
tests = append(tests, test{
name: fmt.Sprint(a.GetTypeString(), "==", b.GetTypeString()),
fields: fields{
a: *a, b: *b,
},
wantOk: i == j,
})
}
}
for i := range valueBuilders {
a := valueBuilders[i]()
for j := range altValueBuilders {
b := altValueBuilders[j]()
tests = append(tests, test{
name: fmt.Sprint(a.GetTypeString(), "!=", b.GetTypeString()),
fields: fields{
a: *a, b: *b,
},
// Only Null == Null
wantOk: a.IsNull() && b.IsNull() && i == 0 && j == 0,
})
}
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if gotOk := tt.fields.a.Equals(tt.fields.b); gotOk != tt.wantOk {
t.Errorf("Equals() = %v, want %v", gotOk, tt.wantOk)
}
})
}
}
func TestValue_CSVString(t *testing.T) {
type fields struct {
value interface{}
}
type test struct {
name string
want string
wantAlt string
}
tests := []test{
{
name: valueBuilders[0]().String(),
want: "",
wantAlt: "",
},
{
name: valueBuilders[1]().String(),
want: "true",
wantAlt: "false",
},
{
name: valueBuilders[2]().String(),
want: "byte contents",
wantAlt: "",
},
{
name: valueBuilders[3]().String(),
want: "3.141592653589793",
wantAlt: "0",
},
{
name: valueBuilders[4]().String(),
want: "4919",
wantAlt: "0",
},
{
name: valueBuilders[5]().String(),
want: "2006-01-02T15:04:05Z",
wantAlt: "0001T",
},
{
name: valueBuilders[6]().String(),
want: "string contents",
wantAlt: "",
},
}
for i, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v := valueBuilders[i]()
vAlt := altValueBuilders[i]()
if got := v.CSVString(); got != tt.want {
t.Errorf("CSVString() = %v, want %v", got, tt.want)
}
if got := vAlt.CSVString(); got != tt.wantAlt {
t.Errorf("CSVString() = %v, want %v", got, tt.wantAlt)
}
})
}
}