minio/pkg/s3select/sql/funceval.go
Klaus Post c9b8bd8de2 S3 Select: optimize output (#8238)
Queue output items and reuse them.
Remove the unneeded type system in sql and just use the Go type system.

In best case this is more than an order of magnitude speedup:

```
BenchmarkSelectAll_1M-12    	       1	1841049400 ns/op	274299728 B/op	 4198522 allocs/op
BenchmarkSelectAll_1M-12    	      14	  84833400 ns/op	169228346 B/op	 3146541 allocs/op
```
2019-09-17 05:56:27 +05:30

565 lines
12 KiB
Go

/*
* MinIO Cloud Storage, (C) 2019 MinIO, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import (
"errors"
"fmt"
"strconv"
"strings"
"time"
)
// FuncName - SQL function name.
type FuncName string
// SQL Function name constants
const (
// Conditionals
sqlFnCoalesce FuncName = "COALESCE"
sqlFnNullIf FuncName = "NULLIF"
// Conversion
sqlFnCast FuncName = "CAST"
// Date and time
sqlFnDateAdd FuncName = "DATE_ADD"
sqlFnDateDiff FuncName = "DATE_DIFF"
sqlFnExtract FuncName = "EXTRACT"
sqlFnToString FuncName = "TO_STRING"
sqlFnToTimestamp FuncName = "TO_TIMESTAMP"
sqlFnUTCNow FuncName = "UTCNOW"
// String
sqlFnCharLength FuncName = "CHAR_LENGTH"
sqlFnCharacterLength FuncName = "CHARACTER_LENGTH"
sqlFnLower FuncName = "LOWER"
sqlFnSubstring FuncName = "SUBSTRING"
sqlFnTrim FuncName = "TRIM"
sqlFnUpper FuncName = "UPPER"
)
var (
errUnimplementedCast = errors.New("This cast not yet implemented")
errNonStringTrimArg = errors.New("TRIM() received a non-string argument")
errNonTimestampArg = errors.New("Expected a timestamp argument")
)
func (e *FuncExpr) getFunctionName() FuncName {
switch {
case e.SFunc != nil:
return FuncName(strings.ToUpper(e.SFunc.FunctionName))
case e.Count != nil:
return FuncName(aggFnCount)
case e.Cast != nil:
return sqlFnCast
case e.Substring != nil:
return sqlFnSubstring
case e.Extract != nil:
return sqlFnExtract
case e.Trim != nil:
return sqlFnTrim
case e.DateAdd != nil:
return sqlFnDateAdd
case e.DateDiff != nil:
return sqlFnDateDiff
default:
return ""
}
}
// evalSQLFnNode assumes that the FuncExpr is not an aggregation
// function.
func (e *FuncExpr) evalSQLFnNode(r Record) (res *Value, err error) {
// Handle functions that have phrase arguments
switch e.getFunctionName() {
case sqlFnCast:
expr := e.Cast.Expr
res, err = expr.castTo(r, strings.ToUpper(e.Cast.CastType))
return
case sqlFnSubstring:
return handleSQLSubstring(r, e.Substring)
case sqlFnExtract:
return handleSQLExtract(r, e.Extract)
case sqlFnTrim:
return handleSQLTrim(r, e.Trim)
case sqlFnDateAdd:
return handleDateAdd(r, e.DateAdd)
case sqlFnDateDiff:
return handleDateDiff(r, e.DateDiff)
}
// For all simple argument functions, we evaluate the arguments here
argVals := make([]*Value, len(e.SFunc.ArgsList))
for i, arg := range e.SFunc.ArgsList {
argVals[i], err = arg.evalNode(r)
if err != nil {
return nil, err
}
}
switch e.getFunctionName() {
case sqlFnCoalesce:
return coalesce(argVals)
case sqlFnNullIf:
return nullif(argVals[0], argVals[1])
case sqlFnCharLength, sqlFnCharacterLength:
return charlen(argVals[0])
case sqlFnLower:
return lowerCase(argVals[0])
case sqlFnUpper:
return upperCase(argVals[0])
case sqlFnUTCNow:
return handleUTCNow()
case sqlFnToString, sqlFnToTimestamp:
// TODO: implement
fallthrough
default:
return nil, errNotImplemented
}
}
func coalesce(args []*Value) (res *Value, err error) {
for _, arg := range args {
if arg.IsNull() {
continue
}
return arg, nil
}
return FromNull(), nil
}
func nullif(v1, v2 *Value) (res *Value, err error) {
// Handle Null cases
if v1.IsNull() || v2.IsNull() {
return v1, nil
}
err = inferTypesForCmp(v1, v2)
if err != nil {
return nil, err
}
atleastOneNumeric := v1.isNumeric() || v2.isNumeric()
bothNumeric := v1.isNumeric() && v2.isNumeric()
if atleastOneNumeric || !bothNumeric {
return v1, nil
}
if v1.SameTypeAs(*v2) {
return v1, nil
}
cmpResult, cmpErr := v1.compareOp(opEq, v2)
if cmpErr != nil {
return nil, cmpErr
}
if cmpResult {
return FromNull(), nil
}
return v1, nil
}
func charlen(v *Value) (*Value, error) {
inferTypeAsString(v)
s, ok := v.ToString()
if !ok {
err := fmt.Errorf("%s/%s expects a string argument", sqlFnCharLength, sqlFnCharacterLength)
return nil, errIncorrectSQLFunctionArgumentType(err)
}
return FromInt(int64(len([]rune(s)))), nil
}
func lowerCase(v *Value) (*Value, error) {
inferTypeAsString(v)
s, ok := v.ToString()
if !ok {
err := fmt.Errorf("%s expects a string argument", sqlFnLower)
return nil, errIncorrectSQLFunctionArgumentType(err)
}
return FromString(strings.ToLower(s)), nil
}
func upperCase(v *Value) (*Value, error) {
inferTypeAsString(v)
s, ok := v.ToString()
if !ok {
err := fmt.Errorf("%s expects a string argument", sqlFnUpper)
return nil, errIncorrectSQLFunctionArgumentType(err)
}
return FromString(strings.ToUpper(s)), nil
}
func handleDateAdd(r Record, d *DateAddFunc) (*Value, error) {
q, err := d.Quantity.evalNode(r)
if err != nil {
return nil, err
}
inferTypeForArithOp(q)
qty, ok := q.ToFloat()
if !ok {
return nil, fmt.Errorf("QUANTITY must be a numeric argument to %s()", sqlFnDateAdd)
}
ts, err := d.Timestamp.evalNode(r)
if err != nil {
return nil, err
}
if err = inferTypeAsTimestamp(ts); err != nil {
return nil, err
}
t, ok := ts.ToTimestamp()
if !ok {
return nil, fmt.Errorf("%s() expects a timestamp argument", sqlFnDateAdd)
}
return dateAdd(strings.ToUpper(d.DatePart), qty, t)
}
func handleDateDiff(r Record, d *DateDiffFunc) (*Value, error) {
tval1, err := d.Timestamp1.evalNode(r)
if err != nil {
return nil, err
}
if err = inferTypeAsTimestamp(tval1); err != nil {
return nil, err
}
ts1, ok := tval1.ToTimestamp()
if !ok {
return nil, fmt.Errorf("%s() expects two timestamp arguments", sqlFnDateDiff)
}
tval2, err := d.Timestamp2.evalNode(r)
if err != nil {
return nil, err
}
if err = inferTypeAsTimestamp(tval2); err != nil {
return nil, err
}
ts2, ok := tval2.ToTimestamp()
if !ok {
return nil, fmt.Errorf("%s() expects two timestamp arguments", sqlFnDateDiff)
}
return dateDiff(strings.ToUpper(d.DatePart), ts1, ts2)
}
func handleUTCNow() (*Value, error) {
return FromTimestamp(time.Now().UTC()), nil
}
func handleSQLSubstring(r Record, e *SubstringFunc) (val *Value, err error) {
// Both forms `SUBSTRING('abc' FROM 2 FOR 1)` and
// SUBSTRING('abc', 2, 1) are supported.
// Evaluate the string argument
v1, err := e.Expr.evalNode(r)
if err != nil {
return nil, err
}
inferTypeAsString(v1)
s, ok := v1.ToString()
if !ok {
err := fmt.Errorf("Incorrect argument type passed to %s", sqlFnSubstring)
return nil, errIncorrectSQLFunctionArgumentType(err)
}
// Assemble other arguments
arg2, arg3 := e.From, e.For
// Check if the second form of substring is being used
if e.From == nil {
arg2, arg3 = e.Arg2, e.Arg3
}
// Evaluate the FROM argument
v2, err := arg2.evalNode(r)
if err != nil {
return nil, err
}
inferTypeForArithOp(v2)
startIdx, ok := v2.ToInt()
if !ok {
err := fmt.Errorf("Incorrect type for start index argument in %s", sqlFnSubstring)
return nil, errIncorrectSQLFunctionArgumentType(err)
}
length := -1
// Evaluate the optional FOR argument
if arg3 != nil {
v3, err := arg3.evalNode(r)
if err != nil {
return nil, err
}
inferTypeForArithOp(v3)
lenInt, ok := v3.ToInt()
if !ok {
err := fmt.Errorf("Incorrect type for length argument in %s", sqlFnSubstring)
return nil, errIncorrectSQLFunctionArgumentType(err)
}
length = int(lenInt)
if length < 0 {
err := fmt.Errorf("Negative length argument in %s", sqlFnSubstring)
return nil, errIncorrectSQLFunctionArgumentType(err)
}
}
res, err := evalSQLSubstring(s, int(startIdx), length)
return FromString(res), err
}
func handleSQLTrim(r Record, e *TrimFunc) (res *Value, err error) {
chars := ""
ok := false
if e.TrimChars != nil {
charsV, cerr := e.TrimChars.evalNode(r)
if cerr != nil {
return nil, cerr
}
inferTypeAsString(charsV)
chars, ok = charsV.ToString()
if !ok {
return nil, errNonStringTrimArg
}
}
fromV, ferr := e.TrimFrom.evalNode(r)
if ferr != nil {
return nil, ferr
}
inferTypeAsString(fromV)
from, ok := fromV.ToString()
if !ok {
return nil, errNonStringTrimArg
}
result, terr := evalSQLTrim(e.TrimWhere, chars, from)
if terr != nil {
return nil, terr
}
return FromString(result), nil
}
func handleSQLExtract(r Record, e *ExtractFunc) (res *Value, err error) {
timeVal, verr := e.From.evalNode(r)
if verr != nil {
return nil, verr
}
if err = inferTypeAsTimestamp(timeVal); err != nil {
return nil, err
}
t, ok := timeVal.ToTimestamp()
if !ok {
return nil, errNonTimestampArg
}
return extract(strings.ToUpper(e.Timeword), t)
}
func errUnsupportedCast(fromType, toType string) error {
return fmt.Errorf("Cannot cast from %v to %v", fromType, toType)
}
func errCastFailure(msg string) error {
return fmt.Errorf("Error casting: %s", msg)
}
// Allowed cast types
const (
castBool = "BOOL"
castInt = "INT"
castInteger = "INTEGER"
castString = "STRING"
castFloat = "FLOAT"
castDecimal = "DECIMAL"
castNumeric = "NUMERIC"
castTimestamp = "TIMESTAMP"
)
func (e *Expression) castTo(r Record, castType string) (res *Value, err error) {
v, err := e.evalNode(r)
if err != nil {
return nil, err
}
switch castType {
case castInt, castInteger:
i, err := intCast(v)
return FromInt(i), err
case castFloat:
f, err := floatCast(v)
return FromFloat(f), err
case castString:
s, err := stringCast(v)
return FromString(s), err
case castTimestamp:
t, err := timestampCast(v)
return FromTimestamp(t), err
case castBool:
b, err := boolCast(v)
return FromBool(b), err
case castDecimal, castNumeric:
fallthrough
default:
return nil, errUnimplementedCast
}
}
func intCast(v *Value) (int64, error) {
// This conversion truncates floating point numbers to
// integer.
strToInt := func(s string) (int64, bool) {
i, errI := strconv.ParseInt(s, 10, 64)
if errI == nil {
return i, true
}
f, errF := strconv.ParseFloat(s, 64)
if errF == nil {
return int64(f), true
}
return 0, false
}
switch x := v.value.(type) {
case float64:
// Truncate fractional part
return int64(x), nil
case int64:
return x, nil
case string:
// Parse as number, truncate floating point if
// needed.
res, ok := strToInt(x)
if !ok {
return 0, errCastFailure("could not parse as int")
}
return res, nil
case []byte:
// Parse as number, truncate floating point if
// needed.
res, ok := strToInt(string(x))
if !ok {
return 0, errCastFailure("could not parse as int")
}
return res, nil
default:
return 0, errUnsupportedCast(v.GetTypeString(), castInt)
}
}
func floatCast(v *Value) (float64, error) {
switch x := v.value.(type) {
case float64:
return x, nil
case int:
return float64(x), nil
case string:
f, err := strconv.ParseFloat(x, 64)
if err != nil {
return 0, errCastFailure("could not parse as float")
}
return f, nil
case []byte:
f, err := strconv.ParseFloat(string(x), 64)
if err != nil {
return 0, errCastFailure("could not parse as float")
}
return f, nil
default:
return 0, errUnsupportedCast(v.GetTypeString(), castFloat)
}
}
func stringCast(v *Value) (string, error) {
switch x := v.value.(type) {
case float64:
return fmt.Sprintf("%v", x), nil
case int64:
return fmt.Sprintf("%v", x), nil
case string:
return x, nil
case []byte:
return string(x), nil
case bool:
return fmt.Sprintf("%v", x), nil
case nil:
// FIXME: verify this case is correct
return "NULL", nil
}
// This does not happen
return "", errCastFailure(fmt.Sprintf("cannot cast %v to string type", v.GetTypeString()))
}
func timestampCast(v *Value) (t time.Time, _ error) {
switch x := v.value.(type) {
case string:
return parseSQLTimestamp(x)
case []byte:
return parseSQLTimestamp(string(x))
case time.Time:
return x, nil
default:
return t, errCastFailure(fmt.Sprintf("cannot cast %v to Timestamp type", v.GetTypeString()))
}
}
func boolCast(v *Value) (b bool, _ error) {
sToB := func(s string) (bool, error) {
switch s {
case "true":
return true, nil
case "false":
return false, nil
default:
return false, errCastFailure("cannot cast to Bool")
}
}
switch x := v.value.(type) {
case bool:
return x, nil
case string:
return sToB(strings.ToLower(x))
case []byte:
return sToB(strings.ToLower(string(x)))
default:
return false, errCastFailure("cannot cast %v to Bool")
}
}