Add new SQL parser to support S3 Select syntax (#7102)

- New parser written from scratch, allows easier and complete parsing
  of the full S3 Select SQL syntax. Parser definition is directly
  provided by the AST defined for the SQL grammar.

- Bring support to parse and interpret SQL involving JSON path
  expressions; evaluation of JSON path expressions will be
  subsequently added.

- Bring automatic type inference and conversion for untyped
  values (e.g. CSV data).
This commit is contained in:
Aditya Manthramurthy
2019-01-28 17:59:48 -08:00
committed by Harshavardhana
parent 0a28c28a8c
commit 2786055df4
65 changed files with 6405 additions and 18231 deletions

View File

@@ -32,7 +32,11 @@ type Record struct {
nameIndexMap map[string]int64
}
// Get - gets the value for a column name.
// Get - gets the value for a column name. CSV fields do not have any
// defined type (other than the default string). So this function
// always returns fields using sql.FromBytes so that the type
// specified/implied by the query can be used, or can be automatically
// converted based on the query.
func (r *Record) Get(name string) (*sql.Value, error) {
index, found := r.nameIndexMap[name]
if !found {
@@ -40,11 +44,12 @@ func (r *Record) Get(name string) (*sql.Value, error) {
}
if index >= int64(len(r.csvRecord)) {
// No value found for column 'name', hence return empty string for compatibility.
return sql.NewString(""), nil
// No value found for column 'name', hence return null
// value
return sql.FromNull(), nil
}
return sql.NewString(r.csvRecord[index]), nil
return sql.FromBytes([]byte(r.csvRecord[index])), nil
}
// Set - sets the value for a column name.

View File

@@ -37,15 +37,15 @@ func (r *Record) Get(name string) (*sql.Value, error) {
result := gjson.GetBytes(r.data, name)
switch result.Type {
case gjson.Null:
return sql.NewNull(), nil
return sql.FromNull(), nil
case gjson.False:
return sql.NewBool(false), nil
return sql.FromBool(false), nil
case gjson.Number:
return sql.NewFloat(result.Float()), nil
return sql.FromFloat(result.Float()), nil
case gjson.String:
return sql.NewString(result.String()), nil
return sql.FromString(result.String()), nil
case gjson.True:
return sql.NewBool(true), nil
return sql.FromBool(true), nil
}
return nil, fmt.Errorf("unsupported gjson value %v; %v", result, result.Type)
@@ -54,19 +54,20 @@ func (r *Record) Get(name string) (*sql.Value, error) {
// Set - sets the value for a column name.
func (r *Record) Set(name string, value *sql.Value) (err error) {
var v interface{}
switch value.Type() {
case sql.Null:
v = value.NullValue()
case sql.Bool:
v = value.BoolValue()
case sql.Int:
v = value.IntValue()
case sql.Float:
v = value.FloatValue()
case sql.String:
v = value.StringValue()
default:
return fmt.Errorf("unsupported sql value %v and type %v", value, value.Type())
if b, ok := value.ToBool(); ok {
v = b
} else if f, ok := value.ToFloat(); ok {
v = f
} else if i, ok := value.ToInt(); ok {
v = i
} else if s, ok := value.ToString(); ok {
v = s
} else if value.IsNull() {
v = nil
} else if b, ok := value.ToBytes(); ok {
v = string(b)
} else {
return fmt.Errorf("unsupported sql value %v and type %v", value, value.GetTypeString())
}
name = strings.Replace(name, "*", "__ALL__", -1)

View File

@@ -32,7 +32,7 @@ type Reader struct {
}
// Read - reads single record.
func (r *Reader) Read() (sql.Record, error) {
func (r *Reader) Read() (rec sql.Record, rerr error) {
parquetRecord, err := r.file.Read()
if err != nil {
if err != io.EOF {
@@ -43,39 +43,41 @@ func (r *Reader) Read() (sql.Record, error) {
}
record := json.NewRecord()
for name, v := range parquetRecord {
f := func(name string, v parquetgo.Value) bool {
if v.Value == nil {
if err = record.Set(name, sql.NewNull()); err != nil {
return nil, errParquetParsingError(err)
if err := record.Set(name, sql.FromNull()); err != nil {
rerr = errParquetParsingError(err)
}
continue
return rerr == nil
}
var value *sql.Value
switch v.Type {
case parquetgen.Type_BOOLEAN:
value = sql.NewBool(v.Value.(bool))
value = sql.FromBool(v.Value.(bool))
case parquetgen.Type_INT32:
value = sql.NewInt(int64(v.Value.(int32)))
value = sql.FromInt(int64(v.Value.(int32)))
case parquetgen.Type_INT64:
value = sql.NewInt(v.Value.(int64))
value = sql.FromInt(int64(v.Value.(int64)))
case parquetgen.Type_FLOAT:
value = sql.NewFloat(float64(v.Value.(float32)))
value = sql.FromFloat(float64(v.Value.(float32)))
case parquetgen.Type_DOUBLE:
value = sql.NewFloat(v.Value.(float64))
value = sql.FromFloat(v.Value.(float64))
case parquetgen.Type_INT96, parquetgen.Type_BYTE_ARRAY, parquetgen.Type_FIXED_LEN_BYTE_ARRAY:
value = sql.NewString(string(v.Value.([]byte)))
value = sql.FromString(string(v.Value.([]byte)))
default:
return nil, errParquetParsingError(nil)
rerr = errParquetParsingError(nil)
return false
}
if err = record.Set(name, value); err != nil {
return nil, errParquetParsingError(err)
rerr = errParquetParsingError(err)
}
return rerr == nil
}
return record, nil
parquetRecord.Range(f)
return record, rerr
}
// Close - closes underlaying readers.

View File

@@ -105,7 +105,7 @@ func (input *InputSerialization) UnmarshalXML(d *xml.Decoder, start xml.StartEle
found++
}
if !parsedInput.ParquetArgs.IsEmpty() {
if parsedInput.CompressionType != noneType {
if parsedInput.CompressionType != "" && parsedInput.CompressionType != noneType {
return errInvalidRequestParameter(fmt.Errorf("CompressionType must be NONE for Parquet format"))
}
@@ -178,7 +178,7 @@ type S3Select struct {
Output OutputSerialization `xml:"OutputSerialization"`
Progress RequestProgress `xml:"RequestProgress"`
statement *sql.Select
statement *sql.SelectStatement
progressReader *progressReader
recordReader recordReader
}
@@ -209,12 +209,12 @@ func (s3Select *S3Select) UnmarshalXML(d *xml.Decoder, start xml.StartElement) e
return errMissingRequiredParameter(fmt.Errorf("OutputSerialization must be provided"))
}
statement, err := sql.NewSelect(parsedS3Select.Expression)
statement, err := sql.ParseSelectStatement(parsedS3Select.Expression)
if err != nil {
return err
}
parsedS3Select.statement = statement
parsedS3Select.statement = &statement
*s3Select = S3Select(parsedS3Select)
return nil
@@ -334,6 +334,14 @@ func (s3Select *S3Select) Evaluate(w http.ResponseWriter) {
}
for {
if s3Select.statement.LimitReached() {
if err = writer.SendStats(s3Select.getProgress()); err != nil {
// FIXME: log this error.
err = nil
}
break
}
if inputRecord, err = s3Select.recordReader.Read(); err != nil {
if err != io.EOF {
break
@@ -358,19 +366,25 @@ func (s3Select *S3Select) Evaluate(w http.ResponseWriter) {
break
}
outputRecord = s3Select.outputRecord()
if outputRecord, err = s3Select.statement.Eval(inputRecord, outputRecord); err != nil {
break
}
if s3Select.statement.IsAggregated() {
if err = s3Select.statement.AggregateRow(inputRecord); err != nil {
break
}
} else {
outputRecord = s3Select.outputRecord()
if outputRecord, err = s3Select.statement.Eval(inputRecord, outputRecord); err != nil {
break
}
if !s3Select.statement.IsAggregated() {
if !sendRecord() {
break
}
}
}
if err != nil {
fmt.Printf("SQL Err: %#v\n", err)
if serr := writer.SendError("InternalError", err.Error()); serr != nil {
// FIXME: log errors.
}

View File

@@ -0,0 +1,318 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import (
"errors"
"fmt"
)
// Aggregation Function name constants
const (
aggFnAvg FuncName = "AVG"
aggFnCount FuncName = "COUNT"
aggFnMax FuncName = "MAX"
aggFnMin FuncName = "MIN"
aggFnSum FuncName = "SUM"
)
var (
errNonNumericArg = func(fnStr FuncName) error {
return fmt.Errorf("%s() requires a numeric argument", fnStr)
}
errInvalidAggregation = errors.New("Invalid aggregation seen")
)
type aggVal struct {
runningSum *Value
runningCount int64
runningMax, runningMin *Value
// Stores if at least one record has been seen
seen bool
}
func newAggVal(fn FuncName) *aggVal {
switch fn {
case aggFnAvg, aggFnSum:
return &aggVal{runningSum: FromInt(0)}
case aggFnMin:
return &aggVal{runningMin: FromInt(0)}
case aggFnMax:
return &aggVal{runningMax: FromInt(0)}
default:
return &aggVal{}
}
}
// evalAggregationNode - performs partial computation using the
// current row and stores the result.
//
// On success, it returns (nil, nil).
func (e *FuncExpr) evalAggregationNode(r Record) error {
// It is assumed that this function is called only when
// `e` is an aggregation function.
var val *Value
var err error
funcName := e.getFunctionName()
if aggFnCount == funcName {
if e.Count.StarArg {
// Handle COUNT(*)
e.aggregate.runningCount++
return nil
}
val, err = e.Count.ExprArg.evalNode(r)
if err != nil {
return err
}
} else {
// Evaluate the (only) argument
val, err = e.SFunc.ArgsList[0].evalNode(r)
if err != nil {
return err
}
}
if val.IsNull() {
// E.g. the column or field does not exist in the
// record - in all such cases the aggregation is not
// updated.
return nil
}
argVal := val
if funcName != aggFnCount {
// All aggregation functions, except COUNT require a
// numeric argument.
// Here, we diverge from Amazon S3 behavior by
// inferring untyped values are numbers.
if i, ok := argVal.bytesToInt(); ok {
argVal.setInt(i)
} else if f, ok := argVal.bytesToFloat(); ok {
argVal.setFloat(f)
} else {
return errNonNumericArg(funcName)
}
}
// Mark that we have seen one non-null value.
isFirstRow := false
if !e.aggregate.seen {
e.aggregate.seen = true
isFirstRow = true
}
switch funcName {
case aggFnCount:
// For all non-null values, the count is incremented.
e.aggregate.runningCount++
case aggFnAvg:
e.aggregate.runningCount++
err = e.aggregate.runningSum.arithOp(opPlus, argVal)
case aggFnMin:
err = e.aggregate.runningMin.minmax(argVal, false, isFirstRow)
case aggFnMax:
err = e.aggregate.runningMax.minmax(argVal, true, isFirstRow)
case aggFnSum:
err = e.aggregate.runningSum.arithOp(opPlus, argVal)
default:
err = errInvalidAggregation
}
return err
}
func (e *AliasedExpression) aggregateRow(r Record) error {
return e.Expression.aggregateRow(r)
}
func (e *Expression) aggregateRow(r Record) error {
for _, ex := range e.And {
err := ex.aggregateRow(r)
if err != nil {
return err
}
}
return nil
}
func (e *AndCondition) aggregateRow(r Record) error {
for _, ex := range e.Condition {
err := ex.aggregateRow(r)
if err != nil {
return err
}
}
return nil
}
func (e *Condition) aggregateRow(r Record) error {
if e.Operand != nil {
return e.Operand.aggregateRow(r)
}
return e.Not.aggregateRow(r)
}
func (e *ConditionOperand) aggregateRow(r Record) error {
err := e.Operand.aggregateRow(r)
if err != nil {
return err
}
if e.ConditionRHS == nil {
return nil
}
switch {
case e.ConditionRHS.Compare != nil:
return e.ConditionRHS.Compare.Operand.aggregateRow(r)
case e.ConditionRHS.Between != nil:
err = e.ConditionRHS.Between.Start.aggregateRow(r)
if err != nil {
return err
}
return e.ConditionRHS.Between.End.aggregateRow(r)
case e.ConditionRHS.In != nil:
for _, elt := range e.ConditionRHS.In.Expressions {
err = elt.aggregateRow(r)
if err != nil {
return err
}
}
return nil
case e.ConditionRHS.Like != nil:
err = e.ConditionRHS.Like.Pattern.aggregateRow(r)
if err != nil {
return err
}
return e.ConditionRHS.Like.EscapeChar.aggregateRow(r)
default:
return errInvalidASTNode
}
}
func (e *Operand) aggregateRow(r Record) error {
err := e.Left.aggregateRow(r)
if err != nil {
return err
}
for _, rt := range e.Right {
err = rt.Right.aggregateRow(r)
if err != nil {
return err
}
}
return nil
}
func (e *MultOp) aggregateRow(r Record) error {
err := e.Left.aggregateRow(r)
if err != nil {
return err
}
for _, rt := range e.Right {
err = rt.Right.aggregateRow(r)
if err != nil {
return err
}
}
return nil
}
func (e *UnaryTerm) aggregateRow(r Record) error {
if e.Negated != nil {
return e.Negated.Term.aggregateRow(r)
}
return e.Primary.aggregateRow(r)
}
func (e *PrimaryTerm) aggregateRow(r Record) error {
switch {
case e.SubExpression != nil:
return e.SubExpression.aggregateRow(r)
case e.FuncCall != nil:
return e.FuncCall.aggregateRow(r)
}
return nil
}
func (e *FuncExpr) aggregateRow(r Record) error {
switch e.getFunctionName() {
case aggFnAvg, aggFnSum, aggFnMax, aggFnMin, aggFnCount:
return e.evalAggregationNode(r)
default:
// TODO: traverse arguments and call aggregateRow on
// them if they could be an ancestor of an
// aggregation.
}
return nil
}
// getAggregate() implementation for each AST node follows. This is
// called after calling aggregateRow() on each input row, to calculate
// the final aggregate result.
func (e *Expression) getAggregate() (*Value, error) {
return e.evalNode(nil)
}
func (e *FuncExpr) getAggregate() (*Value, error) {
switch e.getFunctionName() {
case aggFnCount:
return FromFloat(float64(e.aggregate.runningCount)), nil
case aggFnAvg:
if e.aggregate.runningCount == 0 {
// No rows were seen by AVG.
return FromNull(), nil
}
err := e.aggregate.runningSum.arithOp(opDivide, FromInt(e.aggregate.runningCount))
return e.aggregate.runningSum, err
case aggFnMin:
if !e.aggregate.seen {
// No rows were seen by MIN
return FromNull(), nil
}
return e.aggregate.runningMin, nil
case aggFnMax:
if !e.aggregate.seen {
// No rows were seen by MAX
return FromNull(), nil
}
return e.aggregate.runningMax, nil
case aggFnSum:
// TODO: check if returning 0 when no rows were seen
// by SUM is expected behavior.
return e.aggregate.runningSum, nil
default:
// TODO:
}
return nil, errInvalidAggregation
}

View File

@@ -0,0 +1,290 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import (
"errors"
"fmt"
)
// Query analysis - The query is analyzed to determine if it involves
// aggregation.
//
// Aggregation functions - An expression that involves aggregation of
// rows in some manner. Requires all input rows to be processed,
// before a result is returned.
//
// Row function - An expression that depends on a value in the
// row. They have an output for each input row.
//
// Some types of a queries are not valid. For example, an aggregation
// function combined with a row function is meaningless ("AVG(s.Age) +
// s.Salary"). Analysis determines if such a scenario exists so an
// error can be returned.
var (
// Fatal error for query processing.
errNestedAggregation = errors.New("Cannot nest aggregations")
errFunctionNotImplemented = errors.New("Function is not yet implemented")
errUnexpectedInvalidNode = errors.New("Unexpected node value")
errInvalidKeypath = errors.New("A provided keypath is invalid")
)
// qProp contains analysis info about an SQL term.
type qProp struct {
isAggregation, isRowFunc bool
err error
}
// `combine` combines a pair of `qProp`s, so that errors are
// propagated correctly, and checks that an aggregation is not being
// combined with a row-function term.
func (p *qProp) combine(q qProp) {
switch {
case p.err != nil:
// Do nothing
case q.err != nil:
p.err = q.err
default:
p.isAggregation = p.isAggregation || q.isAggregation
p.isRowFunc = p.isRowFunc || q.isRowFunc
if p.isAggregation && p.isRowFunc {
p.err = errNestedAggregation
}
}
}
func (e *SelectExpression) analyze(s *Select) (result qProp) {
if e.All {
return qProp{isRowFunc: true}
}
for _, ex := range e.Expressions {
result.combine(ex.analyze(s))
}
return
}
func (e *AliasedExpression) analyze(s *Select) qProp {
return e.Expression.analyze(s)
}
func (e *Expression) analyze(s *Select) (result qProp) {
for _, ac := range e.And {
result.combine(ac.analyze(s))
}
return
}
func (e *AndCondition) analyze(s *Select) (result qProp) {
for _, ac := range e.Condition {
result.combine(ac.analyze(s))
}
return
}
func (e *Condition) analyze(s *Select) (result qProp) {
if e.Operand != nil {
result = e.Operand.analyze(s)
} else {
result = e.Not.analyze(s)
}
return
}
func (e *ConditionOperand) analyze(s *Select) (result qProp) {
if e.ConditionRHS == nil {
result = e.Operand.analyze(s)
} else {
result.combine(e.Operand.analyze(s))
result.combine(e.ConditionRHS.analyze(s))
}
return
}
func (e *ConditionRHS) analyze(s *Select) (result qProp) {
switch {
case e.Compare != nil:
result = e.Compare.Operand.analyze(s)
case e.Between != nil:
result.combine(e.Between.Start.analyze(s))
result.combine(e.Between.End.analyze(s))
case e.In != nil:
for _, elt := range e.In.Expressions {
result.combine(elt.analyze(s))
}
case e.Like != nil:
result.combine(e.Like.Pattern.analyze(s))
if e.Like.EscapeChar != nil {
result.combine(e.Like.EscapeChar.analyze(s))
}
default:
result = qProp{err: errUnexpectedInvalidNode}
}
return
}
func (e *Operand) analyze(s *Select) (result qProp) {
result.combine(e.Left.analyze(s))
for _, r := range e.Right {
result.combine(r.Right.analyze(s))
}
return
}
func (e *MultOp) analyze(s *Select) (result qProp) {
result.combine(e.Left.analyze(s))
for _, r := range e.Right {
result.combine(r.Right.analyze(s))
}
return
}
func (e *UnaryTerm) analyze(s *Select) (result qProp) {
if e.Negated != nil {
result = e.Negated.Term.analyze(s)
} else {
result = e.Primary.analyze(s)
}
return
}
func (e *PrimaryTerm) analyze(s *Select) (result qProp) {
switch {
case e.Value != nil:
result = qProp{}
case e.JPathExpr != nil:
// Check if the path expression is valid
if len(e.JPathExpr.PathExpr) > 0 {
if e.JPathExpr.BaseKey.String() != s.From.As {
result = qProp{err: errInvalidKeypath}
return
}
}
result = qProp{isRowFunc: true}
case e.SubExpression != nil:
result = e.SubExpression.analyze(s)
case e.FuncCall != nil:
result = e.FuncCall.analyze(s)
default:
result = qProp{err: errUnexpectedInvalidNode}
}
return
}
func (e *FuncExpr) analyze(s *Select) (result qProp) {
funcName := e.getFunctionName()
switch funcName {
case sqlFnCast:
return e.Cast.Expr.analyze(s)
case sqlFnExtract:
return e.Extract.From.analyze(s)
// Handle aggregation function calls
case aggFnAvg, aggFnMax, aggFnMin, aggFnSum, aggFnCount:
// Initialize accumulator
e.aggregate = newAggVal(funcName)
var exprA qProp
if funcName == aggFnCount {
if e.Count.StarArg {
return qProp{isAggregation: true}
}
exprA = e.Count.ExprArg.analyze(s)
} else {
if len(e.SFunc.ArgsList) != 1 {
return qProp{err: fmt.Errorf("%s takes exactly one argument", funcName)}
}
exprA = e.SFunc.ArgsList[0].analyze(s)
}
if exprA.err != nil {
return exprA
}
if exprA.isAggregation {
return qProp{err: errNestedAggregation}
}
return qProp{isAggregation: true}
case sqlFnCoalesce:
if len(e.SFunc.ArgsList) == 0 {
return qProp{err: fmt.Errorf("%s needs at least one argument", string(funcName))}
}
for _, arg := range e.SFunc.ArgsList {
result.combine(arg.analyze(s))
}
return result
case sqlFnNullIf:
if len(e.SFunc.ArgsList) != 2 {
return qProp{err: fmt.Errorf("%s needs exactly 2 arguments", string(funcName))}
}
for _, arg := range e.SFunc.ArgsList {
result.combine(arg.analyze(s))
}
return result
case sqlFnCharLength, sqlFnCharacterLength:
if len(e.SFunc.ArgsList) != 1 {
return qProp{err: fmt.Errorf("%s needs exactly 2 arguments", string(funcName))}
}
for _, arg := range e.SFunc.ArgsList {
result.combine(arg.analyze(s))
}
return result
case sqlFnLower, sqlFnUpper:
if len(e.SFunc.ArgsList) != 1 {
return qProp{err: fmt.Errorf("%s needs exactly 2 arguments", string(funcName))}
}
for _, arg := range e.SFunc.ArgsList {
result.combine(arg.analyze(s))
}
return result
case sqlFnSubstring:
errVal := fmt.Errorf("Invalid argument(s) to %s", string(funcName))
result.combine(e.Substring.Expr.analyze(s))
switch {
case e.Substring.From != nil:
result.combine(e.Substring.From.analyze(s))
if e.Substring.For != nil {
result.combine(e.Substring.Expr.analyze(s))
}
case e.Substring.Arg2 != nil:
result.combine(e.Substring.Arg2.analyze(s))
if e.Substring.Arg3 != nil {
result.combine(e.Substring.Arg3.analyze(s))
}
default:
result.err = errVal
}
return result
}
// TODO: implement other functions
return qProp{err: errFunctionNotImplemented}
}

View File

@@ -1,175 +0,0 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import "fmt"
// ArithOperator - arithmetic operator.
type ArithOperator string
const (
// Add operator '+'.
Add ArithOperator = "+"
// Subtract operator '-'.
Subtract ArithOperator = "-"
// Multiply operator '*'.
Multiply ArithOperator = "*"
// Divide operator '/'.
Divide ArithOperator = "/"
// Modulo operator '%'.
Modulo ArithOperator = "%"
)
// arithExpr - arithmetic function.
type arithExpr struct {
left Expr
right Expr
operator ArithOperator
funcType Type
}
// String - returns string representation of this function.
func (f *arithExpr) String() string {
return fmt.Sprintf("(%v %v %v)", f.left, f.operator, f.right)
}
func (f *arithExpr) compute(lv, rv *Value) (*Value, error) {
leftValueType := lv.Type()
rightValueType := rv.Type()
if !leftValueType.isNumber() {
err := fmt.Errorf("%v: left side expression evaluated to %v; not to number", f, leftValueType)
return nil, errExternalEvalException(err)
}
if !rightValueType.isNumber() {
err := fmt.Errorf("%v: right side expression evaluated to %v; not to number", f, rightValueType)
return nil, errExternalEvalException(err)
}
leftValue := lv.FloatValue()
rightValue := rv.FloatValue()
var result float64
switch f.operator {
case Add:
result = leftValue + rightValue
case Subtract:
result = leftValue - rightValue
case Multiply:
result = leftValue * rightValue
case Divide:
result = leftValue / rightValue
case Modulo:
result = float64(int64(leftValue) % int64(rightValue))
}
if leftValueType == Float || rightValueType == Float {
return NewFloat(result), nil
}
return NewInt(int64(result)), nil
}
// Call - evaluates this function for given arg values and returns result as Value.
func (f *arithExpr) Eval(record Record) (*Value, error) {
leftValue, err := f.left.Eval(record)
if err != nil {
return nil, err
}
rightValue, err := f.right.Eval(record)
if err != nil {
return nil, err
}
if f.funcType == aggregateFunction {
return nil, nil
}
return f.compute(leftValue, rightValue)
}
// AggregateValue - returns aggregated value.
func (f *arithExpr) AggregateValue() (*Value, error) {
if f.funcType != aggregateFunction {
err := fmt.Errorf("%v is not aggreate expression", f)
return nil, errExternalEvalException(err)
}
lv, err := f.left.AggregateValue()
if err != nil {
return nil, err
}
rv, err := f.right.AggregateValue()
if err != nil {
return nil, err
}
return f.compute(lv, rv)
}
// Type - returns arithmeticFunction or aggregateFunction type.
func (f *arithExpr) Type() Type {
return f.funcType
}
// ReturnType - returns Float as return type.
func (f *arithExpr) ReturnType() Type {
return Float
}
// newArithExpr - creates new arithmetic function.
func newArithExpr(operator ArithOperator, left, right Expr) (*arithExpr, error) {
if !left.ReturnType().isNumberKind() {
err := fmt.Errorf("operator %v: left side expression %v evaluate to %v, not number", operator, left, left.ReturnType())
return nil, errInvalidDataType(err)
}
if !right.ReturnType().isNumberKind() {
err := fmt.Errorf("operator %v: right side expression %v evaluate to %v; not number", operator, right, right.ReturnType())
return nil, errInvalidDataType(err)
}
funcType := arithmeticFunction
if left.Type() == aggregateFunction || right.Type() == aggregateFunction {
funcType = aggregateFunction
switch left.Type() {
case Int, Float, aggregateFunction:
default:
err := fmt.Errorf("operator %v: left side expression %v return type %v is incompatible for aggregate evaluation", operator, left, left.Type())
return nil, errUnsupportedSQLOperation(err)
}
switch right.Type() {
case Int, Float, aggregateFunction:
default:
err := fmt.Errorf("operator %v: right side expression %v return type %v is incompatible for aggregate evaluation", operator, right, right.Type())
return nil, errUnsupportedSQLOperation(err)
}
}
return &arithExpr{
left: left,
right: right,
operator: operator,
funcType: funcType,
}, nil
}

View File

@@ -1,636 +0,0 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import (
"fmt"
"regexp"
"strings"
)
// ComparisonOperator - comparison operator.
type ComparisonOperator string
const (
// Equal operator '='.
Equal ComparisonOperator = "="
// NotEqual operator '!=' or '<>'.
NotEqual ComparisonOperator = "!="
// LessThan operator '<'.
LessThan ComparisonOperator = "<"
// GreaterThan operator '>'.
GreaterThan ComparisonOperator = ">"
// LessThanEqual operator '<='.
LessThanEqual ComparisonOperator = "<="
// GreaterThanEqual operator '>='.
GreaterThanEqual ComparisonOperator = ">="
// Between operator 'BETWEEN'
Between ComparisonOperator = "between"
// In operator 'IN'
In ComparisonOperator = "in"
// Like operator 'LIKE'
Like ComparisonOperator = "like"
// NotBetween operator 'NOT BETWEEN'
NotBetween ComparisonOperator = "not between"
// NotIn operator 'NOT IN'
NotIn ComparisonOperator = "not in"
// NotLike operator 'NOT LIKE'
NotLike ComparisonOperator = "not like"
// IsNull operator 'IS NULL'
IsNull ComparisonOperator = "is null"
// IsNotNull operator 'IS NOT NULL'
IsNotNull ComparisonOperator = "is not null"
)
// String - returns string representation of this operator.
func (operator ComparisonOperator) String() string {
return strings.ToUpper((string(operator)))
}
func equal(leftValue, rightValue *Value) (bool, error) {
switch {
case leftValue.Type() == Null && rightValue.Type() == Null:
return true, nil
case leftValue.Type() == Bool && rightValue.Type() == Bool:
return leftValue.BoolValue() == rightValue.BoolValue(), nil
case (leftValue.Type() == Int || leftValue.Type() == Float) &&
(rightValue.Type() == Int || rightValue.Type() == Float):
return leftValue.FloatValue() == rightValue.FloatValue(), nil
case leftValue.Type() == String && rightValue.Type() == String:
return leftValue.StringValue() == rightValue.StringValue(), nil
case leftValue.Type() == Timestamp && rightValue.Type() == Timestamp:
return leftValue.TimeValue() == rightValue.TimeValue(), nil
}
return false, fmt.Errorf("left value type %v and right value type %v are incompatible for equality check", leftValue.Type(), rightValue.Type())
}
// comparisonExpr - comparison function.
type comparisonExpr struct {
left Expr
right Expr
to Expr
operator ComparisonOperator
funcType Type
}
// String - returns string representation of this function.
func (f *comparisonExpr) String() string {
switch f.operator {
case Equal, NotEqual, LessThan, GreaterThan, LessThanEqual, GreaterThanEqual, In, Like, NotIn, NotLike:
return fmt.Sprintf("(%v %v %v)", f.left, f.operator, f.right)
case Between, NotBetween:
return fmt.Sprintf("(%v %v %v AND %v)", f.left, f.operator, f.right, f.to)
}
return fmt.Sprintf("(%v %v %v %v)", f.left, f.right, f.to, f.operator)
}
func (f *comparisonExpr) equal(leftValue, rightValue *Value) (*Value, error) {
result, err := equal(leftValue, rightValue)
if err != nil {
err = fmt.Errorf("%v: %v", f, err)
return nil, errExternalEvalException(err)
}
return NewBool(result), nil
}
func (f *comparisonExpr) notEqual(leftValue, rightValue *Value) (*Value, error) {
result, err := equal(leftValue, rightValue)
if err != nil {
err = fmt.Errorf("%v: %v", f, err)
return nil, errExternalEvalException(err)
}
return NewBool(!result), nil
}
func (f *comparisonExpr) lessThan(leftValue, rightValue *Value) (*Value, error) {
if !leftValue.Type().isNumber() {
err := fmt.Errorf("%v: left side expression evaluated to %v; not to number", f, leftValue.Type())
return nil, errExternalEvalException(err)
}
if !rightValue.Type().isNumber() {
err := fmt.Errorf("%v: right side expression evaluated to %v; not to number", f, rightValue.Type())
return nil, errExternalEvalException(err)
}
return NewBool(leftValue.FloatValue() < rightValue.FloatValue()), nil
}
func (f *comparisonExpr) greaterThan(leftValue, rightValue *Value) (*Value, error) {
if !leftValue.Type().isNumber() {
err := fmt.Errorf("%v: left side expression evaluated to %v; not to number", f, leftValue.Type())
return nil, errExternalEvalException(err)
}
if !rightValue.Type().isNumber() {
err := fmt.Errorf("%v: right side expression evaluated to %v; not to number", f, rightValue.Type())
return nil, errExternalEvalException(err)
}
return NewBool(leftValue.FloatValue() > rightValue.FloatValue()), nil
}
func (f *comparisonExpr) lessThanEqual(leftValue, rightValue *Value) (*Value, error) {
if !leftValue.Type().isNumber() {
err := fmt.Errorf("%v: left side expression evaluated to %v; not to number", f, leftValue.Type())
return nil, errExternalEvalException(err)
}
if !rightValue.Type().isNumber() {
err := fmt.Errorf("%v: right side expression evaluated to %v; not to number", f, rightValue.Type())
return nil, errExternalEvalException(err)
}
return NewBool(leftValue.FloatValue() <= rightValue.FloatValue()), nil
}
func (f *comparisonExpr) greaterThanEqual(leftValue, rightValue *Value) (*Value, error) {
if !leftValue.Type().isNumber() {
err := fmt.Errorf("%v: left side expression evaluated to %v; not to number", f, leftValue.Type())
return nil, errExternalEvalException(err)
}
if !rightValue.Type().isNumber() {
err := fmt.Errorf("%v: right side expression evaluated to %v; not to number", f, rightValue.Type())
return nil, errExternalEvalException(err)
}
return NewBool(leftValue.FloatValue() >= rightValue.FloatValue()), nil
}
func (f *comparisonExpr) computeBetween(leftValue, fromValue, toValue *Value) (bool, error) {
if !leftValue.Type().isNumber() {
err := fmt.Errorf("%v: left side expression evaluated to %v; not to number", f, leftValue.Type())
return false, errExternalEvalException(err)
}
if !fromValue.Type().isNumber() {
err := fmt.Errorf("%v: from side expression evaluated to %v; not to number", f, fromValue.Type())
return false, errExternalEvalException(err)
}
if !toValue.Type().isNumber() {
err := fmt.Errorf("%v: to side expression evaluated to %v; not to number", f, toValue.Type())
return false, errExternalEvalException(err)
}
return leftValue.FloatValue() >= fromValue.FloatValue() &&
leftValue.FloatValue() <= toValue.FloatValue(), nil
}
func (f *comparisonExpr) between(leftValue, fromValue, toValue *Value) (*Value, error) {
result, err := f.computeBetween(leftValue, fromValue, toValue)
if err != nil {
return nil, err
}
return NewBool(result), nil
}
func (f *comparisonExpr) notBetween(leftValue, fromValue, toValue *Value) (*Value, error) {
result, err := f.computeBetween(leftValue, fromValue, toValue)
if err != nil {
return nil, err
}
return NewBool(!result), nil
}
func (f *comparisonExpr) computeIn(leftValue, rightValue *Value) (found bool, err error) {
if rightValue.Type() != Array {
err := fmt.Errorf("%v: right side expression evaluated to %v; not to Array", f, rightValue.Type())
return false, errExternalEvalException(err)
}
values := rightValue.ArrayValue()
for i := range values {
found, err = equal(leftValue, values[i])
if err != nil {
return false, err
}
if found {
return true, nil
}
}
return false, nil
}
func (f *comparisonExpr) in(leftValue, rightValue *Value) (*Value, error) {
result, err := f.computeIn(leftValue, rightValue)
if err != nil {
err = fmt.Errorf("%v: %v", f, err)
return nil, errExternalEvalException(err)
}
return NewBool(result), nil
}
func (f *comparisonExpr) notIn(leftValue, rightValue *Value) (*Value, error) {
result, err := f.computeIn(leftValue, rightValue)
if err != nil {
err = fmt.Errorf("%v: %v", f, err)
return nil, errExternalEvalException(err)
}
return NewBool(!result), nil
}
func (f *comparisonExpr) computeLike(leftValue, rightValue *Value) (matched bool, err error) {
if leftValue.Type() != String {
err := fmt.Errorf("%v: left side expression evaluated to %v; not to string", f, leftValue.Type())
return false, errExternalEvalException(err)
}
if rightValue.Type() != String {
err := fmt.Errorf("%v: right side expression evaluated to %v; not to string", f, rightValue.Type())
return false, errExternalEvalException(err)
}
matched, err = regexp.MatchString(rightValue.StringValue(), leftValue.StringValue())
if err != nil {
err = fmt.Errorf("%v: %v", f, err)
return false, errExternalEvalException(err)
}
return matched, nil
}
func (f *comparisonExpr) like(leftValue, rightValue *Value) (*Value, error) {
result, err := f.computeLike(leftValue, rightValue)
if err != nil {
return nil, err
}
return NewBool(result), nil
}
func (f *comparisonExpr) notLike(leftValue, rightValue *Value) (*Value, error) {
result, err := f.computeLike(leftValue, rightValue)
if err != nil {
return nil, err
}
return NewBool(!result), nil
}
func (f *comparisonExpr) compute(leftValue, rightValue, toValue *Value) (*Value, error) {
switch f.operator {
case Equal:
return f.equal(leftValue, rightValue)
case NotEqual:
return f.notEqual(leftValue, rightValue)
case LessThan:
return f.lessThan(leftValue, rightValue)
case GreaterThan:
return f.greaterThan(leftValue, rightValue)
case LessThanEqual:
return f.lessThanEqual(leftValue, rightValue)
case GreaterThanEqual:
return f.greaterThanEqual(leftValue, rightValue)
case Between:
return f.between(leftValue, rightValue, toValue)
case In:
return f.in(leftValue, rightValue)
case Like:
return f.like(leftValue, rightValue)
case NotBetween:
return f.notBetween(leftValue, rightValue, toValue)
case NotIn:
return f.notIn(leftValue, rightValue)
case NotLike:
return f.notLike(leftValue, rightValue)
}
panic(fmt.Errorf("unexpected expression %v", f))
}
// Call - evaluates this function for given arg values and returns result as Value.
func (f *comparisonExpr) Eval(record Record) (*Value, error) {
leftValue, err := f.left.Eval(record)
if err != nil {
return nil, err
}
rightValue, err := f.right.Eval(record)
if err != nil {
return nil, err
}
var toValue *Value
if f.to != nil {
toValue, err = f.to.Eval(record)
if err != nil {
return nil, err
}
}
if f.funcType == aggregateFunction {
return nil, nil
}
return f.compute(leftValue, rightValue, toValue)
}
// AggregateValue - returns aggregated value.
func (f *comparisonExpr) AggregateValue() (*Value, error) {
if f.funcType != aggregateFunction {
err := fmt.Errorf("%v is not aggreate expression", f)
return nil, errExternalEvalException(err)
}
leftValue, err := f.left.AggregateValue()
if err != nil {
return nil, err
}
rightValue, err := f.right.AggregateValue()
if err != nil {
return nil, err
}
var toValue *Value
if f.to != nil {
toValue, err = f.to.AggregateValue()
if err != nil {
return nil, err
}
}
return f.compute(leftValue, rightValue, toValue)
}
// Type - returns comparisonFunction or aggregateFunction type.
func (f *comparisonExpr) Type() Type {
return f.funcType
}
// ReturnType - returns Bool as return type.
func (f *comparisonExpr) ReturnType() Type {
return Bool
}
// newComparisonExpr - creates new comparison function.
func newComparisonExpr(operator ComparisonOperator, funcs ...Expr) (*comparisonExpr, error) {
funcType := comparisonFunction
switch operator {
case Equal, NotEqual:
if len(funcs) != 2 {
panic(fmt.Sprintf("exactly two arguments are expected, but found %v", len(funcs)))
}
left := funcs[0]
if !left.ReturnType().isBaseKind() {
err := fmt.Errorf("operator %v: left side expression %v evaluate to %v is incompatible for equality check", operator, left, left.ReturnType())
return nil, errInvalidDataType(err)
}
right := funcs[1]
if !right.ReturnType().isBaseKind() {
err := fmt.Errorf("operator %v: right side expression %v evaluate to %v is incompatible for equality check", operator, right, right.ReturnType())
return nil, errInvalidDataType(err)
}
if left.Type() == aggregateFunction || right.Type() == aggregateFunction {
funcType = aggregateFunction
switch left.Type() {
case column, Array, function, arithmeticFunction, comparisonFunction, logicalFunction, record:
err := fmt.Errorf("operator %v: left side expression %v return type %v is incompatible for equality check", operator, left, left.Type())
return nil, errUnsupportedSQLOperation(err)
}
switch right.Type() {
case column, Array, function, arithmeticFunction, comparisonFunction, logicalFunction, record:
err := fmt.Errorf("operator %v: right side expression %v return type %v is incompatible for equality check", operator, right, right.Type())
return nil, errUnsupportedSQLOperation(err)
}
}
return &comparisonExpr{
left: left,
right: right,
operator: operator,
funcType: funcType,
}, nil
case LessThan, GreaterThan, LessThanEqual, GreaterThanEqual:
if len(funcs) != 2 {
panic(fmt.Sprintf("exactly two arguments are expected, but found %v", len(funcs)))
}
left := funcs[0]
if !left.ReturnType().isNumberKind() {
err := fmt.Errorf("operator %v: left side expression %v evaluate to %v, not number", operator, left, left.ReturnType())
return nil, errInvalidDataType(err)
}
right := funcs[1]
if !right.ReturnType().isNumberKind() {
err := fmt.Errorf("operator %v: right side expression %v evaluate to %v; not number", operator, right, right.ReturnType())
return nil, errInvalidDataType(err)
}
if left.Type() == aggregateFunction || right.Type() == aggregateFunction {
funcType = aggregateFunction
switch left.Type() {
case Int, Float, aggregateFunction:
default:
err := fmt.Errorf("operator %v: left side expression %v return type %v is incompatible for aggregate evaluation", operator, left, left.Type())
return nil, errUnsupportedSQLOperation(err)
}
switch right.Type() {
case Int, Float, aggregateFunction:
default:
err := fmt.Errorf("operator %v: right side expression %v return type %v is incompatible for aggregate evaluation", operator, right, right.Type())
return nil, errUnsupportedSQLOperation(err)
}
}
return &comparisonExpr{
left: left,
right: right,
operator: operator,
funcType: funcType,
}, nil
case In, NotIn:
if len(funcs) != 2 {
panic(fmt.Sprintf("exactly two arguments are expected, but found %v", len(funcs)))
}
left := funcs[0]
if !left.ReturnType().isBaseKind() {
err := fmt.Errorf("operator %v: left side expression %v evaluate to %v is incompatible for equality check", operator, left, left.ReturnType())
return nil, errInvalidDataType(err)
}
right := funcs[1]
if right.ReturnType() != Array {
err := fmt.Errorf("operator %v: right side expression %v evaluate to %v is incompatible for equality check", operator, right, right.ReturnType())
return nil, errInvalidDataType(err)
}
if left.Type() == aggregateFunction || right.Type() == aggregateFunction {
funcType = aggregateFunction
switch left.Type() {
case column, Array, function, arithmeticFunction, comparisonFunction, logicalFunction, record:
err := fmt.Errorf("operator %v: left side expression %v return type %v is incompatible for aggregate evaluation", operator, left, left.Type())
return nil, errUnsupportedSQLOperation(err)
}
switch right.Type() {
case Array, aggregateFunction:
default:
err := fmt.Errorf("operator %v: right side expression %v return type %v is incompatible for aggregate evaluation", operator, right, right.Type())
return nil, errUnsupportedSQLOperation(err)
}
}
return &comparisonExpr{
left: left,
right: right,
operator: operator,
funcType: funcType,
}, nil
case Like, NotLike:
if len(funcs) != 2 {
panic(fmt.Sprintf("exactly two arguments are expected, but found %v", len(funcs)))
}
left := funcs[0]
if !left.ReturnType().isStringKind() {
err := fmt.Errorf("operator %v: left side expression %v evaluate to %v, not string", operator, left, left.ReturnType())
return nil, errLikeInvalidInputs(err)
}
right := funcs[1]
if !right.ReturnType().isStringKind() {
err := fmt.Errorf("operator %v: right side expression %v evaluate to %v, not string", operator, right, right.ReturnType())
return nil, errLikeInvalidInputs(err)
}
if left.Type() == aggregateFunction || right.Type() == aggregateFunction {
funcType = aggregateFunction
switch left.Type() {
case String, aggregateFunction:
default:
err := fmt.Errorf("operator %v: left side expression %v return type %v is incompatible for aggregate evaluation", operator, left, left.Type())
return nil, errUnsupportedSQLOperation(err)
}
switch right.Type() {
case String, aggregateFunction:
default:
err := fmt.Errorf("operator %v: right side expression %v return type %v is incompatible for aggregate evaluation", operator, right, right.Type())
return nil, errUnsupportedSQLOperation(err)
}
}
return &comparisonExpr{
left: left,
right: right,
operator: operator,
funcType: funcType,
}, nil
case Between, NotBetween:
if len(funcs) != 3 {
panic(fmt.Sprintf("too many values in funcs %v", funcs))
}
left := funcs[0]
if !left.ReturnType().isNumberKind() {
err := fmt.Errorf("operator %v: left side expression %v evaluate to %v, not number", operator, left, left.ReturnType())
return nil, errInvalidDataType(err)
}
from := funcs[1]
if !from.ReturnType().isNumberKind() {
err := fmt.Errorf("operator %v: from expression %v evaluate to %v, not number", operator, from, from.ReturnType())
return nil, errInvalidDataType(err)
}
to := funcs[2]
if !to.ReturnType().isNumberKind() {
err := fmt.Errorf("operator %v: to expression %v evaluate to %v, not number", operator, to, to.ReturnType())
return nil, errInvalidDataType(err)
}
if left.Type() == aggregateFunction || from.Type() == aggregateFunction || to.Type() == aggregateFunction {
funcType = aggregateFunction
switch left.Type() {
case Int, Float, aggregateFunction:
default:
err := fmt.Errorf("operator %v: left side expression %v return type %v is incompatible for aggregate evaluation", operator, left, left.Type())
return nil, errUnsupportedSQLOperation(err)
}
switch from.Type() {
case Int, Float, aggregateFunction:
default:
err := fmt.Errorf("operator %v: from expression %v return type %v is incompatible for aggregate evaluation", operator, from, from.Type())
return nil, errUnsupportedSQLOperation(err)
}
switch to.Type() {
case Int, Float, aggregateFunction:
default:
err := fmt.Errorf("operator %v: to expression %v return type %v is incompatible for aggregate evaluation", operator, to, to.Type())
return nil, errUnsupportedSQLOperation(err)
}
}
return &comparisonExpr{
left: left,
right: from,
to: to,
operator: operator,
funcType: funcType,
}, nil
case IsNull, IsNotNull:
if len(funcs) != 1 {
panic(fmt.Sprintf("too many values in funcs %v", funcs))
}
if funcs[0].Type() == aggregateFunction {
funcType = aggregateFunction
}
if operator == IsNull {
operator = Equal
} else {
operator = NotEqual
}
return &comparisonExpr{
left: funcs[0],
right: newValueExpr(NewNull()),
operator: operator,
funcType: funcType,
}, nil
}
return nil, errParseUnknownOperator(fmt.Errorf("unknown operator %v", operator))
}

View File

@@ -43,42 +43,6 @@ func (err *s3Error) Error() string {
return err.message
}
func errUnsupportedSQLStructure(err error) *s3Error {
return &s3Error{
code: "UnsupportedSqlStructure",
message: "Encountered an unsupported SQL structure. Check the SQL Reference.",
statusCode: 400,
cause: err,
}
}
func errParseUnsupportedSelect(err error) *s3Error {
return &s3Error{
code: "ParseUnsupportedSelect",
message: "The SQL expression contains an unsupported use of SELECT.",
statusCode: 400,
cause: err,
}
}
func errParseAsteriskIsNotAloneInSelectList(err error) *s3Error {
return &s3Error{
code: "ParseAsteriskIsNotAloneInSelectList",
message: "Other expressions are not allowed in the SELECT list when '*' is used without dot notation in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errParseInvalidContextForWildcardInSelectList(err error) *s3Error {
return &s3Error{
code: "ParseInvalidContextForWildcardInSelectList",
message: "Invalid use of * in SELECT list in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errInvalidDataType(err error) *s3Error {
return &s3Error{
code: "InvalidDataType",
@@ -88,55 +52,10 @@ func errInvalidDataType(err error) *s3Error {
}
}
func errUnsupportedFunction(err error) *s3Error {
return &s3Error{
code: "UnsupportedFunction",
message: "Encountered an unsupported SQL function.",
statusCode: 400,
cause: err,
}
}
func errParseNonUnaryAgregateFunctionCall(err error) *s3Error {
return &s3Error{
code: "ParseNonUnaryAgregateFunctionCall",
message: "Only one argument is supported for aggregate functions in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errIncorrectSQLFunctionArgumentType(err error) *s3Error {
return &s3Error{
code: "IncorrectSqlFunctionArgumentType",
message: "Incorrect type of arguments in function call in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errEvaluatorInvalidArguments(err error) *s3Error {
return &s3Error{
code: "EvaluatorInvalidArguments",
message: "Incorrect number of arguments in the function call in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errUnsupportedSQLOperation(err error) *s3Error {
return &s3Error{
code: "UnsupportedSqlOperation",
message: "Encountered an unsupported SQL operation.",
statusCode: 400,
cause: err,
}
}
func errParseUnknownOperator(err error) *s3Error {
return &s3Error{
code: "ParseUnknownOperator",
message: "The SQL expression contains an invalid operator.",
message: "Incorrect type of arguments in function call.",
statusCode: 400,
cause: err,
}
@@ -151,64 +70,28 @@ func errLikeInvalidInputs(err error) *s3Error {
}
}
func errExternalEvalException(err error) *s3Error {
func errQueryParseFailure(err error) *s3Error {
return &s3Error{
code: "ExternalEvalException",
message: "The query cannot be evaluated. Check the file and try again.",
code: "ParseSelectFailure",
message: err.Error(),
statusCode: 400,
cause: err,
}
}
func errValueParseFailure(err error) *s3Error {
func errQueryAnalysisFailure(err error) *s3Error {
return &s3Error{
code: "ValueParseFailure",
message: "Time stamp parse failure in the SQL expression.",
code: "InvalidQuery",
message: err.Error(),
statusCode: 400,
cause: err,
}
}
func errEvaluatorBindingDoesNotExist(err error) *s3Error {
func errBadTableName(err error) *s3Error {
return &s3Error{
code: "EvaluatorBindingDoesNotExist",
message: "A column name or a path provided does not exist in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errInternalError(err error) *s3Error {
return &s3Error{
code: "InternalError",
message: "Encountered an internal error.",
statusCode: 500,
cause: err,
}
}
func errParseInvalidTypeParam(err error) *s3Error {
return &s3Error{
code: "ParseInvalidTypeParam",
message: "The SQL expression contains an invalid parameter value.",
statusCode: 400,
cause: err,
}
}
func errParseUnsupportedSyntax(err error) *s3Error {
return &s3Error{
code: "ParseUnsupportedSyntax",
message: "The SQL expression contains unsupported syntax.",
statusCode: 400,
cause: err,
}
}
func errInvalidKeyPath(err error) *s3Error {
return &s3Error{
code: "InvalidKeyPath",
message: "Key path in the SQL expression is invalid.",
code: "BadTableName",
message: "The table name is not supported",
statusCode: 400,
cause: err,
}

View File

@@ -0,0 +1,361 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import (
"errors"
"strings"
)
var (
errInvalidASTNode = errors.New("invalid AST Node")
errExpectedBool = errors.New("expected bool")
errLikeNonStrArg = errors.New("LIKE clause requires string arguments")
errLikeInvalidEscape = errors.New("LIKE clause has invalid ESCAPE character")
errNotImplemented = errors.New("not implemented")
)
// AST Node Evaluation functions
//
// During evaluation, the query is known to be valid, as analysis is
// complete. The only errors possible are due to value type
// mismatches, etc.
//
// If an aggregation node is present as a descendant (when
// e.prop.isAggregation is true), we call evalNode on all child nodes,
// check for errors, but do not perform any combining of the results
// of child nodes. The final result row is returned after all rows are
// processed, and the `getAggregate` function is called.
func (e *AliasedExpression) evalNode(r Record) (*Value, error) {
return e.Expression.evalNode(r)
}
func (e *Expression) evalNode(r Record) (*Value, error) {
if len(e.And) == 1 {
// In this case, result is not required to be boolean
// type.
return e.And[0].evalNode(r)
}
// Compute OR of conditions
result := false
for _, ex := range e.And {
res, err := ex.evalNode(r)
if err != nil {
return nil, err
}
b, ok := res.ToBool()
if !ok {
return nil, errExpectedBool
}
result = result || b
}
return FromBool(result), nil
}
func (e *AndCondition) evalNode(r Record) (*Value, error) {
if len(e.Condition) == 1 {
// In this case, result does not have to be boolean
return e.Condition[0].evalNode(r)
}
// Compute AND of conditions
result := true
for _, ex := range e.Condition {
res, err := ex.evalNode(r)
if err != nil {
return nil, err
}
b, ok := res.ToBool()
if !ok {
return nil, errExpectedBool
}
result = result && b
}
return FromBool(result), nil
}
func (e *Condition) evalNode(r Record) (*Value, error) {
if e.Operand != nil {
// In this case, result does not have to be boolean
return e.Operand.evalNode(r)
}
// Compute NOT of condition
res, err := e.Not.evalNode(r)
if err != nil {
return nil, err
}
b, ok := res.ToBool()
if !ok {
return nil, errExpectedBool
}
return FromBool(!b), nil
}
func (e *ConditionOperand) evalNode(r Record) (*Value, error) {
opVal, opErr := e.Operand.evalNode(r)
if opErr != nil || e.ConditionRHS == nil {
return opVal, opErr
}
// Need to evaluate the ConditionRHS
switch {
case e.ConditionRHS.Compare != nil:
cmpRight, cmpRErr := e.ConditionRHS.Compare.Operand.evalNode(r)
if cmpRErr != nil {
return nil, cmpRErr
}
b, err := opVal.compareOp(e.ConditionRHS.Compare.Operator, cmpRight)
return FromBool(b), err
case e.ConditionRHS.Between != nil:
return e.ConditionRHS.Between.evalBetweenNode(r, opVal)
case e.ConditionRHS.Like != nil:
return e.ConditionRHS.Like.evalLikeNode(r, opVal)
case e.ConditionRHS.In != nil:
return e.ConditionRHS.In.evalInNode(r, opVal)
default:
return nil, errInvalidASTNode
}
}
func (e *Between) evalBetweenNode(r Record, arg *Value) (*Value, error) {
stVal, stErr := e.Start.evalNode(r)
if stErr != nil {
return nil, stErr
}
endVal, endErr := e.End.evalNode(r)
if endErr != nil {
return nil, endErr
}
part1, err1 := stVal.compareOp(opLte, arg)
if err1 != nil {
return nil, err1
}
part2, err2 := arg.compareOp(opLte, endVal)
if err2 != nil {
return nil, err2
}
result := part1 && part2
if e.Not {
result = !result
}
return FromBool(result), nil
}
func (e *Like) evalLikeNode(r Record, arg *Value) (*Value, error) {
inferTypeAsString(arg)
s, ok := arg.ToString()
if !ok {
err := errLikeNonStrArg
return nil, errLikeInvalidInputs(err)
}
pattern, err1 := e.Pattern.evalNode(r)
if err1 != nil {
return nil, err1
}
// Infer pattern as string (in case it is untyped)
inferTypeAsString(pattern)
patternStr, ok := pattern.ToString()
if !ok {
err := errLikeNonStrArg
return nil, errLikeInvalidInputs(err)
}
escape := runeZero
if e.EscapeChar != nil {
escapeVal, err2 := e.EscapeChar.evalNode(r)
if err2 != nil {
return nil, err2
}
inferTypeAsString(escapeVal)
escapeStr, ok := escapeVal.ToString()
if !ok {
err := errLikeNonStrArg
return nil, errLikeInvalidInputs(err)
}
if len([]rune(escapeStr)) > 1 {
err := errLikeInvalidEscape
return nil, errLikeInvalidInputs(err)
}
}
matchResult, err := evalSQLLike(s, patternStr, escape)
if err != nil {
return nil, err
}
if e.Not {
matchResult = !matchResult
}
return FromBool(matchResult), nil
}
func (e *In) evalInNode(r Record, arg *Value) (*Value, error) {
result := false
for _, elt := range e.Expressions {
eltVal, err := elt.evalNode(r)
if err != nil {
return nil, err
}
// FIXME: type inference?
// Types must match.
if arg.vType != eltVal.vType {
// match failed.
continue
}
if arg.value == eltVal.value {
result = true
break
}
}
return FromBool(result), nil
}
func (e *Operand) evalNode(r Record) (*Value, error) {
lval, lerr := e.Left.evalNode(r)
if lerr != nil || len(e.Right) == 0 {
return lval, lerr
}
// Process remaining child nodes - result must be
// numeric. This AST node is for terms separated by + or -
// symbols.
for _, rightTerm := range e.Right {
op := rightTerm.Op
rval, rerr := rightTerm.Right.evalNode(r)
if rerr != nil {
return nil, rerr
}
err := lval.arithOp(op, rval)
if err != nil {
return nil, err
}
}
return lval, nil
}
func (e *MultOp) evalNode(r Record) (*Value, error) {
lval, lerr := e.Left.evalNode(r)
if lerr != nil || len(e.Right) == 0 {
return lval, lerr
}
// Process other child nodes - result must be numeric. This
// AST node is for terms separated by *, / or % symbols.
for _, rightTerm := range e.Right {
op := rightTerm.Op
rval, rerr := rightTerm.Right.evalNode(r)
if rerr != nil {
return nil, rerr
}
err := lval.arithOp(op, rval)
if err != nil {
return nil, err
}
}
return lval, nil
}
func (e *UnaryTerm) evalNode(r Record) (*Value, error) {
if e.Negated == nil {
return e.Primary.evalNode(r)
}
v, err := e.Negated.Term.evalNode(r)
if err != nil {
return nil, err
}
inferTypeForArithOp(v)
if ival, ok := v.ToInt(); ok {
return FromInt(-ival), nil
} else if fval, ok := v.ToFloat(); ok {
return FromFloat(-fval), nil
}
return nil, errArithMismatchedTypes
}
func (e *JSONPath) evalNode(r Record) (*Value, error) {
// Strip the table name from the keypath.
keypath := e.String()
ps := strings.SplitN(keypath, ".", 2)
if len(ps) == 2 {
keypath = ps[1]
}
return r.Get(keypath)
}
func (e *PrimaryTerm) evalNode(r Record) (res *Value, err error) {
switch {
case e.Value != nil:
return e.Value.evalNode(r)
case e.JPathExpr != nil:
return e.JPathExpr.evalNode(r)
case e.SubExpression != nil:
return e.SubExpression.evalNode(r)
case e.FuncCall != nil:
return e.FuncCall.evalNode(r)
}
return nil, errInvalidASTNode
}
func (e *FuncExpr) evalNode(r Record) (res *Value, err error) {
switch e.getFunctionName() {
case aggFnCount, aggFnAvg, aggFnMax, aggFnMin, aggFnSum:
return e.getAggregate()
default:
return e.evalSQLFnNode(r)
}
}
// evalNode on a literal value is independent of the node being an
// aggregation or a row function - it always returns a value.
func (e *LitValue) evalNode(_ Record) (res *Value, err error) {
switch {
case e.Number != nil:
return floatToValue(*e.Number), nil
case e.String != nil:
return FromString(string(*e.String)), nil
case e.Boolean != nil:
return FromBool(bool(*e.Boolean)), nil
}
return FromNull(), nil
}

View File

@@ -1,160 +0,0 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import (
"fmt"
)
// Expr - a SQL expression type.
type Expr interface {
AggregateValue() (*Value, error)
Eval(record Record) (*Value, error)
ReturnType() Type
Type() Type
}
// aliasExpr - aliases expression by alias.
type aliasExpr struct {
alias string
expr Expr
}
// String - returns string representation of this expression.
func (expr *aliasExpr) String() string {
return fmt.Sprintf("(%v AS %v)", expr.expr, expr.alias)
}
// Eval - evaluates underlaying expression for given record and returns evaluated result.
func (expr *aliasExpr) Eval(record Record) (*Value, error) {
return expr.expr.Eval(record)
}
// AggregateValue - returns aggregated value from underlaying expression.
func (expr *aliasExpr) AggregateValue() (*Value, error) {
return expr.expr.AggregateValue()
}
// Type - returns underlaying expression type.
func (expr *aliasExpr) Type() Type {
return expr.expr.Type()
}
// ReturnType - returns underlaying expression's return type.
func (expr *aliasExpr) ReturnType() Type {
return expr.expr.ReturnType()
}
// newAliasExpr - creates new alias expression.
func newAliasExpr(alias string, expr Expr) *aliasExpr {
return &aliasExpr{alias, expr}
}
// starExpr - asterisk (*) expression.
type starExpr struct {
}
// String - returns string representation of this expression.
func (expr *starExpr) String() string {
return "*"
}
// Eval - returns given args as map value.
func (expr *starExpr) Eval(record Record) (*Value, error) {
return newRecordValue(record), nil
}
// AggregateValue - returns nil value.
func (expr *starExpr) AggregateValue() (*Value, error) {
return nil, nil
}
// Type - returns record type.
func (expr *starExpr) Type() Type {
return record
}
// ReturnType - returns record as return type.
func (expr *starExpr) ReturnType() Type {
return record
}
// newStarExpr - returns new asterisk (*) expression.
func newStarExpr() *starExpr {
return &starExpr{}
}
type valueExpr struct {
value *Value
}
func (expr *valueExpr) String() string {
return expr.value.String()
}
func (expr *valueExpr) Eval(record Record) (*Value, error) {
return expr.value, nil
}
func (expr *valueExpr) AggregateValue() (*Value, error) {
return expr.value, nil
}
func (expr *valueExpr) Type() Type {
return expr.value.Type()
}
func (expr *valueExpr) ReturnType() Type {
return expr.value.Type()
}
func newValueExpr(value *Value) *valueExpr {
return &valueExpr{value: value}
}
type columnExpr struct {
name string
}
func (expr *columnExpr) String() string {
return expr.name
}
func (expr *columnExpr) Eval(record Record) (*Value, error) {
value, err := record.Get(expr.name)
if err != nil {
return nil, errEvaluatorBindingDoesNotExist(err)
}
return value, nil
}
func (expr *columnExpr) AggregateValue() (*Value, error) {
return nil, nil
}
func (expr *columnExpr) Type() Type {
return column
}
func (expr *columnExpr) ReturnType() Type {
return column
}
func newColumnExpr(columnName string) *columnExpr {
return &columnExpr{name: columnName}
}

View File

@@ -0,0 +1,433 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import (
"errors"
"fmt"
"strconv"
"strings"
)
// FuncName - SQL function name.
type FuncName string
// SQL Function name constants
const (
// Conditionals
sqlFnCoalesce FuncName = "COALESCE"
sqlFnNullIf FuncName = "NULLIF"
// Conversion
sqlFnCast FuncName = "CAST"
// Date and time
sqlFnDateAdd FuncName = "DATE_ADD"
sqlFnDateDiff FuncName = "DATE_DIFF"
sqlFnExtract FuncName = "EXTRACT"
sqlFnToString FuncName = "TO_STRING"
sqlFnToTimestamp FuncName = "TO_TIMESTAMP"
sqlFnUTCNow FuncName = "UTCNOW"
// String
sqlFnCharLength FuncName = "CHAR_LENGTH"
sqlFnCharacterLength FuncName = "CHARACTER_LENGTH"
sqlFnLower FuncName = "LOWER"
sqlFnSubstring FuncName = "SUBSTRING"
sqlFnTrim FuncName = "TRIM"
sqlFnUpper FuncName = "UPPER"
)
// Allowed cast types
const (
castBool = "BOOL"
castInt = "INT"
castInteger = "INTEGER"
castString = "STRING"
castFloat = "FLOAT"
castDecimal = "DECIMAL"
castNumeric = "NUMERIC"
castTimestamp = "TIMESTAMP"
)
var (
errUnimplementedCast = errors.New("This cast not yet implemented")
errNonStringTrimArg = errors.New("TRIM() received a non-string argument")
)
func (e *FuncExpr) getFunctionName() FuncName {
switch {
case e.SFunc != nil:
return FuncName(strings.ToUpper(e.SFunc.FunctionName))
case e.Count != nil:
return FuncName(aggFnCount)
case e.Cast != nil:
return sqlFnCast
case e.Substring != nil:
return sqlFnSubstring
case e.Extract != nil:
return sqlFnExtract
case e.Trim != nil:
return sqlFnTrim
default:
return ""
}
}
// evalSQLFnNode assumes that the FuncExpr is not an aggregation
// function.
func (e *FuncExpr) evalSQLFnNode(r Record) (res *Value, err error) {
// Handle functions that have phrase arguments
switch e.getFunctionName() {
case sqlFnCast:
expr := e.Cast.Expr
res, err = expr.castTo(r, strings.ToUpper(e.Cast.CastType))
return
case sqlFnSubstring:
return handleSQLSubstring(r, e.Substring)
case sqlFnExtract:
return nil, errNotImplemented
case sqlFnTrim:
return handleSQLTrim(r, e.Trim)
}
// For all simple argument functions, we evaluate the arguments here
argVals := make([]*Value, len(e.SFunc.ArgsList))
for i, arg := range e.SFunc.ArgsList {
argVals[i], err = arg.evalNode(r)
if err != nil {
return nil, err
}
}
switch e.getFunctionName() {
case sqlFnCoalesce:
return coalesce(r, argVals)
case sqlFnNullIf:
return nullif(r, argVals[0], argVals[1])
case sqlFnCharLength, sqlFnCharacterLength:
return charlen(r, argVals[0])
case sqlFnLower:
return lowerCase(r, argVals[0])
case sqlFnUpper:
return upperCase(r, argVals[0])
case sqlFnDateAdd, sqlFnDateDiff, sqlFnToString, sqlFnToTimestamp, sqlFnUTCNow:
// TODO: implement
fallthrough
default:
return nil, errInvalidASTNode
}
}
func coalesce(r Record, args []*Value) (res *Value, err error) {
for _, arg := range args {
if arg.IsNull() {
continue
}
return arg, nil
}
return FromNull(), nil
}
func nullif(r Record, v1, v2 *Value) (res *Value, err error) {
// Handle Null cases
if v1.IsNull() || v2.IsNull() {
return v1, nil
}
err = inferTypesForCmp(v1, v2)
if err != nil {
return nil, err
}
atleastOneNumeric := v1.isNumeric() || v2.isNumeric()
bothNumeric := v1.isNumeric() && v2.isNumeric()
if atleastOneNumeric || !bothNumeric {
return v1, nil
}
if v1.vType != v2.vType {
return v1, nil
}
cmpResult, cmpErr := v1.compareOp(opEq, v2)
if cmpErr != nil {
return nil, cmpErr
}
if cmpResult {
return FromNull(), nil
}
return v1, nil
}
func charlen(r Record, v *Value) (*Value, error) {
inferTypeAsString(v)
s, ok := v.ToString()
if !ok {
err := fmt.Errorf("%s/%s expects a string argument", sqlFnCharLength, sqlFnCharacterLength)
return nil, errIncorrectSQLFunctionArgumentType(err)
}
return FromInt(int64(len(s))), nil
}
func lowerCase(r Record, v *Value) (*Value, error) {
inferTypeAsString(v)
s, ok := v.ToString()
if !ok {
err := fmt.Errorf("%s expects a string argument", sqlFnLower)
return nil, errIncorrectSQLFunctionArgumentType(err)
}
return FromString(strings.ToLower(s)), nil
}
func upperCase(r Record, v *Value) (*Value, error) {
inferTypeAsString(v)
s, ok := v.ToString()
if !ok {
err := fmt.Errorf("%s expects a string argument", sqlFnUpper)
return nil, errIncorrectSQLFunctionArgumentType(err)
}
return FromString(strings.ToUpper(s)), nil
}
func handleSQLSubstring(r Record, e *SubstringFunc) (val *Value, err error) {
// Both forms `SUBSTRING('abc' FROM 2 FOR 1)` and
// SUBSTRING('abc', 2, 1) are supported.
// Evaluate the string argument
v1, err := e.Expr.evalNode(r)
if err != nil {
return nil, err
}
inferTypeAsString(v1)
s, ok := v1.ToString()
if !ok {
err := fmt.Errorf("Incorrect argument type passed to %s", sqlFnSubstring)
return nil, errIncorrectSQLFunctionArgumentType(err)
}
// Assemble other arguments
arg2, arg3 := e.From, e.For
// Check if the second form of substring is being used
if e.From == nil {
arg2, arg3 = e.Arg2, e.Arg3
}
// Evaluate the FROM argument
v2, err := arg2.evalNode(r)
if err != nil {
return nil, err
}
inferTypeForArithOp(v2)
startIdx, ok := v2.ToInt()
if !ok {
err := fmt.Errorf("Incorrect type for start index argument in %s", sqlFnSubstring)
return nil, errIncorrectSQLFunctionArgumentType(err)
}
length := -1
// Evaluate the optional FOR argument
if arg3 != nil {
v3, err := arg3.evalNode(r)
if err != nil {
return nil, err
}
inferTypeForArithOp(v3)
lenInt, ok := v3.ToInt()
if !ok {
err := fmt.Errorf("Incorrect type for length argument in %s", sqlFnSubstring)
return nil, errIncorrectSQLFunctionArgumentType(err)
}
length = int(lenInt)
if length < 0 {
err := fmt.Errorf("Negative length argument in %s", sqlFnSubstring)
return nil, errIncorrectSQLFunctionArgumentType(err)
}
}
res, err := evalSQLSubstring(s, int(startIdx), length)
return FromString(res), err
}
func handleSQLTrim(r Record, e *TrimFunc) (res *Value, err error) {
charsV, cerr := e.TrimChars.evalNode(r)
if cerr != nil {
return nil, cerr
}
inferTypeAsString(charsV)
chars, ok := charsV.ToString()
if !ok {
return nil, errNonStringTrimArg
}
fromV, ferr := e.TrimFrom.evalNode(r)
if ferr != nil {
return nil, ferr
}
from, ok := fromV.ToString()
if !ok {
return nil, errNonStringTrimArg
}
result, terr := evalSQLTrim(e.TrimWhere, chars, from)
if terr != nil {
return nil, terr
}
return FromString(result), nil
}
func errUnsupportedCast(fromType, toType string) error {
return fmt.Errorf("Cannot cast from %v to %v", fromType, toType)
}
func errCastFailure(msg string) error {
return fmt.Errorf("Error casting: %s", msg)
}
func (e *Expression) castTo(r Record, castType string) (res *Value, err error) {
v, err := e.evalNode(r)
if err != nil {
return nil, err
}
fmt.Println("Cast to ", castType)
switch castType {
case castInt, castInteger:
i, err := intCast(v)
return FromInt(i), err
case castFloat:
f, err := floatCast(v)
return FromFloat(f), err
case castString:
s, err := stringCast(v)
return FromString(s), err
case castBool, castDecimal, castNumeric, castTimestamp:
fallthrough
default:
return nil, errUnimplementedCast
}
}
func intCast(v *Value) (int64, error) {
// This conversion truncates floating point numbers to
// integer.
strToInt := func(s string) (int64, bool) {
i, errI := strconv.ParseInt(s, 10, 64)
if errI == nil {
return i, true
}
f, errF := strconv.ParseFloat(s, 64)
if errF == nil {
return int64(f), true
}
return 0, false
}
switch v.vType {
case typeFloat:
// Truncate fractional part
return int64(v.value.(float64)), nil
case typeInt:
return v.value.(int64), nil
case typeString:
// Parse as number, truncate floating point if
// needed.
s, _ := v.ToString()
res, ok := strToInt(s)
if !ok {
return 0, errCastFailure("could not parse as int")
}
return res, nil
case typeBytes:
// Parse as number, truncate floating point if
// needed.
b, _ := v.ToBytes()
s := string(b)
res, ok := strToInt(s)
if !ok {
return 0, errCastFailure("could not parse as int")
}
return res, nil
default:
return 0, errUnsupportedCast(v.GetTypeString(), castInt)
}
}
func floatCast(v *Value) (float64, error) {
switch v.vType {
case typeFloat:
return v.value.(float64), nil
case typeInt:
return float64(v.value.(int64)), nil
case typeString:
f, err := strconv.ParseFloat(v.value.(string), 64)
if err != nil {
return 0, errCastFailure("could not parse as float")
}
return f, nil
case typeBytes:
b, _ := v.ToBytes()
f, err := strconv.ParseFloat(string(b), 64)
if err != nil {
return 0, errCastFailure("could not parse as float")
}
return f, nil
default:
return 0, errUnsupportedCast(v.GetTypeString(), castFloat)
}
}
func stringCast(v *Value) (string, error) {
switch v.vType {
case typeFloat:
f, _ := v.ToFloat()
return fmt.Sprintf("%v", f), nil
case typeInt:
i, _ := v.ToInt()
return fmt.Sprintf("%v", i), nil
case typeString:
s, _ := v.ToString()
return s, nil
case typeBytes:
b, _ := v.ToBytes()
return string(b), nil
case typeBool:
b, _ := v.ToBool()
return fmt.Sprintf("%v", b), nil
case typeNull:
// FIXME: verify this case is correct
return fmt.Sprintf("NULL"), nil
}
// This does not happen
return "", nil
}

View File

@@ -1,550 +0,0 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import (
"fmt"
"strings"
"time"
)
// FuncName - SQL function name.
type FuncName string
const (
// Avg - aggregate SQL function AVG().
Avg FuncName = "AVG"
// Count - aggregate SQL function COUNT().
Count FuncName = "COUNT"
// Max - aggregate SQL function MAX().
Max FuncName = "MAX"
// Min - aggregate SQL function MIN().
Min FuncName = "MIN"
// Sum - aggregate SQL function SUM().
Sum FuncName = "SUM"
// Coalesce - conditional SQL function COALESCE().
Coalesce FuncName = "COALESCE"
// NullIf - conditional SQL function NULLIF().
NullIf FuncName = "NULLIF"
// ToTimestamp - conversion SQL function TO_TIMESTAMP().
ToTimestamp FuncName = "TO_TIMESTAMP"
// UTCNow - date SQL function UTCNOW().
UTCNow FuncName = "UTCNOW"
// CharLength - string SQL function CHAR_LENGTH().
CharLength FuncName = "CHAR_LENGTH"
// CharacterLength - string SQL function CHARACTER_LENGTH() same as CHAR_LENGTH().
CharacterLength FuncName = "CHARACTER_LENGTH"
// Lower - string SQL function LOWER().
Lower FuncName = "LOWER"
// Substring - string SQL function SUBSTRING().
Substring FuncName = "SUBSTRING"
// Trim - string SQL function TRIM().
Trim FuncName = "TRIM"
// Upper - string SQL function UPPER().
Upper FuncName = "UPPER"
// DateAdd FuncName = "DATE_ADD"
// DateDiff FuncName = "DATE_DIFF"
// Extract FuncName = "EXTRACT"
// ToString FuncName = "TO_STRING"
// Cast FuncName = "CAST" // CAST('2007-04-05T14:30Z' AS TIMESTAMP)
)
func isAggregateFuncName(s string) bool {
switch FuncName(s) {
case Avg, Count, Max, Min, Sum:
return true
}
return false
}
func callForNumber(f Expr, record Record) (*Value, error) {
value, err := f.Eval(record)
if err != nil {
return nil, err
}
if !value.Type().isNumber() {
err := fmt.Errorf("%v evaluated to %v; not to number", f, value.Type())
return nil, errExternalEvalException(err)
}
return value, nil
}
func callForInt(f Expr, record Record) (*Value, error) {
value, err := f.Eval(record)
if err != nil {
return nil, err
}
if value.Type() != Int {
err := fmt.Errorf("%v evaluated to %v; not to int", f, value.Type())
return nil, errExternalEvalException(err)
}
return value, nil
}
func callForString(f Expr, record Record) (*Value, error) {
value, err := f.Eval(record)
if err != nil {
return nil, err
}
if value.Type() != String {
err := fmt.Errorf("%v evaluated to %v; not to string", f, value.Type())
return nil, errExternalEvalException(err)
}
return value, nil
}
// funcExpr - SQL function.
type funcExpr struct {
args []Expr
name FuncName
sumValue float64
countValue int64
maxValue float64
minValue float64
}
// String - returns string representation of this function.
func (f *funcExpr) String() string {
var argStrings []string
for _, arg := range f.args {
argStrings = append(argStrings, fmt.Sprintf("%v", arg))
}
return fmt.Sprintf("%v(%v)", f.name, strings.Join(argStrings, ","))
}
func (f *funcExpr) sum(record Record) (*Value, error) {
value, err := callForNumber(f.args[0], record)
if err != nil {
return nil, err
}
f.sumValue += value.FloatValue()
f.countValue++
return nil, nil
}
func (f *funcExpr) count(record Record) (*Value, error) {
value, err := f.args[0].Eval(record)
if err != nil {
return nil, err
}
if value.valueType != Null {
f.countValue++
}
return nil, nil
}
func (f *funcExpr) max(record Record) (*Value, error) {
value, err := callForNumber(f.args[0], record)
if err != nil {
return nil, err
}
v := value.FloatValue()
if v > f.maxValue {
f.maxValue = v
}
return nil, nil
}
func (f *funcExpr) min(record Record) (*Value, error) {
value, err := callForNumber(f.args[0], record)
if err != nil {
return nil, err
}
v := value.FloatValue()
if v < f.minValue {
f.minValue = v
}
return nil, nil
}
func (f *funcExpr) charLength(record Record) (*Value, error) {
value, err := callForString(f.args[0], record)
if err != nil {
return nil, err
}
return NewInt(int64(len(value.StringValue()))), nil
}
func (f *funcExpr) trim(record Record) (*Value, error) {
value, err := callForString(f.args[0], record)
if err != nil {
return nil, err
}
return NewString(strings.TrimSpace(value.StringValue())), nil
}
func (f *funcExpr) lower(record Record) (*Value, error) {
value, err := callForString(f.args[0], record)
if err != nil {
return nil, err
}
return NewString(strings.ToLower(value.StringValue())), nil
}
func (f *funcExpr) upper(record Record) (*Value, error) {
value, err := callForString(f.args[0], record)
if err != nil {
return nil, err
}
return NewString(strings.ToUpper(value.StringValue())), nil
}
func (f *funcExpr) substring(record Record) (*Value, error) {
stringValue, err := callForString(f.args[0], record)
if err != nil {
return nil, err
}
offsetValue, err := callForInt(f.args[1], record)
if err != nil {
return nil, err
}
var lengthValue *Value
if len(f.args) == 3 {
lengthValue, err = callForInt(f.args[2], record)
if err != nil {
return nil, err
}
}
value := stringValue.StringValue()
offset := int(offsetValue.FloatValue())
if offset < 0 || offset > len(value) {
offset = 0
}
length := len(value)
if lengthValue != nil {
length = int(lengthValue.FloatValue())
if length < 0 || length > len(value) {
length = len(value)
}
}
return NewString(value[offset:length]), nil
}
func (f *funcExpr) coalesce(record Record) (*Value, error) {
values := make([]*Value, len(f.args))
var err error
for i := range f.args {
values[i], err = f.args[i].Eval(record)
if err != nil {
return nil, err
}
}
for i := range values {
if values[i].Type() != Null {
return values[i], nil
}
}
return values[0], nil
}
func (f *funcExpr) nullIf(record Record) (*Value, error) {
value1, err := f.args[0].Eval(record)
if err != nil {
return nil, err
}
value2, err := f.args[1].Eval(record)
if err != nil {
return nil, err
}
result, err := equal(value1, value2)
if err != nil {
return nil, err
}
if result {
return NewNull(), nil
}
return value1, nil
}
func (f *funcExpr) toTimeStamp(record Record) (*Value, error) {
value, err := callForString(f.args[0], record)
if err != nil {
return nil, err
}
t, err := time.Parse(time.RFC3339, value.StringValue())
if err != nil {
err := fmt.Errorf("%v: value '%v': %v", f, value, err)
return nil, errValueParseFailure(err)
}
return NewTime(t), nil
}
func (f *funcExpr) utcNow(record Record) (*Value, error) {
return NewTime(time.Now().UTC()), nil
}
// Call - evaluates this function for given arg values and returns result as Value.
func (f *funcExpr) Eval(record Record) (*Value, error) {
switch f.name {
case Avg, Sum:
return f.sum(record)
case Count:
return f.count(record)
case Max:
return f.max(record)
case Min:
return f.min(record)
case Coalesce:
return f.coalesce(record)
case NullIf:
return f.nullIf(record)
case ToTimestamp:
return f.toTimeStamp(record)
case UTCNow:
return f.utcNow(record)
case Substring:
return f.substring(record)
case CharLength, CharacterLength:
return f.charLength(record)
case Trim:
return f.trim(record)
case Lower:
return f.lower(record)
case Upper:
return f.upper(record)
}
panic(fmt.Sprintf("unsupported aggregate function %v", f.name))
}
// AggregateValue - returns aggregated value.
func (f *funcExpr) AggregateValue() (*Value, error) {
switch f.name {
case Avg:
return NewFloat(f.sumValue / float64(f.countValue)), nil
case Count:
return NewInt(f.countValue), nil
case Max:
return NewFloat(f.maxValue), nil
case Min:
return NewFloat(f.minValue), nil
case Sum:
return NewFloat(f.sumValue), nil
}
err := fmt.Errorf("%v is not aggreate function", f)
return nil, errExternalEvalException(err)
}
// Type - returns Function or aggregateFunction type.
func (f *funcExpr) Type() Type {
switch f.name {
case Avg, Count, Max, Min, Sum:
return aggregateFunction
}
return function
}
// ReturnType - returns respective primitive type depending on SQL function.
func (f *funcExpr) ReturnType() Type {
switch f.name {
case Avg, Max, Min, Sum:
return Float
case Count:
return Int
case CharLength, CharacterLength, Trim, Lower, Upper, Substring:
return String
case ToTimestamp, UTCNow:
return Timestamp
case Coalesce, NullIf:
return column
}
return function
}
// newFuncExpr - creates new SQL function.
func newFuncExpr(funcName FuncName, funcs ...Expr) (*funcExpr, error) {
switch funcName {
case Avg, Max, Min, Sum:
if len(funcs) != 1 {
err := fmt.Errorf("%v(): exactly one argument expected; got %v", funcName, len(funcs))
return nil, errParseNonUnaryAgregateFunctionCall(err)
}
if !funcs[0].ReturnType().isNumberKind() {
err := fmt.Errorf("%v(): argument %v evaluate to %v, not number", funcName, funcs[0], funcs[0].ReturnType())
return nil, errIncorrectSQLFunctionArgumentType(err)
}
return &funcExpr{
args: funcs,
name: funcName,
}, nil
case Count:
if len(funcs) != 1 {
err := fmt.Errorf("%v(): exactly one argument expected; got %v", funcName, len(funcs))
return nil, errParseNonUnaryAgregateFunctionCall(err)
}
switch funcs[0].ReturnType() {
case Null, Bool, Int, Float, String, Timestamp, column, record:
default:
err := fmt.Errorf("%v(): argument %v evaluate to %v is incompatible", funcName, funcs[0], funcs[0].ReturnType())
return nil, errIncorrectSQLFunctionArgumentType(err)
}
return &funcExpr{
args: funcs,
name: funcName,
}, nil
case CharLength, CharacterLength, Trim, Lower, Upper, ToTimestamp:
if len(funcs) != 1 {
err := fmt.Errorf("%v(): exactly one argument expected; got %v", funcName, len(funcs))
return nil, errEvaluatorInvalidArguments(err)
}
if !funcs[0].ReturnType().isStringKind() {
err := fmt.Errorf("%v(): argument %v evaluate to %v, not string", funcName, funcs[0], funcs[0].ReturnType())
return nil, errIncorrectSQLFunctionArgumentType(err)
}
return &funcExpr{
args: funcs,
name: funcName,
}, nil
case Coalesce:
if len(funcs) < 1 {
err := fmt.Errorf("%v(): one or more argument expected; got %v", funcName, len(funcs))
return nil, errEvaluatorInvalidArguments(err)
}
for i := range funcs {
if !funcs[i].ReturnType().isBaseKind() {
err := fmt.Errorf("%v(): argument-%v %v evaluate to %v is incompatible", funcName, i+1, funcs[i], funcs[i].ReturnType())
return nil, errIncorrectSQLFunctionArgumentType(err)
}
}
return &funcExpr{
args: funcs,
name: funcName,
}, nil
case NullIf:
if len(funcs) != 2 {
err := fmt.Errorf("%v(): exactly two arguments expected; got %v", funcName, len(funcs))
return nil, errEvaluatorInvalidArguments(err)
}
if !funcs[0].ReturnType().isBaseKind() {
err := fmt.Errorf("%v(): argument-1 %v evaluate to %v is incompatible", funcName, funcs[0], funcs[0].ReturnType())
return nil, errIncorrectSQLFunctionArgumentType(err)
}
if !funcs[1].ReturnType().isBaseKind() {
err := fmt.Errorf("%v(): argument-2 %v evaluate to %v is incompatible", funcName, funcs[1], funcs[1].ReturnType())
return nil, errIncorrectSQLFunctionArgumentType(err)
}
return &funcExpr{
args: funcs,
name: funcName,
}, nil
case UTCNow:
if len(funcs) != 0 {
err := fmt.Errorf("%v(): no argument expected; got %v", funcName, len(funcs))
return nil, errEvaluatorInvalidArguments(err)
}
return &funcExpr{
args: funcs,
name: funcName,
}, nil
case Substring:
if len(funcs) < 2 || len(funcs) > 3 {
err := fmt.Errorf("%v(): exactly two or three arguments expected; got %v", funcName, len(funcs))
return nil, errEvaluatorInvalidArguments(err)
}
if !funcs[0].ReturnType().isStringKind() {
err := fmt.Errorf("%v(): argument-1 %v evaluate to %v, not string", funcName, funcs[0], funcs[0].ReturnType())
return nil, errIncorrectSQLFunctionArgumentType(err)
}
if !funcs[1].ReturnType().isIntKind() {
err := fmt.Errorf("%v(): argument-2 %v evaluate to %v, not int", funcName, funcs[1], funcs[1].ReturnType())
return nil, errIncorrectSQLFunctionArgumentType(err)
}
if len(funcs) > 2 {
if !funcs[2].ReturnType().isIntKind() {
err := fmt.Errorf("%v(): argument-3 %v evaluate to %v, not int", funcName, funcs[2], funcs[2].ReturnType())
return nil, errIncorrectSQLFunctionArgumentType(err)
}
}
return &funcExpr{
args: funcs,
name: funcName,
}, nil
}
return nil, errUnsupportedFunction(fmt.Errorf("unknown function name %v", funcName))
}

View File

@@ -1,336 +0,0 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import "fmt"
// andExpr - logical AND function.
type andExpr struct {
left Expr
right Expr
funcType Type
}
// String - returns string representation of this function.
func (f *andExpr) String() string {
return fmt.Sprintf("(%v AND %v)", f.left, f.right)
}
// Call - evaluates this function for given arg values and returns result as Value.
func (f *andExpr) Eval(record Record) (*Value, error) {
leftValue, err := f.left.Eval(record)
if err != nil {
return nil, err
}
if f.funcType == aggregateFunction {
_, err = f.right.Eval(record)
return nil, err
}
if leftValue.Type() != Bool {
err := fmt.Errorf("%v: left side expression evaluated to %v; not to bool", f, leftValue.Type())
return nil, errExternalEvalException(err)
}
if !leftValue.BoolValue() {
return leftValue, nil
}
rightValue, err := f.right.Eval(record)
if err != nil {
return nil, err
}
if rightValue.Type() != Bool {
err := fmt.Errorf("%v: right side expression evaluated to %v; not to bool", f, rightValue.Type())
return nil, errExternalEvalException(err)
}
return rightValue, nil
}
// AggregateValue - returns aggregated value.
func (f *andExpr) AggregateValue() (*Value, error) {
if f.funcType != aggregateFunction {
err := fmt.Errorf("%v is not aggreate expression", f)
return nil, errExternalEvalException(err)
}
leftValue, err := f.left.AggregateValue()
if err != nil {
return nil, err
}
if leftValue.Type() != Bool {
err := fmt.Errorf("%v: left side expression evaluated to %v; not to bool", f, leftValue.Type())
return nil, errExternalEvalException(err)
}
if !leftValue.BoolValue() {
return leftValue, nil
}
rightValue, err := f.right.AggregateValue()
if err != nil {
return nil, err
}
if rightValue.Type() != Bool {
err := fmt.Errorf("%v: right side expression evaluated to %v; not to bool", f, rightValue.Type())
return nil, errExternalEvalException(err)
}
return rightValue, nil
}
// Type - returns logicalFunction or aggregateFunction type.
func (f *andExpr) Type() Type {
return f.funcType
}
// ReturnType - returns Bool as return type.
func (f *andExpr) ReturnType() Type {
return Bool
}
// newAndExpr - creates new AND logical function.
func newAndExpr(left, right Expr) (*andExpr, error) {
if !left.ReturnType().isBoolKind() {
err := fmt.Errorf("operator AND: left side expression %v evaluate to %v, not bool", left, left.ReturnType())
return nil, errInvalidDataType(err)
}
if !right.ReturnType().isBoolKind() {
err := fmt.Errorf("operator AND: right side expression %v evaluate to %v; not bool", right, right.ReturnType())
return nil, errInvalidDataType(err)
}
funcType := logicalFunction
if left.Type() == aggregateFunction || right.Type() == aggregateFunction {
funcType = aggregateFunction
if left.Type() == column {
err := fmt.Errorf("operator AND: left side expression %v return type %v is incompatible for aggregate evaluation", left, left.Type())
return nil, errUnsupportedSQLOperation(err)
}
if right.Type() == column {
err := fmt.Errorf("operator AND: right side expression %v return type %v is incompatible for aggregate evaluation", right, right.Type())
return nil, errUnsupportedSQLOperation(err)
}
}
return &andExpr{
left: left,
right: right,
funcType: funcType,
}, nil
}
// orExpr - logical OR function.
type orExpr struct {
left Expr
right Expr
funcType Type
}
// String - returns string representation of this function.
func (f *orExpr) String() string {
return fmt.Sprintf("(%v OR %v)", f.left, f.right)
}
// Call - evaluates this function for given arg values and returns result as Value.
func (f *orExpr) Eval(record Record) (*Value, error) {
leftValue, err := f.left.Eval(record)
if err != nil {
return nil, err
}
if f.funcType == aggregateFunction {
_, err = f.right.Eval(record)
return nil, err
}
if leftValue.Type() != Bool {
err := fmt.Errorf("%v: left side expression evaluated to %v; not to bool", f, leftValue.Type())
return nil, errExternalEvalException(err)
}
if leftValue.BoolValue() {
return leftValue, nil
}
rightValue, err := f.right.Eval(record)
if err != nil {
return nil, err
}
if rightValue.Type() != Bool {
err := fmt.Errorf("%v: right side expression evaluated to %v; not to bool", f, rightValue.Type())
return nil, errExternalEvalException(err)
}
return rightValue, nil
}
// AggregateValue - returns aggregated value.
func (f *orExpr) AggregateValue() (*Value, error) {
if f.funcType != aggregateFunction {
err := fmt.Errorf("%v is not aggreate expression", f)
return nil, errExternalEvalException(err)
}
leftValue, err := f.left.AggregateValue()
if err != nil {
return nil, err
}
if leftValue.Type() != Bool {
err := fmt.Errorf("%v: left side expression evaluated to %v; not to bool", f, leftValue.Type())
return nil, errExternalEvalException(err)
}
if leftValue.BoolValue() {
return leftValue, nil
}
rightValue, err := f.right.AggregateValue()
if err != nil {
return nil, err
}
if rightValue.Type() != Bool {
err := fmt.Errorf("%v: right side expression evaluated to %v; not to bool", f, rightValue.Type())
return nil, errExternalEvalException(err)
}
return rightValue, nil
}
// Type - returns logicalFunction or aggregateFunction type.
func (f *orExpr) Type() Type {
return f.funcType
}
// ReturnType - returns Bool as return type.
func (f *orExpr) ReturnType() Type {
return Bool
}
// newOrExpr - creates new OR logical function.
func newOrExpr(left, right Expr) (*orExpr, error) {
if !left.ReturnType().isBoolKind() {
err := fmt.Errorf("operator OR: left side expression %v evaluate to %v, not bool", left, left.ReturnType())
return nil, errInvalidDataType(err)
}
if !right.ReturnType().isBoolKind() {
err := fmt.Errorf("operator OR: right side expression %v evaluate to %v; not bool", right, right.ReturnType())
return nil, errInvalidDataType(err)
}
funcType := logicalFunction
if left.Type() == aggregateFunction || right.Type() == aggregateFunction {
funcType = aggregateFunction
if left.Type() == column {
err := fmt.Errorf("operator OR: left side expression %v return type %v is incompatible for aggregate evaluation", left, left.Type())
return nil, errUnsupportedSQLOperation(err)
}
if right.Type() == column {
err := fmt.Errorf("operator OR: right side expression %v return type %v is incompatible for aggregate evaluation", right, right.Type())
return nil, errUnsupportedSQLOperation(err)
}
}
return &orExpr{
left: left,
right: right,
funcType: funcType,
}, nil
}
// notExpr - logical NOT function.
type notExpr struct {
right Expr
funcType Type
}
// String - returns string representation of this function.
func (f *notExpr) String() string {
return fmt.Sprintf("(%v)", f.right)
}
// Call - evaluates this function for given arg values and returns result as Value.
func (f *notExpr) Eval(record Record) (*Value, error) {
rightValue, err := f.right.Eval(record)
if err != nil {
return nil, err
}
if f.funcType == aggregateFunction {
return nil, nil
}
if rightValue.Type() != Bool {
err := fmt.Errorf("%v: right side expression evaluated to %v; not to bool", f, rightValue.Type())
return nil, errExternalEvalException(err)
}
return NewBool(!rightValue.BoolValue()), nil
}
// AggregateValue - returns aggregated value.
func (f *notExpr) AggregateValue() (*Value, error) {
if f.funcType != aggregateFunction {
err := fmt.Errorf("%v is not aggreate expression", f)
return nil, errExternalEvalException(err)
}
rightValue, err := f.right.AggregateValue()
if err != nil {
return nil, err
}
if rightValue.Type() != Bool {
err := fmt.Errorf("%v: right side expression evaluated to %v; not to bool", f, rightValue.Type())
return nil, errExternalEvalException(err)
}
return NewBool(!rightValue.BoolValue()), nil
}
// Type - returns logicalFunction or aggregateFunction type.
func (f *notExpr) Type() Type {
return f.funcType
}
// ReturnType - returns Bool as return type.
func (f *notExpr) ReturnType() Type {
return Bool
}
// newNotExpr - creates new NOT logical function.
func newNotExpr(right Expr) (*notExpr, error) {
if !right.ReturnType().isBoolKind() {
err := fmt.Errorf("operator NOT: right side expression %v evaluate to %v; not bool", right, right.ReturnType())
return nil, errInvalidDataType(err)
}
funcType := logicalFunction
if right.Type() == aggregateFunction {
funcType = aggregateFunction
}
return &notExpr{
right: right,
funcType: funcType,
}, nil
}

329
pkg/s3select/sql/parser.go Normal file
View File

@@ -0,0 +1,329 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import (
"strings"
"github.com/alecthomas/participle"
"github.com/alecthomas/participle/lexer"
)
// Types with custom Capture interface for parsing
// Boolean is a type for a parsed Boolean literal
type Boolean bool
// Capture interface used by participle
func (b *Boolean) Capture(values []string) error {
*b = strings.ToLower(values[0]) == "true"
return nil
}
// LiteralString is a type for parsed SQL string literals
type LiteralString string
// Capture interface used by participle
func (ls *LiteralString) Capture(values []string) error {
// Remove enclosing single quote
n := len(values[0])
r := values[0][1 : n-1]
// Translate doubled quotes
*ls = LiteralString(strings.Replace(r, "''", "'", -1))
return nil
}
// ObjectKey is a type for parsed strings occurring in key paths
type ObjectKey struct {
Lit *LiteralString `parser:" \"[\" @LitString \"]\""`
ID *Identifier `parser:"| \".\" @@"`
}
// QuotedIdentifier is a type for parsed strings that are double
// quoted.
type QuotedIdentifier string
// Capture inferface used by participle
func (qi *QuotedIdentifier) Capture(values []string) error {
// Remove enclosing quotes
n := len(values[0])
r := values[0][1 : n-1]
// Translate doubled quotes
*qi = QuotedIdentifier(strings.Replace(r, `""`, `"`, -1))
return nil
}
// Types representing AST of SQL statement. Only SELECT is supported.
// Select is the top level AST node type
type Select struct {
Expression *SelectExpression `parser:"\"SELECT\" @@"`
From *TableExpression `parser:"\"FROM\" @@"`
Where *Expression `parser:"[ \"WHERE\" @@ ]"`
Limit *LitValue `parser:"[ \"LIMIT\" @@ ]"`
}
// SelectExpression represents the items requested in the select
// statement
type SelectExpression struct {
All bool `parser:" @\"*\""`
Expressions []*AliasedExpression `parser:"| @@ { \",\" @@ }"`
prop qProp
}
// TableExpression represents the FROM clause
type TableExpression struct {
Table *JSONPath `parser:"@@"`
As string `parser:"( \"AS\"? @Ident )?"`
}
// JSONPathElement represents a keypath component
type JSONPathElement struct {
Key *ObjectKey `parser:" @@"` // ['name'] and .name forms
Index *uint64 `parser:"| \"[\" @Number \"]\""` // [3] form
ObjectWildcard bool `parser:"| @\".*\""` // .* form
ArrayWildcard bool `parser:"| @\"[*]\""` // [*] form
}
// JSONPath represents a keypath
type JSONPath struct {
BaseKey *Identifier `parser:" @@"`
PathExpr []*JSONPathElement `parser:"(@@)*"`
}
// AliasedExpression is an expression that can be optionally named
type AliasedExpression struct {
Expression *Expression `parser:"@@"`
As string `parser:"[ \"AS\" @Ident ]"`
}
// Grammar for Expression
//
// Expression → AndCondition ("OR" AndCondition)*
// AndCondition → Condition ("AND" Condition)*
// Condition → "NOT" Condition | ConditionExpression
// ConditionExpression → ValueExpression ("=" | "<>" | "<=" | ">=" | "<" | ">") ValueExpression
// | ValueExpression "LIKE" ValueExpression ("ESCAPE" LitString)?
// | ValueExpression ("NOT"? "BETWEEN" ValueExpression "AND" ValueExpression)
// | ValueExpression "IN" "(" Expression ("," Expression)* ")"
// | ValueExpression
// ValueExpression → Operand
//
// Operand grammar follows below
// Expression represents a logical disjunction of clauses
type Expression struct {
And []*AndCondition `parser:"@@ ( \"OR\" @@ )*"`
}
// AndCondition represents logical conjunction of clauses
type AndCondition struct {
Condition []*Condition `parser:"@@ ( \"AND\" @@ )*"`
}
// Condition represents a negation or a condition operand
type Condition struct {
Operand *ConditionOperand `parser:" @@"`
Not *Condition `parser:"| \"NOT\" @@"`
}
// ConditionOperand is a operand followed by an an optional operation
// expression
type ConditionOperand struct {
Operand *Operand `parser:"@@"`
ConditionRHS *ConditionRHS `parser:"@@?"`
}
// ConditionRHS represents the right-hand-side of Compare, Between, In
// or Like expressions.
type ConditionRHS struct {
Compare *Compare `parser:" @@"`
Between *Between `parser:"| @@"`
In *In `parser:"| \"IN\" \"(\" @@ \")\""`
Like *Like `parser:"| @@"`
}
// Compare represents the RHS of a comparison expression
type Compare struct {
Operator string `parser:"@( \"<>\" | \"<=\" | \">=\" | \"=\" | \"<\" | \">\" | \"!=\" )"`
Operand *Operand `parser:" @@"`
}
// Like represents the RHS of a LIKE expression
type Like struct {
Not bool `parser:" @\"NOT\"? "`
Pattern *Operand `parser:" \"LIKE\" @@ "`
EscapeChar *Operand `parser:" (\"ESCAPE\" @@)? "`
}
// Between represents the RHS of a BETWEEN expression
type Between struct {
Not bool `parser:" @\"NOT\"? "`
Start *Operand `parser:" \"BETWEEN\" @@ "`
End *Operand `parser:" \"AND\" @@ "`
}
// In represents the RHS of an IN expression
type In struct {
Expressions []*Expression `parser:"@@ ( \",\" @@ )*"`
}
// Grammar for Operand:
//
// operand → multOp ( ("-" | "+") multOp )*
// multOp → unary ( ("/" | "*" | "%") unary )*
// unary → "-" unary | primary
// primary → Value | Variable | "(" expression ")"
//
// An Operand is a single term followed by an optional sequence of
// terms separated by +/-
type Operand struct {
Left *MultOp `parser:"@@"`
Right []*OpFactor `parser:"(@@)*"`
}
// OpFactor represents the right-side of a +/- operation.
type OpFactor struct {
Op string `parser:"@(\"+\" | \"-\")"`
Right *MultOp `parser:"@@"`
}
// MultOp represents a single term followed by an optional sequence of
// terms separated by *, / or % operators.
type MultOp struct {
Left *UnaryTerm `parser:"@@"`
Right []*OpUnaryTerm `parser:"(@@)*"`
}
// OpUnaryTerm represents the right side of *, / or % binary operations.
type OpUnaryTerm struct {
Op string `parser:"@(\"*\" | \"/\" | \"%\")"`
Right *UnaryTerm `parser:"@@"`
}
// UnaryTerm represents a single negated term or a primary term
type UnaryTerm struct {
Negated *NegatedTerm `parser:" @@"`
Primary *PrimaryTerm `parser:"| @@"`
}
// NegatedTerm has a leading minus sign.
type NegatedTerm struct {
Term *PrimaryTerm `parser:"\"-\" @@"`
}
// PrimaryTerm represents a Value, Path expression, a Sub-expression
// or a function call.
type PrimaryTerm struct {
Value *LitValue `parser:" @@"`
JPathExpr *JSONPath `parser:"| @@"`
SubExpression *Expression `parser:"| \"(\" @@ \")\""`
// Include function expressions here.
FuncCall *FuncExpr `parser:"| @@"`
}
// FuncExpr represents a function call
type FuncExpr struct {
SFunc *SimpleArgFunc `parser:" @@"`
Count *CountFunc `parser:"| @@"`
Cast *CastFunc `parser:"| @@"`
Substring *SubstringFunc `parser:"| @@"`
Extract *ExtractFunc `parser:"| @@"`
Trim *TrimFunc `parser:"| @@"`
// Used during evaluation for aggregation funcs
aggregate *aggVal
}
// SimpleArgFunc represents functions with simple expression
// arguments.
type SimpleArgFunc struct {
FunctionName string `parser:" @(\"AVG\" | \"MAX\" | \"MIN\" | \"SUM\" | \"COALESCE\" | \"NULLIF\" | \"DATE_ADD\" | \"DATE_DIFF\" | \"TO_STRING\" | \"TO_TIMESTAMP\" | \"UTCNOW\" | \"CHAR_LENGTH\" | \"CHARACTER_LENGTH\" | \"LOWER\" | \"UPPER\") "`
ArgsList []*Expression `parser:"\"(\" (@@ (\",\" @@)*)?\")\""`
}
// CountFunc represents the COUNT sql function
type CountFunc struct {
StarArg bool `parser:" \"COUNT\" \"(\" ( @\"*\"?"`
ExprArg *Expression `parser:" @@? )! \")\""`
}
// CastFunc represents CAST sql function
type CastFunc struct {
Expr *Expression `parser:" \"CAST\" \"(\" @@ "`
CastType string `parser:" \"AS\" @(\"BOOL\" | \"INT\" | \"INTEGER\" | \"STRING\" | \"FLOAT\" | \"DECIMAL\" | \"NUMERIC\" | \"TIMESTAMP\") \")\" "`
}
// SubstringFunc represents SUBSTRING sql function
type SubstringFunc struct {
Expr *PrimaryTerm `parser:" \"SUBSTRING\" \"(\" @@ "`
From *Operand `parser:" ( \"FROM\" @@ "`
For *Operand `parser:" (\"FOR\" @@)? \")\" "`
Arg2 *Operand `parser:" | \",\" @@ "`
Arg3 *Operand `parser:" (\",\" @@)? \")\" )"`
}
// ExtractFunc represents EXTRACT sql function
type ExtractFunc struct {
Timeword string `parser:" \"EXTRACT\" \"(\" @( \"YEAR\":Timeword | \"MONTH\":Timeword | \"DAY\":Timeword | \"HOUR\":Timeword | \"MINUTE\":Timeword | \"SECOND\":Timeword | \"TIMEZONE_HOUR\":Timeword | \"TIMEZONE_MINUTE\":Timeword ) "`
From *PrimaryTerm `parser:" \"FROM\" @@ \")\" "`
}
// TrimFunc represents TRIM sql function
type TrimFunc struct {
TrimWhere *string `parser:" \"TRIM\" \"(\" ( @( \"LEADING\" | \"TRAILING\" | \"BOTH\" ) "`
TrimChars *PrimaryTerm `parser:" @@? "`
TrimFrom *PrimaryTerm `parser:" \"FROM\" )? @@ \")\" "`
}
// LitValue represents a literal value parsed from the sql
type LitValue struct {
Number *float64 `parser:"( @Number"`
String *LiteralString `parser:" | @LitString"`
Boolean *Boolean `parser:" | @(\"TRUE\" | \"FALSE\")"`
Null bool `parser:" | @\"NULL\")"`
}
// Identifier represents a parsed identifier
type Identifier struct {
Unquoted *string `parser:" @Ident"`
Quoted *QuotedIdentifier `parser:"| @QuotIdent"`
}
var (
sqlLexer = lexer.Must(lexer.Regexp(`(\s+)` +
`|(?P<Timeword>(?i)\b(?:YEAR|MONTH|DAY|HOUR|MINUTE|SECOND|TIMEZONE_HOUR|TIMEZONE_MINUTE)\b)` +
`|(?P<Keyword>(?i)\b(?:SELECT|FROM|TOP|DISTINCT|ALL|WHERE|GROUP|BY|HAVING|UNION|MINUS|EXCEPT|INTERSECT|ORDER|LIMIT|OFFSET|TRUE|FALSE|NULL|IS|NOT|ANY|SOME|BETWEEN|AND|OR|LIKE|ESCAPE|AS|IN|BOOL|INT|INTEGER|STRING|FLOAT|DECIMAL|NUMERIC|TIMESTAMP|AVG|COUNT|MAX|MIN|SUM|COALESCE|NULLIF|CAST|DATE_ADD|DATE_DIFF|EXTRACT|TO_STRING|TO_TIMESTAMP|UTCNOW|CHAR_LENGTH|CHARACTER_LENGTH|LOWER|SUBSTRING|TRIM|UPPER|LEADING|TRAILING|BOTH|FOR)\b)` +
`|(?P<Ident>[a-zA-Z_][a-zA-Z0-9_]*)` +
`|(?P<QuotIdent>"([^"]*("")?)*")` +
`|(?P<Number>\d*\.?\d+([eE][-+]?\d+)?)` +
`|(?P<LitString>'([^']*('')?)*')` +
`|(?P<Operators><>|!=|<=|>=|\.\*|\[\*\]|[-+*/%,.()=<>\[\]])`,
))
// SQLParser is used to parse SQL statements
SQLParser = participle.MustBuild(
&Select{},
participle.Lexer(sqlLexer),
participle.CaseInsensitive("Keyword"),
participle.CaseInsensitive("Timeword"),
)
)

View File

@@ -0,0 +1,383 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import (
"bytes"
"testing"
"github.com/alecthomas/participle"
"github.com/alecthomas/participle/lexer"
)
func TestJSONPathElement(t *testing.T) {
p := participle.MustBuild(
&JSONPathElement{},
participle.Lexer(sqlLexer),
participle.CaseInsensitive("Keyword"),
)
j := JSONPathElement{}
cases := []string{
// Key
"['name']", ".name", `."name"`,
// Index
"[2]", "[0]", "[100]",
// Object wilcard
".*",
// array wildcard
"[*]",
}
for i, tc := range cases {
err := p.ParseString(tc, &j)
if err != nil {
t.Fatalf("%d: %v", i, err)
}
// repr.Println(j, repr.Indent(" "), repr.OmitEmpty(true))
}
}
func TestJSONPath(t *testing.T) {
p := participle.MustBuild(
&JSONPath{},
participle.Lexer(sqlLexer),
participle.CaseInsensitive("Keyword"),
)
j := JSONPath{}
cases := []string{
"S3Object",
"S3Object.id",
"S3Object.book.title",
"S3Object.id[1]",
"S3Object.id['abc']",
"S3Object.id['ab']",
"S3Object.words.*.id",
"S3Object.words.name[*].val",
"S3Object.words.name[*].val[*]",
"S3Object.words.name[*].val.*",
}
for i, tc := range cases {
err := p.ParseString(tc, &j)
if err != nil {
t.Fatalf("%d: %v", i, err)
}
// repr.Println(j, repr.Indent(" "), repr.OmitEmpty(true))
}
}
func TestIdentifierParsing(t *testing.T) {
p := participle.MustBuild(
&Identifier{},
participle.Lexer(sqlLexer),
participle.CaseInsensitive("Keyword"),
)
id := Identifier{}
validCases := []string{
"a",
"_a",
"abc_a",
"a2",
`"abc"`,
`"abc\a""ac"`,
}
for i, tc := range validCases {
err := p.ParseString(tc, &id)
if err != nil {
t.Fatalf("%d: %v", i, err)
}
// repr.Println(id, repr.Indent(" "), repr.OmitEmpty(true))
}
invalidCases := []string{
"+a",
"-a",
"1a",
`"ab`,
`abc"`,
`aa""a`,
`"a"a"`,
}
for i, tc := range invalidCases {
err := p.ParseString(tc, &id)
if err == nil {
t.Fatalf("%d: %v", i, err)
}
// fmt.Println(tc, err)
}
}
func TestLiteralStringParsing(t *testing.T) {
var k ObjectKey
p := participle.MustBuild(
&ObjectKey{},
participle.Lexer(sqlLexer),
participle.CaseInsensitive("Keyword"),
)
validCases := []string{
"['abc']",
"['ab''c']",
"['a''b''c']",
"['abc-x_1##@(*&(#*))/\\']",
}
for i, tc := range validCases {
err := p.ParseString(tc, &k)
if err != nil {
t.Fatalf("%d: %v", i, err)
}
if string(*k.Lit) == "" {
t.Fatalf("Incorrect parse %#v", k)
}
// repr.Println(k, repr.Indent(" "), repr.OmitEmpty(true))
}
invalidCases := []string{
"['abc'']",
"['-abc'sc']",
"[abc']",
"['ac]",
}
for i, tc := range invalidCases {
err := p.ParseString(tc, &k)
if err == nil {
t.Fatalf("%d: %v", i, err)
}
// fmt.Println(tc, err)
}
}
func TestFunctionParsing(t *testing.T) {
var fex FuncExpr
p := participle.MustBuild(
&FuncExpr{},
participle.Lexer(sqlLexer),
participle.CaseInsensitive("Keyword"),
participle.CaseInsensitive("Timeword"),
)
validCases := []string{
"count(*)",
"sum(2 + s.id)",
"sum(t)",
"avg(s.id[1])",
"coalesce(s.id[1], 2, 2 + 3)",
"cast(s as string)",
"cast(s AS INT)",
"cast(s as DECIMAL)",
"extract(YEAR from '2018-01-09')",
"extract(month from '2018-01-09')",
"extract(hour from '2018-01-09')",
"extract(day from '2018-01-09')",
"substring('abcd' from 2 for 2)",
"substring('abcd' from 2)",
"substring('abcd' , 2 , 2)",
"substring('abcd' , 22 )",
"trim(' aab ')",
"trim(leading from ' aab ')",
"trim(trailing from ' aab ')",
"trim(both from ' aab ')",
"trim(both '12' from ' aab ')",
"trim(leading '12' from ' aab ')",
"trim(trailing '12' from ' aab ')",
"count(23)",
}
for i, tc := range validCases {
err := p.ParseString(tc, &fex)
if err != nil {
t.Fatalf("%d: %v", i, err)
}
// repr.Println(fex, repr.Indent(" "), repr.OmitEmpty(true))
}
}
func TestSqlLexer(t *testing.T) {
// s := bytes.NewBuffer([]byte("s.['name'].*.[*].abc.[\"abc\"]"))
s := bytes.NewBuffer([]byte("S3Object.words.*.id"))
// s := bytes.NewBuffer([]byte("COUNT(Id)"))
lex, err := sqlLexer.Lex(s)
if err != nil {
t.Fatal(err)
}
tokens, err := lexer.ConsumeAll(lex)
if err != nil {
t.Fatal(err)
}
// for i, t := range tokens {
// fmt.Printf("%d: %#v\n", i, t)
// }
if len(tokens) != 7 {
t.Fatalf("Expected 7 got %d", len(tokens))
}
}
func TestSelectWhere(t *testing.T) {
p := participle.MustBuild(
&Select{},
participle.Lexer(sqlLexer),
participle.CaseInsensitive("Keyword"),
)
s := Select{}
cases := []string{
"select * from s3object",
"select a, b from s3object s",
"select a, b from s3object as s",
"select a, b from s3object as s where a = 1",
"select a, b from s3object s where a = 1",
"select a, b from s3object where a = 1",
}
for i, tc := range cases {
err := p.ParseString(tc, &s)
if err != nil {
t.Fatalf("%d: %v", i, err)
}
// repr.Println(s, repr.Indent(" "), repr.OmitEmpty(true))
}
}
func TestLikeClause(t *testing.T) {
p := participle.MustBuild(
&Select{},
participle.Lexer(sqlLexer),
participle.CaseInsensitive("Keyword"),
)
s := Select{}
cases := []string{
`select * from s3object where Name like 'abcd'`,
`select Name like 'abc' from s3object`,
`select * from s3object where Name not like 'abc'`,
`select * from s3object where Name like 'abc' escape 't'`,
`select * from s3object where Name like 'a\%' escape '?'`,
`select * from s3object where Name not like 'abc\' escape '?'`,
`select * from s3object where Name like 'a\%' escape LOWER('?')`,
`select * from s3object where Name not like LOWER('Bc\') escape '?'`,
}
for i, tc := range cases {
err := p.ParseString(tc, &s)
if err != nil {
t.Errorf("%d: %v", i, err)
}
}
}
func TestBetweenClause(t *testing.T) {
p := participle.MustBuild(
&Select{},
participle.Lexer(sqlLexer),
participle.CaseInsensitive("Keyword"),
)
s := Select{}
cases := []string{
`select * from s3object where Id between 1 and 2`,
`select * from s3object where Id between 1 and 2 and name = 'Ab'`,
`select * from s3object where Id not between 1 and 2`,
`select * from s3object where Id not between 1 and 2 and name = 'Bc'`,
}
for i, tc := range cases {
err := p.ParseString(tc, &s)
if err != nil {
t.Errorf("%d: %v", i, err)
}
}
}
func TestFromClauseJSONPath(t *testing.T) {
p := participle.MustBuild(
&Select{},
participle.Lexer(sqlLexer),
participle.CaseInsensitive("Keyword"),
)
s := Select{}
cases := []string{
"select * from s3object",
"select * from s3object[*].name",
"select * from s3object[*].books[*]",
"select * from s3object[*].books[*].name",
"select * from s3object where name > 2",
"select * from s3object[*].name where name > 2",
"select * from s3object[*].books[*] where name > 2",
"select * from s3object[*].books[*].name where name > 2",
"select * from s3object[*].books[*] s",
"select * from s3object[*].books[*].name as s",
"select * from s3object s where name > 2",
"select * from s3object[*].name as s where name > 2",
}
for i, tc := range cases {
err := p.ParseString(tc, &s)
if err != nil {
t.Fatalf("%d: %v", i, err)
}
// repr.Println(s, repr.Indent(" "), repr.OmitEmpty(true))
}
}
func TestSelectParsing(t *testing.T) {
p := participle.MustBuild(
&Select{},
participle.Lexer(sqlLexer),
participle.CaseInsensitive("Keyword"),
)
s := Select{}
cases := []string{
"select * from s3object where name > 2 or value > 1 or word > 2",
"select s.word.id + 2 from s3object s",
"select 1-2-3 from s3object s limit 1",
}
for i, tc := range cases {
err := p.ParseString(tc, &s)
if err != nil {
t.Fatalf("%d: %v", i, err)
}
// repr.Println(s, repr.Indent(" "), repr.OmitEmpty(true))
}
}
func TestSqlLexerArithOps(t *testing.T) {
s := bytes.NewBuffer([]byte("year from select month hour distinct"))
lex, err := sqlLexer.Lex(s)
if err != nil {
t.Fatal(err)
}
tokens, err := lexer.ConsumeAll(lex)
if err != nil {
t.Fatal(err)
}
if len(tokens) != 7 {
t.Errorf("Expected 7 got %d", len(tokens))
}
// for i, t := range tokens {
// fmt.Printf("%d: %#v\n", i, t)
// }
}

View File

@@ -1,529 +0,0 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import (
"fmt"
"strings"
"github.com/xwb1989/sqlparser"
)
func getColumnName(colName *sqlparser.ColName) string {
columnName := colName.Qualifier.Name.String()
if qualifier := colName.Qualifier.Qualifier.String(); qualifier != "" {
columnName = qualifier + "." + columnName
}
if columnName == "" {
columnName = colName.Name.String()
} else {
columnName = columnName + "." + colName.Name.String()
}
return columnName
}
func newLiteralExpr(parserExpr sqlparser.Expr, tableAlias string) (Expr, error) {
switch parserExpr.(type) {
case *sqlparser.NullVal:
return newValueExpr(NewNull()), nil
case sqlparser.BoolVal:
return newValueExpr(NewBool((bool(parserExpr.(sqlparser.BoolVal))))), nil
case *sqlparser.SQLVal:
sqlValue := parserExpr.(*sqlparser.SQLVal)
value, err := NewValue(sqlValue)
if err != nil {
return nil, err
}
return newValueExpr(value), nil
case *sqlparser.ColName:
columnName := getColumnName(parserExpr.(*sqlparser.ColName))
if tableAlias != "" {
if !strings.HasPrefix(columnName, tableAlias+".") {
err := fmt.Errorf("column name %v does not start with table alias %v", columnName, tableAlias)
return nil, errInvalidKeyPath(err)
}
columnName = strings.TrimPrefix(columnName, tableAlias+".")
}
return newColumnExpr(columnName), nil
case sqlparser.ValTuple:
var valueType Type
var values []*Value
for i, valExpr := range parserExpr.(sqlparser.ValTuple) {
sqlVal, ok := valExpr.(*sqlparser.SQLVal)
if !ok {
return nil, errParseInvalidTypeParam(fmt.Errorf("value %v in Tuple should be primitive value", i+1))
}
val, err := NewValue(sqlVal)
if err != nil {
return nil, err
}
if i == 0 {
valueType = val.Type()
} else if valueType != val.Type() {
return nil, errParseInvalidTypeParam(fmt.Errorf("mixed value type is not allowed in Tuple"))
}
values = append(values, val)
}
return newValueExpr(NewArray(values)), nil
}
return nil, nil
}
func isExprToComparisonExpr(parserExpr *sqlparser.IsExpr, tableAlias string, isSelectExpr bool) (Expr, error) {
leftExpr, err := newExpr(parserExpr.Expr, tableAlias, isSelectExpr)
if err != nil {
return nil, err
}
f, err := newComparisonExpr(ComparisonOperator(parserExpr.Operator), leftExpr)
if err != nil {
return nil, err
}
if !leftExpr.Type().isBase() {
return f, nil
}
value, err := f.Eval(nil)
if err != nil {
return nil, err
}
return newValueExpr(value), nil
}
func rangeCondToComparisonFunc(parserExpr *sqlparser.RangeCond, tableAlias string, isSelectExpr bool) (Expr, error) {
leftExpr, err := newExpr(parserExpr.Left, tableAlias, isSelectExpr)
if err != nil {
return nil, err
}
fromExpr, err := newExpr(parserExpr.From, tableAlias, isSelectExpr)
if err != nil {
return nil, err
}
toExpr, err := newExpr(parserExpr.To, tableAlias, isSelectExpr)
if err != nil {
return nil, err
}
f, err := newComparisonExpr(ComparisonOperator(parserExpr.Operator), leftExpr, fromExpr, toExpr)
if err != nil {
return nil, err
}
if !leftExpr.Type().isBase() || !fromExpr.Type().isBase() || !toExpr.Type().isBase() {
return f, nil
}
value, err := f.Eval(nil)
if err != nil {
return nil, err
}
return newValueExpr(value), nil
}
func toComparisonExpr(parserExpr *sqlparser.ComparisonExpr, tableAlias string, isSelectExpr bool) (Expr, error) {
leftExpr, err := newExpr(parserExpr.Left, tableAlias, isSelectExpr)
if err != nil {
return nil, err
}
rightExpr, err := newExpr(parserExpr.Right, tableAlias, isSelectExpr)
if err != nil {
return nil, err
}
f, err := newComparisonExpr(ComparisonOperator(parserExpr.Operator), leftExpr, rightExpr)
if err != nil {
return nil, err
}
if !leftExpr.Type().isBase() || !rightExpr.Type().isBase() {
return f, nil
}
value, err := f.Eval(nil)
if err != nil {
return nil, err
}
return newValueExpr(value), nil
}
func toArithExpr(parserExpr *sqlparser.BinaryExpr, tableAlias string, isSelectExpr bool) (Expr, error) {
leftExpr, err := newExpr(parserExpr.Left, tableAlias, isSelectExpr)
if err != nil {
return nil, err
}
rightExpr, err := newExpr(parserExpr.Right, tableAlias, isSelectExpr)
if err != nil {
return nil, err
}
f, err := newArithExpr(ArithOperator(parserExpr.Operator), leftExpr, rightExpr)
if err != nil {
return nil, err
}
if !leftExpr.Type().isBase() || !rightExpr.Type().isBase() {
return f, nil
}
value, err := f.Eval(nil)
if err != nil {
return nil, err
}
return newValueExpr(value), nil
}
func toFuncExpr(parserExpr *sqlparser.FuncExpr, tableAlias string, isSelectExpr bool) (Expr, error) {
funcName := strings.ToUpper(parserExpr.Name.String())
if !isSelectExpr && isAggregateFuncName(funcName) {
return nil, errUnsupportedSQLOperation(fmt.Errorf("%v() must be used in select expression", funcName))
}
funcs, aggregatedExprFound, err := newSelectExprs(parserExpr.Exprs, tableAlias)
if err != nil {
return nil, err
}
if aggregatedExprFound {
return nil, errIncorrectSQLFunctionArgumentType(fmt.Errorf("%v(): aggregated expression must not be used as argument", funcName))
}
return newFuncExpr(FuncName(funcName), funcs...)
}
func toAndExpr(parserExpr *sqlparser.AndExpr, tableAlias string, isSelectExpr bool) (Expr, error) {
leftExpr, err := newExpr(parserExpr.Left, tableAlias, isSelectExpr)
if err != nil {
return nil, err
}
rightExpr, err := newExpr(parserExpr.Right, tableAlias, isSelectExpr)
if err != nil {
return nil, err
}
f, err := newAndExpr(leftExpr, rightExpr)
if err != nil {
return nil, err
}
if leftExpr.Type() != Bool || rightExpr.Type() != Bool {
return f, nil
}
value, err := f.Eval(nil)
if err != nil {
return nil, err
}
return newValueExpr(value), nil
}
func toOrExpr(parserExpr *sqlparser.OrExpr, tableAlias string, isSelectExpr bool) (Expr, error) {
leftExpr, err := newExpr(parserExpr.Left, tableAlias, isSelectExpr)
if err != nil {
return nil, err
}
rightExpr, err := newExpr(parserExpr.Right, tableAlias, isSelectExpr)
if err != nil {
return nil, err
}
f, err := newOrExpr(leftExpr, rightExpr)
if err != nil {
return nil, err
}
if leftExpr.Type() != Bool || rightExpr.Type() != Bool {
return f, nil
}
value, err := f.Eval(nil)
if err != nil {
return nil, err
}
return newValueExpr(value), nil
}
func toNotExpr(parserExpr *sqlparser.NotExpr, tableAlias string, isSelectExpr bool) (Expr, error) {
rightExpr, err := newExpr(parserExpr.Expr, tableAlias, isSelectExpr)
if err != nil {
return nil, err
}
f, err := newNotExpr(rightExpr)
if err != nil {
return nil, err
}
if rightExpr.Type() != Bool {
return f, nil
}
value, err := f.Eval(nil)
if err != nil {
return nil, err
}
return newValueExpr(value), nil
}
func newExpr(parserExpr sqlparser.Expr, tableAlias string, isSelectExpr bool) (Expr, error) {
f, err := newLiteralExpr(parserExpr, tableAlias)
if err != nil {
return nil, err
}
if f != nil {
return f, nil
}
switch parserExpr.(type) {
case *sqlparser.ParenExpr:
return newExpr(parserExpr.(*sqlparser.ParenExpr).Expr, tableAlias, isSelectExpr)
case *sqlparser.IsExpr:
return isExprToComparisonExpr(parserExpr.(*sqlparser.IsExpr), tableAlias, isSelectExpr)
case *sqlparser.RangeCond:
return rangeCondToComparisonFunc(parserExpr.(*sqlparser.RangeCond), tableAlias, isSelectExpr)
case *sqlparser.ComparisonExpr:
return toComparisonExpr(parserExpr.(*sqlparser.ComparisonExpr), tableAlias, isSelectExpr)
case *sqlparser.BinaryExpr:
return toArithExpr(parserExpr.(*sqlparser.BinaryExpr), tableAlias, isSelectExpr)
case *sqlparser.FuncExpr:
return toFuncExpr(parserExpr.(*sqlparser.FuncExpr), tableAlias, isSelectExpr)
case *sqlparser.AndExpr:
return toAndExpr(parserExpr.(*sqlparser.AndExpr), tableAlias, isSelectExpr)
case *sqlparser.OrExpr:
return toOrExpr(parserExpr.(*sqlparser.OrExpr), tableAlias, isSelectExpr)
case *sqlparser.NotExpr:
return toNotExpr(parserExpr.(*sqlparser.NotExpr), tableAlias, isSelectExpr)
}
return nil, errParseUnsupportedSyntax(fmt.Errorf("unknown expression type %T; %v", parserExpr, parserExpr))
}
func newSelectExprs(parserSelectExprs []sqlparser.SelectExpr, tableAlias string) ([]Expr, bool, error) {
var funcs []Expr
starExprFound := false
aggregatedExprFound := false
for _, selectExpr := range parserSelectExprs {
switch selectExpr.(type) {
case *sqlparser.AliasedExpr:
if starExprFound {
return nil, false, errParseAsteriskIsNotAloneInSelectList(nil)
}
aliasedExpr := selectExpr.(*sqlparser.AliasedExpr)
f, err := newExpr(aliasedExpr.Expr, tableAlias, true)
if err != nil {
return nil, false, err
}
if f.Type() == aggregateFunction {
if !aggregatedExprFound {
aggregatedExprFound = true
if len(funcs) > 0 {
return nil, false, errParseUnsupportedSyntax(fmt.Errorf("expression must not mixed with aggregated expression"))
}
}
} else if aggregatedExprFound {
return nil, false, errParseUnsupportedSyntax(fmt.Errorf("expression must not mixed with aggregated expression"))
}
alias := aliasedExpr.As.String()
if alias != "" {
f = newAliasExpr(alias, f)
}
funcs = append(funcs, f)
case *sqlparser.StarExpr:
if starExprFound {
err := fmt.Errorf("only single star expression allowed")
return nil, false, errParseInvalidContextForWildcardInSelectList(err)
}
starExprFound = true
funcs = append(funcs, newStarExpr())
default:
return nil, false, errParseUnsupportedSyntax(fmt.Errorf("unknown select expression %v", selectExpr))
}
}
return funcs, aggregatedExprFound, nil
}
// Select - SQL Select statement.
type Select struct {
tableName string
tableAlias string
selectExprs []Expr
aggregatedExprFound bool
whereExpr Expr
}
// TableAlias - returns table alias name.
func (statement *Select) TableAlias() string {
return statement.tableAlias
}
// IsSelectAll - returns whether '*' is used in select expression or not.
func (statement *Select) IsSelectAll() bool {
if len(statement.selectExprs) == 1 {
_, ok := statement.selectExprs[0].(*starExpr)
return ok
}
return false
}
// IsAggregated - returns whether aggregated functions are used in select expression or not.
func (statement *Select) IsAggregated() bool {
return statement.aggregatedExprFound
}
// AggregateResult - returns aggregate result as record.
func (statement *Select) AggregateResult(output Record) error {
if !statement.aggregatedExprFound {
return nil
}
for i, expr := range statement.selectExprs {
value, err := expr.AggregateValue()
if err != nil {
return err
}
if value == nil {
return errInternalError(fmt.Errorf("%v returns <nil> for AggregateValue()", expr))
}
name := fmt.Sprintf("_%v", i+1)
if _, ok := expr.(*aliasExpr); ok {
name = expr.(*aliasExpr).alias
}
if err = output.Set(name, value); err != nil {
return errInternalError(fmt.Errorf("error occurred to store value %v for %v; %v", value, name, err))
}
}
return nil
}
// Eval - evaluates this Select expressions for given record.
func (statement *Select) Eval(input, output Record) (Record, error) {
if statement.whereExpr != nil {
value, err := statement.whereExpr.Eval(input)
if err != nil {
return nil, err
}
if value == nil || value.valueType != Bool {
err = fmt.Errorf("WHERE expression %v returns invalid bool value %v", statement.whereExpr, value)
return nil, errInternalError(err)
}
if !value.BoolValue() {
return nil, nil
}
}
// Call selectExprs
for i, expr := range statement.selectExprs {
value, err := expr.Eval(input)
if err != nil {
return nil, err
}
if statement.aggregatedExprFound {
continue
}
name := fmt.Sprintf("_%v", i+1)
switch expr.(type) {
case *starExpr:
return value.recordValue(), nil
case *aliasExpr:
name = expr.(*aliasExpr).alias
case *columnExpr:
name = expr.(*columnExpr).name
}
if err = output.Set(name, value); err != nil {
return nil, errInternalError(fmt.Errorf("error occurred to store value %v for %v; %v", value, name, err))
}
}
return output, nil
}
// NewSelect - creates new Select by parsing sql.
func NewSelect(sql string) (*Select, error) {
stmt, err := sqlparser.Parse(sql)
if err != nil {
return nil, errUnsupportedSQLStructure(err)
}
selectStmt, ok := stmt.(*sqlparser.Select)
if !ok {
return nil, errParseUnsupportedSelect(fmt.Errorf("unsupported SQL statement %v", sql))
}
var tableName, tableAlias string
for _, fromExpr := range selectStmt.From {
tableExpr := fromExpr.(*sqlparser.AliasedTableExpr)
tableName = tableExpr.Expr.(sqlparser.TableName).Name.String()
tableAlias = tableExpr.As.String()
}
selectExprs, aggregatedExprFound, err := newSelectExprs(selectStmt.SelectExprs, tableAlias)
if err != nil {
return nil, err
}
var whereExpr Expr
if selectStmt.Where != nil {
whereExpr, err = newExpr(selectStmt.Where.Expr, tableAlias, false)
if err != nil {
return nil, err
}
}
return &Select{
tableName: tableName,
tableAlias: tableAlias,
selectExprs: selectExprs,
aggregatedExprFound: aggregatedExprFound,
whereExpr: whereExpr,
}, nil
}

View File

@@ -0,0 +1,202 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import (
"errors"
"fmt"
"strings"
)
var (
errBadLimitSpecified = errors.New("Limit value must be a positive integer")
)
// SelectStatement is the top level parsed and analyzed structure
type SelectStatement struct {
selectAST *Select
// Analysis result of the statement
selectQProp qProp
// Result of parsing the limit clause if one is present
// (otherwise -1)
limitValue int64
// Count of rows that have been output.
outputCount int64
}
// ParseSelectStatement - parses a select query from the given string
// and analyzes it.
func ParseSelectStatement(s string) (stmt SelectStatement, err error) {
var selectAST Select
err = SQLParser.ParseString(s, &selectAST)
if err != nil {
err = errQueryParseFailure(err)
return
}
stmt.selectAST = &selectAST
// Check the parsed limit value
stmt.limitValue, err = parseLimit(selectAST.Limit)
if err != nil {
err = errQueryAnalysisFailure(err)
return
}
// Analyze where clause
if selectAST.Where != nil {
whereQProp := selectAST.Where.analyze(&selectAST)
if whereQProp.err != nil {
err = errQueryAnalysisFailure(fmt.Errorf("Where clause error: %v", whereQProp.err))
return
}
if whereQProp.isAggregation {
err = errQueryAnalysisFailure(errors.New("WHERE clause cannot have an aggregation"))
return
}
}
// Validate table name
tableString := strings.ToLower(selectAST.From.Table.String())
if !strings.HasPrefix(tableString, "s3object.") && tableString != "s3object" {
err = errBadTableName(errors.New("Table name must be s3object"))
return
}
// Analyze main select expression
stmt.selectQProp = selectAST.Expression.analyze(&selectAST)
err = stmt.selectQProp.err
if err != nil {
fmt.Println("Got Analysis err:", err)
err = errQueryAnalysisFailure(err)
}
return
}
func parseLimit(v *LitValue) (int64, error) {
switch {
case v == nil:
return -1, nil
case v.Number == nil:
return -1, errBadLimitSpecified
default:
r := int64(*v.Number)
if r < 0 {
return -1, errBadLimitSpecified
}
return r, nil
}
}
// IsAggregated returns if the statement involves SQL aggregation
func (e *SelectStatement) IsAggregated() bool {
return e.selectQProp.isAggregation
}
// AggregateResult - returns the aggregated result after all input
// records have been processed. Applies only to aggregation queries.
func (e *SelectStatement) AggregateResult(output Record) error {
for i, expr := range e.selectAST.Expression.Expressions {
v, err := expr.evalNode(nil)
if err != nil {
return err
}
output.Set(fmt.Sprintf("_%d", i+1), v)
}
return nil
}
// AggregateRow - aggregates the input record. Applies only to
// aggregation queries.
func (e *SelectStatement) AggregateRow(input Record) error {
for _, expr := range e.selectAST.Expression.Expressions {
err := expr.aggregateRow(input)
if err != nil {
return err
}
}
return nil
}
// Eval - evaluates the Select statement for the given record. It
// applies only to non-aggregation queries.
func (e *SelectStatement) Eval(input, output Record) (Record, error) {
if whereExpr := e.selectAST.Where; whereExpr != nil {
value, err := whereExpr.evalNode(input)
if err != nil {
return nil, err
}
b, ok := value.ToBool()
if !ok {
err = fmt.Errorf("WHERE expression did not return bool")
return nil, err
}
if !b {
// Where clause is not satisfied by the row
return nil, nil
}
}
if e.selectAST.Expression.All {
// Return the input record for `SELECT * FROM
// .. WHERE ..`
// Update count of records output.
if e.limitValue > -1 {
e.outputCount++
}
return input, nil
}
for i, expr := range e.selectAST.Expression.Expressions {
v, err := expr.evalNode(input)
if err != nil {
return nil, err
}
// Pick output column names
if expr.As != "" {
output.Set(expr.As, v)
} else if comp, ok := getLastKeypathComponent(expr.Expression); ok {
output.Set(comp, v)
} else {
output.Set(fmt.Sprintf("_%d", i+1), v)
}
}
// Update count of records output.
if e.limitValue > -1 {
e.outputCount++
}
return output, nil
}
// LimitReached - returns true if the number of records output has
// reached the value of the `LIMIT` clause.
func (e *SelectStatement) LimitReached() bool {
if e.limitValue == -1 {
return false
}
return e.outputCount >= e.limitValue
}

View File

@@ -0,0 +1,188 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import (
"errors"
"strings"
)
var (
errMalformedEscapeSequence = errors.New("Malformed escape sequence in LIKE clause")
errInvalidTrimArg = errors.New("Trim argument is invalid - this should not happen")
errInvalidSubstringIndexLen = errors.New("Substring start index or length falls outside the string")
)
const (
percent rune = '%'
underscore rune = '_'
runeZero rune = 0
)
func evalSQLLike(text, pattern string, escape rune) (match bool, err error) {
s := []rune{}
prev := runeZero
hasLeadingPercent := false
patLen := len([]rune(pattern))
for i, r := range pattern {
if i > 0 && prev == escape {
switch r {
case percent, escape, underscore:
s = append(s, r)
prev = r
if r == escape {
prev = runeZero
}
default:
return false, errMalformedEscapeSequence
}
continue
}
prev = r
var ok bool
switch r {
case percent:
if len(s) == 0 {
hasLeadingPercent = true
continue
}
text, ok = matcher(text, string(s), hasLeadingPercent)
if !ok {
return false, nil
}
hasLeadingPercent = true
s = []rune{}
if i == patLen-1 {
// Last pattern character is a %, so
// we are done.
return true, nil
}
case underscore:
if len(s) == 0 {
text, ok = dropRune(text)
if !ok {
return false, nil
}
continue
}
text, ok = matcher(text, string(s), hasLeadingPercent)
if !ok {
return false, nil
}
hasLeadingPercent = false
text, ok = dropRune(text)
if !ok {
return false, nil
}
s = []rune{}
case escape:
if i == patLen-1 {
return false, errMalformedEscapeSequence
}
// Otherwise do nothing.
default:
s = append(s, r)
}
}
if hasLeadingPercent {
return strings.HasSuffix(text, string(s)), nil
}
return string(s) == text, nil
}
// matcher - Finds `pat` in `text`, and returns the part remainder of
// `text`, after the match. If leadingPercent is false, `pat` must be
// the prefix of `text`, otherwise it must be a substring.
func matcher(text, pat string, leadingPercent bool) (res string, match bool) {
if !leadingPercent {
res = strings.TrimPrefix(text, pat)
if len(text) == len(res) {
return "", false
}
} else {
parts := strings.SplitN(text, pat, 2)
if len(parts) == 1 {
return "", false
}
res = parts[1]
}
return res, true
}
func dropRune(text string) (res string, ok bool) {
r := []rune(text)
if len(r) == 0 {
return "", false
}
return string(r[1:]), true
}
func evalSQLSubstring(s string, startIdx, length int) (res string, err error) {
if startIdx <= 0 || startIdx > len(s) {
return "", errInvalidSubstringIndexLen
}
// StartIdx is 1-based in the input
startIdx--
rs := []rune(s)
endIdx := len(rs)
if length != -1 {
if length < 0 || startIdx+length > len(s) {
return "", errInvalidSubstringIndexLen
}
endIdx = startIdx + length
}
return string(rs[startIdx:endIdx]), nil
}
const (
trimLeading = "LEADING"
trimTrailing = "TRAILING"
trimBoth = "BOTH"
)
func evalSQLTrim(where *string, trimChars, text string) (result string, err error) {
cutSet := " "
if trimChars != "" {
cutSet = trimChars
}
trimFunc := strings.Trim
switch {
case where == nil:
case *where == trimBoth:
case *where == trimLeading:
trimFunc = strings.TrimLeft
case *where == trimTrailing:
trimFunc = strings.TrimRight
default:
return "", errInvalidTrimArg
}
return trimFunc(text, cutSet), nil
}

View File

@@ -0,0 +1,107 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import (
"testing"
)
func TestEvalSQLLike(t *testing.T) {
dropCases := []struct {
input, resultExpected string
matchExpected bool
}{
{"", "", false},
{"a", "", true},
{"ab", "b", true},
{"தமிழ்", "மிழ்", true},
}
for i, tc := range dropCases {
res, ok := dropRune(tc.input)
if res != tc.resultExpected || ok != tc.matchExpected {
t.Errorf("DropRune Case %d failed", i)
}
}
matcherCases := []struct {
iText, iPat string
iHasLeadingPercent bool
resultExpected string
matchExpected bool
}{
{"abcd", "bcd", false, "", false},
{"abcd", "bcd", true, "", true},
{"abcd", "abcd", false, "", true},
{"abcd", "abcd", true, "", true},
{"abcd", "ab", false, "cd", true},
{"abcd", "ab", true, "cd", true},
{"abcd", "bc", false, "", false},
{"abcd", "bc", true, "d", true},
}
for i, tc := range matcherCases {
res, ok := matcher(tc.iText, tc.iPat, tc.iHasLeadingPercent)
if res != tc.resultExpected || ok != tc.matchExpected {
t.Errorf("Matcher Case %d failed", i)
}
}
evalCases := []struct {
iText, iPat string
iEsc rune
matchExpected bool
errExpected error
}{
{"abcd", "abc", runeZero, false, nil},
{"abcd", "abcd", runeZero, true, nil},
{"abcd", "abc_", runeZero, true, nil},
{"abcd", "_bdd", runeZero, false, nil},
{"abcd", "_b_d", runeZero, true, nil},
{"abcd", "____", runeZero, true, nil},
{"abcd", "____%", runeZero, true, nil},
{"abcd", "%____", runeZero, true, nil},
{"abcd", "%__", runeZero, true, nil},
{"abcd", "____", runeZero, true, nil},
{"", "_", runeZero, false, nil},
{"", "%", runeZero, true, nil},
{"abcd", "%%%%%", runeZero, true, nil},
{"abcd", "_____", runeZero, false, nil},
{"abcd", "%%%%%", runeZero, true, nil},
{"a%%d", `a\%\%d`, '\\', true, nil},
{"a%%d", `a\%d`, '\\', false, nil},
{`a%%\d`, `a\%\%\\d`, '\\', true, nil},
{`a%%\`, `a\%\%\\`, '\\', true, nil},
{`a%__%\`, `a\%\_\_\%\\`, '\\', true, nil},
{`a%__%\`, `a\%\_\_\%_`, '\\', true, nil},
{`a%__%\`, `a\%\_\__`, '\\', false, nil},
{`a%__%\`, `a\%\_\_%`, '\\', true, nil},
{`a%__%\`, `a?%?_?_?%\`, '?', true, nil},
}
for i, tc := range evalCases {
// fmt.Println("Case:", i)
res, err := evalSQLLike(tc.iText, tc.iPat, tc.iEsc)
if res != tc.matchExpected || err != tc.errExpected {
t.Errorf("Eval Case %d failed: %v %v", i, res, err)
}
}
}

View File

@@ -1,118 +0,0 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
// Type - value type.
type Type string
const (
// Null - represents NULL value type.
Null Type = "null"
// Bool - represents boolean value type.
Bool Type = "bool"
// Int - represents integer value type.
Int Type = "int"
// Float - represents floating point value type.
Float Type = "float"
// String - represents string value type.
String Type = "string"
// Timestamp - represents time value type.
Timestamp Type = "timestamp"
// Array - represents array of values where each value type is one of above.
Array Type = "array"
column Type = "column"
record Type = "record"
function Type = "function"
aggregateFunction Type = "aggregatefunction"
arithmeticFunction Type = "arithmeticfunction"
comparisonFunction Type = "comparisonfunction"
logicalFunction Type = "logicalfunction"
// Integer Type = "integer" // Same as Int
// Decimal Type = "decimal" // Same as Float
// Numeric Type = "numeric" // Same as Float
)
func (t Type) isBase() bool {
switch t {
case Null, Bool, Int, Float, String, Timestamp:
return true
}
return false
}
func (t Type) isBaseKind() bool {
switch t {
case Null, Bool, Int, Float, String, Timestamp, column:
return true
}
return false
}
func (t Type) isNumber() bool {
switch t {
case Int, Float:
return true
}
return false
}
func (t Type) isNumberKind() bool {
switch t {
case Int, Float, column:
return true
}
return false
}
func (t Type) isIntKind() bool {
switch t {
case Int, column:
return true
}
return false
}
func (t Type) isBoolKind() bool {
switch t {
case Bool, column:
return true
}
return false
}
func (t Type) isStringKind() bool {
switch t {
case String, column:
return true
}
return false
}

87
pkg/s3select/sql/utils.go Normal file
View File

@@ -0,0 +1,87 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import (
"fmt"
"strings"
)
// String functions
// String - returns the JSONPath representation
func (e *JSONPath) String() string {
parts := make([]string, len(e.PathExpr)+1)
parts[0] = e.BaseKey.String()
for i, pe := range e.PathExpr {
parts[i+1] = pe.String()
}
return strings.Join(parts, "")
}
func (e *JSONPathElement) String() string {
switch {
case e.Key != nil:
return e.Key.String()
case e.Index != nil:
return fmt.Sprintf("[%d]", *e.Index)
case e.ObjectWildcard:
return ".*"
case e.ArrayWildcard:
return "[*]"
}
return ""
}
// Removes double quotes in quoted identifiers
func (i *Identifier) String() string {
if i.Unquoted != nil {
return *i.Unquoted
}
return string(*i.Quoted)
}
func (o *ObjectKey) String() string {
if o.Lit != nil {
return fmt.Sprintf("['%s']", string(*o.Lit))
}
return fmt.Sprintf(".%s", o.ID.String())
}
// getLastKeypathComponent checks if the given expression is a path
// expression, and if so extracts the last dot separated component of
// the path. Otherwise it returns false.
func getLastKeypathComponent(e *Expression) (string, bool) {
if len(e.And) > 1 ||
len(e.And[0].Condition) > 1 ||
e.And[0].Condition[0].Not != nil ||
e.And[0].Condition[0].Operand.ConditionRHS != nil {
return "", false
}
operand := e.And[0].Condition[0].Operand.Operand
if operand.Right != nil ||
operand.Left.Right != nil ||
operand.Left.Left.Negated != nil ||
operand.Left.Left.Primary.JPathExpr == nil {
return "", false
}
keypath := operand.Left.Left.Primary.JPathExpr.String()
ps := strings.Split(keypath, ".")
return ps[len(ps)-1], true
}

View File

@@ -17,220 +17,710 @@
package sql
import (
"encoding/json"
"errors"
"fmt"
"math"
"strconv"
"strings"
"time"
"github.com/xwb1989/sqlparser"
)
// Value - represents any primitive value of bool, int, float, string and time.
var (
errArithMismatchedTypes = errors.New("cannot perform arithmetic on mismatched types")
errArithInvalidOperator = errors.New("invalid arithmetic operator")
errArithDivideByZero = errors.New("cannot divide by 0")
errCmpMismatchedTypes = errors.New("cannot compare values of different types")
errCmpInvalidBoolOperator = errors.New("invalid comparison operator for boolean arguments")
)
// vType represents the concrete type of a `Value`
type vType int
// Valid values for Type
const (
typeNull vType = iota + 1
typeBool
typeString
// 64-bit signed integer
typeInt
// 64-bit floating point
typeFloat
// This type refers to untyped values, e.g. as read from CSV
typeBytes
)
// Value represents a value of restricted type reduced from an
// expression represented by an ASTNode. Only one of the fields is
// non-nil.
//
// In cases where we are fetching data from a data source (like csv),
// the type may not be determined yet. In these cases, a byte-slice is
// used.
type Value struct {
value interface{}
valueType Type
value interface{}
vType vType
}
// String - represents value as string.
func (value *Value) String() string {
if value.value == nil {
if value.valueType == Null {
return "NULL"
}
return "<nil>"
// GetTypeString returns a string representation for vType
func (v *Value) GetTypeString() string {
switch v.vType {
case typeNull:
return "NULL"
case typeBool:
return "BOOL"
case typeString:
return "STRING"
case typeInt:
return "INT"
case typeFloat:
return "FLOAT"
case typeBytes:
return "BYTES"
}
switch value.valueType {
case String:
return fmt.Sprintf("'%v'", value.value)
case Array:
var valueStrings []string
for _, v := range value.value.([]*Value) {
valueStrings = append(valueStrings, fmt.Sprintf("%v", v))
}
return fmt.Sprintf("(%v)", strings.Join(valueStrings, ","))
}
return fmt.Sprintf("%v", value.value)
return "--"
}
// CSVString - encodes to CSV string.
func (value *Value) CSVString() string {
if value.valueType == Null {
// Repr returns a string representation of value.
func (v *Value) Repr() string {
switch v.vType {
case typeNull:
return ":NULL"
case typeBool, typeInt, typeFloat:
return fmt.Sprintf("%v:%s", v.value, v.GetTypeString())
case typeString:
return fmt.Sprintf("\"%s\":%s", v.value.(string), v.GetTypeString())
case typeBytes:
return fmt.Sprintf("\"%s\":BYTES", string(v.value.([]byte)))
default:
return fmt.Sprintf("%v:INVALID", v.value)
}
}
// FromFloat creates a Value from a number
func FromFloat(f float64) *Value {
return &Value{value: f, vType: typeFloat}
}
// FromInt creates a Value from an int
func FromInt(f int64) *Value {
return &Value{value: f, vType: typeInt}
}
// FromString creates a Value from a string
func FromString(str string) *Value {
return &Value{value: str, vType: typeString}
}
// FromBool creates a Value from a bool
func FromBool(b bool) *Value {
return &Value{value: b, vType: typeBool}
}
// FromNull creates a Value with Null value
func FromNull() *Value {
return &Value{vType: typeNull}
}
// FromBytes creates a Value from a []byte
func FromBytes(b []byte) *Value {
return &Value{value: b, vType: typeBytes}
}
// ToFloat works for int and float values
func (v *Value) ToFloat() (val float64, ok bool) {
switch v.vType {
case typeFloat:
val, ok = v.value.(float64)
case typeInt:
var i int64
i, ok = v.value.(int64)
val = float64(i)
default:
}
return
}
// ToInt converts value to int.
func (v *Value) ToInt() (val int64, ok bool) {
switch v.vType {
case typeInt:
val, ok = v.value.(int64)
default:
}
return
}
// ToString converts value to string.
func (v *Value) ToString() (val string, ok bool) {
switch v.vType {
case typeString:
val, ok = v.value.(string)
default:
}
return
}
// ToBool returns the bool value; second return value refers to if the bool
// conversion succeeded.
func (v *Value) ToBool() (val bool, ok bool) {
switch v.vType {
case typeBool:
return v.value.(bool), true
}
return false, false
}
// ToBytes converts Value to byte-slice.
func (v *Value) ToBytes() ([]byte, bool) {
switch v.vType {
case typeBytes:
return v.value.([]byte), true
}
return nil, false
}
// IsNull - checks if value is missing.
func (v *Value) IsNull() bool {
return v.vType == typeNull
}
func (v *Value) isNumeric() bool {
return v.vType == typeInt || v.vType == typeFloat
}
// setters used internally to mutate values
func (v *Value) setInt(i int64) {
v.vType = typeInt
v.value = i
}
func (v *Value) setFloat(f float64) {
v.vType = typeFloat
v.value = f
}
func (v *Value) setString(s string) {
v.vType = typeString
v.value = s
}
func (v *Value) setBool(b bool) {
v.vType = typeBool
v.value = b
}
// CSVString - convert to string for CSV serialization
func (v *Value) CSVString() string {
switch v.vType {
case typeNull:
return ""
case typeBool:
return fmt.Sprintf("%v", v.value.(bool))
case typeString:
return fmt.Sprintf("%s", v.value.(string))
case typeInt:
return fmt.Sprintf("%v", v.value.(int64))
case typeFloat:
return fmt.Sprintf("%v", v.value.(float64))
case typeBytes:
return fmt.Sprintf("%v", string(v.value.([]byte)))
default:
return "CSV serialization not implemented for this type"
}
}
// floatToValue converts a float into int representation if needed.
func floatToValue(f float64) *Value {
intPart, fracPart := math.Modf(f)
if fracPart == 0 {
return FromInt(int64(intPart))
}
return FromFloat(f)
}
// Value comparison functions: we do not expose them outside the
// module. Logical operators "<", ">", ">=", "<=" work on strings and
// numbers. Equality operators "=", "!=" work on strings,
// numbers and booleans.
// Supported comparison operators
const (
opLt = "<"
opLte = "<="
opGt = ">"
opGte = ">="
opEq = "="
opIneq = "!="
)
// When numeric types are compared, type promotions could happen. If
// values do not have types (e.g. when reading from CSV), for
// comparison operations, automatic type conversion happens by trying
// to check if the value is a number (first an integer, then a float),
// and falling back to string.
func (v *Value) compareOp(op string, a *Value) (res bool, err error) {
if !isValidComparisonOperator(op) {
return false, errArithInvalidOperator
}
return fmt.Sprintf("%v", value.value)
// Check if type conversion/inference is needed - it is needed
// if the Value is a byte-slice.
err = inferTypesForCmp(v, a)
if err != nil {
return false, err
}
isNumeric := v.isNumeric() && a.isNumeric()
if isNumeric {
intV, ok1i := v.ToInt()
intA, ok2i := a.ToInt()
if ok1i && ok2i {
return intCompare(op, intV, intA), nil
}
// If both values are numeric, then at least one is
// float since we got here, so we convert.
flV, _ := v.ToFloat()
flA, _ := a.ToFloat()
return floatCompare(op, flV, flA), nil
}
strV, ok1s := v.ToString()
strA, ok2s := a.ToString()
if ok1s && ok2s {
return stringCompare(op, strV, strA), nil
}
boolV, ok1b := v.ToBool()
boolA, ok2b := v.ToBool()
if ok1b && ok2b {
return boolCompare(op, boolV, boolA)
}
return false, errCmpMismatchedTypes
}
// MarshalJSON - encodes to JSON data.
func (value *Value) MarshalJSON() ([]byte, error) {
return json.Marshal(value.value)
func inferTypesForCmp(a *Value, b *Value) error {
_, okA := a.ToBytes()
_, okB := b.ToBytes()
switch {
case !okA && !okB:
// Both Values already have types
return nil
case okA && okB:
// Both Values are untyped so try the types in order:
// int, float, bool, string
// Check for numeric inference
iA, okAi := a.bytesToInt()
iB, okBi := b.bytesToInt()
if okAi && okBi {
a.setInt(iA)
b.setInt(iB)
return nil
}
fA, okAf := a.bytesToFloat()
fB, okBf := b.bytesToFloat()
if okAf && okBf {
a.setFloat(fA)
b.setFloat(fB)
return nil
}
// Check if they int and float combination.
if okAi && okBf {
a.setInt(iA)
b.setFloat(fA)
return nil
}
if okBi && okAf {
a.setFloat(fA)
b.setInt(iB)
return nil
}
// Not numeric types at this point.
// Check for bool inference
bA, okAb := a.bytesToBool()
bB, okBb := b.bytesToBool()
if okAb && okBb {
a.setBool(bA)
b.setBool(bB)
return nil
}
// Fallback to string
sA := a.bytesToString()
sB := b.bytesToString()
a.setString(sA)
b.setString(sB)
return nil
case okA && !okB:
// Here a has `a` is untyped, but `b` has a fixed
// type.
switch b.vType {
case typeString:
s := a.bytesToString()
a.setString(s)
case typeInt, typeFloat:
if iA, ok := a.bytesToInt(); ok {
a.setInt(iA)
} else if fA, ok := a.bytesToFloat(); ok {
a.setFloat(fA)
} else {
return fmt.Errorf("Could not convert %s to a number", string(a.value.([]byte)))
}
case typeBool:
if bA, ok := a.bytesToBool(); ok {
a.setBool(bA)
} else {
return fmt.Errorf("Could not convert %s to a boolean", string(a.value.([]byte)))
}
default:
return errCmpMismatchedTypes
}
return nil
case !okA && okB:
// swap arguments to avoid repeating code
return inferTypesForCmp(b, a)
default:
// Does not happen
return nil
}
}
// NullValue - returns underlying null value. It panics if value is not null type.
func (value *Value) NullValue() *struct{} {
if value.valueType == Null {
// Value arithmetic functions: we do not expose them outside the
// module. All arithmetic works only on numeric values with automatic
// promotion to the "larger" type that can represent the value. TODO:
// Add support for large number arithmetic.
// Supported arithmetic operators
const (
opPlus = "+"
opMinus = "-"
opDivide = "/"
opMultiply = "*"
opModulo = "%"
)
// For arithmetic operations, if both values are numeric then the
// operation shall succeed. If the types are unknown automatic type
// conversion to a number is attempted.
func (v *Value) arithOp(op string, a *Value) error {
err := inferTypeForArithOp(v)
if err != nil {
return err
}
err = inferTypeForArithOp(a)
if err != nil {
return err
}
if !v.isNumeric() || !a.isNumeric() {
return errInvalidDataType(errArithMismatchedTypes)
}
if !isValidArithOperator(op) {
return errInvalidDataType(errArithMismatchedTypes)
}
intV, ok1i := v.ToInt()
intA, ok2i := a.ToInt()
switch {
case ok1i && ok2i:
res, err := intArithOp(op, intV, intA)
v.setInt(res)
return err
default:
// Convert arguments to float
flV, _ := v.ToFloat()
flA, _ := a.ToFloat()
res, err := floatArithOp(op, flV, flA)
v.setFloat(res)
return err
}
}
func inferTypeForArithOp(a *Value) error {
if _, ok := a.ToBytes(); !ok {
return nil
}
panic(fmt.Sprintf("requested bool value but found %T type", value.value))
}
// BoolValue - returns underlying bool value. It panics if value is not Bool type.
func (value *Value) BoolValue() bool {
if value.valueType == Bool {
return value.value.(bool)
if i, ok := a.bytesToInt(); ok {
a.setInt(i)
return nil
}
panic(fmt.Sprintf("requested bool value but found %T type", value.value))
}
// IntValue - returns underlying int value. It panics if value is not Int type.
func (value *Value) IntValue() int64 {
if value.valueType == Int {
return value.value.(int64)
if f, ok := a.bytesToFloat(); ok {
a.setFloat(f)
return nil
}
panic(fmt.Sprintf("requested int value but found %T type", value.value))
err := fmt.Errorf("Could not convert %s to a number", string(a.value.([]byte)))
return errInvalidDataType(err)
}
// FloatValue - returns underlying int/float value as float64. It panics if value is not Int/Float type.
func (value *Value) FloatValue() float64 {
switch value.valueType {
case Int:
return float64(value.value.(int64))
case Float:
return value.value.(float64)
// All the bytesTo* functions defined below assume the value is a byte-slice.
// Converts untyped value into int. The bool return implies success -
// it returns false only if there is a conversion failure.
func (v *Value) bytesToInt() (int64, bool) {
bytes, _ := v.ToBytes()
i, err := strconv.ParseInt(string(bytes), 10, 64)
return i, err == nil
}
// Converts untyped value into float. The bool return implies success
// - it returns false only if there is a conversion failure.
func (v *Value) bytesToFloat() (float64, bool) {
bytes, _ := v.ToBytes()
i, err := strconv.ParseFloat(string(bytes), 64)
return i, err == nil
}
// Converts untyped value into bool. The second bool return implies
// success - it returns false in case of a conversion failure.
func (v *Value) bytesToBool() (val bool, ok bool) {
bytes, _ := v.ToBytes()
ok = true
switch strings.ToLower(string(bytes)) {
case "t", "true":
val = true
case "f", "false":
val = false
default:
ok = false
}
return val, ok
}
// bytesToString - never fails
func (v *Value) bytesToString() string {
bytes, _ := v.ToBytes()
return string(bytes)
}
// Calculates minimum or maximum of v and a and assigns the result to
// v - it works only on numeric arguments, where `v` is already
// assumed to be numeric. Attempts conversion to numeric type for `a`
// (first int, then float) only if the underlying values do not have a
// type.
func (v *Value) minmax(a *Value, isMax, isFirstRow bool) error {
err := inferTypeForArithOp(a)
if err != nil {
return err
}
panic(fmt.Sprintf("requested float value but found %T type", value.value))
}
// StringValue - returns underlying string value. It panics if value is not String type.
func (value *Value) StringValue() string {
if value.valueType == String {
return value.value.(string)
if !a.isNumeric() {
return errArithMismatchedTypes
}
panic(fmt.Sprintf("requested string value but found %T type", value.value))
}
// TimeValue - returns underlying time value. It panics if value is not Timestamp type.
func (value *Value) TimeValue() time.Time {
if value.valueType == Timestamp {
return value.value.(time.Time)
}
panic(fmt.Sprintf("requested time value but found %T type", value.value))
}
// ArrayValue - returns underlying value array. It panics if value is not Array type.
func (value *Value) ArrayValue() []*Value {
if value.valueType == Array {
return value.value.([]*Value)
}
panic(fmt.Sprintf("requested array value but found %T type", value.value))
}
func (value *Value) recordValue() Record {
if value.valueType == record {
return value.value.(Record)
}
panic(fmt.Sprintf("requested record value but found %T type", value.value))
}
// Type - returns value type.
func (value *Value) Type() Type {
return value.valueType
}
// Value - returns underneath value interface.
func (value *Value) Value() interface{} {
return value.value
}
// NewNull - creates new null value.
func NewNull() *Value {
return &Value{nil, Null}
}
// NewBool - creates new Bool value of b.
func NewBool(b bool) *Value {
return &Value{b, Bool}
}
// NewInt - creates new Int value of i.
func NewInt(i int64) *Value {
return &Value{i, Int}
}
// NewFloat - creates new Float value of f.
func NewFloat(f float64) *Value {
return &Value{f, Float}
}
// NewString - creates new Sring value of s.
func NewString(s string) *Value {
return &Value{s, String}
}
// NewTime - creates new Time value of t.
func NewTime(t time.Time) *Value {
return &Value{t, Timestamp}
}
// NewArray - creates new Array value of values.
func NewArray(values []*Value) *Value {
return &Value{values, Array}
}
func newRecordValue(r Record) *Value {
return &Value{r, record}
}
// NewValue - creates new Value from SQLVal v.
func NewValue(v *sqlparser.SQLVal) (*Value, error) {
switch v.Type {
case sqlparser.StrVal:
return NewString(string(v.Val)), nil
case sqlparser.IntVal:
i64, err := strconv.ParseInt(string(v.Val), 10, 64)
if err != nil {
return nil, err
// In case of first row, set v to a.
if isFirstRow {
intA, okI := a.ToInt()
if okI {
v.setInt(intA)
return nil
}
return NewInt(i64), nil
case sqlparser.FloatVal:
f64, err := strconv.ParseFloat(string(v.Val), 64)
if err != nil {
return nil, err
}
return NewFloat(f64), nil
case sqlparser.HexNum: // represented as 0xDD
i64, err := strconv.ParseInt(string(v.Val), 16, 64)
if err != nil {
return nil, err
}
return NewInt(i64), nil
case sqlparser.HexVal: // represented as X'0DD'
i64, err := strconv.ParseInt(string(v.Val), 16, 64)
if err != nil {
return nil, err
}
return NewInt(i64), nil
case sqlparser.BitVal: // represented as B'00'
i64, err := strconv.ParseInt(string(v.Val), 2, 64)
if err != nil {
return nil, err
}
return NewInt(i64), nil
case sqlparser.ValArg:
// FIXME: the format is unknown and not sure how to handle it.
floatA, _ := a.ToFloat()
v.setFloat(floatA)
return nil
}
return nil, fmt.Errorf("unknown SQL value %v; %v ", v, v.Type)
intV, ok1i := v.ToInt()
intA, ok2i := a.ToInt()
if ok1i && ok2i {
result := intV
if !isMax {
if intA < result {
result = intA
}
} else {
if intA > result {
result = intA
}
}
v.setInt(result)
return nil
}
floatV, _ := v.ToFloat()
floatA, _ := a.ToFloat()
var result float64
if !isMax {
result = math.Min(floatV, floatA)
} else {
result = math.Max(floatV, floatA)
}
v.setFloat(result)
return nil
}
// inferTypeAsString is used to convert untyped values to string - it
// is called when the caller requires a string context to proceed.
func inferTypeAsString(v *Value) {
b, ok := v.ToBytes()
if !ok {
return
}
v.setString(string(b))
}
func isValidComparisonOperator(op string) bool {
switch op {
case opLt:
case opLte:
case opGt:
case opGte:
case opEq:
case opIneq:
default:
return false
}
return true
}
func intCompare(op string, left, right int64) bool {
switch op {
case opLt:
return left < right
case opLte:
return left <= right
case opGt:
return left > right
case opGte:
return left >= right
case opEq:
return left == right
case opIneq:
return left != right
}
// This case does not happen
return false
}
func floatCompare(op string, left, right float64) bool {
switch op {
case opLt:
return left < right
case opLte:
return left <= right
case opGt:
return left > right
case opGte:
return left >= right
case opEq:
return left == right
case opIneq:
return left != right
}
// This case does not happen
return false
}
func stringCompare(op string, left, right string) bool {
switch op {
case opLt:
return left < right
case opLte:
return left <= right
case opGt:
return left > right
case opGte:
return left >= right
case opEq:
return left == right
case opIneq:
return left != right
}
// This case does not happen
return false
}
func boolCompare(op string, left, right bool) (bool, error) {
switch op {
case opEq:
return left == right, nil
case opIneq:
return left != right, nil
default:
return false, errCmpInvalidBoolOperator
}
}
func isValidArithOperator(op string) bool {
switch op {
case opPlus:
case opMinus:
case opDivide:
case opMultiply:
case opModulo:
default:
return false
}
return true
}
// Overflow errors are ignored.
func intArithOp(op string, left, right int64) (int64, error) {
switch op {
case opPlus:
return left + right, nil
case opMinus:
return left - right, nil
case opDivide:
if right == 0 {
return 0, errArithDivideByZero
}
return left / right, nil
case opMultiply:
return left * right, nil
case opModulo:
if right == 0 {
return 0, errArithDivideByZero
}
return left % right, nil
}
// This does not happen
return 0, nil
}
// Overflow errors are ignored.
func floatArithOp(op string, left, right float64) (float64, error) {
switch op {
case opPlus:
return left + right, nil
case opMinus:
return left - right, nil
case opDivide:
if right == 0 {
return 0, errArithDivideByZero
}
return left / right, nil
case opMultiply:
return left * right, nil
case opModulo:
if right == 0 {
return 0, errArithDivideByZero
}
return math.Mod(left, right), nil
}
// This does not happen
return 0, nil
}