mirror of
https://github.com/minio/minio.git
synced 2025-11-10 22:10:12 -05:00
Add new SQL parser to support S3 Select syntax (#7102)
- New parser written from scratch, allows easier and complete parsing of the full S3 Select SQL syntax. Parser definition is directly provided by the AST defined for the SQL grammar. - Bring support to parse and interpret SQL involving JSON path expressions; evaluation of JSON path expressions will be subsequently added. - Bring automatic type inference and conversion for untyped values (e.g. CSV data).
This commit is contained in:
committed by
Harshavardhana
parent
0a28c28a8c
commit
2786055df4
@@ -32,7 +32,11 @@ type Record struct {
|
||||
nameIndexMap map[string]int64
|
||||
}
|
||||
|
||||
// Get - gets the value for a column name.
|
||||
// Get - gets the value for a column name. CSV fields do not have any
|
||||
// defined type (other than the default string). So this function
|
||||
// always returns fields using sql.FromBytes so that the type
|
||||
// specified/implied by the query can be used, or can be automatically
|
||||
// converted based on the query.
|
||||
func (r *Record) Get(name string) (*sql.Value, error) {
|
||||
index, found := r.nameIndexMap[name]
|
||||
if !found {
|
||||
@@ -40,11 +44,12 @@ func (r *Record) Get(name string) (*sql.Value, error) {
|
||||
}
|
||||
|
||||
if index >= int64(len(r.csvRecord)) {
|
||||
// No value found for column 'name', hence return empty string for compatibility.
|
||||
return sql.NewString(""), nil
|
||||
// No value found for column 'name', hence return null
|
||||
// value
|
||||
return sql.FromNull(), nil
|
||||
}
|
||||
|
||||
return sql.NewString(r.csvRecord[index]), nil
|
||||
return sql.FromBytes([]byte(r.csvRecord[index])), nil
|
||||
}
|
||||
|
||||
// Set - sets the value for a column name.
|
||||
|
||||
@@ -37,15 +37,15 @@ func (r *Record) Get(name string) (*sql.Value, error) {
|
||||
result := gjson.GetBytes(r.data, name)
|
||||
switch result.Type {
|
||||
case gjson.Null:
|
||||
return sql.NewNull(), nil
|
||||
return sql.FromNull(), nil
|
||||
case gjson.False:
|
||||
return sql.NewBool(false), nil
|
||||
return sql.FromBool(false), nil
|
||||
case gjson.Number:
|
||||
return sql.NewFloat(result.Float()), nil
|
||||
return sql.FromFloat(result.Float()), nil
|
||||
case gjson.String:
|
||||
return sql.NewString(result.String()), nil
|
||||
return sql.FromString(result.String()), nil
|
||||
case gjson.True:
|
||||
return sql.NewBool(true), nil
|
||||
return sql.FromBool(true), nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unsupported gjson value %v; %v", result, result.Type)
|
||||
@@ -54,19 +54,20 @@ func (r *Record) Get(name string) (*sql.Value, error) {
|
||||
// Set - sets the value for a column name.
|
||||
func (r *Record) Set(name string, value *sql.Value) (err error) {
|
||||
var v interface{}
|
||||
switch value.Type() {
|
||||
case sql.Null:
|
||||
v = value.NullValue()
|
||||
case sql.Bool:
|
||||
v = value.BoolValue()
|
||||
case sql.Int:
|
||||
v = value.IntValue()
|
||||
case sql.Float:
|
||||
v = value.FloatValue()
|
||||
case sql.String:
|
||||
v = value.StringValue()
|
||||
default:
|
||||
return fmt.Errorf("unsupported sql value %v and type %v", value, value.Type())
|
||||
if b, ok := value.ToBool(); ok {
|
||||
v = b
|
||||
} else if f, ok := value.ToFloat(); ok {
|
||||
v = f
|
||||
} else if i, ok := value.ToInt(); ok {
|
||||
v = i
|
||||
} else if s, ok := value.ToString(); ok {
|
||||
v = s
|
||||
} else if value.IsNull() {
|
||||
v = nil
|
||||
} else if b, ok := value.ToBytes(); ok {
|
||||
v = string(b)
|
||||
} else {
|
||||
return fmt.Errorf("unsupported sql value %v and type %v", value, value.GetTypeString())
|
||||
}
|
||||
|
||||
name = strings.Replace(name, "*", "__ALL__", -1)
|
||||
|
||||
@@ -32,7 +32,7 @@ type Reader struct {
|
||||
}
|
||||
|
||||
// Read - reads single record.
|
||||
func (r *Reader) Read() (sql.Record, error) {
|
||||
func (r *Reader) Read() (rec sql.Record, rerr error) {
|
||||
parquetRecord, err := r.file.Read()
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
@@ -43,39 +43,41 @@ func (r *Reader) Read() (sql.Record, error) {
|
||||
}
|
||||
|
||||
record := json.NewRecord()
|
||||
for name, v := range parquetRecord {
|
||||
f := func(name string, v parquetgo.Value) bool {
|
||||
if v.Value == nil {
|
||||
if err = record.Set(name, sql.NewNull()); err != nil {
|
||||
return nil, errParquetParsingError(err)
|
||||
if err := record.Set(name, sql.FromNull()); err != nil {
|
||||
rerr = errParquetParsingError(err)
|
||||
}
|
||||
|
||||
continue
|
||||
return rerr == nil
|
||||
}
|
||||
|
||||
var value *sql.Value
|
||||
switch v.Type {
|
||||
case parquetgen.Type_BOOLEAN:
|
||||
value = sql.NewBool(v.Value.(bool))
|
||||
value = sql.FromBool(v.Value.(bool))
|
||||
case parquetgen.Type_INT32:
|
||||
value = sql.NewInt(int64(v.Value.(int32)))
|
||||
value = sql.FromInt(int64(v.Value.(int32)))
|
||||
case parquetgen.Type_INT64:
|
||||
value = sql.NewInt(v.Value.(int64))
|
||||
value = sql.FromInt(int64(v.Value.(int64)))
|
||||
case parquetgen.Type_FLOAT:
|
||||
value = sql.NewFloat(float64(v.Value.(float32)))
|
||||
value = sql.FromFloat(float64(v.Value.(float32)))
|
||||
case parquetgen.Type_DOUBLE:
|
||||
value = sql.NewFloat(v.Value.(float64))
|
||||
value = sql.FromFloat(v.Value.(float64))
|
||||
case parquetgen.Type_INT96, parquetgen.Type_BYTE_ARRAY, parquetgen.Type_FIXED_LEN_BYTE_ARRAY:
|
||||
value = sql.NewString(string(v.Value.([]byte)))
|
||||
value = sql.FromString(string(v.Value.([]byte)))
|
||||
default:
|
||||
return nil, errParquetParsingError(nil)
|
||||
rerr = errParquetParsingError(nil)
|
||||
return false
|
||||
}
|
||||
|
||||
if err = record.Set(name, value); err != nil {
|
||||
return nil, errParquetParsingError(err)
|
||||
rerr = errParquetParsingError(err)
|
||||
}
|
||||
return rerr == nil
|
||||
}
|
||||
|
||||
return record, nil
|
||||
parquetRecord.Range(f)
|
||||
return record, rerr
|
||||
}
|
||||
|
||||
// Close - closes underlaying readers.
|
||||
|
||||
@@ -105,7 +105,7 @@ func (input *InputSerialization) UnmarshalXML(d *xml.Decoder, start xml.StartEle
|
||||
found++
|
||||
}
|
||||
if !parsedInput.ParquetArgs.IsEmpty() {
|
||||
if parsedInput.CompressionType != noneType {
|
||||
if parsedInput.CompressionType != "" && parsedInput.CompressionType != noneType {
|
||||
return errInvalidRequestParameter(fmt.Errorf("CompressionType must be NONE for Parquet format"))
|
||||
}
|
||||
|
||||
@@ -178,7 +178,7 @@ type S3Select struct {
|
||||
Output OutputSerialization `xml:"OutputSerialization"`
|
||||
Progress RequestProgress `xml:"RequestProgress"`
|
||||
|
||||
statement *sql.Select
|
||||
statement *sql.SelectStatement
|
||||
progressReader *progressReader
|
||||
recordReader recordReader
|
||||
}
|
||||
@@ -209,12 +209,12 @@ func (s3Select *S3Select) UnmarshalXML(d *xml.Decoder, start xml.StartElement) e
|
||||
return errMissingRequiredParameter(fmt.Errorf("OutputSerialization must be provided"))
|
||||
}
|
||||
|
||||
statement, err := sql.NewSelect(parsedS3Select.Expression)
|
||||
statement, err := sql.ParseSelectStatement(parsedS3Select.Expression)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
parsedS3Select.statement = statement
|
||||
parsedS3Select.statement = &statement
|
||||
|
||||
*s3Select = S3Select(parsedS3Select)
|
||||
return nil
|
||||
@@ -334,6 +334,14 @@ func (s3Select *S3Select) Evaluate(w http.ResponseWriter) {
|
||||
}
|
||||
|
||||
for {
|
||||
if s3Select.statement.LimitReached() {
|
||||
if err = writer.SendStats(s3Select.getProgress()); err != nil {
|
||||
// FIXME: log this error.
|
||||
err = nil
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if inputRecord, err = s3Select.recordReader.Read(); err != nil {
|
||||
if err != io.EOF {
|
||||
break
|
||||
@@ -358,19 +366,25 @@ func (s3Select *S3Select) Evaluate(w http.ResponseWriter) {
|
||||
break
|
||||
}
|
||||
|
||||
outputRecord = s3Select.outputRecord()
|
||||
if outputRecord, err = s3Select.statement.Eval(inputRecord, outputRecord); err != nil {
|
||||
break
|
||||
}
|
||||
if s3Select.statement.IsAggregated() {
|
||||
if err = s3Select.statement.AggregateRow(inputRecord); err != nil {
|
||||
break
|
||||
}
|
||||
} else {
|
||||
outputRecord = s3Select.outputRecord()
|
||||
if outputRecord, err = s3Select.statement.Eval(inputRecord, outputRecord); err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
if !s3Select.statement.IsAggregated() {
|
||||
if !sendRecord() {
|
||||
break
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
fmt.Printf("SQL Err: %#v\n", err)
|
||||
if serr := writer.SendError("InternalError", err.Error()); serr != nil {
|
||||
// FIXME: log errors.
|
||||
}
|
||||
|
||||
318
pkg/s3select/sql/aggregation.go
Normal file
318
pkg/s3select/sql/aggregation.go
Normal file
@@ -0,0 +1,318 @@
|
||||
/*
|
||||
* Minio Cloud Storage, (C) 2019 Minio, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Aggregation Function name constants
|
||||
const (
|
||||
aggFnAvg FuncName = "AVG"
|
||||
aggFnCount FuncName = "COUNT"
|
||||
aggFnMax FuncName = "MAX"
|
||||
aggFnMin FuncName = "MIN"
|
||||
aggFnSum FuncName = "SUM"
|
||||
)
|
||||
|
||||
var (
|
||||
errNonNumericArg = func(fnStr FuncName) error {
|
||||
return fmt.Errorf("%s() requires a numeric argument", fnStr)
|
||||
}
|
||||
errInvalidAggregation = errors.New("Invalid aggregation seen")
|
||||
)
|
||||
|
||||
type aggVal struct {
|
||||
runningSum *Value
|
||||
runningCount int64
|
||||
runningMax, runningMin *Value
|
||||
|
||||
// Stores if at least one record has been seen
|
||||
seen bool
|
||||
}
|
||||
|
||||
func newAggVal(fn FuncName) *aggVal {
|
||||
switch fn {
|
||||
case aggFnAvg, aggFnSum:
|
||||
return &aggVal{runningSum: FromInt(0)}
|
||||
case aggFnMin:
|
||||
return &aggVal{runningMin: FromInt(0)}
|
||||
case aggFnMax:
|
||||
return &aggVal{runningMax: FromInt(0)}
|
||||
default:
|
||||
return &aggVal{}
|
||||
}
|
||||
}
|
||||
|
||||
// evalAggregationNode - performs partial computation using the
|
||||
// current row and stores the result.
|
||||
//
|
||||
// On success, it returns (nil, nil).
|
||||
func (e *FuncExpr) evalAggregationNode(r Record) error {
|
||||
// It is assumed that this function is called only when
|
||||
// `e` is an aggregation function.
|
||||
|
||||
var val *Value
|
||||
var err error
|
||||
funcName := e.getFunctionName()
|
||||
if aggFnCount == funcName {
|
||||
if e.Count.StarArg {
|
||||
// Handle COUNT(*)
|
||||
e.aggregate.runningCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
val, err = e.Count.ExprArg.evalNode(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// Evaluate the (only) argument
|
||||
val, err = e.SFunc.ArgsList[0].evalNode(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if val.IsNull() {
|
||||
// E.g. the column or field does not exist in the
|
||||
// record - in all such cases the aggregation is not
|
||||
// updated.
|
||||
return nil
|
||||
}
|
||||
|
||||
argVal := val
|
||||
if funcName != aggFnCount {
|
||||
// All aggregation functions, except COUNT require a
|
||||
// numeric argument.
|
||||
|
||||
// Here, we diverge from Amazon S3 behavior by
|
||||
// inferring untyped values are numbers.
|
||||
if i, ok := argVal.bytesToInt(); ok {
|
||||
argVal.setInt(i)
|
||||
} else if f, ok := argVal.bytesToFloat(); ok {
|
||||
argVal.setFloat(f)
|
||||
} else {
|
||||
return errNonNumericArg(funcName)
|
||||
}
|
||||
}
|
||||
|
||||
// Mark that we have seen one non-null value.
|
||||
isFirstRow := false
|
||||
if !e.aggregate.seen {
|
||||
e.aggregate.seen = true
|
||||
isFirstRow = true
|
||||
}
|
||||
|
||||
switch funcName {
|
||||
case aggFnCount:
|
||||
// For all non-null values, the count is incremented.
|
||||
e.aggregate.runningCount++
|
||||
|
||||
case aggFnAvg:
|
||||
e.aggregate.runningCount++
|
||||
err = e.aggregate.runningSum.arithOp(opPlus, argVal)
|
||||
|
||||
case aggFnMin:
|
||||
err = e.aggregate.runningMin.minmax(argVal, false, isFirstRow)
|
||||
|
||||
case aggFnMax:
|
||||
err = e.aggregate.runningMax.minmax(argVal, true, isFirstRow)
|
||||
|
||||
case aggFnSum:
|
||||
err = e.aggregate.runningSum.arithOp(opPlus, argVal)
|
||||
|
||||
default:
|
||||
err = errInvalidAggregation
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (e *AliasedExpression) aggregateRow(r Record) error {
|
||||
return e.Expression.aggregateRow(r)
|
||||
}
|
||||
|
||||
func (e *Expression) aggregateRow(r Record) error {
|
||||
for _, ex := range e.And {
|
||||
err := ex.aggregateRow(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *AndCondition) aggregateRow(r Record) error {
|
||||
for _, ex := range e.Condition {
|
||||
err := ex.aggregateRow(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Condition) aggregateRow(r Record) error {
|
||||
if e.Operand != nil {
|
||||
return e.Operand.aggregateRow(r)
|
||||
}
|
||||
return e.Not.aggregateRow(r)
|
||||
}
|
||||
|
||||
func (e *ConditionOperand) aggregateRow(r Record) error {
|
||||
err := e.Operand.aggregateRow(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if e.ConditionRHS == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch {
|
||||
case e.ConditionRHS.Compare != nil:
|
||||
return e.ConditionRHS.Compare.Operand.aggregateRow(r)
|
||||
case e.ConditionRHS.Between != nil:
|
||||
err = e.ConditionRHS.Between.Start.aggregateRow(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return e.ConditionRHS.Between.End.aggregateRow(r)
|
||||
case e.ConditionRHS.In != nil:
|
||||
for _, elt := range e.ConditionRHS.In.Expressions {
|
||||
err = elt.aggregateRow(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
case e.ConditionRHS.Like != nil:
|
||||
err = e.ConditionRHS.Like.Pattern.aggregateRow(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return e.ConditionRHS.Like.EscapeChar.aggregateRow(r)
|
||||
default:
|
||||
return errInvalidASTNode
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Operand) aggregateRow(r Record) error {
|
||||
err := e.Left.aggregateRow(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, rt := range e.Right {
|
||||
err = rt.Right.aggregateRow(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *MultOp) aggregateRow(r Record) error {
|
||||
err := e.Left.aggregateRow(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, rt := range e.Right {
|
||||
err = rt.Right.aggregateRow(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *UnaryTerm) aggregateRow(r Record) error {
|
||||
if e.Negated != nil {
|
||||
return e.Negated.Term.aggregateRow(r)
|
||||
}
|
||||
return e.Primary.aggregateRow(r)
|
||||
}
|
||||
|
||||
func (e *PrimaryTerm) aggregateRow(r Record) error {
|
||||
switch {
|
||||
case e.SubExpression != nil:
|
||||
return e.SubExpression.aggregateRow(r)
|
||||
case e.FuncCall != nil:
|
||||
return e.FuncCall.aggregateRow(r)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *FuncExpr) aggregateRow(r Record) error {
|
||||
switch e.getFunctionName() {
|
||||
case aggFnAvg, aggFnSum, aggFnMax, aggFnMin, aggFnCount:
|
||||
return e.evalAggregationNode(r)
|
||||
default:
|
||||
// TODO: traverse arguments and call aggregateRow on
|
||||
// them if they could be an ancestor of an
|
||||
// aggregation.
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getAggregate() implementation for each AST node follows. This is
|
||||
// called after calling aggregateRow() on each input row, to calculate
|
||||
// the final aggregate result.
|
||||
|
||||
func (e *Expression) getAggregate() (*Value, error) {
|
||||
return e.evalNode(nil)
|
||||
}
|
||||
|
||||
func (e *FuncExpr) getAggregate() (*Value, error) {
|
||||
switch e.getFunctionName() {
|
||||
case aggFnCount:
|
||||
return FromFloat(float64(e.aggregate.runningCount)), nil
|
||||
|
||||
case aggFnAvg:
|
||||
if e.aggregate.runningCount == 0 {
|
||||
// No rows were seen by AVG.
|
||||
return FromNull(), nil
|
||||
}
|
||||
err := e.aggregate.runningSum.arithOp(opDivide, FromInt(e.aggregate.runningCount))
|
||||
return e.aggregate.runningSum, err
|
||||
|
||||
case aggFnMin:
|
||||
if !e.aggregate.seen {
|
||||
// No rows were seen by MIN
|
||||
return FromNull(), nil
|
||||
}
|
||||
return e.aggregate.runningMin, nil
|
||||
|
||||
case aggFnMax:
|
||||
if !e.aggregate.seen {
|
||||
// No rows were seen by MAX
|
||||
return FromNull(), nil
|
||||
}
|
||||
return e.aggregate.runningMax, nil
|
||||
|
||||
case aggFnSum:
|
||||
// TODO: check if returning 0 when no rows were seen
|
||||
// by SUM is expected behavior.
|
||||
return e.aggregate.runningSum, nil
|
||||
|
||||
default:
|
||||
// TODO:
|
||||
}
|
||||
|
||||
return nil, errInvalidAggregation
|
||||
}
|
||||
290
pkg/s3select/sql/analysis.go
Normal file
290
pkg/s3select/sql/analysis.go
Normal file
@@ -0,0 +1,290 @@
|
||||
/*
|
||||
* Minio Cloud Storage, (C) 2019 Minio, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Query analysis - The query is analyzed to determine if it involves
|
||||
// aggregation.
|
||||
//
|
||||
// Aggregation functions - An expression that involves aggregation of
|
||||
// rows in some manner. Requires all input rows to be processed,
|
||||
// before a result is returned.
|
||||
//
|
||||
// Row function - An expression that depends on a value in the
|
||||
// row. They have an output for each input row.
|
||||
//
|
||||
// Some types of a queries are not valid. For example, an aggregation
|
||||
// function combined with a row function is meaningless ("AVG(s.Age) +
|
||||
// s.Salary"). Analysis determines if such a scenario exists so an
|
||||
// error can be returned.
|
||||
|
||||
var (
|
||||
// Fatal error for query processing.
|
||||
errNestedAggregation = errors.New("Cannot nest aggregations")
|
||||
errFunctionNotImplemented = errors.New("Function is not yet implemented")
|
||||
errUnexpectedInvalidNode = errors.New("Unexpected node value")
|
||||
errInvalidKeypath = errors.New("A provided keypath is invalid")
|
||||
)
|
||||
|
||||
// qProp contains analysis info about an SQL term.
|
||||
type qProp struct {
|
||||
isAggregation, isRowFunc bool
|
||||
|
||||
err error
|
||||
}
|
||||
|
||||
// `combine` combines a pair of `qProp`s, so that errors are
|
||||
// propagated correctly, and checks that an aggregation is not being
|
||||
// combined with a row-function term.
|
||||
func (p *qProp) combine(q qProp) {
|
||||
switch {
|
||||
case p.err != nil:
|
||||
// Do nothing
|
||||
case q.err != nil:
|
||||
p.err = q.err
|
||||
default:
|
||||
p.isAggregation = p.isAggregation || q.isAggregation
|
||||
p.isRowFunc = p.isRowFunc || q.isRowFunc
|
||||
if p.isAggregation && p.isRowFunc {
|
||||
p.err = errNestedAggregation
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (e *SelectExpression) analyze(s *Select) (result qProp) {
|
||||
if e.All {
|
||||
return qProp{isRowFunc: true}
|
||||
}
|
||||
|
||||
for _, ex := range e.Expressions {
|
||||
result.combine(ex.analyze(s))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (e *AliasedExpression) analyze(s *Select) qProp {
|
||||
return e.Expression.analyze(s)
|
||||
}
|
||||
|
||||
func (e *Expression) analyze(s *Select) (result qProp) {
|
||||
for _, ac := range e.And {
|
||||
result.combine(ac.analyze(s))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (e *AndCondition) analyze(s *Select) (result qProp) {
|
||||
for _, ac := range e.Condition {
|
||||
result.combine(ac.analyze(s))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (e *Condition) analyze(s *Select) (result qProp) {
|
||||
if e.Operand != nil {
|
||||
result = e.Operand.analyze(s)
|
||||
} else {
|
||||
result = e.Not.analyze(s)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (e *ConditionOperand) analyze(s *Select) (result qProp) {
|
||||
if e.ConditionRHS == nil {
|
||||
result = e.Operand.analyze(s)
|
||||
} else {
|
||||
result.combine(e.Operand.analyze(s))
|
||||
result.combine(e.ConditionRHS.analyze(s))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (e *ConditionRHS) analyze(s *Select) (result qProp) {
|
||||
switch {
|
||||
case e.Compare != nil:
|
||||
result = e.Compare.Operand.analyze(s)
|
||||
case e.Between != nil:
|
||||
result.combine(e.Between.Start.analyze(s))
|
||||
result.combine(e.Between.End.analyze(s))
|
||||
case e.In != nil:
|
||||
for _, elt := range e.In.Expressions {
|
||||
result.combine(elt.analyze(s))
|
||||
}
|
||||
case e.Like != nil:
|
||||
result.combine(e.Like.Pattern.analyze(s))
|
||||
if e.Like.EscapeChar != nil {
|
||||
result.combine(e.Like.EscapeChar.analyze(s))
|
||||
}
|
||||
default:
|
||||
result = qProp{err: errUnexpectedInvalidNode}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (e *Operand) analyze(s *Select) (result qProp) {
|
||||
result.combine(e.Left.analyze(s))
|
||||
for _, r := range e.Right {
|
||||
result.combine(r.Right.analyze(s))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (e *MultOp) analyze(s *Select) (result qProp) {
|
||||
result.combine(e.Left.analyze(s))
|
||||
for _, r := range e.Right {
|
||||
result.combine(r.Right.analyze(s))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (e *UnaryTerm) analyze(s *Select) (result qProp) {
|
||||
if e.Negated != nil {
|
||||
result = e.Negated.Term.analyze(s)
|
||||
} else {
|
||||
result = e.Primary.analyze(s)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (e *PrimaryTerm) analyze(s *Select) (result qProp) {
|
||||
switch {
|
||||
case e.Value != nil:
|
||||
result = qProp{}
|
||||
|
||||
case e.JPathExpr != nil:
|
||||
// Check if the path expression is valid
|
||||
if len(e.JPathExpr.PathExpr) > 0 {
|
||||
if e.JPathExpr.BaseKey.String() != s.From.As {
|
||||
result = qProp{err: errInvalidKeypath}
|
||||
return
|
||||
}
|
||||
}
|
||||
result = qProp{isRowFunc: true}
|
||||
|
||||
case e.SubExpression != nil:
|
||||
result = e.SubExpression.analyze(s)
|
||||
|
||||
case e.FuncCall != nil:
|
||||
result = e.FuncCall.analyze(s)
|
||||
|
||||
default:
|
||||
result = qProp{err: errUnexpectedInvalidNode}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (e *FuncExpr) analyze(s *Select) (result qProp) {
|
||||
funcName := e.getFunctionName()
|
||||
|
||||
switch funcName {
|
||||
case sqlFnCast:
|
||||
return e.Cast.Expr.analyze(s)
|
||||
|
||||
case sqlFnExtract:
|
||||
return e.Extract.From.analyze(s)
|
||||
|
||||
// Handle aggregation function calls
|
||||
case aggFnAvg, aggFnMax, aggFnMin, aggFnSum, aggFnCount:
|
||||
// Initialize accumulator
|
||||
e.aggregate = newAggVal(funcName)
|
||||
|
||||
var exprA qProp
|
||||
if funcName == aggFnCount {
|
||||
if e.Count.StarArg {
|
||||
return qProp{isAggregation: true}
|
||||
}
|
||||
|
||||
exprA = e.Count.ExprArg.analyze(s)
|
||||
} else {
|
||||
if len(e.SFunc.ArgsList) != 1 {
|
||||
return qProp{err: fmt.Errorf("%s takes exactly one argument", funcName)}
|
||||
}
|
||||
exprA = e.SFunc.ArgsList[0].analyze(s)
|
||||
}
|
||||
|
||||
if exprA.err != nil {
|
||||
return exprA
|
||||
}
|
||||
if exprA.isAggregation {
|
||||
return qProp{err: errNestedAggregation}
|
||||
}
|
||||
return qProp{isAggregation: true}
|
||||
|
||||
case sqlFnCoalesce:
|
||||
if len(e.SFunc.ArgsList) == 0 {
|
||||
return qProp{err: fmt.Errorf("%s needs at least one argument", string(funcName))}
|
||||
}
|
||||
for _, arg := range e.SFunc.ArgsList {
|
||||
result.combine(arg.analyze(s))
|
||||
}
|
||||
return result
|
||||
|
||||
case sqlFnNullIf:
|
||||
if len(e.SFunc.ArgsList) != 2 {
|
||||
return qProp{err: fmt.Errorf("%s needs exactly 2 arguments", string(funcName))}
|
||||
}
|
||||
for _, arg := range e.SFunc.ArgsList {
|
||||
result.combine(arg.analyze(s))
|
||||
}
|
||||
return result
|
||||
|
||||
case sqlFnCharLength, sqlFnCharacterLength:
|
||||
if len(e.SFunc.ArgsList) != 1 {
|
||||
return qProp{err: fmt.Errorf("%s needs exactly 2 arguments", string(funcName))}
|
||||
}
|
||||
for _, arg := range e.SFunc.ArgsList {
|
||||
result.combine(arg.analyze(s))
|
||||
}
|
||||
return result
|
||||
|
||||
case sqlFnLower, sqlFnUpper:
|
||||
if len(e.SFunc.ArgsList) != 1 {
|
||||
return qProp{err: fmt.Errorf("%s needs exactly 2 arguments", string(funcName))}
|
||||
}
|
||||
for _, arg := range e.SFunc.ArgsList {
|
||||
result.combine(arg.analyze(s))
|
||||
}
|
||||
return result
|
||||
|
||||
case sqlFnSubstring:
|
||||
errVal := fmt.Errorf("Invalid argument(s) to %s", string(funcName))
|
||||
result.combine(e.Substring.Expr.analyze(s))
|
||||
switch {
|
||||
case e.Substring.From != nil:
|
||||
result.combine(e.Substring.From.analyze(s))
|
||||
if e.Substring.For != nil {
|
||||
result.combine(e.Substring.Expr.analyze(s))
|
||||
}
|
||||
case e.Substring.Arg2 != nil:
|
||||
result.combine(e.Substring.Arg2.analyze(s))
|
||||
if e.Substring.Arg3 != nil {
|
||||
result.combine(e.Substring.Arg3.analyze(s))
|
||||
}
|
||||
default:
|
||||
result.err = errVal
|
||||
}
|
||||
return result
|
||||
|
||||
}
|
||||
|
||||
// TODO: implement other functions
|
||||
return qProp{err: errFunctionNotImplemented}
|
||||
}
|
||||
@@ -1,175 +0,0 @@
|
||||
/*
|
||||
* Minio Cloud Storage, (C) 2019 Minio, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package sql
|
||||
|
||||
import "fmt"
|
||||
|
||||
// ArithOperator - arithmetic operator.
|
||||
type ArithOperator string
|
||||
|
||||
const (
|
||||
// Add operator '+'.
|
||||
Add ArithOperator = "+"
|
||||
|
||||
// Subtract operator '-'.
|
||||
Subtract ArithOperator = "-"
|
||||
|
||||
// Multiply operator '*'.
|
||||
Multiply ArithOperator = "*"
|
||||
|
||||
// Divide operator '/'.
|
||||
Divide ArithOperator = "/"
|
||||
|
||||
// Modulo operator '%'.
|
||||
Modulo ArithOperator = "%"
|
||||
)
|
||||
|
||||
// arithExpr - arithmetic function.
|
||||
type arithExpr struct {
|
||||
left Expr
|
||||
right Expr
|
||||
operator ArithOperator
|
||||
funcType Type
|
||||
}
|
||||
|
||||
// String - returns string representation of this function.
|
||||
func (f *arithExpr) String() string {
|
||||
return fmt.Sprintf("(%v %v %v)", f.left, f.operator, f.right)
|
||||
}
|
||||
|
||||
func (f *arithExpr) compute(lv, rv *Value) (*Value, error) {
|
||||
leftValueType := lv.Type()
|
||||
rightValueType := rv.Type()
|
||||
if !leftValueType.isNumber() {
|
||||
err := fmt.Errorf("%v: left side expression evaluated to %v; not to number", f, leftValueType)
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
if !rightValueType.isNumber() {
|
||||
err := fmt.Errorf("%v: right side expression evaluated to %v; not to number", f, rightValueType)
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
leftValue := lv.FloatValue()
|
||||
rightValue := rv.FloatValue()
|
||||
|
||||
var result float64
|
||||
switch f.operator {
|
||||
case Add:
|
||||
result = leftValue + rightValue
|
||||
case Subtract:
|
||||
result = leftValue - rightValue
|
||||
case Multiply:
|
||||
result = leftValue * rightValue
|
||||
case Divide:
|
||||
result = leftValue / rightValue
|
||||
case Modulo:
|
||||
result = float64(int64(leftValue) % int64(rightValue))
|
||||
}
|
||||
|
||||
if leftValueType == Float || rightValueType == Float {
|
||||
return NewFloat(result), nil
|
||||
}
|
||||
|
||||
return NewInt(int64(result)), nil
|
||||
}
|
||||
|
||||
// Call - evaluates this function for given arg values and returns result as Value.
|
||||
func (f *arithExpr) Eval(record Record) (*Value, error) {
|
||||
leftValue, err := f.left.Eval(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rightValue, err := f.right.Eval(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if f.funcType == aggregateFunction {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return f.compute(leftValue, rightValue)
|
||||
}
|
||||
|
||||
// AggregateValue - returns aggregated value.
|
||||
func (f *arithExpr) AggregateValue() (*Value, error) {
|
||||
if f.funcType != aggregateFunction {
|
||||
err := fmt.Errorf("%v is not aggreate expression", f)
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
lv, err := f.left.AggregateValue()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rv, err := f.right.AggregateValue()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return f.compute(lv, rv)
|
||||
}
|
||||
|
||||
// Type - returns arithmeticFunction or aggregateFunction type.
|
||||
func (f *arithExpr) Type() Type {
|
||||
return f.funcType
|
||||
}
|
||||
|
||||
// ReturnType - returns Float as return type.
|
||||
func (f *arithExpr) ReturnType() Type {
|
||||
return Float
|
||||
}
|
||||
|
||||
// newArithExpr - creates new arithmetic function.
|
||||
func newArithExpr(operator ArithOperator, left, right Expr) (*arithExpr, error) {
|
||||
if !left.ReturnType().isNumberKind() {
|
||||
err := fmt.Errorf("operator %v: left side expression %v evaluate to %v, not number", operator, left, left.ReturnType())
|
||||
return nil, errInvalidDataType(err)
|
||||
}
|
||||
|
||||
if !right.ReturnType().isNumberKind() {
|
||||
err := fmt.Errorf("operator %v: right side expression %v evaluate to %v; not number", operator, right, right.ReturnType())
|
||||
return nil, errInvalidDataType(err)
|
||||
}
|
||||
|
||||
funcType := arithmeticFunction
|
||||
if left.Type() == aggregateFunction || right.Type() == aggregateFunction {
|
||||
funcType = aggregateFunction
|
||||
switch left.Type() {
|
||||
case Int, Float, aggregateFunction:
|
||||
default:
|
||||
err := fmt.Errorf("operator %v: left side expression %v return type %v is incompatible for aggregate evaluation", operator, left, left.Type())
|
||||
return nil, errUnsupportedSQLOperation(err)
|
||||
}
|
||||
|
||||
switch right.Type() {
|
||||
case Int, Float, aggregateFunction:
|
||||
default:
|
||||
err := fmt.Errorf("operator %v: right side expression %v return type %v is incompatible for aggregate evaluation", operator, right, right.Type())
|
||||
return nil, errUnsupportedSQLOperation(err)
|
||||
}
|
||||
}
|
||||
|
||||
return &arithExpr{
|
||||
left: left,
|
||||
right: right,
|
||||
operator: operator,
|
||||
funcType: funcType,
|
||||
}, nil
|
||||
}
|
||||
@@ -1,636 +0,0 @@
|
||||
/*
|
||||
* Minio Cloud Storage, (C) 2019 Minio, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ComparisonOperator - comparison operator.
|
||||
type ComparisonOperator string
|
||||
|
||||
const (
|
||||
// Equal operator '='.
|
||||
Equal ComparisonOperator = "="
|
||||
|
||||
// NotEqual operator '!=' or '<>'.
|
||||
NotEqual ComparisonOperator = "!="
|
||||
|
||||
// LessThan operator '<'.
|
||||
LessThan ComparisonOperator = "<"
|
||||
|
||||
// GreaterThan operator '>'.
|
||||
GreaterThan ComparisonOperator = ">"
|
||||
|
||||
// LessThanEqual operator '<='.
|
||||
LessThanEqual ComparisonOperator = "<="
|
||||
|
||||
// GreaterThanEqual operator '>='.
|
||||
GreaterThanEqual ComparisonOperator = ">="
|
||||
|
||||
// Between operator 'BETWEEN'
|
||||
Between ComparisonOperator = "between"
|
||||
|
||||
// In operator 'IN'
|
||||
In ComparisonOperator = "in"
|
||||
|
||||
// Like operator 'LIKE'
|
||||
Like ComparisonOperator = "like"
|
||||
|
||||
// NotBetween operator 'NOT BETWEEN'
|
||||
NotBetween ComparisonOperator = "not between"
|
||||
|
||||
// NotIn operator 'NOT IN'
|
||||
NotIn ComparisonOperator = "not in"
|
||||
|
||||
// NotLike operator 'NOT LIKE'
|
||||
NotLike ComparisonOperator = "not like"
|
||||
|
||||
// IsNull operator 'IS NULL'
|
||||
IsNull ComparisonOperator = "is null"
|
||||
|
||||
// IsNotNull operator 'IS NOT NULL'
|
||||
IsNotNull ComparisonOperator = "is not null"
|
||||
)
|
||||
|
||||
// String - returns string representation of this operator.
|
||||
func (operator ComparisonOperator) String() string {
|
||||
return strings.ToUpper((string(operator)))
|
||||
}
|
||||
|
||||
func equal(leftValue, rightValue *Value) (bool, error) {
|
||||
switch {
|
||||
case leftValue.Type() == Null && rightValue.Type() == Null:
|
||||
return true, nil
|
||||
case leftValue.Type() == Bool && rightValue.Type() == Bool:
|
||||
return leftValue.BoolValue() == rightValue.BoolValue(), nil
|
||||
case (leftValue.Type() == Int || leftValue.Type() == Float) &&
|
||||
(rightValue.Type() == Int || rightValue.Type() == Float):
|
||||
return leftValue.FloatValue() == rightValue.FloatValue(), nil
|
||||
case leftValue.Type() == String && rightValue.Type() == String:
|
||||
return leftValue.StringValue() == rightValue.StringValue(), nil
|
||||
case leftValue.Type() == Timestamp && rightValue.Type() == Timestamp:
|
||||
return leftValue.TimeValue() == rightValue.TimeValue(), nil
|
||||
}
|
||||
|
||||
return false, fmt.Errorf("left value type %v and right value type %v are incompatible for equality check", leftValue.Type(), rightValue.Type())
|
||||
}
|
||||
|
||||
// comparisonExpr - comparison function.
|
||||
type comparisonExpr struct {
|
||||
left Expr
|
||||
right Expr
|
||||
to Expr
|
||||
operator ComparisonOperator
|
||||
funcType Type
|
||||
}
|
||||
|
||||
// String - returns string representation of this function.
|
||||
func (f *comparisonExpr) String() string {
|
||||
switch f.operator {
|
||||
case Equal, NotEqual, LessThan, GreaterThan, LessThanEqual, GreaterThanEqual, In, Like, NotIn, NotLike:
|
||||
return fmt.Sprintf("(%v %v %v)", f.left, f.operator, f.right)
|
||||
case Between, NotBetween:
|
||||
return fmt.Sprintf("(%v %v %v AND %v)", f.left, f.operator, f.right, f.to)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("(%v %v %v %v)", f.left, f.right, f.to, f.operator)
|
||||
}
|
||||
|
||||
func (f *comparisonExpr) equal(leftValue, rightValue *Value) (*Value, error) {
|
||||
result, err := equal(leftValue, rightValue)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("%v: %v", f, err)
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
return NewBool(result), nil
|
||||
}
|
||||
|
||||
func (f *comparisonExpr) notEqual(leftValue, rightValue *Value) (*Value, error) {
|
||||
result, err := equal(leftValue, rightValue)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("%v: %v", f, err)
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
return NewBool(!result), nil
|
||||
}
|
||||
|
||||
func (f *comparisonExpr) lessThan(leftValue, rightValue *Value) (*Value, error) {
|
||||
if !leftValue.Type().isNumber() {
|
||||
err := fmt.Errorf("%v: left side expression evaluated to %v; not to number", f, leftValue.Type())
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
if !rightValue.Type().isNumber() {
|
||||
err := fmt.Errorf("%v: right side expression evaluated to %v; not to number", f, rightValue.Type())
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
return NewBool(leftValue.FloatValue() < rightValue.FloatValue()), nil
|
||||
}
|
||||
|
||||
func (f *comparisonExpr) greaterThan(leftValue, rightValue *Value) (*Value, error) {
|
||||
if !leftValue.Type().isNumber() {
|
||||
err := fmt.Errorf("%v: left side expression evaluated to %v; not to number", f, leftValue.Type())
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
if !rightValue.Type().isNumber() {
|
||||
err := fmt.Errorf("%v: right side expression evaluated to %v; not to number", f, rightValue.Type())
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
return NewBool(leftValue.FloatValue() > rightValue.FloatValue()), nil
|
||||
}
|
||||
|
||||
func (f *comparisonExpr) lessThanEqual(leftValue, rightValue *Value) (*Value, error) {
|
||||
if !leftValue.Type().isNumber() {
|
||||
err := fmt.Errorf("%v: left side expression evaluated to %v; not to number", f, leftValue.Type())
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
if !rightValue.Type().isNumber() {
|
||||
err := fmt.Errorf("%v: right side expression evaluated to %v; not to number", f, rightValue.Type())
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
return NewBool(leftValue.FloatValue() <= rightValue.FloatValue()), nil
|
||||
}
|
||||
|
||||
func (f *comparisonExpr) greaterThanEqual(leftValue, rightValue *Value) (*Value, error) {
|
||||
if !leftValue.Type().isNumber() {
|
||||
err := fmt.Errorf("%v: left side expression evaluated to %v; not to number", f, leftValue.Type())
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
if !rightValue.Type().isNumber() {
|
||||
err := fmt.Errorf("%v: right side expression evaluated to %v; not to number", f, rightValue.Type())
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
return NewBool(leftValue.FloatValue() >= rightValue.FloatValue()), nil
|
||||
}
|
||||
|
||||
func (f *comparisonExpr) computeBetween(leftValue, fromValue, toValue *Value) (bool, error) {
|
||||
if !leftValue.Type().isNumber() {
|
||||
err := fmt.Errorf("%v: left side expression evaluated to %v; not to number", f, leftValue.Type())
|
||||
return false, errExternalEvalException(err)
|
||||
}
|
||||
if !fromValue.Type().isNumber() {
|
||||
err := fmt.Errorf("%v: from side expression evaluated to %v; not to number", f, fromValue.Type())
|
||||
return false, errExternalEvalException(err)
|
||||
}
|
||||
if !toValue.Type().isNumber() {
|
||||
err := fmt.Errorf("%v: to side expression evaluated to %v; not to number", f, toValue.Type())
|
||||
return false, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
return leftValue.FloatValue() >= fromValue.FloatValue() &&
|
||||
leftValue.FloatValue() <= toValue.FloatValue(), nil
|
||||
}
|
||||
|
||||
func (f *comparisonExpr) between(leftValue, fromValue, toValue *Value) (*Value, error) {
|
||||
result, err := f.computeBetween(leftValue, fromValue, toValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewBool(result), nil
|
||||
}
|
||||
|
||||
func (f *comparisonExpr) notBetween(leftValue, fromValue, toValue *Value) (*Value, error) {
|
||||
result, err := f.computeBetween(leftValue, fromValue, toValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewBool(!result), nil
|
||||
}
|
||||
|
||||
func (f *comparisonExpr) computeIn(leftValue, rightValue *Value) (found bool, err error) {
|
||||
if rightValue.Type() != Array {
|
||||
err := fmt.Errorf("%v: right side expression evaluated to %v; not to Array", f, rightValue.Type())
|
||||
return false, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
values := rightValue.ArrayValue()
|
||||
|
||||
for i := range values {
|
||||
found, err = equal(leftValue, values[i])
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if found {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (f *comparisonExpr) in(leftValue, rightValue *Value) (*Value, error) {
|
||||
result, err := f.computeIn(leftValue, rightValue)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("%v: %v", f, err)
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
return NewBool(result), nil
|
||||
}
|
||||
|
||||
func (f *comparisonExpr) notIn(leftValue, rightValue *Value) (*Value, error) {
|
||||
result, err := f.computeIn(leftValue, rightValue)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("%v: %v", f, err)
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
return NewBool(!result), nil
|
||||
}
|
||||
|
||||
func (f *comparisonExpr) computeLike(leftValue, rightValue *Value) (matched bool, err error) {
|
||||
if leftValue.Type() != String {
|
||||
err := fmt.Errorf("%v: left side expression evaluated to %v; not to string", f, leftValue.Type())
|
||||
return false, errExternalEvalException(err)
|
||||
}
|
||||
if rightValue.Type() != String {
|
||||
err := fmt.Errorf("%v: right side expression evaluated to %v; not to string", f, rightValue.Type())
|
||||
return false, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
matched, err = regexp.MatchString(rightValue.StringValue(), leftValue.StringValue())
|
||||
if err != nil {
|
||||
err = fmt.Errorf("%v: %v", f, err)
|
||||
return false, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
return matched, nil
|
||||
}
|
||||
|
||||
func (f *comparisonExpr) like(leftValue, rightValue *Value) (*Value, error) {
|
||||
result, err := f.computeLike(leftValue, rightValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewBool(result), nil
|
||||
}
|
||||
|
||||
func (f *comparisonExpr) notLike(leftValue, rightValue *Value) (*Value, error) {
|
||||
result, err := f.computeLike(leftValue, rightValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewBool(!result), nil
|
||||
}
|
||||
|
||||
func (f *comparisonExpr) compute(leftValue, rightValue, toValue *Value) (*Value, error) {
|
||||
switch f.operator {
|
||||
case Equal:
|
||||
return f.equal(leftValue, rightValue)
|
||||
case NotEqual:
|
||||
return f.notEqual(leftValue, rightValue)
|
||||
case LessThan:
|
||||
return f.lessThan(leftValue, rightValue)
|
||||
case GreaterThan:
|
||||
return f.greaterThan(leftValue, rightValue)
|
||||
case LessThanEqual:
|
||||
return f.lessThanEqual(leftValue, rightValue)
|
||||
case GreaterThanEqual:
|
||||
return f.greaterThanEqual(leftValue, rightValue)
|
||||
case Between:
|
||||
return f.between(leftValue, rightValue, toValue)
|
||||
case In:
|
||||
return f.in(leftValue, rightValue)
|
||||
case Like:
|
||||
return f.like(leftValue, rightValue)
|
||||
case NotBetween:
|
||||
return f.notBetween(leftValue, rightValue, toValue)
|
||||
case NotIn:
|
||||
return f.notIn(leftValue, rightValue)
|
||||
case NotLike:
|
||||
return f.notLike(leftValue, rightValue)
|
||||
}
|
||||
|
||||
panic(fmt.Errorf("unexpected expression %v", f))
|
||||
}
|
||||
|
||||
// Call - evaluates this function for given arg values and returns result as Value.
|
||||
func (f *comparisonExpr) Eval(record Record) (*Value, error) {
|
||||
leftValue, err := f.left.Eval(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rightValue, err := f.right.Eval(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var toValue *Value
|
||||
if f.to != nil {
|
||||
toValue, err = f.to.Eval(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if f.funcType == aggregateFunction {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return f.compute(leftValue, rightValue, toValue)
|
||||
}
|
||||
|
||||
// AggregateValue - returns aggregated value.
|
||||
func (f *comparisonExpr) AggregateValue() (*Value, error) {
|
||||
if f.funcType != aggregateFunction {
|
||||
err := fmt.Errorf("%v is not aggreate expression", f)
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
leftValue, err := f.left.AggregateValue()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rightValue, err := f.right.AggregateValue()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var toValue *Value
|
||||
if f.to != nil {
|
||||
toValue, err = f.to.AggregateValue()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return f.compute(leftValue, rightValue, toValue)
|
||||
}
|
||||
|
||||
// Type - returns comparisonFunction or aggregateFunction type.
|
||||
func (f *comparisonExpr) Type() Type {
|
||||
return f.funcType
|
||||
}
|
||||
|
||||
// ReturnType - returns Bool as return type.
|
||||
func (f *comparisonExpr) ReturnType() Type {
|
||||
return Bool
|
||||
}
|
||||
|
||||
// newComparisonExpr - creates new comparison function.
|
||||
func newComparisonExpr(operator ComparisonOperator, funcs ...Expr) (*comparisonExpr, error) {
|
||||
funcType := comparisonFunction
|
||||
switch operator {
|
||||
case Equal, NotEqual:
|
||||
if len(funcs) != 2 {
|
||||
panic(fmt.Sprintf("exactly two arguments are expected, but found %v", len(funcs)))
|
||||
}
|
||||
|
||||
left := funcs[0]
|
||||
if !left.ReturnType().isBaseKind() {
|
||||
err := fmt.Errorf("operator %v: left side expression %v evaluate to %v is incompatible for equality check", operator, left, left.ReturnType())
|
||||
return nil, errInvalidDataType(err)
|
||||
}
|
||||
|
||||
right := funcs[1]
|
||||
if !right.ReturnType().isBaseKind() {
|
||||
err := fmt.Errorf("operator %v: right side expression %v evaluate to %v is incompatible for equality check", operator, right, right.ReturnType())
|
||||
return nil, errInvalidDataType(err)
|
||||
}
|
||||
|
||||
if left.Type() == aggregateFunction || right.Type() == aggregateFunction {
|
||||
funcType = aggregateFunction
|
||||
switch left.Type() {
|
||||
case column, Array, function, arithmeticFunction, comparisonFunction, logicalFunction, record:
|
||||
err := fmt.Errorf("operator %v: left side expression %v return type %v is incompatible for equality check", operator, left, left.Type())
|
||||
return nil, errUnsupportedSQLOperation(err)
|
||||
}
|
||||
switch right.Type() {
|
||||
case column, Array, function, arithmeticFunction, comparisonFunction, logicalFunction, record:
|
||||
err := fmt.Errorf("operator %v: right side expression %v return type %v is incompatible for equality check", operator, right, right.Type())
|
||||
return nil, errUnsupportedSQLOperation(err)
|
||||
}
|
||||
}
|
||||
|
||||
return &comparisonExpr{
|
||||
left: left,
|
||||
right: right,
|
||||
operator: operator,
|
||||
funcType: funcType,
|
||||
}, nil
|
||||
|
||||
case LessThan, GreaterThan, LessThanEqual, GreaterThanEqual:
|
||||
if len(funcs) != 2 {
|
||||
panic(fmt.Sprintf("exactly two arguments are expected, but found %v", len(funcs)))
|
||||
}
|
||||
|
||||
left := funcs[0]
|
||||
if !left.ReturnType().isNumberKind() {
|
||||
err := fmt.Errorf("operator %v: left side expression %v evaluate to %v, not number", operator, left, left.ReturnType())
|
||||
return nil, errInvalidDataType(err)
|
||||
}
|
||||
|
||||
right := funcs[1]
|
||||
if !right.ReturnType().isNumberKind() {
|
||||
err := fmt.Errorf("operator %v: right side expression %v evaluate to %v; not number", operator, right, right.ReturnType())
|
||||
return nil, errInvalidDataType(err)
|
||||
}
|
||||
|
||||
if left.Type() == aggregateFunction || right.Type() == aggregateFunction {
|
||||
funcType = aggregateFunction
|
||||
switch left.Type() {
|
||||
case Int, Float, aggregateFunction:
|
||||
default:
|
||||
err := fmt.Errorf("operator %v: left side expression %v return type %v is incompatible for aggregate evaluation", operator, left, left.Type())
|
||||
return nil, errUnsupportedSQLOperation(err)
|
||||
}
|
||||
|
||||
switch right.Type() {
|
||||
case Int, Float, aggregateFunction:
|
||||
default:
|
||||
err := fmt.Errorf("operator %v: right side expression %v return type %v is incompatible for aggregate evaluation", operator, right, right.Type())
|
||||
return nil, errUnsupportedSQLOperation(err)
|
||||
}
|
||||
}
|
||||
|
||||
return &comparisonExpr{
|
||||
left: left,
|
||||
right: right,
|
||||
operator: operator,
|
||||
funcType: funcType,
|
||||
}, nil
|
||||
|
||||
case In, NotIn:
|
||||
if len(funcs) != 2 {
|
||||
panic(fmt.Sprintf("exactly two arguments are expected, but found %v", len(funcs)))
|
||||
}
|
||||
|
||||
left := funcs[0]
|
||||
if !left.ReturnType().isBaseKind() {
|
||||
err := fmt.Errorf("operator %v: left side expression %v evaluate to %v is incompatible for equality check", operator, left, left.ReturnType())
|
||||
return nil, errInvalidDataType(err)
|
||||
}
|
||||
|
||||
right := funcs[1]
|
||||
if right.ReturnType() != Array {
|
||||
err := fmt.Errorf("operator %v: right side expression %v evaluate to %v is incompatible for equality check", operator, right, right.ReturnType())
|
||||
return nil, errInvalidDataType(err)
|
||||
}
|
||||
|
||||
if left.Type() == aggregateFunction || right.Type() == aggregateFunction {
|
||||
funcType = aggregateFunction
|
||||
switch left.Type() {
|
||||
case column, Array, function, arithmeticFunction, comparisonFunction, logicalFunction, record:
|
||||
err := fmt.Errorf("operator %v: left side expression %v return type %v is incompatible for aggregate evaluation", operator, left, left.Type())
|
||||
return nil, errUnsupportedSQLOperation(err)
|
||||
}
|
||||
switch right.Type() {
|
||||
case Array, aggregateFunction:
|
||||
default:
|
||||
err := fmt.Errorf("operator %v: right side expression %v return type %v is incompatible for aggregate evaluation", operator, right, right.Type())
|
||||
return nil, errUnsupportedSQLOperation(err)
|
||||
}
|
||||
}
|
||||
|
||||
return &comparisonExpr{
|
||||
left: left,
|
||||
right: right,
|
||||
operator: operator,
|
||||
funcType: funcType,
|
||||
}, nil
|
||||
|
||||
case Like, NotLike:
|
||||
if len(funcs) != 2 {
|
||||
panic(fmt.Sprintf("exactly two arguments are expected, but found %v", len(funcs)))
|
||||
}
|
||||
|
||||
left := funcs[0]
|
||||
if !left.ReturnType().isStringKind() {
|
||||
err := fmt.Errorf("operator %v: left side expression %v evaluate to %v, not string", operator, left, left.ReturnType())
|
||||
return nil, errLikeInvalidInputs(err)
|
||||
}
|
||||
|
||||
right := funcs[1]
|
||||
if !right.ReturnType().isStringKind() {
|
||||
err := fmt.Errorf("operator %v: right side expression %v evaluate to %v, not string", operator, right, right.ReturnType())
|
||||
return nil, errLikeInvalidInputs(err)
|
||||
}
|
||||
|
||||
if left.Type() == aggregateFunction || right.Type() == aggregateFunction {
|
||||
funcType = aggregateFunction
|
||||
switch left.Type() {
|
||||
case String, aggregateFunction:
|
||||
default:
|
||||
err := fmt.Errorf("operator %v: left side expression %v return type %v is incompatible for aggregate evaluation", operator, left, left.Type())
|
||||
return nil, errUnsupportedSQLOperation(err)
|
||||
}
|
||||
switch right.Type() {
|
||||
case String, aggregateFunction:
|
||||
default:
|
||||
err := fmt.Errorf("operator %v: right side expression %v return type %v is incompatible for aggregate evaluation", operator, right, right.Type())
|
||||
return nil, errUnsupportedSQLOperation(err)
|
||||
}
|
||||
}
|
||||
|
||||
return &comparisonExpr{
|
||||
left: left,
|
||||
right: right,
|
||||
operator: operator,
|
||||
funcType: funcType,
|
||||
}, nil
|
||||
case Between, NotBetween:
|
||||
if len(funcs) != 3 {
|
||||
panic(fmt.Sprintf("too many values in funcs %v", funcs))
|
||||
}
|
||||
|
||||
left := funcs[0]
|
||||
if !left.ReturnType().isNumberKind() {
|
||||
err := fmt.Errorf("operator %v: left side expression %v evaluate to %v, not number", operator, left, left.ReturnType())
|
||||
return nil, errInvalidDataType(err)
|
||||
}
|
||||
|
||||
from := funcs[1]
|
||||
if !from.ReturnType().isNumberKind() {
|
||||
err := fmt.Errorf("operator %v: from expression %v evaluate to %v, not number", operator, from, from.ReturnType())
|
||||
return nil, errInvalidDataType(err)
|
||||
}
|
||||
|
||||
to := funcs[2]
|
||||
if !to.ReturnType().isNumberKind() {
|
||||
err := fmt.Errorf("operator %v: to expression %v evaluate to %v, not number", operator, to, to.ReturnType())
|
||||
return nil, errInvalidDataType(err)
|
||||
}
|
||||
|
||||
if left.Type() == aggregateFunction || from.Type() == aggregateFunction || to.Type() == aggregateFunction {
|
||||
funcType = aggregateFunction
|
||||
switch left.Type() {
|
||||
case Int, Float, aggregateFunction:
|
||||
default:
|
||||
err := fmt.Errorf("operator %v: left side expression %v return type %v is incompatible for aggregate evaluation", operator, left, left.Type())
|
||||
return nil, errUnsupportedSQLOperation(err)
|
||||
}
|
||||
switch from.Type() {
|
||||
case Int, Float, aggregateFunction:
|
||||
default:
|
||||
err := fmt.Errorf("operator %v: from expression %v return type %v is incompatible for aggregate evaluation", operator, from, from.Type())
|
||||
return nil, errUnsupportedSQLOperation(err)
|
||||
}
|
||||
switch to.Type() {
|
||||
case Int, Float, aggregateFunction:
|
||||
default:
|
||||
err := fmt.Errorf("operator %v: to expression %v return type %v is incompatible for aggregate evaluation", operator, to, to.Type())
|
||||
return nil, errUnsupportedSQLOperation(err)
|
||||
}
|
||||
}
|
||||
|
||||
return &comparisonExpr{
|
||||
left: left,
|
||||
right: from,
|
||||
to: to,
|
||||
operator: operator,
|
||||
funcType: funcType,
|
||||
}, nil
|
||||
case IsNull, IsNotNull:
|
||||
if len(funcs) != 1 {
|
||||
panic(fmt.Sprintf("too many values in funcs %v", funcs))
|
||||
}
|
||||
|
||||
if funcs[0].Type() == aggregateFunction {
|
||||
funcType = aggregateFunction
|
||||
}
|
||||
|
||||
if operator == IsNull {
|
||||
operator = Equal
|
||||
} else {
|
||||
operator = NotEqual
|
||||
}
|
||||
|
||||
return &comparisonExpr{
|
||||
left: funcs[0],
|
||||
right: newValueExpr(NewNull()),
|
||||
operator: operator,
|
||||
funcType: funcType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, errParseUnknownOperator(fmt.Errorf("unknown operator %v", operator))
|
||||
}
|
||||
@@ -43,42 +43,6 @@ func (err *s3Error) Error() string {
|
||||
return err.message
|
||||
}
|
||||
|
||||
func errUnsupportedSQLStructure(err error) *s3Error {
|
||||
return &s3Error{
|
||||
code: "UnsupportedSqlStructure",
|
||||
message: "Encountered an unsupported SQL structure. Check the SQL Reference.",
|
||||
statusCode: 400,
|
||||
cause: err,
|
||||
}
|
||||
}
|
||||
|
||||
func errParseUnsupportedSelect(err error) *s3Error {
|
||||
return &s3Error{
|
||||
code: "ParseUnsupportedSelect",
|
||||
message: "The SQL expression contains an unsupported use of SELECT.",
|
||||
statusCode: 400,
|
||||
cause: err,
|
||||
}
|
||||
}
|
||||
|
||||
func errParseAsteriskIsNotAloneInSelectList(err error) *s3Error {
|
||||
return &s3Error{
|
||||
code: "ParseAsteriskIsNotAloneInSelectList",
|
||||
message: "Other expressions are not allowed in the SELECT list when '*' is used without dot notation in the SQL expression.",
|
||||
statusCode: 400,
|
||||
cause: err,
|
||||
}
|
||||
}
|
||||
|
||||
func errParseInvalidContextForWildcardInSelectList(err error) *s3Error {
|
||||
return &s3Error{
|
||||
code: "ParseInvalidContextForWildcardInSelectList",
|
||||
message: "Invalid use of * in SELECT list in the SQL expression.",
|
||||
statusCode: 400,
|
||||
cause: err,
|
||||
}
|
||||
}
|
||||
|
||||
func errInvalidDataType(err error) *s3Error {
|
||||
return &s3Error{
|
||||
code: "InvalidDataType",
|
||||
@@ -88,55 +52,10 @@ func errInvalidDataType(err error) *s3Error {
|
||||
}
|
||||
}
|
||||
|
||||
func errUnsupportedFunction(err error) *s3Error {
|
||||
return &s3Error{
|
||||
code: "UnsupportedFunction",
|
||||
message: "Encountered an unsupported SQL function.",
|
||||
statusCode: 400,
|
||||
cause: err,
|
||||
}
|
||||
}
|
||||
|
||||
func errParseNonUnaryAgregateFunctionCall(err error) *s3Error {
|
||||
return &s3Error{
|
||||
code: "ParseNonUnaryAgregateFunctionCall",
|
||||
message: "Only one argument is supported for aggregate functions in the SQL expression.",
|
||||
statusCode: 400,
|
||||
cause: err,
|
||||
}
|
||||
}
|
||||
|
||||
func errIncorrectSQLFunctionArgumentType(err error) *s3Error {
|
||||
return &s3Error{
|
||||
code: "IncorrectSqlFunctionArgumentType",
|
||||
message: "Incorrect type of arguments in function call in the SQL expression.",
|
||||
statusCode: 400,
|
||||
cause: err,
|
||||
}
|
||||
}
|
||||
|
||||
func errEvaluatorInvalidArguments(err error) *s3Error {
|
||||
return &s3Error{
|
||||
code: "EvaluatorInvalidArguments",
|
||||
message: "Incorrect number of arguments in the function call in the SQL expression.",
|
||||
statusCode: 400,
|
||||
cause: err,
|
||||
}
|
||||
}
|
||||
|
||||
func errUnsupportedSQLOperation(err error) *s3Error {
|
||||
return &s3Error{
|
||||
code: "UnsupportedSqlOperation",
|
||||
message: "Encountered an unsupported SQL operation.",
|
||||
statusCode: 400,
|
||||
cause: err,
|
||||
}
|
||||
}
|
||||
|
||||
func errParseUnknownOperator(err error) *s3Error {
|
||||
return &s3Error{
|
||||
code: "ParseUnknownOperator",
|
||||
message: "The SQL expression contains an invalid operator.",
|
||||
message: "Incorrect type of arguments in function call.",
|
||||
statusCode: 400,
|
||||
cause: err,
|
||||
}
|
||||
@@ -151,64 +70,28 @@ func errLikeInvalidInputs(err error) *s3Error {
|
||||
}
|
||||
}
|
||||
|
||||
func errExternalEvalException(err error) *s3Error {
|
||||
func errQueryParseFailure(err error) *s3Error {
|
||||
return &s3Error{
|
||||
code: "ExternalEvalException",
|
||||
message: "The query cannot be evaluated. Check the file and try again.",
|
||||
code: "ParseSelectFailure",
|
||||
message: err.Error(),
|
||||
statusCode: 400,
|
||||
cause: err,
|
||||
}
|
||||
}
|
||||
|
||||
func errValueParseFailure(err error) *s3Error {
|
||||
func errQueryAnalysisFailure(err error) *s3Error {
|
||||
return &s3Error{
|
||||
code: "ValueParseFailure",
|
||||
message: "Time stamp parse failure in the SQL expression.",
|
||||
code: "InvalidQuery",
|
||||
message: err.Error(),
|
||||
statusCode: 400,
|
||||
cause: err,
|
||||
}
|
||||
}
|
||||
|
||||
func errEvaluatorBindingDoesNotExist(err error) *s3Error {
|
||||
func errBadTableName(err error) *s3Error {
|
||||
return &s3Error{
|
||||
code: "EvaluatorBindingDoesNotExist",
|
||||
message: "A column name or a path provided does not exist in the SQL expression.",
|
||||
statusCode: 400,
|
||||
cause: err,
|
||||
}
|
||||
}
|
||||
|
||||
func errInternalError(err error) *s3Error {
|
||||
return &s3Error{
|
||||
code: "InternalError",
|
||||
message: "Encountered an internal error.",
|
||||
statusCode: 500,
|
||||
cause: err,
|
||||
}
|
||||
}
|
||||
|
||||
func errParseInvalidTypeParam(err error) *s3Error {
|
||||
return &s3Error{
|
||||
code: "ParseInvalidTypeParam",
|
||||
message: "The SQL expression contains an invalid parameter value.",
|
||||
statusCode: 400,
|
||||
cause: err,
|
||||
}
|
||||
}
|
||||
|
||||
func errParseUnsupportedSyntax(err error) *s3Error {
|
||||
return &s3Error{
|
||||
code: "ParseUnsupportedSyntax",
|
||||
message: "The SQL expression contains unsupported syntax.",
|
||||
statusCode: 400,
|
||||
cause: err,
|
||||
}
|
||||
}
|
||||
|
||||
func errInvalidKeyPath(err error) *s3Error {
|
||||
return &s3Error{
|
||||
code: "InvalidKeyPath",
|
||||
message: "Key path in the SQL expression is invalid.",
|
||||
code: "BadTableName",
|
||||
message: "The table name is not supported",
|
||||
statusCode: 400,
|
||||
cause: err,
|
||||
}
|
||||
|
||||
361
pkg/s3select/sql/evaluate.go
Normal file
361
pkg/s3select/sql/evaluate.go
Normal file
@@ -0,0 +1,361 @@
|
||||
/*
|
||||
* Minio Cloud Storage, (C) 2019 Minio, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
errInvalidASTNode = errors.New("invalid AST Node")
|
||||
errExpectedBool = errors.New("expected bool")
|
||||
errLikeNonStrArg = errors.New("LIKE clause requires string arguments")
|
||||
errLikeInvalidEscape = errors.New("LIKE clause has invalid ESCAPE character")
|
||||
errNotImplemented = errors.New("not implemented")
|
||||
)
|
||||
|
||||
// AST Node Evaluation functions
|
||||
//
|
||||
// During evaluation, the query is known to be valid, as analysis is
|
||||
// complete. The only errors possible are due to value type
|
||||
// mismatches, etc.
|
||||
//
|
||||
// If an aggregation node is present as a descendant (when
|
||||
// e.prop.isAggregation is true), we call evalNode on all child nodes,
|
||||
// check for errors, but do not perform any combining of the results
|
||||
// of child nodes. The final result row is returned after all rows are
|
||||
// processed, and the `getAggregate` function is called.
|
||||
|
||||
func (e *AliasedExpression) evalNode(r Record) (*Value, error) {
|
||||
return e.Expression.evalNode(r)
|
||||
}
|
||||
|
||||
func (e *Expression) evalNode(r Record) (*Value, error) {
|
||||
if len(e.And) == 1 {
|
||||
// In this case, result is not required to be boolean
|
||||
// type.
|
||||
return e.And[0].evalNode(r)
|
||||
}
|
||||
|
||||
// Compute OR of conditions
|
||||
result := false
|
||||
for _, ex := range e.And {
|
||||
res, err := ex.evalNode(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b, ok := res.ToBool()
|
||||
if !ok {
|
||||
return nil, errExpectedBool
|
||||
}
|
||||
result = result || b
|
||||
}
|
||||
return FromBool(result), nil
|
||||
}
|
||||
|
||||
func (e *AndCondition) evalNode(r Record) (*Value, error) {
|
||||
if len(e.Condition) == 1 {
|
||||
// In this case, result does not have to be boolean
|
||||
return e.Condition[0].evalNode(r)
|
||||
}
|
||||
|
||||
// Compute AND of conditions
|
||||
result := true
|
||||
for _, ex := range e.Condition {
|
||||
res, err := ex.evalNode(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b, ok := res.ToBool()
|
||||
if !ok {
|
||||
return nil, errExpectedBool
|
||||
}
|
||||
result = result && b
|
||||
}
|
||||
return FromBool(result), nil
|
||||
}
|
||||
|
||||
func (e *Condition) evalNode(r Record) (*Value, error) {
|
||||
if e.Operand != nil {
|
||||
// In this case, result does not have to be boolean
|
||||
return e.Operand.evalNode(r)
|
||||
}
|
||||
|
||||
// Compute NOT of condition
|
||||
res, err := e.Not.evalNode(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b, ok := res.ToBool()
|
||||
if !ok {
|
||||
return nil, errExpectedBool
|
||||
}
|
||||
return FromBool(!b), nil
|
||||
}
|
||||
|
||||
func (e *ConditionOperand) evalNode(r Record) (*Value, error) {
|
||||
opVal, opErr := e.Operand.evalNode(r)
|
||||
if opErr != nil || e.ConditionRHS == nil {
|
||||
return opVal, opErr
|
||||
}
|
||||
|
||||
// Need to evaluate the ConditionRHS
|
||||
switch {
|
||||
case e.ConditionRHS.Compare != nil:
|
||||
cmpRight, cmpRErr := e.ConditionRHS.Compare.Operand.evalNode(r)
|
||||
if cmpRErr != nil {
|
||||
return nil, cmpRErr
|
||||
}
|
||||
|
||||
b, err := opVal.compareOp(e.ConditionRHS.Compare.Operator, cmpRight)
|
||||
return FromBool(b), err
|
||||
|
||||
case e.ConditionRHS.Between != nil:
|
||||
return e.ConditionRHS.Between.evalBetweenNode(r, opVal)
|
||||
|
||||
case e.ConditionRHS.Like != nil:
|
||||
return e.ConditionRHS.Like.evalLikeNode(r, opVal)
|
||||
|
||||
case e.ConditionRHS.In != nil:
|
||||
return e.ConditionRHS.In.evalInNode(r, opVal)
|
||||
|
||||
default:
|
||||
return nil, errInvalidASTNode
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Between) evalBetweenNode(r Record, arg *Value) (*Value, error) {
|
||||
stVal, stErr := e.Start.evalNode(r)
|
||||
if stErr != nil {
|
||||
return nil, stErr
|
||||
}
|
||||
|
||||
endVal, endErr := e.End.evalNode(r)
|
||||
if endErr != nil {
|
||||
return nil, endErr
|
||||
}
|
||||
|
||||
part1, err1 := stVal.compareOp(opLte, arg)
|
||||
if err1 != nil {
|
||||
return nil, err1
|
||||
}
|
||||
|
||||
part2, err2 := arg.compareOp(opLte, endVal)
|
||||
if err2 != nil {
|
||||
return nil, err2
|
||||
}
|
||||
|
||||
result := part1 && part2
|
||||
if e.Not {
|
||||
result = !result
|
||||
}
|
||||
|
||||
return FromBool(result), nil
|
||||
}
|
||||
|
||||
func (e *Like) evalLikeNode(r Record, arg *Value) (*Value, error) {
|
||||
inferTypeAsString(arg)
|
||||
|
||||
s, ok := arg.ToString()
|
||||
if !ok {
|
||||
err := errLikeNonStrArg
|
||||
return nil, errLikeInvalidInputs(err)
|
||||
}
|
||||
|
||||
pattern, err1 := e.Pattern.evalNode(r)
|
||||
if err1 != nil {
|
||||
return nil, err1
|
||||
}
|
||||
|
||||
// Infer pattern as string (in case it is untyped)
|
||||
inferTypeAsString(pattern)
|
||||
|
||||
patternStr, ok := pattern.ToString()
|
||||
if !ok {
|
||||
err := errLikeNonStrArg
|
||||
return nil, errLikeInvalidInputs(err)
|
||||
}
|
||||
|
||||
escape := runeZero
|
||||
if e.EscapeChar != nil {
|
||||
escapeVal, err2 := e.EscapeChar.evalNode(r)
|
||||
if err2 != nil {
|
||||
return nil, err2
|
||||
}
|
||||
|
||||
inferTypeAsString(escapeVal)
|
||||
|
||||
escapeStr, ok := escapeVal.ToString()
|
||||
if !ok {
|
||||
err := errLikeNonStrArg
|
||||
return nil, errLikeInvalidInputs(err)
|
||||
}
|
||||
|
||||
if len([]rune(escapeStr)) > 1 {
|
||||
err := errLikeInvalidEscape
|
||||
return nil, errLikeInvalidInputs(err)
|
||||
}
|
||||
}
|
||||
|
||||
matchResult, err := evalSQLLike(s, patternStr, escape)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if e.Not {
|
||||
matchResult = !matchResult
|
||||
}
|
||||
|
||||
return FromBool(matchResult), nil
|
||||
}
|
||||
|
||||
func (e *In) evalInNode(r Record, arg *Value) (*Value, error) {
|
||||
result := false
|
||||
for _, elt := range e.Expressions {
|
||||
eltVal, err := elt.evalNode(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// FIXME: type inference?
|
||||
|
||||
// Types must match.
|
||||
if arg.vType != eltVal.vType {
|
||||
// match failed.
|
||||
continue
|
||||
}
|
||||
|
||||
if arg.value == eltVal.value {
|
||||
result = true
|
||||
break
|
||||
}
|
||||
}
|
||||
return FromBool(result), nil
|
||||
}
|
||||
|
||||
func (e *Operand) evalNode(r Record) (*Value, error) {
|
||||
lval, lerr := e.Left.evalNode(r)
|
||||
if lerr != nil || len(e.Right) == 0 {
|
||||
return lval, lerr
|
||||
}
|
||||
|
||||
// Process remaining child nodes - result must be
|
||||
// numeric. This AST node is for terms separated by + or -
|
||||
// symbols.
|
||||
for _, rightTerm := range e.Right {
|
||||
op := rightTerm.Op
|
||||
rval, rerr := rightTerm.Right.evalNode(r)
|
||||
if rerr != nil {
|
||||
return nil, rerr
|
||||
}
|
||||
err := lval.arithOp(op, rval)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return lval, nil
|
||||
}
|
||||
|
||||
func (e *MultOp) evalNode(r Record) (*Value, error) {
|
||||
lval, lerr := e.Left.evalNode(r)
|
||||
if lerr != nil || len(e.Right) == 0 {
|
||||
return lval, lerr
|
||||
}
|
||||
|
||||
// Process other child nodes - result must be numeric. This
|
||||
// AST node is for terms separated by *, / or % symbols.
|
||||
for _, rightTerm := range e.Right {
|
||||
op := rightTerm.Op
|
||||
rval, rerr := rightTerm.Right.evalNode(r)
|
||||
if rerr != nil {
|
||||
return nil, rerr
|
||||
}
|
||||
|
||||
err := lval.arithOp(op, rval)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return lval, nil
|
||||
}
|
||||
|
||||
func (e *UnaryTerm) evalNode(r Record) (*Value, error) {
|
||||
if e.Negated == nil {
|
||||
return e.Primary.evalNode(r)
|
||||
}
|
||||
|
||||
v, err := e.Negated.Term.evalNode(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
inferTypeForArithOp(v)
|
||||
if ival, ok := v.ToInt(); ok {
|
||||
return FromInt(-ival), nil
|
||||
} else if fval, ok := v.ToFloat(); ok {
|
||||
return FromFloat(-fval), nil
|
||||
}
|
||||
return nil, errArithMismatchedTypes
|
||||
}
|
||||
|
||||
func (e *JSONPath) evalNode(r Record) (*Value, error) {
|
||||
// Strip the table name from the keypath.
|
||||
keypath := e.String()
|
||||
ps := strings.SplitN(keypath, ".", 2)
|
||||
if len(ps) == 2 {
|
||||
keypath = ps[1]
|
||||
}
|
||||
return r.Get(keypath)
|
||||
}
|
||||
|
||||
func (e *PrimaryTerm) evalNode(r Record) (res *Value, err error) {
|
||||
switch {
|
||||
case e.Value != nil:
|
||||
return e.Value.evalNode(r)
|
||||
case e.JPathExpr != nil:
|
||||
return e.JPathExpr.evalNode(r)
|
||||
case e.SubExpression != nil:
|
||||
return e.SubExpression.evalNode(r)
|
||||
case e.FuncCall != nil:
|
||||
return e.FuncCall.evalNode(r)
|
||||
}
|
||||
return nil, errInvalidASTNode
|
||||
}
|
||||
|
||||
func (e *FuncExpr) evalNode(r Record) (res *Value, err error) {
|
||||
switch e.getFunctionName() {
|
||||
case aggFnCount, aggFnAvg, aggFnMax, aggFnMin, aggFnSum:
|
||||
return e.getAggregate()
|
||||
default:
|
||||
return e.evalSQLFnNode(r)
|
||||
}
|
||||
}
|
||||
|
||||
// evalNode on a literal value is independent of the node being an
|
||||
// aggregation or a row function - it always returns a value.
|
||||
func (e *LitValue) evalNode(_ Record) (res *Value, err error) {
|
||||
switch {
|
||||
case e.Number != nil:
|
||||
return floatToValue(*e.Number), nil
|
||||
case e.String != nil:
|
||||
return FromString(string(*e.String)), nil
|
||||
case e.Boolean != nil:
|
||||
return FromBool(bool(*e.Boolean)), nil
|
||||
}
|
||||
return FromNull(), nil
|
||||
}
|
||||
@@ -1,160 +0,0 @@
|
||||
/*
|
||||
* Minio Cloud Storage, (C) 2019 Minio, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Expr - a SQL expression type.
|
||||
type Expr interface {
|
||||
AggregateValue() (*Value, error)
|
||||
Eval(record Record) (*Value, error)
|
||||
ReturnType() Type
|
||||
Type() Type
|
||||
}
|
||||
|
||||
// aliasExpr - aliases expression by alias.
|
||||
type aliasExpr struct {
|
||||
alias string
|
||||
expr Expr
|
||||
}
|
||||
|
||||
// String - returns string representation of this expression.
|
||||
func (expr *aliasExpr) String() string {
|
||||
return fmt.Sprintf("(%v AS %v)", expr.expr, expr.alias)
|
||||
}
|
||||
|
||||
// Eval - evaluates underlaying expression for given record and returns evaluated result.
|
||||
func (expr *aliasExpr) Eval(record Record) (*Value, error) {
|
||||
return expr.expr.Eval(record)
|
||||
}
|
||||
|
||||
// AggregateValue - returns aggregated value from underlaying expression.
|
||||
func (expr *aliasExpr) AggregateValue() (*Value, error) {
|
||||
return expr.expr.AggregateValue()
|
||||
}
|
||||
|
||||
// Type - returns underlaying expression type.
|
||||
func (expr *aliasExpr) Type() Type {
|
||||
return expr.expr.Type()
|
||||
}
|
||||
|
||||
// ReturnType - returns underlaying expression's return type.
|
||||
func (expr *aliasExpr) ReturnType() Type {
|
||||
return expr.expr.ReturnType()
|
||||
}
|
||||
|
||||
// newAliasExpr - creates new alias expression.
|
||||
func newAliasExpr(alias string, expr Expr) *aliasExpr {
|
||||
return &aliasExpr{alias, expr}
|
||||
}
|
||||
|
||||
// starExpr - asterisk (*) expression.
|
||||
type starExpr struct {
|
||||
}
|
||||
|
||||
// String - returns string representation of this expression.
|
||||
func (expr *starExpr) String() string {
|
||||
return "*"
|
||||
}
|
||||
|
||||
// Eval - returns given args as map value.
|
||||
func (expr *starExpr) Eval(record Record) (*Value, error) {
|
||||
return newRecordValue(record), nil
|
||||
}
|
||||
|
||||
// AggregateValue - returns nil value.
|
||||
func (expr *starExpr) AggregateValue() (*Value, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Type - returns record type.
|
||||
func (expr *starExpr) Type() Type {
|
||||
return record
|
||||
}
|
||||
|
||||
// ReturnType - returns record as return type.
|
||||
func (expr *starExpr) ReturnType() Type {
|
||||
return record
|
||||
}
|
||||
|
||||
// newStarExpr - returns new asterisk (*) expression.
|
||||
func newStarExpr() *starExpr {
|
||||
return &starExpr{}
|
||||
}
|
||||
|
||||
type valueExpr struct {
|
||||
value *Value
|
||||
}
|
||||
|
||||
func (expr *valueExpr) String() string {
|
||||
return expr.value.String()
|
||||
}
|
||||
|
||||
func (expr *valueExpr) Eval(record Record) (*Value, error) {
|
||||
return expr.value, nil
|
||||
}
|
||||
|
||||
func (expr *valueExpr) AggregateValue() (*Value, error) {
|
||||
return expr.value, nil
|
||||
}
|
||||
|
||||
func (expr *valueExpr) Type() Type {
|
||||
return expr.value.Type()
|
||||
}
|
||||
|
||||
func (expr *valueExpr) ReturnType() Type {
|
||||
return expr.value.Type()
|
||||
}
|
||||
|
||||
func newValueExpr(value *Value) *valueExpr {
|
||||
return &valueExpr{value: value}
|
||||
}
|
||||
|
||||
type columnExpr struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func (expr *columnExpr) String() string {
|
||||
return expr.name
|
||||
}
|
||||
|
||||
func (expr *columnExpr) Eval(record Record) (*Value, error) {
|
||||
value, err := record.Get(expr.name)
|
||||
if err != nil {
|
||||
return nil, errEvaluatorBindingDoesNotExist(err)
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func (expr *columnExpr) AggregateValue() (*Value, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (expr *columnExpr) Type() Type {
|
||||
return column
|
||||
}
|
||||
|
||||
func (expr *columnExpr) ReturnType() Type {
|
||||
return column
|
||||
}
|
||||
|
||||
func newColumnExpr(columnName string) *columnExpr {
|
||||
return &columnExpr{name: columnName}
|
||||
}
|
||||
433
pkg/s3select/sql/funceval.go
Normal file
433
pkg/s3select/sql/funceval.go
Normal file
@@ -0,0 +1,433 @@
|
||||
/*
|
||||
* Minio Cloud Storage, (C) 2019 Minio, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// FuncName - SQL function name.
|
||||
type FuncName string
|
||||
|
||||
// SQL Function name constants
|
||||
const (
|
||||
// Conditionals
|
||||
sqlFnCoalesce FuncName = "COALESCE"
|
||||
sqlFnNullIf FuncName = "NULLIF"
|
||||
|
||||
// Conversion
|
||||
sqlFnCast FuncName = "CAST"
|
||||
|
||||
// Date and time
|
||||
sqlFnDateAdd FuncName = "DATE_ADD"
|
||||
sqlFnDateDiff FuncName = "DATE_DIFF"
|
||||
sqlFnExtract FuncName = "EXTRACT"
|
||||
sqlFnToString FuncName = "TO_STRING"
|
||||
sqlFnToTimestamp FuncName = "TO_TIMESTAMP"
|
||||
sqlFnUTCNow FuncName = "UTCNOW"
|
||||
|
||||
// String
|
||||
sqlFnCharLength FuncName = "CHAR_LENGTH"
|
||||
sqlFnCharacterLength FuncName = "CHARACTER_LENGTH"
|
||||
sqlFnLower FuncName = "LOWER"
|
||||
sqlFnSubstring FuncName = "SUBSTRING"
|
||||
sqlFnTrim FuncName = "TRIM"
|
||||
sqlFnUpper FuncName = "UPPER"
|
||||
)
|
||||
|
||||
// Allowed cast types
|
||||
const (
|
||||
castBool = "BOOL"
|
||||
castInt = "INT"
|
||||
castInteger = "INTEGER"
|
||||
castString = "STRING"
|
||||
castFloat = "FLOAT"
|
||||
castDecimal = "DECIMAL"
|
||||
castNumeric = "NUMERIC"
|
||||
castTimestamp = "TIMESTAMP"
|
||||
)
|
||||
|
||||
var (
|
||||
errUnimplementedCast = errors.New("This cast not yet implemented")
|
||||
errNonStringTrimArg = errors.New("TRIM() received a non-string argument")
|
||||
)
|
||||
|
||||
func (e *FuncExpr) getFunctionName() FuncName {
|
||||
switch {
|
||||
case e.SFunc != nil:
|
||||
return FuncName(strings.ToUpper(e.SFunc.FunctionName))
|
||||
case e.Count != nil:
|
||||
return FuncName(aggFnCount)
|
||||
case e.Cast != nil:
|
||||
return sqlFnCast
|
||||
case e.Substring != nil:
|
||||
return sqlFnSubstring
|
||||
case e.Extract != nil:
|
||||
return sqlFnExtract
|
||||
case e.Trim != nil:
|
||||
return sqlFnTrim
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// evalSQLFnNode assumes that the FuncExpr is not an aggregation
|
||||
// function.
|
||||
func (e *FuncExpr) evalSQLFnNode(r Record) (res *Value, err error) {
|
||||
// Handle functions that have phrase arguments
|
||||
switch e.getFunctionName() {
|
||||
case sqlFnCast:
|
||||
expr := e.Cast.Expr
|
||||
res, err = expr.castTo(r, strings.ToUpper(e.Cast.CastType))
|
||||
return
|
||||
|
||||
case sqlFnSubstring:
|
||||
return handleSQLSubstring(r, e.Substring)
|
||||
|
||||
case sqlFnExtract:
|
||||
return nil, errNotImplemented
|
||||
|
||||
case sqlFnTrim:
|
||||
return handleSQLTrim(r, e.Trim)
|
||||
}
|
||||
|
||||
// For all simple argument functions, we evaluate the arguments here
|
||||
argVals := make([]*Value, len(e.SFunc.ArgsList))
|
||||
for i, arg := range e.SFunc.ArgsList {
|
||||
argVals[i], err = arg.evalNode(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
switch e.getFunctionName() {
|
||||
case sqlFnCoalesce:
|
||||
return coalesce(r, argVals)
|
||||
|
||||
case sqlFnNullIf:
|
||||
return nullif(r, argVals[0], argVals[1])
|
||||
|
||||
case sqlFnCharLength, sqlFnCharacterLength:
|
||||
return charlen(r, argVals[0])
|
||||
|
||||
case sqlFnLower:
|
||||
return lowerCase(r, argVals[0])
|
||||
|
||||
case sqlFnUpper:
|
||||
return upperCase(r, argVals[0])
|
||||
|
||||
case sqlFnDateAdd, sqlFnDateDiff, sqlFnToString, sqlFnToTimestamp, sqlFnUTCNow:
|
||||
// TODO: implement
|
||||
fallthrough
|
||||
|
||||
default:
|
||||
return nil, errInvalidASTNode
|
||||
}
|
||||
}
|
||||
|
||||
func coalesce(r Record, args []*Value) (res *Value, err error) {
|
||||
for _, arg := range args {
|
||||
if arg.IsNull() {
|
||||
continue
|
||||
}
|
||||
return arg, nil
|
||||
}
|
||||
return FromNull(), nil
|
||||
}
|
||||
|
||||
func nullif(r Record, v1, v2 *Value) (res *Value, err error) {
|
||||
// Handle Null cases
|
||||
if v1.IsNull() || v2.IsNull() {
|
||||
return v1, nil
|
||||
}
|
||||
|
||||
err = inferTypesForCmp(v1, v2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
atleastOneNumeric := v1.isNumeric() || v2.isNumeric()
|
||||
bothNumeric := v1.isNumeric() && v2.isNumeric()
|
||||
if atleastOneNumeric || !bothNumeric {
|
||||
return v1, nil
|
||||
}
|
||||
|
||||
if v1.vType != v2.vType {
|
||||
return v1, nil
|
||||
}
|
||||
|
||||
cmpResult, cmpErr := v1.compareOp(opEq, v2)
|
||||
if cmpErr != nil {
|
||||
return nil, cmpErr
|
||||
}
|
||||
|
||||
if cmpResult {
|
||||
return FromNull(), nil
|
||||
}
|
||||
|
||||
return v1, nil
|
||||
}
|
||||
|
||||
func charlen(r Record, v *Value) (*Value, error) {
|
||||
inferTypeAsString(v)
|
||||
s, ok := v.ToString()
|
||||
if !ok {
|
||||
err := fmt.Errorf("%s/%s expects a string argument", sqlFnCharLength, sqlFnCharacterLength)
|
||||
return nil, errIncorrectSQLFunctionArgumentType(err)
|
||||
}
|
||||
return FromInt(int64(len(s))), nil
|
||||
}
|
||||
|
||||
func lowerCase(r Record, v *Value) (*Value, error) {
|
||||
inferTypeAsString(v)
|
||||
s, ok := v.ToString()
|
||||
if !ok {
|
||||
err := fmt.Errorf("%s expects a string argument", sqlFnLower)
|
||||
return nil, errIncorrectSQLFunctionArgumentType(err)
|
||||
}
|
||||
return FromString(strings.ToLower(s)), nil
|
||||
}
|
||||
|
||||
func upperCase(r Record, v *Value) (*Value, error) {
|
||||
inferTypeAsString(v)
|
||||
s, ok := v.ToString()
|
||||
if !ok {
|
||||
err := fmt.Errorf("%s expects a string argument", sqlFnUpper)
|
||||
return nil, errIncorrectSQLFunctionArgumentType(err)
|
||||
}
|
||||
return FromString(strings.ToUpper(s)), nil
|
||||
}
|
||||
|
||||
func handleSQLSubstring(r Record, e *SubstringFunc) (val *Value, err error) {
|
||||
// Both forms `SUBSTRING('abc' FROM 2 FOR 1)` and
|
||||
// SUBSTRING('abc', 2, 1) are supported.
|
||||
|
||||
// Evaluate the string argument
|
||||
v1, err := e.Expr.evalNode(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inferTypeAsString(v1)
|
||||
s, ok := v1.ToString()
|
||||
if !ok {
|
||||
err := fmt.Errorf("Incorrect argument type passed to %s", sqlFnSubstring)
|
||||
return nil, errIncorrectSQLFunctionArgumentType(err)
|
||||
}
|
||||
|
||||
// Assemble other arguments
|
||||
arg2, arg3 := e.From, e.For
|
||||
// Check if the second form of substring is being used
|
||||
if e.From == nil {
|
||||
arg2, arg3 = e.Arg2, e.Arg3
|
||||
}
|
||||
|
||||
// Evaluate the FROM argument
|
||||
v2, err := arg2.evalNode(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inferTypeForArithOp(v2)
|
||||
startIdx, ok := v2.ToInt()
|
||||
if !ok {
|
||||
err := fmt.Errorf("Incorrect type for start index argument in %s", sqlFnSubstring)
|
||||
return nil, errIncorrectSQLFunctionArgumentType(err)
|
||||
}
|
||||
|
||||
length := -1
|
||||
// Evaluate the optional FOR argument
|
||||
if arg3 != nil {
|
||||
v3, err := arg3.evalNode(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inferTypeForArithOp(v3)
|
||||
lenInt, ok := v3.ToInt()
|
||||
if !ok {
|
||||
err := fmt.Errorf("Incorrect type for length argument in %s", sqlFnSubstring)
|
||||
return nil, errIncorrectSQLFunctionArgumentType(err)
|
||||
}
|
||||
length = int(lenInt)
|
||||
if length < 0 {
|
||||
err := fmt.Errorf("Negative length argument in %s", sqlFnSubstring)
|
||||
return nil, errIncorrectSQLFunctionArgumentType(err)
|
||||
}
|
||||
}
|
||||
|
||||
res, err := evalSQLSubstring(s, int(startIdx), length)
|
||||
return FromString(res), err
|
||||
}
|
||||
|
||||
func handleSQLTrim(r Record, e *TrimFunc) (res *Value, err error) {
|
||||
charsV, cerr := e.TrimChars.evalNode(r)
|
||||
if cerr != nil {
|
||||
return nil, cerr
|
||||
}
|
||||
inferTypeAsString(charsV)
|
||||
chars, ok := charsV.ToString()
|
||||
if !ok {
|
||||
return nil, errNonStringTrimArg
|
||||
}
|
||||
|
||||
fromV, ferr := e.TrimFrom.evalNode(r)
|
||||
if ferr != nil {
|
||||
return nil, ferr
|
||||
}
|
||||
from, ok := fromV.ToString()
|
||||
if !ok {
|
||||
return nil, errNonStringTrimArg
|
||||
}
|
||||
|
||||
result, terr := evalSQLTrim(e.TrimWhere, chars, from)
|
||||
if terr != nil {
|
||||
return nil, terr
|
||||
}
|
||||
return FromString(result), nil
|
||||
}
|
||||
|
||||
func errUnsupportedCast(fromType, toType string) error {
|
||||
return fmt.Errorf("Cannot cast from %v to %v", fromType, toType)
|
||||
}
|
||||
|
||||
func errCastFailure(msg string) error {
|
||||
return fmt.Errorf("Error casting: %s", msg)
|
||||
}
|
||||
|
||||
func (e *Expression) castTo(r Record, castType string) (res *Value, err error) {
|
||||
v, err := e.evalNode(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fmt.Println("Cast to ", castType)
|
||||
|
||||
switch castType {
|
||||
case castInt, castInteger:
|
||||
i, err := intCast(v)
|
||||
return FromInt(i), err
|
||||
|
||||
case castFloat:
|
||||
f, err := floatCast(v)
|
||||
return FromFloat(f), err
|
||||
|
||||
case castString:
|
||||
s, err := stringCast(v)
|
||||
return FromString(s), err
|
||||
|
||||
case castBool, castDecimal, castNumeric, castTimestamp:
|
||||
fallthrough
|
||||
|
||||
default:
|
||||
return nil, errUnimplementedCast
|
||||
}
|
||||
}
|
||||
|
||||
func intCast(v *Value) (int64, error) {
|
||||
// This conversion truncates floating point numbers to
|
||||
// integer.
|
||||
strToInt := func(s string) (int64, bool) {
|
||||
i, errI := strconv.ParseInt(s, 10, 64)
|
||||
if errI == nil {
|
||||
return i, true
|
||||
}
|
||||
f, errF := strconv.ParseFloat(s, 64)
|
||||
if errF == nil {
|
||||
return int64(f), true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
switch v.vType {
|
||||
case typeFloat:
|
||||
// Truncate fractional part
|
||||
return int64(v.value.(float64)), nil
|
||||
case typeInt:
|
||||
return v.value.(int64), nil
|
||||
case typeString:
|
||||
// Parse as number, truncate floating point if
|
||||
// needed.
|
||||
s, _ := v.ToString()
|
||||
res, ok := strToInt(s)
|
||||
if !ok {
|
||||
return 0, errCastFailure("could not parse as int")
|
||||
}
|
||||
return res, nil
|
||||
case typeBytes:
|
||||
// Parse as number, truncate floating point if
|
||||
// needed.
|
||||
b, _ := v.ToBytes()
|
||||
s := string(b)
|
||||
res, ok := strToInt(s)
|
||||
if !ok {
|
||||
return 0, errCastFailure("could not parse as int")
|
||||
}
|
||||
return res, nil
|
||||
|
||||
default:
|
||||
return 0, errUnsupportedCast(v.GetTypeString(), castInt)
|
||||
}
|
||||
}
|
||||
|
||||
func floatCast(v *Value) (float64, error) {
|
||||
switch v.vType {
|
||||
case typeFloat:
|
||||
return v.value.(float64), nil
|
||||
case typeInt:
|
||||
return float64(v.value.(int64)), nil
|
||||
case typeString:
|
||||
f, err := strconv.ParseFloat(v.value.(string), 64)
|
||||
if err != nil {
|
||||
return 0, errCastFailure("could not parse as float")
|
||||
}
|
||||
return f, nil
|
||||
case typeBytes:
|
||||
b, _ := v.ToBytes()
|
||||
f, err := strconv.ParseFloat(string(b), 64)
|
||||
if err != nil {
|
||||
return 0, errCastFailure("could not parse as float")
|
||||
}
|
||||
return f, nil
|
||||
default:
|
||||
return 0, errUnsupportedCast(v.GetTypeString(), castFloat)
|
||||
}
|
||||
}
|
||||
|
||||
func stringCast(v *Value) (string, error) {
|
||||
switch v.vType {
|
||||
case typeFloat:
|
||||
f, _ := v.ToFloat()
|
||||
return fmt.Sprintf("%v", f), nil
|
||||
case typeInt:
|
||||
i, _ := v.ToInt()
|
||||
return fmt.Sprintf("%v", i), nil
|
||||
case typeString:
|
||||
s, _ := v.ToString()
|
||||
return s, nil
|
||||
case typeBytes:
|
||||
b, _ := v.ToBytes()
|
||||
return string(b), nil
|
||||
case typeBool:
|
||||
b, _ := v.ToBool()
|
||||
return fmt.Sprintf("%v", b), nil
|
||||
case typeNull:
|
||||
// FIXME: verify this case is correct
|
||||
return fmt.Sprintf("NULL"), nil
|
||||
}
|
||||
// This does not happen
|
||||
return "", nil
|
||||
}
|
||||
@@ -1,550 +0,0 @@
|
||||
/*
|
||||
* Minio Cloud Storage, (C) 2019 Minio, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// FuncName - SQL function name.
|
||||
type FuncName string
|
||||
|
||||
const (
|
||||
// Avg - aggregate SQL function AVG().
|
||||
Avg FuncName = "AVG"
|
||||
|
||||
// Count - aggregate SQL function COUNT().
|
||||
Count FuncName = "COUNT"
|
||||
|
||||
// Max - aggregate SQL function MAX().
|
||||
Max FuncName = "MAX"
|
||||
|
||||
// Min - aggregate SQL function MIN().
|
||||
Min FuncName = "MIN"
|
||||
|
||||
// Sum - aggregate SQL function SUM().
|
||||
Sum FuncName = "SUM"
|
||||
|
||||
// Coalesce - conditional SQL function COALESCE().
|
||||
Coalesce FuncName = "COALESCE"
|
||||
|
||||
// NullIf - conditional SQL function NULLIF().
|
||||
NullIf FuncName = "NULLIF"
|
||||
|
||||
// ToTimestamp - conversion SQL function TO_TIMESTAMP().
|
||||
ToTimestamp FuncName = "TO_TIMESTAMP"
|
||||
|
||||
// UTCNow - date SQL function UTCNOW().
|
||||
UTCNow FuncName = "UTCNOW"
|
||||
|
||||
// CharLength - string SQL function CHAR_LENGTH().
|
||||
CharLength FuncName = "CHAR_LENGTH"
|
||||
|
||||
// CharacterLength - string SQL function CHARACTER_LENGTH() same as CHAR_LENGTH().
|
||||
CharacterLength FuncName = "CHARACTER_LENGTH"
|
||||
|
||||
// Lower - string SQL function LOWER().
|
||||
Lower FuncName = "LOWER"
|
||||
|
||||
// Substring - string SQL function SUBSTRING().
|
||||
Substring FuncName = "SUBSTRING"
|
||||
|
||||
// Trim - string SQL function TRIM().
|
||||
Trim FuncName = "TRIM"
|
||||
|
||||
// Upper - string SQL function UPPER().
|
||||
Upper FuncName = "UPPER"
|
||||
|
||||
// DateAdd FuncName = "DATE_ADD"
|
||||
// DateDiff FuncName = "DATE_DIFF"
|
||||
// Extract FuncName = "EXTRACT"
|
||||
// ToString FuncName = "TO_STRING"
|
||||
// Cast FuncName = "CAST" // CAST('2007-04-05T14:30Z' AS TIMESTAMP)
|
||||
)
|
||||
|
||||
func isAggregateFuncName(s string) bool {
|
||||
switch FuncName(s) {
|
||||
case Avg, Count, Max, Min, Sum:
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func callForNumber(f Expr, record Record) (*Value, error) {
|
||||
value, err := f.Eval(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !value.Type().isNumber() {
|
||||
err := fmt.Errorf("%v evaluated to %v; not to number", f, value.Type())
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func callForInt(f Expr, record Record) (*Value, error) {
|
||||
value, err := f.Eval(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if value.Type() != Int {
|
||||
err := fmt.Errorf("%v evaluated to %v; not to int", f, value.Type())
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func callForString(f Expr, record Record) (*Value, error) {
|
||||
value, err := f.Eval(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if value.Type() != String {
|
||||
err := fmt.Errorf("%v evaluated to %v; not to string", f, value.Type())
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// funcExpr - SQL function.
|
||||
type funcExpr struct {
|
||||
args []Expr
|
||||
name FuncName
|
||||
|
||||
sumValue float64
|
||||
countValue int64
|
||||
maxValue float64
|
||||
minValue float64
|
||||
}
|
||||
|
||||
// String - returns string representation of this function.
|
||||
func (f *funcExpr) String() string {
|
||||
var argStrings []string
|
||||
for _, arg := range f.args {
|
||||
argStrings = append(argStrings, fmt.Sprintf("%v", arg))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%v(%v)", f.name, strings.Join(argStrings, ","))
|
||||
}
|
||||
|
||||
func (f *funcExpr) sum(record Record) (*Value, error) {
|
||||
value, err := callForNumber(f.args[0], record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f.sumValue += value.FloatValue()
|
||||
f.countValue++
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *funcExpr) count(record Record) (*Value, error) {
|
||||
value, err := f.args[0].Eval(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if value.valueType != Null {
|
||||
f.countValue++
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *funcExpr) max(record Record) (*Value, error) {
|
||||
value, err := callForNumber(f.args[0], record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
v := value.FloatValue()
|
||||
if v > f.maxValue {
|
||||
f.maxValue = v
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *funcExpr) min(record Record) (*Value, error) {
|
||||
value, err := callForNumber(f.args[0], record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
v := value.FloatValue()
|
||||
if v < f.minValue {
|
||||
f.minValue = v
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *funcExpr) charLength(record Record) (*Value, error) {
|
||||
value, err := callForString(f.args[0], record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewInt(int64(len(value.StringValue()))), nil
|
||||
}
|
||||
|
||||
func (f *funcExpr) trim(record Record) (*Value, error) {
|
||||
value, err := callForString(f.args[0], record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewString(strings.TrimSpace(value.StringValue())), nil
|
||||
}
|
||||
|
||||
func (f *funcExpr) lower(record Record) (*Value, error) {
|
||||
value, err := callForString(f.args[0], record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewString(strings.ToLower(value.StringValue())), nil
|
||||
}
|
||||
|
||||
func (f *funcExpr) upper(record Record) (*Value, error) {
|
||||
value, err := callForString(f.args[0], record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewString(strings.ToUpper(value.StringValue())), nil
|
||||
}
|
||||
|
||||
func (f *funcExpr) substring(record Record) (*Value, error) {
|
||||
stringValue, err := callForString(f.args[0], record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
offsetValue, err := callForInt(f.args[1], record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var lengthValue *Value
|
||||
if len(f.args) == 3 {
|
||||
lengthValue, err = callForInt(f.args[2], record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
value := stringValue.StringValue()
|
||||
offset := int(offsetValue.FloatValue())
|
||||
if offset < 0 || offset > len(value) {
|
||||
offset = 0
|
||||
}
|
||||
length := len(value)
|
||||
if lengthValue != nil {
|
||||
length = int(lengthValue.FloatValue())
|
||||
if length < 0 || length > len(value) {
|
||||
length = len(value)
|
||||
}
|
||||
}
|
||||
|
||||
return NewString(value[offset:length]), nil
|
||||
}
|
||||
|
||||
func (f *funcExpr) coalesce(record Record) (*Value, error) {
|
||||
values := make([]*Value, len(f.args))
|
||||
var err error
|
||||
for i := range f.args {
|
||||
values[i], err = f.args[i].Eval(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
for i := range values {
|
||||
if values[i].Type() != Null {
|
||||
return values[i], nil
|
||||
}
|
||||
}
|
||||
|
||||
return values[0], nil
|
||||
}
|
||||
|
||||
func (f *funcExpr) nullIf(record Record) (*Value, error) {
|
||||
value1, err := f.args[0].Eval(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
value2, err := f.args[1].Eval(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result, err := equal(value1, value2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result {
|
||||
return NewNull(), nil
|
||||
}
|
||||
|
||||
return value1, nil
|
||||
}
|
||||
|
||||
func (f *funcExpr) toTimeStamp(record Record) (*Value, error) {
|
||||
value, err := callForString(f.args[0], record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t, err := time.Parse(time.RFC3339, value.StringValue())
|
||||
if err != nil {
|
||||
err := fmt.Errorf("%v: value '%v': %v", f, value, err)
|
||||
return nil, errValueParseFailure(err)
|
||||
}
|
||||
|
||||
return NewTime(t), nil
|
||||
}
|
||||
|
||||
func (f *funcExpr) utcNow(record Record) (*Value, error) {
|
||||
return NewTime(time.Now().UTC()), nil
|
||||
}
|
||||
|
||||
// Call - evaluates this function for given arg values and returns result as Value.
|
||||
func (f *funcExpr) Eval(record Record) (*Value, error) {
|
||||
switch f.name {
|
||||
case Avg, Sum:
|
||||
return f.sum(record)
|
||||
case Count:
|
||||
return f.count(record)
|
||||
case Max:
|
||||
return f.max(record)
|
||||
case Min:
|
||||
return f.min(record)
|
||||
case Coalesce:
|
||||
return f.coalesce(record)
|
||||
case NullIf:
|
||||
return f.nullIf(record)
|
||||
case ToTimestamp:
|
||||
return f.toTimeStamp(record)
|
||||
case UTCNow:
|
||||
return f.utcNow(record)
|
||||
case Substring:
|
||||
return f.substring(record)
|
||||
case CharLength, CharacterLength:
|
||||
return f.charLength(record)
|
||||
case Trim:
|
||||
return f.trim(record)
|
||||
case Lower:
|
||||
return f.lower(record)
|
||||
case Upper:
|
||||
return f.upper(record)
|
||||
}
|
||||
|
||||
panic(fmt.Sprintf("unsupported aggregate function %v", f.name))
|
||||
}
|
||||
|
||||
// AggregateValue - returns aggregated value.
|
||||
func (f *funcExpr) AggregateValue() (*Value, error) {
|
||||
switch f.name {
|
||||
case Avg:
|
||||
return NewFloat(f.sumValue / float64(f.countValue)), nil
|
||||
case Count:
|
||||
return NewInt(f.countValue), nil
|
||||
case Max:
|
||||
return NewFloat(f.maxValue), nil
|
||||
case Min:
|
||||
return NewFloat(f.minValue), nil
|
||||
case Sum:
|
||||
return NewFloat(f.sumValue), nil
|
||||
}
|
||||
|
||||
err := fmt.Errorf("%v is not aggreate function", f)
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
// Type - returns Function or aggregateFunction type.
|
||||
func (f *funcExpr) Type() Type {
|
||||
switch f.name {
|
||||
case Avg, Count, Max, Min, Sum:
|
||||
return aggregateFunction
|
||||
}
|
||||
|
||||
return function
|
||||
}
|
||||
|
||||
// ReturnType - returns respective primitive type depending on SQL function.
|
||||
func (f *funcExpr) ReturnType() Type {
|
||||
switch f.name {
|
||||
case Avg, Max, Min, Sum:
|
||||
return Float
|
||||
case Count:
|
||||
return Int
|
||||
case CharLength, CharacterLength, Trim, Lower, Upper, Substring:
|
||||
return String
|
||||
case ToTimestamp, UTCNow:
|
||||
return Timestamp
|
||||
case Coalesce, NullIf:
|
||||
return column
|
||||
}
|
||||
|
||||
return function
|
||||
}
|
||||
|
||||
// newFuncExpr - creates new SQL function.
|
||||
func newFuncExpr(funcName FuncName, funcs ...Expr) (*funcExpr, error) {
|
||||
switch funcName {
|
||||
case Avg, Max, Min, Sum:
|
||||
if len(funcs) != 1 {
|
||||
err := fmt.Errorf("%v(): exactly one argument expected; got %v", funcName, len(funcs))
|
||||
return nil, errParseNonUnaryAgregateFunctionCall(err)
|
||||
}
|
||||
|
||||
if !funcs[0].ReturnType().isNumberKind() {
|
||||
err := fmt.Errorf("%v(): argument %v evaluate to %v, not number", funcName, funcs[0], funcs[0].ReturnType())
|
||||
return nil, errIncorrectSQLFunctionArgumentType(err)
|
||||
}
|
||||
|
||||
return &funcExpr{
|
||||
args: funcs,
|
||||
name: funcName,
|
||||
}, nil
|
||||
|
||||
case Count:
|
||||
if len(funcs) != 1 {
|
||||
err := fmt.Errorf("%v(): exactly one argument expected; got %v", funcName, len(funcs))
|
||||
return nil, errParseNonUnaryAgregateFunctionCall(err)
|
||||
}
|
||||
|
||||
switch funcs[0].ReturnType() {
|
||||
case Null, Bool, Int, Float, String, Timestamp, column, record:
|
||||
default:
|
||||
err := fmt.Errorf("%v(): argument %v evaluate to %v is incompatible", funcName, funcs[0], funcs[0].ReturnType())
|
||||
return nil, errIncorrectSQLFunctionArgumentType(err)
|
||||
}
|
||||
|
||||
return &funcExpr{
|
||||
args: funcs,
|
||||
name: funcName,
|
||||
}, nil
|
||||
|
||||
case CharLength, CharacterLength, Trim, Lower, Upper, ToTimestamp:
|
||||
if len(funcs) != 1 {
|
||||
err := fmt.Errorf("%v(): exactly one argument expected; got %v", funcName, len(funcs))
|
||||
return nil, errEvaluatorInvalidArguments(err)
|
||||
}
|
||||
|
||||
if !funcs[0].ReturnType().isStringKind() {
|
||||
err := fmt.Errorf("%v(): argument %v evaluate to %v, not string", funcName, funcs[0], funcs[0].ReturnType())
|
||||
return nil, errIncorrectSQLFunctionArgumentType(err)
|
||||
}
|
||||
|
||||
return &funcExpr{
|
||||
args: funcs,
|
||||
name: funcName,
|
||||
}, nil
|
||||
|
||||
case Coalesce:
|
||||
if len(funcs) < 1 {
|
||||
err := fmt.Errorf("%v(): one or more argument expected; got %v", funcName, len(funcs))
|
||||
return nil, errEvaluatorInvalidArguments(err)
|
||||
}
|
||||
|
||||
for i := range funcs {
|
||||
if !funcs[i].ReturnType().isBaseKind() {
|
||||
err := fmt.Errorf("%v(): argument-%v %v evaluate to %v is incompatible", funcName, i+1, funcs[i], funcs[i].ReturnType())
|
||||
return nil, errIncorrectSQLFunctionArgumentType(err)
|
||||
}
|
||||
}
|
||||
|
||||
return &funcExpr{
|
||||
args: funcs,
|
||||
name: funcName,
|
||||
}, nil
|
||||
|
||||
case NullIf:
|
||||
if len(funcs) != 2 {
|
||||
err := fmt.Errorf("%v(): exactly two arguments expected; got %v", funcName, len(funcs))
|
||||
return nil, errEvaluatorInvalidArguments(err)
|
||||
}
|
||||
|
||||
if !funcs[0].ReturnType().isBaseKind() {
|
||||
err := fmt.Errorf("%v(): argument-1 %v evaluate to %v is incompatible", funcName, funcs[0], funcs[0].ReturnType())
|
||||
return nil, errIncorrectSQLFunctionArgumentType(err)
|
||||
}
|
||||
|
||||
if !funcs[1].ReturnType().isBaseKind() {
|
||||
err := fmt.Errorf("%v(): argument-2 %v evaluate to %v is incompatible", funcName, funcs[1], funcs[1].ReturnType())
|
||||
return nil, errIncorrectSQLFunctionArgumentType(err)
|
||||
}
|
||||
|
||||
return &funcExpr{
|
||||
args: funcs,
|
||||
name: funcName,
|
||||
}, nil
|
||||
|
||||
case UTCNow:
|
||||
if len(funcs) != 0 {
|
||||
err := fmt.Errorf("%v(): no argument expected; got %v", funcName, len(funcs))
|
||||
return nil, errEvaluatorInvalidArguments(err)
|
||||
}
|
||||
|
||||
return &funcExpr{
|
||||
args: funcs,
|
||||
name: funcName,
|
||||
}, nil
|
||||
|
||||
case Substring:
|
||||
if len(funcs) < 2 || len(funcs) > 3 {
|
||||
err := fmt.Errorf("%v(): exactly two or three arguments expected; got %v", funcName, len(funcs))
|
||||
return nil, errEvaluatorInvalidArguments(err)
|
||||
}
|
||||
|
||||
if !funcs[0].ReturnType().isStringKind() {
|
||||
err := fmt.Errorf("%v(): argument-1 %v evaluate to %v, not string", funcName, funcs[0], funcs[0].ReturnType())
|
||||
return nil, errIncorrectSQLFunctionArgumentType(err)
|
||||
}
|
||||
|
||||
if !funcs[1].ReturnType().isIntKind() {
|
||||
err := fmt.Errorf("%v(): argument-2 %v evaluate to %v, not int", funcName, funcs[1], funcs[1].ReturnType())
|
||||
return nil, errIncorrectSQLFunctionArgumentType(err)
|
||||
}
|
||||
|
||||
if len(funcs) > 2 {
|
||||
if !funcs[2].ReturnType().isIntKind() {
|
||||
err := fmt.Errorf("%v(): argument-3 %v evaluate to %v, not int", funcName, funcs[2], funcs[2].ReturnType())
|
||||
return nil, errIncorrectSQLFunctionArgumentType(err)
|
||||
}
|
||||
}
|
||||
|
||||
return &funcExpr{
|
||||
args: funcs,
|
||||
name: funcName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, errUnsupportedFunction(fmt.Errorf("unknown function name %v", funcName))
|
||||
}
|
||||
@@ -1,336 +0,0 @@
|
||||
/*
|
||||
* Minio Cloud Storage, (C) 2019 Minio, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package sql
|
||||
|
||||
import "fmt"
|
||||
|
||||
// andExpr - logical AND function.
|
||||
type andExpr struct {
|
||||
left Expr
|
||||
right Expr
|
||||
funcType Type
|
||||
}
|
||||
|
||||
// String - returns string representation of this function.
|
||||
func (f *andExpr) String() string {
|
||||
return fmt.Sprintf("(%v AND %v)", f.left, f.right)
|
||||
}
|
||||
|
||||
// Call - evaluates this function for given arg values and returns result as Value.
|
||||
func (f *andExpr) Eval(record Record) (*Value, error) {
|
||||
leftValue, err := f.left.Eval(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if f.funcType == aggregateFunction {
|
||||
_, err = f.right.Eval(record)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if leftValue.Type() != Bool {
|
||||
err := fmt.Errorf("%v: left side expression evaluated to %v; not to bool", f, leftValue.Type())
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
if !leftValue.BoolValue() {
|
||||
return leftValue, nil
|
||||
}
|
||||
|
||||
rightValue, err := f.right.Eval(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rightValue.Type() != Bool {
|
||||
err := fmt.Errorf("%v: right side expression evaluated to %v; not to bool", f, rightValue.Type())
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
return rightValue, nil
|
||||
}
|
||||
|
||||
// AggregateValue - returns aggregated value.
|
||||
func (f *andExpr) AggregateValue() (*Value, error) {
|
||||
if f.funcType != aggregateFunction {
|
||||
err := fmt.Errorf("%v is not aggreate expression", f)
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
leftValue, err := f.left.AggregateValue()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if leftValue.Type() != Bool {
|
||||
err := fmt.Errorf("%v: left side expression evaluated to %v; not to bool", f, leftValue.Type())
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
if !leftValue.BoolValue() {
|
||||
return leftValue, nil
|
||||
}
|
||||
|
||||
rightValue, err := f.right.AggregateValue()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rightValue.Type() != Bool {
|
||||
err := fmt.Errorf("%v: right side expression evaluated to %v; not to bool", f, rightValue.Type())
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
return rightValue, nil
|
||||
}
|
||||
|
||||
// Type - returns logicalFunction or aggregateFunction type.
|
||||
func (f *andExpr) Type() Type {
|
||||
return f.funcType
|
||||
}
|
||||
|
||||
// ReturnType - returns Bool as return type.
|
||||
func (f *andExpr) ReturnType() Type {
|
||||
return Bool
|
||||
}
|
||||
|
||||
// newAndExpr - creates new AND logical function.
|
||||
func newAndExpr(left, right Expr) (*andExpr, error) {
|
||||
if !left.ReturnType().isBoolKind() {
|
||||
err := fmt.Errorf("operator AND: left side expression %v evaluate to %v, not bool", left, left.ReturnType())
|
||||
return nil, errInvalidDataType(err)
|
||||
}
|
||||
|
||||
if !right.ReturnType().isBoolKind() {
|
||||
err := fmt.Errorf("operator AND: right side expression %v evaluate to %v; not bool", right, right.ReturnType())
|
||||
return nil, errInvalidDataType(err)
|
||||
}
|
||||
|
||||
funcType := logicalFunction
|
||||
if left.Type() == aggregateFunction || right.Type() == aggregateFunction {
|
||||
funcType = aggregateFunction
|
||||
if left.Type() == column {
|
||||
err := fmt.Errorf("operator AND: left side expression %v return type %v is incompatible for aggregate evaluation", left, left.Type())
|
||||
return nil, errUnsupportedSQLOperation(err)
|
||||
}
|
||||
|
||||
if right.Type() == column {
|
||||
err := fmt.Errorf("operator AND: right side expression %v return type %v is incompatible for aggregate evaluation", right, right.Type())
|
||||
return nil, errUnsupportedSQLOperation(err)
|
||||
}
|
||||
}
|
||||
|
||||
return &andExpr{
|
||||
left: left,
|
||||
right: right,
|
||||
funcType: funcType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// orExpr - logical OR function.
|
||||
type orExpr struct {
|
||||
left Expr
|
||||
right Expr
|
||||
funcType Type
|
||||
}
|
||||
|
||||
// String - returns string representation of this function.
|
||||
func (f *orExpr) String() string {
|
||||
return fmt.Sprintf("(%v OR %v)", f.left, f.right)
|
||||
}
|
||||
|
||||
// Call - evaluates this function for given arg values and returns result as Value.
|
||||
func (f *orExpr) Eval(record Record) (*Value, error) {
|
||||
leftValue, err := f.left.Eval(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if f.funcType == aggregateFunction {
|
||||
_, err = f.right.Eval(record)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if leftValue.Type() != Bool {
|
||||
err := fmt.Errorf("%v: left side expression evaluated to %v; not to bool", f, leftValue.Type())
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
if leftValue.BoolValue() {
|
||||
return leftValue, nil
|
||||
}
|
||||
|
||||
rightValue, err := f.right.Eval(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rightValue.Type() != Bool {
|
||||
err := fmt.Errorf("%v: right side expression evaluated to %v; not to bool", f, rightValue.Type())
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
return rightValue, nil
|
||||
}
|
||||
|
||||
// AggregateValue - returns aggregated value.
|
||||
func (f *orExpr) AggregateValue() (*Value, error) {
|
||||
if f.funcType != aggregateFunction {
|
||||
err := fmt.Errorf("%v is not aggreate expression", f)
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
leftValue, err := f.left.AggregateValue()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if leftValue.Type() != Bool {
|
||||
err := fmt.Errorf("%v: left side expression evaluated to %v; not to bool", f, leftValue.Type())
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
if leftValue.BoolValue() {
|
||||
return leftValue, nil
|
||||
}
|
||||
|
||||
rightValue, err := f.right.AggregateValue()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rightValue.Type() != Bool {
|
||||
err := fmt.Errorf("%v: right side expression evaluated to %v; not to bool", f, rightValue.Type())
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
return rightValue, nil
|
||||
}
|
||||
|
||||
// Type - returns logicalFunction or aggregateFunction type.
|
||||
func (f *orExpr) Type() Type {
|
||||
return f.funcType
|
||||
}
|
||||
|
||||
// ReturnType - returns Bool as return type.
|
||||
func (f *orExpr) ReturnType() Type {
|
||||
return Bool
|
||||
}
|
||||
|
||||
// newOrExpr - creates new OR logical function.
|
||||
func newOrExpr(left, right Expr) (*orExpr, error) {
|
||||
if !left.ReturnType().isBoolKind() {
|
||||
err := fmt.Errorf("operator OR: left side expression %v evaluate to %v, not bool", left, left.ReturnType())
|
||||
return nil, errInvalidDataType(err)
|
||||
}
|
||||
|
||||
if !right.ReturnType().isBoolKind() {
|
||||
err := fmt.Errorf("operator OR: right side expression %v evaluate to %v; not bool", right, right.ReturnType())
|
||||
return nil, errInvalidDataType(err)
|
||||
}
|
||||
|
||||
funcType := logicalFunction
|
||||
if left.Type() == aggregateFunction || right.Type() == aggregateFunction {
|
||||
funcType = aggregateFunction
|
||||
if left.Type() == column {
|
||||
err := fmt.Errorf("operator OR: left side expression %v return type %v is incompatible for aggregate evaluation", left, left.Type())
|
||||
return nil, errUnsupportedSQLOperation(err)
|
||||
}
|
||||
|
||||
if right.Type() == column {
|
||||
err := fmt.Errorf("operator OR: right side expression %v return type %v is incompatible for aggregate evaluation", right, right.Type())
|
||||
return nil, errUnsupportedSQLOperation(err)
|
||||
}
|
||||
}
|
||||
|
||||
return &orExpr{
|
||||
left: left,
|
||||
right: right,
|
||||
funcType: funcType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// notExpr - logical NOT function.
|
||||
type notExpr struct {
|
||||
right Expr
|
||||
funcType Type
|
||||
}
|
||||
|
||||
// String - returns string representation of this function.
|
||||
func (f *notExpr) String() string {
|
||||
return fmt.Sprintf("(%v)", f.right)
|
||||
}
|
||||
|
||||
// Call - evaluates this function for given arg values and returns result as Value.
|
||||
func (f *notExpr) Eval(record Record) (*Value, error) {
|
||||
rightValue, err := f.right.Eval(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if f.funcType == aggregateFunction {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if rightValue.Type() != Bool {
|
||||
err := fmt.Errorf("%v: right side expression evaluated to %v; not to bool", f, rightValue.Type())
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
return NewBool(!rightValue.BoolValue()), nil
|
||||
}
|
||||
|
||||
// AggregateValue - returns aggregated value.
|
||||
func (f *notExpr) AggregateValue() (*Value, error) {
|
||||
if f.funcType != aggregateFunction {
|
||||
err := fmt.Errorf("%v is not aggreate expression", f)
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
rightValue, err := f.right.AggregateValue()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rightValue.Type() != Bool {
|
||||
err := fmt.Errorf("%v: right side expression evaluated to %v; not to bool", f, rightValue.Type())
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
return NewBool(!rightValue.BoolValue()), nil
|
||||
}
|
||||
|
||||
// Type - returns logicalFunction or aggregateFunction type.
|
||||
func (f *notExpr) Type() Type {
|
||||
return f.funcType
|
||||
}
|
||||
|
||||
// ReturnType - returns Bool as return type.
|
||||
func (f *notExpr) ReturnType() Type {
|
||||
return Bool
|
||||
}
|
||||
|
||||
// newNotExpr - creates new NOT logical function.
|
||||
func newNotExpr(right Expr) (*notExpr, error) {
|
||||
if !right.ReturnType().isBoolKind() {
|
||||
err := fmt.Errorf("operator NOT: right side expression %v evaluate to %v; not bool", right, right.ReturnType())
|
||||
return nil, errInvalidDataType(err)
|
||||
}
|
||||
|
||||
funcType := logicalFunction
|
||||
if right.Type() == aggregateFunction {
|
||||
funcType = aggregateFunction
|
||||
}
|
||||
|
||||
return ¬Expr{
|
||||
right: right,
|
||||
funcType: funcType,
|
||||
}, nil
|
||||
}
|
||||
329
pkg/s3select/sql/parser.go
Normal file
329
pkg/s3select/sql/parser.go
Normal file
@@ -0,0 +1,329 @@
|
||||
/*
|
||||
* Minio Cloud Storage, (C) 2019 Minio, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/alecthomas/participle"
|
||||
"github.com/alecthomas/participle/lexer"
|
||||
)
|
||||
|
||||
// Types with custom Capture interface for parsing
|
||||
|
||||
// Boolean is a type for a parsed Boolean literal
|
||||
type Boolean bool
|
||||
|
||||
// Capture interface used by participle
|
||||
func (b *Boolean) Capture(values []string) error {
|
||||
*b = strings.ToLower(values[0]) == "true"
|
||||
return nil
|
||||
}
|
||||
|
||||
// LiteralString is a type for parsed SQL string literals
|
||||
type LiteralString string
|
||||
|
||||
// Capture interface used by participle
|
||||
func (ls *LiteralString) Capture(values []string) error {
|
||||
// Remove enclosing single quote
|
||||
n := len(values[0])
|
||||
r := values[0][1 : n-1]
|
||||
// Translate doubled quotes
|
||||
*ls = LiteralString(strings.Replace(r, "''", "'", -1))
|
||||
return nil
|
||||
}
|
||||
|
||||
// ObjectKey is a type for parsed strings occurring in key paths
|
||||
type ObjectKey struct {
|
||||
Lit *LiteralString `parser:" \"[\" @LitString \"]\""`
|
||||
ID *Identifier `parser:"| \".\" @@"`
|
||||
}
|
||||
|
||||
// QuotedIdentifier is a type for parsed strings that are double
|
||||
// quoted.
|
||||
type QuotedIdentifier string
|
||||
|
||||
// Capture inferface used by participle
|
||||
func (qi *QuotedIdentifier) Capture(values []string) error {
|
||||
// Remove enclosing quotes
|
||||
n := len(values[0])
|
||||
r := values[0][1 : n-1]
|
||||
|
||||
// Translate doubled quotes
|
||||
*qi = QuotedIdentifier(strings.Replace(r, `""`, `"`, -1))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Types representing AST of SQL statement. Only SELECT is supported.
|
||||
|
||||
// Select is the top level AST node type
|
||||
type Select struct {
|
||||
Expression *SelectExpression `parser:"\"SELECT\" @@"`
|
||||
From *TableExpression `parser:"\"FROM\" @@"`
|
||||
Where *Expression `parser:"[ \"WHERE\" @@ ]"`
|
||||
Limit *LitValue `parser:"[ \"LIMIT\" @@ ]"`
|
||||
}
|
||||
|
||||
// SelectExpression represents the items requested in the select
|
||||
// statement
|
||||
type SelectExpression struct {
|
||||
All bool `parser:" @\"*\""`
|
||||
Expressions []*AliasedExpression `parser:"| @@ { \",\" @@ }"`
|
||||
|
||||
prop qProp
|
||||
}
|
||||
|
||||
// TableExpression represents the FROM clause
|
||||
type TableExpression struct {
|
||||
Table *JSONPath `parser:"@@"`
|
||||
As string `parser:"( \"AS\"? @Ident )?"`
|
||||
}
|
||||
|
||||
// JSONPathElement represents a keypath component
|
||||
type JSONPathElement struct {
|
||||
Key *ObjectKey `parser:" @@"` // ['name'] and .name forms
|
||||
Index *uint64 `parser:"| \"[\" @Number \"]\""` // [3] form
|
||||
ObjectWildcard bool `parser:"| @\".*\""` // .* form
|
||||
ArrayWildcard bool `parser:"| @\"[*]\""` // [*] form
|
||||
}
|
||||
|
||||
// JSONPath represents a keypath
|
||||
type JSONPath struct {
|
||||
BaseKey *Identifier `parser:" @@"`
|
||||
PathExpr []*JSONPathElement `parser:"(@@)*"`
|
||||
}
|
||||
|
||||
// AliasedExpression is an expression that can be optionally named
|
||||
type AliasedExpression struct {
|
||||
Expression *Expression `parser:"@@"`
|
||||
As string `parser:"[ \"AS\" @Ident ]"`
|
||||
}
|
||||
|
||||
// Grammar for Expression
|
||||
//
|
||||
// Expression → AndCondition ("OR" AndCondition)*
|
||||
// AndCondition → Condition ("AND" Condition)*
|
||||
// Condition → "NOT" Condition | ConditionExpression
|
||||
// ConditionExpression → ValueExpression ("=" | "<>" | "<=" | ">=" | "<" | ">") ValueExpression
|
||||
// | ValueExpression "LIKE" ValueExpression ("ESCAPE" LitString)?
|
||||
// | ValueExpression ("NOT"? "BETWEEN" ValueExpression "AND" ValueExpression)
|
||||
// | ValueExpression "IN" "(" Expression ("," Expression)* ")"
|
||||
// | ValueExpression
|
||||
// ValueExpression → Operand
|
||||
//
|
||||
// Operand grammar follows below
|
||||
|
||||
// Expression represents a logical disjunction of clauses
|
||||
type Expression struct {
|
||||
And []*AndCondition `parser:"@@ ( \"OR\" @@ )*"`
|
||||
}
|
||||
|
||||
// AndCondition represents logical conjunction of clauses
|
||||
type AndCondition struct {
|
||||
Condition []*Condition `parser:"@@ ( \"AND\" @@ )*"`
|
||||
}
|
||||
|
||||
// Condition represents a negation or a condition operand
|
||||
type Condition struct {
|
||||
Operand *ConditionOperand `parser:" @@"`
|
||||
Not *Condition `parser:"| \"NOT\" @@"`
|
||||
}
|
||||
|
||||
// ConditionOperand is a operand followed by an an optional operation
|
||||
// expression
|
||||
type ConditionOperand struct {
|
||||
Operand *Operand `parser:"@@"`
|
||||
ConditionRHS *ConditionRHS `parser:"@@?"`
|
||||
}
|
||||
|
||||
// ConditionRHS represents the right-hand-side of Compare, Between, In
|
||||
// or Like expressions.
|
||||
type ConditionRHS struct {
|
||||
Compare *Compare `parser:" @@"`
|
||||
Between *Between `parser:"| @@"`
|
||||
In *In `parser:"| \"IN\" \"(\" @@ \")\""`
|
||||
Like *Like `parser:"| @@"`
|
||||
}
|
||||
|
||||
// Compare represents the RHS of a comparison expression
|
||||
type Compare struct {
|
||||
Operator string `parser:"@( \"<>\" | \"<=\" | \">=\" | \"=\" | \"<\" | \">\" | \"!=\" )"`
|
||||
Operand *Operand `parser:" @@"`
|
||||
}
|
||||
|
||||
// Like represents the RHS of a LIKE expression
|
||||
type Like struct {
|
||||
Not bool `parser:" @\"NOT\"? "`
|
||||
Pattern *Operand `parser:" \"LIKE\" @@ "`
|
||||
EscapeChar *Operand `parser:" (\"ESCAPE\" @@)? "`
|
||||
}
|
||||
|
||||
// Between represents the RHS of a BETWEEN expression
|
||||
type Between struct {
|
||||
Not bool `parser:" @\"NOT\"? "`
|
||||
Start *Operand `parser:" \"BETWEEN\" @@ "`
|
||||
End *Operand `parser:" \"AND\" @@ "`
|
||||
}
|
||||
|
||||
// In represents the RHS of an IN expression
|
||||
type In struct {
|
||||
Expressions []*Expression `parser:"@@ ( \",\" @@ )*"`
|
||||
}
|
||||
|
||||
// Grammar for Operand:
|
||||
//
|
||||
// operand → multOp ( ("-" | "+") multOp )*
|
||||
// multOp → unary ( ("/" | "*" | "%") unary )*
|
||||
// unary → "-" unary | primary
|
||||
// primary → Value | Variable | "(" expression ")"
|
||||
//
|
||||
|
||||
// An Operand is a single term followed by an optional sequence of
|
||||
// terms separated by +/-
|
||||
type Operand struct {
|
||||
Left *MultOp `parser:"@@"`
|
||||
Right []*OpFactor `parser:"(@@)*"`
|
||||
}
|
||||
|
||||
// OpFactor represents the right-side of a +/- operation.
|
||||
type OpFactor struct {
|
||||
Op string `parser:"@(\"+\" | \"-\")"`
|
||||
Right *MultOp `parser:"@@"`
|
||||
}
|
||||
|
||||
// MultOp represents a single term followed by an optional sequence of
|
||||
// terms separated by *, / or % operators.
|
||||
type MultOp struct {
|
||||
Left *UnaryTerm `parser:"@@"`
|
||||
Right []*OpUnaryTerm `parser:"(@@)*"`
|
||||
}
|
||||
|
||||
// OpUnaryTerm represents the right side of *, / or % binary operations.
|
||||
type OpUnaryTerm struct {
|
||||
Op string `parser:"@(\"*\" | \"/\" | \"%\")"`
|
||||
Right *UnaryTerm `parser:"@@"`
|
||||
}
|
||||
|
||||
// UnaryTerm represents a single negated term or a primary term
|
||||
type UnaryTerm struct {
|
||||
Negated *NegatedTerm `parser:" @@"`
|
||||
Primary *PrimaryTerm `parser:"| @@"`
|
||||
}
|
||||
|
||||
// NegatedTerm has a leading minus sign.
|
||||
type NegatedTerm struct {
|
||||
Term *PrimaryTerm `parser:"\"-\" @@"`
|
||||
}
|
||||
|
||||
// PrimaryTerm represents a Value, Path expression, a Sub-expression
|
||||
// or a function call.
|
||||
type PrimaryTerm struct {
|
||||
Value *LitValue `parser:" @@"`
|
||||
JPathExpr *JSONPath `parser:"| @@"`
|
||||
SubExpression *Expression `parser:"| \"(\" @@ \")\""`
|
||||
// Include function expressions here.
|
||||
FuncCall *FuncExpr `parser:"| @@"`
|
||||
}
|
||||
|
||||
// FuncExpr represents a function call
|
||||
type FuncExpr struct {
|
||||
SFunc *SimpleArgFunc `parser:" @@"`
|
||||
Count *CountFunc `parser:"| @@"`
|
||||
Cast *CastFunc `parser:"| @@"`
|
||||
Substring *SubstringFunc `parser:"| @@"`
|
||||
Extract *ExtractFunc `parser:"| @@"`
|
||||
Trim *TrimFunc `parser:"| @@"`
|
||||
|
||||
// Used during evaluation for aggregation funcs
|
||||
aggregate *aggVal
|
||||
}
|
||||
|
||||
// SimpleArgFunc represents functions with simple expression
|
||||
// arguments.
|
||||
type SimpleArgFunc struct {
|
||||
FunctionName string `parser:" @(\"AVG\" | \"MAX\" | \"MIN\" | \"SUM\" | \"COALESCE\" | \"NULLIF\" | \"DATE_ADD\" | \"DATE_DIFF\" | \"TO_STRING\" | \"TO_TIMESTAMP\" | \"UTCNOW\" | \"CHAR_LENGTH\" | \"CHARACTER_LENGTH\" | \"LOWER\" | \"UPPER\") "`
|
||||
|
||||
ArgsList []*Expression `parser:"\"(\" (@@ (\",\" @@)*)?\")\""`
|
||||
}
|
||||
|
||||
// CountFunc represents the COUNT sql function
|
||||
type CountFunc struct {
|
||||
StarArg bool `parser:" \"COUNT\" \"(\" ( @\"*\"?"`
|
||||
ExprArg *Expression `parser:" @@? )! \")\""`
|
||||
}
|
||||
|
||||
// CastFunc represents CAST sql function
|
||||
type CastFunc struct {
|
||||
Expr *Expression `parser:" \"CAST\" \"(\" @@ "`
|
||||
CastType string `parser:" \"AS\" @(\"BOOL\" | \"INT\" | \"INTEGER\" | \"STRING\" | \"FLOAT\" | \"DECIMAL\" | \"NUMERIC\" | \"TIMESTAMP\") \")\" "`
|
||||
}
|
||||
|
||||
// SubstringFunc represents SUBSTRING sql function
|
||||
type SubstringFunc struct {
|
||||
Expr *PrimaryTerm `parser:" \"SUBSTRING\" \"(\" @@ "`
|
||||
From *Operand `parser:" ( \"FROM\" @@ "`
|
||||
For *Operand `parser:" (\"FOR\" @@)? \")\" "`
|
||||
Arg2 *Operand `parser:" | \",\" @@ "`
|
||||
Arg3 *Operand `parser:" (\",\" @@)? \")\" )"`
|
||||
}
|
||||
|
||||
// ExtractFunc represents EXTRACT sql function
|
||||
type ExtractFunc struct {
|
||||
Timeword string `parser:" \"EXTRACT\" \"(\" @( \"YEAR\":Timeword | \"MONTH\":Timeword | \"DAY\":Timeword | \"HOUR\":Timeword | \"MINUTE\":Timeword | \"SECOND\":Timeword | \"TIMEZONE_HOUR\":Timeword | \"TIMEZONE_MINUTE\":Timeword ) "`
|
||||
From *PrimaryTerm `parser:" \"FROM\" @@ \")\" "`
|
||||
}
|
||||
|
||||
// TrimFunc represents TRIM sql function
|
||||
type TrimFunc struct {
|
||||
TrimWhere *string `parser:" \"TRIM\" \"(\" ( @( \"LEADING\" | \"TRAILING\" | \"BOTH\" ) "`
|
||||
TrimChars *PrimaryTerm `parser:" @@? "`
|
||||
TrimFrom *PrimaryTerm `parser:" \"FROM\" )? @@ \")\" "`
|
||||
}
|
||||
|
||||
// LitValue represents a literal value parsed from the sql
|
||||
type LitValue struct {
|
||||
Number *float64 `parser:"( @Number"`
|
||||
String *LiteralString `parser:" | @LitString"`
|
||||
Boolean *Boolean `parser:" | @(\"TRUE\" | \"FALSE\")"`
|
||||
Null bool `parser:" | @\"NULL\")"`
|
||||
}
|
||||
|
||||
// Identifier represents a parsed identifier
|
||||
type Identifier struct {
|
||||
Unquoted *string `parser:" @Ident"`
|
||||
Quoted *QuotedIdentifier `parser:"| @QuotIdent"`
|
||||
}
|
||||
|
||||
var (
|
||||
sqlLexer = lexer.Must(lexer.Regexp(`(\s+)` +
|
||||
`|(?P<Timeword>(?i)\b(?:YEAR|MONTH|DAY|HOUR|MINUTE|SECOND|TIMEZONE_HOUR|TIMEZONE_MINUTE)\b)` +
|
||||
`|(?P<Keyword>(?i)\b(?:SELECT|FROM|TOP|DISTINCT|ALL|WHERE|GROUP|BY|HAVING|UNION|MINUS|EXCEPT|INTERSECT|ORDER|LIMIT|OFFSET|TRUE|FALSE|NULL|IS|NOT|ANY|SOME|BETWEEN|AND|OR|LIKE|ESCAPE|AS|IN|BOOL|INT|INTEGER|STRING|FLOAT|DECIMAL|NUMERIC|TIMESTAMP|AVG|COUNT|MAX|MIN|SUM|COALESCE|NULLIF|CAST|DATE_ADD|DATE_DIFF|EXTRACT|TO_STRING|TO_TIMESTAMP|UTCNOW|CHAR_LENGTH|CHARACTER_LENGTH|LOWER|SUBSTRING|TRIM|UPPER|LEADING|TRAILING|BOTH|FOR)\b)` +
|
||||
`|(?P<Ident>[a-zA-Z_][a-zA-Z0-9_]*)` +
|
||||
`|(?P<QuotIdent>"([^"]*("")?)*")` +
|
||||
`|(?P<Number>\d*\.?\d+([eE][-+]?\d+)?)` +
|
||||
`|(?P<LitString>'([^']*('')?)*')` +
|
||||
`|(?P<Operators><>|!=|<=|>=|\.\*|\[\*\]|[-+*/%,.()=<>\[\]])`,
|
||||
))
|
||||
|
||||
// SQLParser is used to parse SQL statements
|
||||
SQLParser = participle.MustBuild(
|
||||
&Select{},
|
||||
participle.Lexer(sqlLexer),
|
||||
participle.CaseInsensitive("Keyword"),
|
||||
participle.CaseInsensitive("Timeword"),
|
||||
)
|
||||
)
|
||||
383
pkg/s3select/sql/parser_test.go
Normal file
383
pkg/s3select/sql/parser_test.go
Normal file
@@ -0,0 +1,383 @@
|
||||
/*
|
||||
* Minio Cloud Storage, (C) 2019 Minio, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/alecthomas/participle"
|
||||
"github.com/alecthomas/participle/lexer"
|
||||
)
|
||||
|
||||
func TestJSONPathElement(t *testing.T) {
|
||||
p := participle.MustBuild(
|
||||
&JSONPathElement{},
|
||||
participle.Lexer(sqlLexer),
|
||||
participle.CaseInsensitive("Keyword"),
|
||||
)
|
||||
|
||||
j := JSONPathElement{}
|
||||
cases := []string{
|
||||
// Key
|
||||
"['name']", ".name", `."name"`,
|
||||
|
||||
// Index
|
||||
"[2]", "[0]", "[100]",
|
||||
|
||||
// Object wilcard
|
||||
".*",
|
||||
|
||||
// array wildcard
|
||||
"[*]",
|
||||
}
|
||||
for i, tc := range cases {
|
||||
err := p.ParseString(tc, &j)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: %v", i, err)
|
||||
}
|
||||
// repr.Println(j, repr.Indent(" "), repr.OmitEmpty(true))
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONPath(t *testing.T) {
|
||||
p := participle.MustBuild(
|
||||
&JSONPath{},
|
||||
participle.Lexer(sqlLexer),
|
||||
participle.CaseInsensitive("Keyword"),
|
||||
)
|
||||
|
||||
j := JSONPath{}
|
||||
cases := []string{
|
||||
"S3Object",
|
||||
"S3Object.id",
|
||||
"S3Object.book.title",
|
||||
"S3Object.id[1]",
|
||||
"S3Object.id['abc']",
|
||||
"S3Object.id['ab']",
|
||||
"S3Object.words.*.id",
|
||||
"S3Object.words.name[*].val",
|
||||
"S3Object.words.name[*].val[*]",
|
||||
"S3Object.words.name[*].val.*",
|
||||
}
|
||||
for i, tc := range cases {
|
||||
err := p.ParseString(tc, &j)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: %v", i, err)
|
||||
}
|
||||
// repr.Println(j, repr.Indent(" "), repr.OmitEmpty(true))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestIdentifierParsing(t *testing.T) {
|
||||
p := participle.MustBuild(
|
||||
&Identifier{},
|
||||
participle.Lexer(sqlLexer),
|
||||
participle.CaseInsensitive("Keyword"),
|
||||
)
|
||||
|
||||
id := Identifier{}
|
||||
validCases := []string{
|
||||
"a",
|
||||
"_a",
|
||||
"abc_a",
|
||||
"a2",
|
||||
`"abc"`,
|
||||
`"abc\a""ac"`,
|
||||
}
|
||||
for i, tc := range validCases {
|
||||
err := p.ParseString(tc, &id)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: %v", i, err)
|
||||
}
|
||||
// repr.Println(id, repr.Indent(" "), repr.OmitEmpty(true))
|
||||
}
|
||||
|
||||
invalidCases := []string{
|
||||
"+a",
|
||||
"-a",
|
||||
"1a",
|
||||
`"ab`,
|
||||
`abc"`,
|
||||
`aa""a`,
|
||||
`"a"a"`,
|
||||
}
|
||||
for i, tc := range invalidCases {
|
||||
err := p.ParseString(tc, &id)
|
||||
if err == nil {
|
||||
t.Fatalf("%d: %v", i, err)
|
||||
}
|
||||
// fmt.Println(tc, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLiteralStringParsing(t *testing.T) {
|
||||
var k ObjectKey
|
||||
p := participle.MustBuild(
|
||||
&ObjectKey{},
|
||||
participle.Lexer(sqlLexer),
|
||||
participle.CaseInsensitive("Keyword"),
|
||||
)
|
||||
|
||||
validCases := []string{
|
||||
"['abc']",
|
||||
"['ab''c']",
|
||||
"['a''b''c']",
|
||||
"['abc-x_1##@(*&(#*))/\\']",
|
||||
}
|
||||
for i, tc := range validCases {
|
||||
err := p.ParseString(tc, &k)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: %v", i, err)
|
||||
}
|
||||
if string(*k.Lit) == "" {
|
||||
t.Fatalf("Incorrect parse %#v", k)
|
||||
}
|
||||
// repr.Println(k, repr.Indent(" "), repr.OmitEmpty(true))
|
||||
}
|
||||
|
||||
invalidCases := []string{
|
||||
"['abc'']",
|
||||
"['-abc'sc']",
|
||||
"[abc']",
|
||||
"['ac]",
|
||||
}
|
||||
for i, tc := range invalidCases {
|
||||
err := p.ParseString(tc, &k)
|
||||
if err == nil {
|
||||
t.Fatalf("%d: %v", i, err)
|
||||
}
|
||||
// fmt.Println(tc, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFunctionParsing(t *testing.T) {
|
||||
var fex FuncExpr
|
||||
p := participle.MustBuild(
|
||||
&FuncExpr{},
|
||||
participle.Lexer(sqlLexer),
|
||||
participle.CaseInsensitive("Keyword"),
|
||||
participle.CaseInsensitive("Timeword"),
|
||||
)
|
||||
|
||||
validCases := []string{
|
||||
"count(*)",
|
||||
"sum(2 + s.id)",
|
||||
"sum(t)",
|
||||
"avg(s.id[1])",
|
||||
"coalesce(s.id[1], 2, 2 + 3)",
|
||||
|
||||
"cast(s as string)",
|
||||
"cast(s AS INT)",
|
||||
"cast(s as DECIMAL)",
|
||||
"extract(YEAR from '2018-01-09')",
|
||||
"extract(month from '2018-01-09')",
|
||||
|
||||
"extract(hour from '2018-01-09')",
|
||||
"extract(day from '2018-01-09')",
|
||||
"substring('abcd' from 2 for 2)",
|
||||
"substring('abcd' from 2)",
|
||||
"substring('abcd' , 2 , 2)",
|
||||
|
||||
"substring('abcd' , 22 )",
|
||||
"trim(' aab ')",
|
||||
"trim(leading from ' aab ')",
|
||||
"trim(trailing from ' aab ')",
|
||||
"trim(both from ' aab ')",
|
||||
|
||||
"trim(both '12' from ' aab ')",
|
||||
"trim(leading '12' from ' aab ')",
|
||||
"trim(trailing '12' from ' aab ')",
|
||||
"count(23)",
|
||||
}
|
||||
for i, tc := range validCases {
|
||||
err := p.ParseString(tc, &fex)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: %v", i, err)
|
||||
}
|
||||
// repr.Println(fex, repr.Indent(" "), repr.OmitEmpty(true))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlLexer(t *testing.T) {
|
||||
// s := bytes.NewBuffer([]byte("s.['name'].*.[*].abc.[\"abc\"]"))
|
||||
s := bytes.NewBuffer([]byte("S3Object.words.*.id"))
|
||||
// s := bytes.NewBuffer([]byte("COUNT(Id)"))
|
||||
lex, err := sqlLexer.Lex(s)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tokens, err := lexer.ConsumeAll(lex)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// for i, t := range tokens {
|
||||
// fmt.Printf("%d: %#v\n", i, t)
|
||||
// }
|
||||
if len(tokens) != 7 {
|
||||
t.Fatalf("Expected 7 got %d", len(tokens))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectWhere(t *testing.T) {
|
||||
p := participle.MustBuild(
|
||||
&Select{},
|
||||
participle.Lexer(sqlLexer),
|
||||
participle.CaseInsensitive("Keyword"),
|
||||
)
|
||||
|
||||
s := Select{}
|
||||
cases := []string{
|
||||
"select * from s3object",
|
||||
"select a, b from s3object s",
|
||||
"select a, b from s3object as s",
|
||||
"select a, b from s3object as s where a = 1",
|
||||
"select a, b from s3object s where a = 1",
|
||||
"select a, b from s3object where a = 1",
|
||||
}
|
||||
for i, tc := range cases {
|
||||
err := p.ParseString(tc, &s)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: %v", i, err)
|
||||
}
|
||||
|
||||
// repr.Println(s, repr.Indent(" "), repr.OmitEmpty(true))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLikeClause(t *testing.T) {
|
||||
p := participle.MustBuild(
|
||||
&Select{},
|
||||
participle.Lexer(sqlLexer),
|
||||
participle.CaseInsensitive("Keyword"),
|
||||
)
|
||||
|
||||
s := Select{}
|
||||
cases := []string{
|
||||
`select * from s3object where Name like 'abcd'`,
|
||||
`select Name like 'abc' from s3object`,
|
||||
`select * from s3object where Name not like 'abc'`,
|
||||
`select * from s3object where Name like 'abc' escape 't'`,
|
||||
`select * from s3object where Name like 'a\%' escape '?'`,
|
||||
`select * from s3object where Name not like 'abc\' escape '?'`,
|
||||
`select * from s3object where Name like 'a\%' escape LOWER('?')`,
|
||||
`select * from s3object where Name not like LOWER('Bc\') escape '?'`,
|
||||
}
|
||||
for i, tc := range cases {
|
||||
err := p.ParseString(tc, &s)
|
||||
if err != nil {
|
||||
t.Errorf("%d: %v", i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBetweenClause(t *testing.T) {
|
||||
p := participle.MustBuild(
|
||||
&Select{},
|
||||
participle.Lexer(sqlLexer),
|
||||
participle.CaseInsensitive("Keyword"),
|
||||
)
|
||||
|
||||
s := Select{}
|
||||
cases := []string{
|
||||
`select * from s3object where Id between 1 and 2`,
|
||||
`select * from s3object where Id between 1 and 2 and name = 'Ab'`,
|
||||
`select * from s3object where Id not between 1 and 2`,
|
||||
`select * from s3object where Id not between 1 and 2 and name = 'Bc'`,
|
||||
}
|
||||
for i, tc := range cases {
|
||||
err := p.ParseString(tc, &s)
|
||||
if err != nil {
|
||||
t.Errorf("%d: %v", i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromClauseJSONPath(t *testing.T) {
|
||||
p := participle.MustBuild(
|
||||
&Select{},
|
||||
participle.Lexer(sqlLexer),
|
||||
participle.CaseInsensitive("Keyword"),
|
||||
)
|
||||
|
||||
s := Select{}
|
||||
cases := []string{
|
||||
"select * from s3object",
|
||||
"select * from s3object[*].name",
|
||||
"select * from s3object[*].books[*]",
|
||||
"select * from s3object[*].books[*].name",
|
||||
"select * from s3object where name > 2",
|
||||
"select * from s3object[*].name where name > 2",
|
||||
"select * from s3object[*].books[*] where name > 2",
|
||||
"select * from s3object[*].books[*].name where name > 2",
|
||||
"select * from s3object[*].books[*] s",
|
||||
"select * from s3object[*].books[*].name as s",
|
||||
"select * from s3object s where name > 2",
|
||||
"select * from s3object[*].name as s where name > 2",
|
||||
}
|
||||
for i, tc := range cases {
|
||||
err := p.ParseString(tc, &s)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: %v", i, err)
|
||||
}
|
||||
|
||||
// repr.Println(s, repr.Indent(" "), repr.OmitEmpty(true))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestSelectParsing(t *testing.T) {
|
||||
p := participle.MustBuild(
|
||||
&Select{},
|
||||
participle.Lexer(sqlLexer),
|
||||
participle.CaseInsensitive("Keyword"),
|
||||
)
|
||||
|
||||
s := Select{}
|
||||
cases := []string{
|
||||
"select * from s3object where name > 2 or value > 1 or word > 2",
|
||||
"select s.word.id + 2 from s3object s",
|
||||
"select 1-2-3 from s3object s limit 1",
|
||||
}
|
||||
for i, tc := range cases {
|
||||
err := p.ParseString(tc, &s)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: %v", i, err)
|
||||
}
|
||||
|
||||
// repr.Println(s, repr.Indent(" "), repr.OmitEmpty(true))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlLexerArithOps(t *testing.T) {
|
||||
s := bytes.NewBuffer([]byte("year from select month hour distinct"))
|
||||
lex, err := sqlLexer.Lex(s)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tokens, err := lexer.ConsumeAll(lex)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(tokens) != 7 {
|
||||
t.Errorf("Expected 7 got %d", len(tokens))
|
||||
}
|
||||
// for i, t := range tokens {
|
||||
// fmt.Printf("%d: %#v\n", i, t)
|
||||
// }
|
||||
}
|
||||
@@ -1,529 +0,0 @@
|
||||
/*
|
||||
* Minio Cloud Storage, (C) 2019 Minio, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/xwb1989/sqlparser"
|
||||
)
|
||||
|
||||
func getColumnName(colName *sqlparser.ColName) string {
|
||||
columnName := colName.Qualifier.Name.String()
|
||||
if qualifier := colName.Qualifier.Qualifier.String(); qualifier != "" {
|
||||
columnName = qualifier + "." + columnName
|
||||
}
|
||||
|
||||
if columnName == "" {
|
||||
columnName = colName.Name.String()
|
||||
} else {
|
||||
columnName = columnName + "." + colName.Name.String()
|
||||
}
|
||||
|
||||
return columnName
|
||||
}
|
||||
|
||||
func newLiteralExpr(parserExpr sqlparser.Expr, tableAlias string) (Expr, error) {
|
||||
switch parserExpr.(type) {
|
||||
case *sqlparser.NullVal:
|
||||
return newValueExpr(NewNull()), nil
|
||||
case sqlparser.BoolVal:
|
||||
return newValueExpr(NewBool((bool(parserExpr.(sqlparser.BoolVal))))), nil
|
||||
case *sqlparser.SQLVal:
|
||||
sqlValue := parserExpr.(*sqlparser.SQLVal)
|
||||
value, err := NewValue(sqlValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newValueExpr(value), nil
|
||||
case *sqlparser.ColName:
|
||||
columnName := getColumnName(parserExpr.(*sqlparser.ColName))
|
||||
if tableAlias != "" {
|
||||
if !strings.HasPrefix(columnName, tableAlias+".") {
|
||||
err := fmt.Errorf("column name %v does not start with table alias %v", columnName, tableAlias)
|
||||
return nil, errInvalidKeyPath(err)
|
||||
}
|
||||
columnName = strings.TrimPrefix(columnName, tableAlias+".")
|
||||
}
|
||||
|
||||
return newColumnExpr(columnName), nil
|
||||
case sqlparser.ValTuple:
|
||||
var valueType Type
|
||||
var values []*Value
|
||||
for i, valExpr := range parserExpr.(sqlparser.ValTuple) {
|
||||
sqlVal, ok := valExpr.(*sqlparser.SQLVal)
|
||||
if !ok {
|
||||
return nil, errParseInvalidTypeParam(fmt.Errorf("value %v in Tuple should be primitive value", i+1))
|
||||
}
|
||||
|
||||
val, err := NewValue(sqlVal)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if i == 0 {
|
||||
valueType = val.Type()
|
||||
} else if valueType != val.Type() {
|
||||
return nil, errParseInvalidTypeParam(fmt.Errorf("mixed value type is not allowed in Tuple"))
|
||||
}
|
||||
|
||||
values = append(values, val)
|
||||
}
|
||||
|
||||
return newValueExpr(NewArray(values)), nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func isExprToComparisonExpr(parserExpr *sqlparser.IsExpr, tableAlias string, isSelectExpr bool) (Expr, error) {
|
||||
leftExpr, err := newExpr(parserExpr.Expr, tableAlias, isSelectExpr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f, err := newComparisonExpr(ComparisonOperator(parserExpr.Operator), leftExpr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !leftExpr.Type().isBase() {
|
||||
return f, nil
|
||||
}
|
||||
|
||||
value, err := f.Eval(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newValueExpr(value), nil
|
||||
}
|
||||
|
||||
func rangeCondToComparisonFunc(parserExpr *sqlparser.RangeCond, tableAlias string, isSelectExpr bool) (Expr, error) {
|
||||
leftExpr, err := newExpr(parserExpr.Left, tableAlias, isSelectExpr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fromExpr, err := newExpr(parserExpr.From, tableAlias, isSelectExpr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
toExpr, err := newExpr(parserExpr.To, tableAlias, isSelectExpr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f, err := newComparisonExpr(ComparisonOperator(parserExpr.Operator), leftExpr, fromExpr, toExpr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !leftExpr.Type().isBase() || !fromExpr.Type().isBase() || !toExpr.Type().isBase() {
|
||||
return f, nil
|
||||
}
|
||||
|
||||
value, err := f.Eval(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newValueExpr(value), nil
|
||||
}
|
||||
|
||||
func toComparisonExpr(parserExpr *sqlparser.ComparisonExpr, tableAlias string, isSelectExpr bool) (Expr, error) {
|
||||
leftExpr, err := newExpr(parserExpr.Left, tableAlias, isSelectExpr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rightExpr, err := newExpr(parserExpr.Right, tableAlias, isSelectExpr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f, err := newComparisonExpr(ComparisonOperator(parserExpr.Operator), leftExpr, rightExpr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !leftExpr.Type().isBase() || !rightExpr.Type().isBase() {
|
||||
return f, nil
|
||||
}
|
||||
|
||||
value, err := f.Eval(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newValueExpr(value), nil
|
||||
}
|
||||
|
||||
func toArithExpr(parserExpr *sqlparser.BinaryExpr, tableAlias string, isSelectExpr bool) (Expr, error) {
|
||||
leftExpr, err := newExpr(parserExpr.Left, tableAlias, isSelectExpr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rightExpr, err := newExpr(parserExpr.Right, tableAlias, isSelectExpr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f, err := newArithExpr(ArithOperator(parserExpr.Operator), leftExpr, rightExpr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !leftExpr.Type().isBase() || !rightExpr.Type().isBase() {
|
||||
return f, nil
|
||||
}
|
||||
|
||||
value, err := f.Eval(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newValueExpr(value), nil
|
||||
}
|
||||
|
||||
func toFuncExpr(parserExpr *sqlparser.FuncExpr, tableAlias string, isSelectExpr bool) (Expr, error) {
|
||||
funcName := strings.ToUpper(parserExpr.Name.String())
|
||||
if !isSelectExpr && isAggregateFuncName(funcName) {
|
||||
return nil, errUnsupportedSQLOperation(fmt.Errorf("%v() must be used in select expression", funcName))
|
||||
}
|
||||
funcs, aggregatedExprFound, err := newSelectExprs(parserExpr.Exprs, tableAlias)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if aggregatedExprFound {
|
||||
return nil, errIncorrectSQLFunctionArgumentType(fmt.Errorf("%v(): aggregated expression must not be used as argument", funcName))
|
||||
}
|
||||
|
||||
return newFuncExpr(FuncName(funcName), funcs...)
|
||||
}
|
||||
|
||||
func toAndExpr(parserExpr *sqlparser.AndExpr, tableAlias string, isSelectExpr bool) (Expr, error) {
|
||||
leftExpr, err := newExpr(parserExpr.Left, tableAlias, isSelectExpr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rightExpr, err := newExpr(parserExpr.Right, tableAlias, isSelectExpr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f, err := newAndExpr(leftExpr, rightExpr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if leftExpr.Type() != Bool || rightExpr.Type() != Bool {
|
||||
return f, nil
|
||||
}
|
||||
|
||||
value, err := f.Eval(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newValueExpr(value), nil
|
||||
}
|
||||
|
||||
func toOrExpr(parserExpr *sqlparser.OrExpr, tableAlias string, isSelectExpr bool) (Expr, error) {
|
||||
leftExpr, err := newExpr(parserExpr.Left, tableAlias, isSelectExpr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rightExpr, err := newExpr(parserExpr.Right, tableAlias, isSelectExpr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f, err := newOrExpr(leftExpr, rightExpr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if leftExpr.Type() != Bool || rightExpr.Type() != Bool {
|
||||
return f, nil
|
||||
}
|
||||
|
||||
value, err := f.Eval(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newValueExpr(value), nil
|
||||
}
|
||||
|
||||
func toNotExpr(parserExpr *sqlparser.NotExpr, tableAlias string, isSelectExpr bool) (Expr, error) {
|
||||
rightExpr, err := newExpr(parserExpr.Expr, tableAlias, isSelectExpr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f, err := newNotExpr(rightExpr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if rightExpr.Type() != Bool {
|
||||
return f, nil
|
||||
}
|
||||
|
||||
value, err := f.Eval(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newValueExpr(value), nil
|
||||
}
|
||||
|
||||
func newExpr(parserExpr sqlparser.Expr, tableAlias string, isSelectExpr bool) (Expr, error) {
|
||||
f, err := newLiteralExpr(parserExpr, tableAlias)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if f != nil {
|
||||
return f, nil
|
||||
}
|
||||
|
||||
switch parserExpr.(type) {
|
||||
case *sqlparser.ParenExpr:
|
||||
return newExpr(parserExpr.(*sqlparser.ParenExpr).Expr, tableAlias, isSelectExpr)
|
||||
case *sqlparser.IsExpr:
|
||||
return isExprToComparisonExpr(parserExpr.(*sqlparser.IsExpr), tableAlias, isSelectExpr)
|
||||
case *sqlparser.RangeCond:
|
||||
return rangeCondToComparisonFunc(parserExpr.(*sqlparser.RangeCond), tableAlias, isSelectExpr)
|
||||
case *sqlparser.ComparisonExpr:
|
||||
return toComparisonExpr(parserExpr.(*sqlparser.ComparisonExpr), tableAlias, isSelectExpr)
|
||||
case *sqlparser.BinaryExpr:
|
||||
return toArithExpr(parserExpr.(*sqlparser.BinaryExpr), tableAlias, isSelectExpr)
|
||||
case *sqlparser.FuncExpr:
|
||||
return toFuncExpr(parserExpr.(*sqlparser.FuncExpr), tableAlias, isSelectExpr)
|
||||
case *sqlparser.AndExpr:
|
||||
return toAndExpr(parserExpr.(*sqlparser.AndExpr), tableAlias, isSelectExpr)
|
||||
case *sqlparser.OrExpr:
|
||||
return toOrExpr(parserExpr.(*sqlparser.OrExpr), tableAlias, isSelectExpr)
|
||||
case *sqlparser.NotExpr:
|
||||
return toNotExpr(parserExpr.(*sqlparser.NotExpr), tableAlias, isSelectExpr)
|
||||
}
|
||||
|
||||
return nil, errParseUnsupportedSyntax(fmt.Errorf("unknown expression type %T; %v", parserExpr, parserExpr))
|
||||
}
|
||||
|
||||
func newSelectExprs(parserSelectExprs []sqlparser.SelectExpr, tableAlias string) ([]Expr, bool, error) {
|
||||
var funcs []Expr
|
||||
starExprFound := false
|
||||
aggregatedExprFound := false
|
||||
|
||||
for _, selectExpr := range parserSelectExprs {
|
||||
switch selectExpr.(type) {
|
||||
case *sqlparser.AliasedExpr:
|
||||
if starExprFound {
|
||||
return nil, false, errParseAsteriskIsNotAloneInSelectList(nil)
|
||||
}
|
||||
|
||||
aliasedExpr := selectExpr.(*sqlparser.AliasedExpr)
|
||||
f, err := newExpr(aliasedExpr.Expr, tableAlias, true)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
if f.Type() == aggregateFunction {
|
||||
if !aggregatedExprFound {
|
||||
aggregatedExprFound = true
|
||||
if len(funcs) > 0 {
|
||||
return nil, false, errParseUnsupportedSyntax(fmt.Errorf("expression must not mixed with aggregated expression"))
|
||||
}
|
||||
}
|
||||
} else if aggregatedExprFound {
|
||||
return nil, false, errParseUnsupportedSyntax(fmt.Errorf("expression must not mixed with aggregated expression"))
|
||||
}
|
||||
|
||||
alias := aliasedExpr.As.String()
|
||||
if alias != "" {
|
||||
f = newAliasExpr(alias, f)
|
||||
}
|
||||
|
||||
funcs = append(funcs, f)
|
||||
case *sqlparser.StarExpr:
|
||||
if starExprFound {
|
||||
err := fmt.Errorf("only single star expression allowed")
|
||||
return nil, false, errParseInvalidContextForWildcardInSelectList(err)
|
||||
}
|
||||
starExprFound = true
|
||||
funcs = append(funcs, newStarExpr())
|
||||
default:
|
||||
return nil, false, errParseUnsupportedSyntax(fmt.Errorf("unknown select expression %v", selectExpr))
|
||||
}
|
||||
}
|
||||
|
||||
return funcs, aggregatedExprFound, nil
|
||||
}
|
||||
|
||||
// Select - SQL Select statement.
|
||||
type Select struct {
|
||||
tableName string
|
||||
tableAlias string
|
||||
selectExprs []Expr
|
||||
aggregatedExprFound bool
|
||||
whereExpr Expr
|
||||
}
|
||||
|
||||
// TableAlias - returns table alias name.
|
||||
func (statement *Select) TableAlias() string {
|
||||
return statement.tableAlias
|
||||
}
|
||||
|
||||
// IsSelectAll - returns whether '*' is used in select expression or not.
|
||||
func (statement *Select) IsSelectAll() bool {
|
||||
if len(statement.selectExprs) == 1 {
|
||||
_, ok := statement.selectExprs[0].(*starExpr)
|
||||
return ok
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// IsAggregated - returns whether aggregated functions are used in select expression or not.
|
||||
func (statement *Select) IsAggregated() bool {
|
||||
return statement.aggregatedExprFound
|
||||
}
|
||||
|
||||
// AggregateResult - returns aggregate result as record.
|
||||
func (statement *Select) AggregateResult(output Record) error {
|
||||
if !statement.aggregatedExprFound {
|
||||
return nil
|
||||
}
|
||||
|
||||
for i, expr := range statement.selectExprs {
|
||||
value, err := expr.AggregateValue()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if value == nil {
|
||||
return errInternalError(fmt.Errorf("%v returns <nil> for AggregateValue()", expr))
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("_%v", i+1)
|
||||
if _, ok := expr.(*aliasExpr); ok {
|
||||
name = expr.(*aliasExpr).alias
|
||||
}
|
||||
|
||||
if err = output.Set(name, value); err != nil {
|
||||
return errInternalError(fmt.Errorf("error occurred to store value %v for %v; %v", value, name, err))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Eval - evaluates this Select expressions for given record.
|
||||
func (statement *Select) Eval(input, output Record) (Record, error) {
|
||||
if statement.whereExpr != nil {
|
||||
value, err := statement.whereExpr.Eval(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if value == nil || value.valueType != Bool {
|
||||
err = fmt.Errorf("WHERE expression %v returns invalid bool value %v", statement.whereExpr, value)
|
||||
return nil, errInternalError(err)
|
||||
}
|
||||
|
||||
if !value.BoolValue() {
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Call selectExprs
|
||||
for i, expr := range statement.selectExprs {
|
||||
value, err := expr.Eval(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if statement.aggregatedExprFound {
|
||||
continue
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("_%v", i+1)
|
||||
switch expr.(type) {
|
||||
case *starExpr:
|
||||
return value.recordValue(), nil
|
||||
case *aliasExpr:
|
||||
name = expr.(*aliasExpr).alias
|
||||
case *columnExpr:
|
||||
name = expr.(*columnExpr).name
|
||||
}
|
||||
|
||||
if err = output.Set(name, value); err != nil {
|
||||
return nil, errInternalError(fmt.Errorf("error occurred to store value %v for %v; %v", value, name, err))
|
||||
}
|
||||
}
|
||||
|
||||
return output, nil
|
||||
}
|
||||
|
||||
// NewSelect - creates new Select by parsing sql.
|
||||
func NewSelect(sql string) (*Select, error) {
|
||||
stmt, err := sqlparser.Parse(sql)
|
||||
if err != nil {
|
||||
return nil, errUnsupportedSQLStructure(err)
|
||||
}
|
||||
|
||||
selectStmt, ok := stmt.(*sqlparser.Select)
|
||||
if !ok {
|
||||
return nil, errParseUnsupportedSelect(fmt.Errorf("unsupported SQL statement %v", sql))
|
||||
}
|
||||
|
||||
var tableName, tableAlias string
|
||||
for _, fromExpr := range selectStmt.From {
|
||||
tableExpr := fromExpr.(*sqlparser.AliasedTableExpr)
|
||||
tableName = tableExpr.Expr.(sqlparser.TableName).Name.String()
|
||||
tableAlias = tableExpr.As.String()
|
||||
}
|
||||
|
||||
selectExprs, aggregatedExprFound, err := newSelectExprs(selectStmt.SelectExprs, tableAlias)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var whereExpr Expr
|
||||
if selectStmt.Where != nil {
|
||||
whereExpr, err = newExpr(selectStmt.Where.Expr, tableAlias, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &Select{
|
||||
tableName: tableName,
|
||||
tableAlias: tableAlias,
|
||||
selectExprs: selectExprs,
|
||||
aggregatedExprFound: aggregatedExprFound,
|
||||
whereExpr: whereExpr,
|
||||
}, nil
|
||||
}
|
||||
202
pkg/s3select/sql/statement.go
Normal file
202
pkg/s3select/sql/statement.go
Normal file
@@ -0,0 +1,202 @@
|
||||
/*
|
||||
* Minio Cloud Storage, (C) 2019 Minio, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
errBadLimitSpecified = errors.New("Limit value must be a positive integer")
|
||||
)
|
||||
|
||||
// SelectStatement is the top level parsed and analyzed structure
|
||||
type SelectStatement struct {
|
||||
selectAST *Select
|
||||
|
||||
// Analysis result of the statement
|
||||
selectQProp qProp
|
||||
|
||||
// Result of parsing the limit clause if one is present
|
||||
// (otherwise -1)
|
||||
limitValue int64
|
||||
|
||||
// Count of rows that have been output.
|
||||
outputCount int64
|
||||
}
|
||||
|
||||
// ParseSelectStatement - parses a select query from the given string
|
||||
// and analyzes it.
|
||||
func ParseSelectStatement(s string) (stmt SelectStatement, err error) {
|
||||
var selectAST Select
|
||||
err = SQLParser.ParseString(s, &selectAST)
|
||||
if err != nil {
|
||||
err = errQueryParseFailure(err)
|
||||
return
|
||||
}
|
||||
stmt.selectAST = &selectAST
|
||||
|
||||
// Check the parsed limit value
|
||||
stmt.limitValue, err = parseLimit(selectAST.Limit)
|
||||
if err != nil {
|
||||
err = errQueryAnalysisFailure(err)
|
||||
return
|
||||
}
|
||||
|
||||
// Analyze where clause
|
||||
if selectAST.Where != nil {
|
||||
whereQProp := selectAST.Where.analyze(&selectAST)
|
||||
if whereQProp.err != nil {
|
||||
err = errQueryAnalysisFailure(fmt.Errorf("Where clause error: %v", whereQProp.err))
|
||||
return
|
||||
}
|
||||
|
||||
if whereQProp.isAggregation {
|
||||
err = errQueryAnalysisFailure(errors.New("WHERE clause cannot have an aggregation"))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Validate table name
|
||||
tableString := strings.ToLower(selectAST.From.Table.String())
|
||||
if !strings.HasPrefix(tableString, "s3object.") && tableString != "s3object" {
|
||||
err = errBadTableName(errors.New("Table name must be s3object"))
|
||||
return
|
||||
}
|
||||
|
||||
// Analyze main select expression
|
||||
stmt.selectQProp = selectAST.Expression.analyze(&selectAST)
|
||||
err = stmt.selectQProp.err
|
||||
if err != nil {
|
||||
fmt.Println("Got Analysis err:", err)
|
||||
err = errQueryAnalysisFailure(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func parseLimit(v *LitValue) (int64, error) {
|
||||
switch {
|
||||
case v == nil:
|
||||
return -1, nil
|
||||
case v.Number == nil:
|
||||
return -1, errBadLimitSpecified
|
||||
default:
|
||||
r := int64(*v.Number)
|
||||
if r < 0 {
|
||||
return -1, errBadLimitSpecified
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
}
|
||||
|
||||
// IsAggregated returns if the statement involves SQL aggregation
|
||||
func (e *SelectStatement) IsAggregated() bool {
|
||||
return e.selectQProp.isAggregation
|
||||
}
|
||||
|
||||
// AggregateResult - returns the aggregated result after all input
|
||||
// records have been processed. Applies only to aggregation queries.
|
||||
func (e *SelectStatement) AggregateResult(output Record) error {
|
||||
for i, expr := range e.selectAST.Expression.Expressions {
|
||||
v, err := expr.evalNode(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
output.Set(fmt.Sprintf("_%d", i+1), v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AggregateRow - aggregates the input record. Applies only to
|
||||
// aggregation queries.
|
||||
func (e *SelectStatement) AggregateRow(input Record) error {
|
||||
for _, expr := range e.selectAST.Expression.Expressions {
|
||||
err := expr.aggregateRow(input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Eval - evaluates the Select statement for the given record. It
|
||||
// applies only to non-aggregation queries.
|
||||
func (e *SelectStatement) Eval(input, output Record) (Record, error) {
|
||||
if whereExpr := e.selectAST.Where; whereExpr != nil {
|
||||
value, err := whereExpr.evalNode(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b, ok := value.ToBool()
|
||||
if !ok {
|
||||
err = fmt.Errorf("WHERE expression did not return bool")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !b {
|
||||
// Where clause is not satisfied by the row
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
if e.selectAST.Expression.All {
|
||||
// Return the input record for `SELECT * FROM
|
||||
// .. WHERE ..`
|
||||
|
||||
// Update count of records output.
|
||||
if e.limitValue > -1 {
|
||||
e.outputCount++
|
||||
}
|
||||
|
||||
return input, nil
|
||||
}
|
||||
|
||||
for i, expr := range e.selectAST.Expression.Expressions {
|
||||
v, err := expr.evalNode(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Pick output column names
|
||||
if expr.As != "" {
|
||||
output.Set(expr.As, v)
|
||||
} else if comp, ok := getLastKeypathComponent(expr.Expression); ok {
|
||||
output.Set(comp, v)
|
||||
} else {
|
||||
output.Set(fmt.Sprintf("_%d", i+1), v)
|
||||
}
|
||||
}
|
||||
|
||||
// Update count of records output.
|
||||
if e.limitValue > -1 {
|
||||
e.outputCount++
|
||||
}
|
||||
|
||||
return output, nil
|
||||
}
|
||||
|
||||
// LimitReached - returns true if the number of records output has
|
||||
// reached the value of the `LIMIT` clause.
|
||||
func (e *SelectStatement) LimitReached() bool {
|
||||
if e.limitValue == -1 {
|
||||
return false
|
||||
}
|
||||
return e.outputCount >= e.limitValue
|
||||
}
|
||||
188
pkg/s3select/sql/stringfuncs.go
Normal file
188
pkg/s3select/sql/stringfuncs.go
Normal file
@@ -0,0 +1,188 @@
|
||||
/*
|
||||
* Minio Cloud Storage, (C) 2019 Minio, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
errMalformedEscapeSequence = errors.New("Malformed escape sequence in LIKE clause")
|
||||
errInvalidTrimArg = errors.New("Trim argument is invalid - this should not happen")
|
||||
errInvalidSubstringIndexLen = errors.New("Substring start index or length falls outside the string")
|
||||
)
|
||||
|
||||
const (
|
||||
percent rune = '%'
|
||||
underscore rune = '_'
|
||||
runeZero rune = 0
|
||||
)
|
||||
|
||||
func evalSQLLike(text, pattern string, escape rune) (match bool, err error) {
|
||||
s := []rune{}
|
||||
prev := runeZero
|
||||
hasLeadingPercent := false
|
||||
patLen := len([]rune(pattern))
|
||||
for i, r := range pattern {
|
||||
if i > 0 && prev == escape {
|
||||
switch r {
|
||||
case percent, escape, underscore:
|
||||
s = append(s, r)
|
||||
prev = r
|
||||
if r == escape {
|
||||
prev = runeZero
|
||||
}
|
||||
default:
|
||||
return false, errMalformedEscapeSequence
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
prev = r
|
||||
|
||||
var ok bool
|
||||
switch r {
|
||||
case percent:
|
||||
if len(s) == 0 {
|
||||
hasLeadingPercent = true
|
||||
continue
|
||||
}
|
||||
|
||||
text, ok = matcher(text, string(s), hasLeadingPercent)
|
||||
if !ok {
|
||||
return false, nil
|
||||
}
|
||||
hasLeadingPercent = true
|
||||
s = []rune{}
|
||||
|
||||
if i == patLen-1 {
|
||||
// Last pattern character is a %, so
|
||||
// we are done.
|
||||
return true, nil
|
||||
}
|
||||
|
||||
case underscore:
|
||||
if len(s) == 0 {
|
||||
text, ok = dropRune(text)
|
||||
if !ok {
|
||||
return false, nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
text, ok = matcher(text, string(s), hasLeadingPercent)
|
||||
if !ok {
|
||||
return false, nil
|
||||
}
|
||||
hasLeadingPercent = false
|
||||
|
||||
text, ok = dropRune(text)
|
||||
if !ok {
|
||||
return false, nil
|
||||
}
|
||||
s = []rune{}
|
||||
|
||||
case escape:
|
||||
if i == patLen-1 {
|
||||
return false, errMalformedEscapeSequence
|
||||
}
|
||||
// Otherwise do nothing.
|
||||
|
||||
default:
|
||||
s = append(s, r)
|
||||
}
|
||||
|
||||
}
|
||||
if hasLeadingPercent {
|
||||
return strings.HasSuffix(text, string(s)), nil
|
||||
}
|
||||
return string(s) == text, nil
|
||||
}
|
||||
|
||||
// matcher - Finds `pat` in `text`, and returns the part remainder of
|
||||
// `text`, after the match. If leadingPercent is false, `pat` must be
|
||||
// the prefix of `text`, otherwise it must be a substring.
|
||||
func matcher(text, pat string, leadingPercent bool) (res string, match bool) {
|
||||
if !leadingPercent {
|
||||
res = strings.TrimPrefix(text, pat)
|
||||
if len(text) == len(res) {
|
||||
return "", false
|
||||
}
|
||||
} else {
|
||||
parts := strings.SplitN(text, pat, 2)
|
||||
if len(parts) == 1 {
|
||||
return "", false
|
||||
}
|
||||
res = parts[1]
|
||||
}
|
||||
return res, true
|
||||
}
|
||||
|
||||
func dropRune(text string) (res string, ok bool) {
|
||||
r := []rune(text)
|
||||
if len(r) == 0 {
|
||||
return "", false
|
||||
}
|
||||
return string(r[1:]), true
|
||||
}
|
||||
|
||||
func evalSQLSubstring(s string, startIdx, length int) (res string, err error) {
|
||||
if startIdx <= 0 || startIdx > len(s) {
|
||||
return "", errInvalidSubstringIndexLen
|
||||
}
|
||||
// StartIdx is 1-based in the input
|
||||
startIdx--
|
||||
|
||||
rs := []rune(s)
|
||||
endIdx := len(rs)
|
||||
if length != -1 {
|
||||
if length < 0 || startIdx+length > len(s) {
|
||||
return "", errInvalidSubstringIndexLen
|
||||
}
|
||||
endIdx = startIdx + length
|
||||
}
|
||||
|
||||
return string(rs[startIdx:endIdx]), nil
|
||||
}
|
||||
|
||||
const (
|
||||
trimLeading = "LEADING"
|
||||
trimTrailing = "TRAILING"
|
||||
trimBoth = "BOTH"
|
||||
)
|
||||
|
||||
func evalSQLTrim(where *string, trimChars, text string) (result string, err error) {
|
||||
cutSet := " "
|
||||
if trimChars != "" {
|
||||
cutSet = trimChars
|
||||
}
|
||||
|
||||
trimFunc := strings.Trim
|
||||
switch {
|
||||
case where == nil:
|
||||
case *where == trimBoth:
|
||||
case *where == trimLeading:
|
||||
trimFunc = strings.TrimLeft
|
||||
case *where == trimTrailing:
|
||||
trimFunc = strings.TrimRight
|
||||
default:
|
||||
return "", errInvalidTrimArg
|
||||
}
|
||||
|
||||
return trimFunc(text, cutSet), nil
|
||||
}
|
||||
107
pkg/s3select/sql/stringfuncs_test.go
Normal file
107
pkg/s3select/sql/stringfuncs_test.go
Normal file
@@ -0,0 +1,107 @@
|
||||
/*
|
||||
* Minio Cloud Storage, (C) 2019 Minio, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEvalSQLLike(t *testing.T) {
|
||||
dropCases := []struct {
|
||||
input, resultExpected string
|
||||
matchExpected bool
|
||||
}{
|
||||
{"", "", false},
|
||||
{"a", "", true},
|
||||
{"ab", "b", true},
|
||||
{"தமிழ்", "மிழ்", true},
|
||||
}
|
||||
|
||||
for i, tc := range dropCases {
|
||||
res, ok := dropRune(tc.input)
|
||||
if res != tc.resultExpected || ok != tc.matchExpected {
|
||||
t.Errorf("DropRune Case %d failed", i)
|
||||
}
|
||||
}
|
||||
|
||||
matcherCases := []struct {
|
||||
iText, iPat string
|
||||
iHasLeadingPercent bool
|
||||
resultExpected string
|
||||
matchExpected bool
|
||||
}{
|
||||
{"abcd", "bcd", false, "", false},
|
||||
{"abcd", "bcd", true, "", true},
|
||||
{"abcd", "abcd", false, "", true},
|
||||
{"abcd", "abcd", true, "", true},
|
||||
{"abcd", "ab", false, "cd", true},
|
||||
{"abcd", "ab", true, "cd", true},
|
||||
{"abcd", "bc", false, "", false},
|
||||
{"abcd", "bc", true, "d", true},
|
||||
}
|
||||
|
||||
for i, tc := range matcherCases {
|
||||
res, ok := matcher(tc.iText, tc.iPat, tc.iHasLeadingPercent)
|
||||
if res != tc.resultExpected || ok != tc.matchExpected {
|
||||
t.Errorf("Matcher Case %d failed", i)
|
||||
}
|
||||
}
|
||||
|
||||
evalCases := []struct {
|
||||
iText, iPat string
|
||||
iEsc rune
|
||||
matchExpected bool
|
||||
errExpected error
|
||||
}{
|
||||
{"abcd", "abc", runeZero, false, nil},
|
||||
{"abcd", "abcd", runeZero, true, nil},
|
||||
{"abcd", "abc_", runeZero, true, nil},
|
||||
{"abcd", "_bdd", runeZero, false, nil},
|
||||
{"abcd", "_b_d", runeZero, true, nil},
|
||||
|
||||
{"abcd", "____", runeZero, true, nil},
|
||||
{"abcd", "____%", runeZero, true, nil},
|
||||
{"abcd", "%____", runeZero, true, nil},
|
||||
{"abcd", "%__", runeZero, true, nil},
|
||||
{"abcd", "____", runeZero, true, nil},
|
||||
|
||||
{"", "_", runeZero, false, nil},
|
||||
{"", "%", runeZero, true, nil},
|
||||
{"abcd", "%%%%%", runeZero, true, nil},
|
||||
{"abcd", "_____", runeZero, false, nil},
|
||||
{"abcd", "%%%%%", runeZero, true, nil},
|
||||
|
||||
{"a%%d", `a\%\%d`, '\\', true, nil},
|
||||
{"a%%d", `a\%d`, '\\', false, nil},
|
||||
{`a%%\d`, `a\%\%\\d`, '\\', true, nil},
|
||||
{`a%%\`, `a\%\%\\`, '\\', true, nil},
|
||||
{`a%__%\`, `a\%\_\_\%\\`, '\\', true, nil},
|
||||
|
||||
{`a%__%\`, `a\%\_\_\%_`, '\\', true, nil},
|
||||
{`a%__%\`, `a\%\_\__`, '\\', false, nil},
|
||||
{`a%__%\`, `a\%\_\_%`, '\\', true, nil},
|
||||
{`a%__%\`, `a?%?_?_?%\`, '?', true, nil},
|
||||
}
|
||||
|
||||
for i, tc := range evalCases {
|
||||
// fmt.Println("Case:", i)
|
||||
res, err := evalSQLLike(tc.iText, tc.iPat, tc.iEsc)
|
||||
if res != tc.matchExpected || err != tc.errExpected {
|
||||
t.Errorf("Eval Case %d failed: %v %v", i, res, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,118 +0,0 @@
|
||||
/*
|
||||
* Minio Cloud Storage, (C) 2019 Minio, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package sql
|
||||
|
||||
// Type - value type.
|
||||
type Type string
|
||||
|
||||
const (
|
||||
// Null - represents NULL value type.
|
||||
Null Type = "null"
|
||||
|
||||
// Bool - represents boolean value type.
|
||||
Bool Type = "bool"
|
||||
|
||||
// Int - represents integer value type.
|
||||
Int Type = "int"
|
||||
|
||||
// Float - represents floating point value type.
|
||||
Float Type = "float"
|
||||
|
||||
// String - represents string value type.
|
||||
String Type = "string"
|
||||
|
||||
// Timestamp - represents time value type.
|
||||
Timestamp Type = "timestamp"
|
||||
|
||||
// Array - represents array of values where each value type is one of above.
|
||||
Array Type = "array"
|
||||
|
||||
column Type = "column"
|
||||
record Type = "record"
|
||||
function Type = "function"
|
||||
aggregateFunction Type = "aggregatefunction"
|
||||
arithmeticFunction Type = "arithmeticfunction"
|
||||
comparisonFunction Type = "comparisonfunction"
|
||||
logicalFunction Type = "logicalfunction"
|
||||
|
||||
// Integer Type = "integer" // Same as Int
|
||||
// Decimal Type = "decimal" // Same as Float
|
||||
// Numeric Type = "numeric" // Same as Float
|
||||
)
|
||||
|
||||
func (t Type) isBase() bool {
|
||||
switch t {
|
||||
case Null, Bool, Int, Float, String, Timestamp:
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (t Type) isBaseKind() bool {
|
||||
switch t {
|
||||
case Null, Bool, Int, Float, String, Timestamp, column:
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (t Type) isNumber() bool {
|
||||
switch t {
|
||||
case Int, Float:
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (t Type) isNumberKind() bool {
|
||||
switch t {
|
||||
case Int, Float, column:
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (t Type) isIntKind() bool {
|
||||
switch t {
|
||||
case Int, column:
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (t Type) isBoolKind() bool {
|
||||
switch t {
|
||||
case Bool, column:
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (t Type) isStringKind() bool {
|
||||
switch t {
|
||||
case String, column:
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
87
pkg/s3select/sql/utils.go
Normal file
87
pkg/s3select/sql/utils.go
Normal file
@@ -0,0 +1,87 @@
|
||||
/*
|
||||
* Minio Cloud Storage, (C) 2019 Minio, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// String functions
|
||||
|
||||
// String - returns the JSONPath representation
|
||||
func (e *JSONPath) String() string {
|
||||
parts := make([]string, len(e.PathExpr)+1)
|
||||
parts[0] = e.BaseKey.String()
|
||||
for i, pe := range e.PathExpr {
|
||||
parts[i+1] = pe.String()
|
||||
}
|
||||
return strings.Join(parts, "")
|
||||
}
|
||||
|
||||
func (e *JSONPathElement) String() string {
|
||||
switch {
|
||||
case e.Key != nil:
|
||||
return e.Key.String()
|
||||
case e.Index != nil:
|
||||
return fmt.Sprintf("[%d]", *e.Index)
|
||||
case e.ObjectWildcard:
|
||||
return ".*"
|
||||
case e.ArrayWildcard:
|
||||
return "[*]"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Removes double quotes in quoted identifiers
|
||||
func (i *Identifier) String() string {
|
||||
if i.Unquoted != nil {
|
||||
return *i.Unquoted
|
||||
}
|
||||
return string(*i.Quoted)
|
||||
}
|
||||
|
||||
func (o *ObjectKey) String() string {
|
||||
if o.Lit != nil {
|
||||
return fmt.Sprintf("['%s']", string(*o.Lit))
|
||||
}
|
||||
return fmt.Sprintf(".%s", o.ID.String())
|
||||
}
|
||||
|
||||
// getLastKeypathComponent checks if the given expression is a path
|
||||
// expression, and if so extracts the last dot separated component of
|
||||
// the path. Otherwise it returns false.
|
||||
func getLastKeypathComponent(e *Expression) (string, bool) {
|
||||
if len(e.And) > 1 ||
|
||||
len(e.And[0].Condition) > 1 ||
|
||||
e.And[0].Condition[0].Not != nil ||
|
||||
e.And[0].Condition[0].Operand.ConditionRHS != nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
operand := e.And[0].Condition[0].Operand.Operand
|
||||
if operand.Right != nil ||
|
||||
operand.Left.Right != nil ||
|
||||
operand.Left.Left.Negated != nil ||
|
||||
operand.Left.Left.Primary.JPathExpr == nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
keypath := operand.Left.Left.Primary.JPathExpr.String()
|
||||
ps := strings.Split(keypath, ".")
|
||||
return ps[len(ps)-1], true
|
||||
}
|
||||
@@ -17,220 +17,710 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/xwb1989/sqlparser"
|
||||
)
|
||||
|
||||
// Value - represents any primitive value of bool, int, float, string and time.
|
||||
var (
|
||||
errArithMismatchedTypes = errors.New("cannot perform arithmetic on mismatched types")
|
||||
errArithInvalidOperator = errors.New("invalid arithmetic operator")
|
||||
errArithDivideByZero = errors.New("cannot divide by 0")
|
||||
|
||||
errCmpMismatchedTypes = errors.New("cannot compare values of different types")
|
||||
errCmpInvalidBoolOperator = errors.New("invalid comparison operator for boolean arguments")
|
||||
)
|
||||
|
||||
// vType represents the concrete type of a `Value`
|
||||
type vType int
|
||||
|
||||
// Valid values for Type
|
||||
const (
|
||||
typeNull vType = iota + 1
|
||||
typeBool
|
||||
typeString
|
||||
|
||||
// 64-bit signed integer
|
||||
typeInt
|
||||
|
||||
// 64-bit floating point
|
||||
typeFloat
|
||||
|
||||
// This type refers to untyped values, e.g. as read from CSV
|
||||
typeBytes
|
||||
)
|
||||
|
||||
// Value represents a value of restricted type reduced from an
|
||||
// expression represented by an ASTNode. Only one of the fields is
|
||||
// non-nil.
|
||||
//
|
||||
// In cases where we are fetching data from a data source (like csv),
|
||||
// the type may not be determined yet. In these cases, a byte-slice is
|
||||
// used.
|
||||
type Value struct {
|
||||
value interface{}
|
||||
valueType Type
|
||||
value interface{}
|
||||
vType vType
|
||||
}
|
||||
|
||||
// String - represents value as string.
|
||||
func (value *Value) String() string {
|
||||
if value.value == nil {
|
||||
if value.valueType == Null {
|
||||
return "NULL"
|
||||
}
|
||||
|
||||
return "<nil>"
|
||||
// GetTypeString returns a string representation for vType
|
||||
func (v *Value) GetTypeString() string {
|
||||
switch v.vType {
|
||||
case typeNull:
|
||||
return "NULL"
|
||||
case typeBool:
|
||||
return "BOOL"
|
||||
case typeString:
|
||||
return "STRING"
|
||||
case typeInt:
|
||||
return "INT"
|
||||
case typeFloat:
|
||||
return "FLOAT"
|
||||
case typeBytes:
|
||||
return "BYTES"
|
||||
}
|
||||
|
||||
switch value.valueType {
|
||||
case String:
|
||||
return fmt.Sprintf("'%v'", value.value)
|
||||
case Array:
|
||||
var valueStrings []string
|
||||
for _, v := range value.value.([]*Value) {
|
||||
valueStrings = append(valueStrings, fmt.Sprintf("%v", v))
|
||||
}
|
||||
return fmt.Sprintf("(%v)", strings.Join(valueStrings, ","))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%v", value.value)
|
||||
return "--"
|
||||
}
|
||||
|
||||
// CSVString - encodes to CSV string.
|
||||
func (value *Value) CSVString() string {
|
||||
if value.valueType == Null {
|
||||
// Repr returns a string representation of value.
|
||||
func (v *Value) Repr() string {
|
||||
switch v.vType {
|
||||
case typeNull:
|
||||
return ":NULL"
|
||||
case typeBool, typeInt, typeFloat:
|
||||
return fmt.Sprintf("%v:%s", v.value, v.GetTypeString())
|
||||
case typeString:
|
||||
return fmt.Sprintf("\"%s\":%s", v.value.(string), v.GetTypeString())
|
||||
case typeBytes:
|
||||
return fmt.Sprintf("\"%s\":BYTES", string(v.value.([]byte)))
|
||||
default:
|
||||
return fmt.Sprintf("%v:INVALID", v.value)
|
||||
}
|
||||
}
|
||||
|
||||
// FromFloat creates a Value from a number
|
||||
func FromFloat(f float64) *Value {
|
||||
return &Value{value: f, vType: typeFloat}
|
||||
}
|
||||
|
||||
// FromInt creates a Value from an int
|
||||
func FromInt(f int64) *Value {
|
||||
return &Value{value: f, vType: typeInt}
|
||||
}
|
||||
|
||||
// FromString creates a Value from a string
|
||||
func FromString(str string) *Value {
|
||||
return &Value{value: str, vType: typeString}
|
||||
}
|
||||
|
||||
// FromBool creates a Value from a bool
|
||||
func FromBool(b bool) *Value {
|
||||
return &Value{value: b, vType: typeBool}
|
||||
}
|
||||
|
||||
// FromNull creates a Value with Null value
|
||||
func FromNull() *Value {
|
||||
return &Value{vType: typeNull}
|
||||
}
|
||||
|
||||
// FromBytes creates a Value from a []byte
|
||||
func FromBytes(b []byte) *Value {
|
||||
return &Value{value: b, vType: typeBytes}
|
||||
}
|
||||
|
||||
// ToFloat works for int and float values
|
||||
func (v *Value) ToFloat() (val float64, ok bool) {
|
||||
switch v.vType {
|
||||
case typeFloat:
|
||||
val, ok = v.value.(float64)
|
||||
case typeInt:
|
||||
var i int64
|
||||
i, ok = v.value.(int64)
|
||||
val = float64(i)
|
||||
default:
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// ToInt converts value to int.
|
||||
func (v *Value) ToInt() (val int64, ok bool) {
|
||||
switch v.vType {
|
||||
case typeInt:
|
||||
val, ok = v.value.(int64)
|
||||
default:
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// ToString converts value to string.
|
||||
func (v *Value) ToString() (val string, ok bool) {
|
||||
switch v.vType {
|
||||
case typeString:
|
||||
val, ok = v.value.(string)
|
||||
default:
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// ToBool returns the bool value; second return value refers to if the bool
|
||||
// conversion succeeded.
|
||||
func (v *Value) ToBool() (val bool, ok bool) {
|
||||
switch v.vType {
|
||||
case typeBool:
|
||||
return v.value.(bool), true
|
||||
}
|
||||
return false, false
|
||||
}
|
||||
|
||||
// ToBytes converts Value to byte-slice.
|
||||
func (v *Value) ToBytes() ([]byte, bool) {
|
||||
switch v.vType {
|
||||
case typeBytes:
|
||||
return v.value.([]byte), true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// IsNull - checks if value is missing.
|
||||
func (v *Value) IsNull() bool {
|
||||
return v.vType == typeNull
|
||||
}
|
||||
|
||||
func (v *Value) isNumeric() bool {
|
||||
return v.vType == typeInt || v.vType == typeFloat
|
||||
}
|
||||
|
||||
// setters used internally to mutate values
|
||||
|
||||
func (v *Value) setInt(i int64) {
|
||||
v.vType = typeInt
|
||||
v.value = i
|
||||
}
|
||||
|
||||
func (v *Value) setFloat(f float64) {
|
||||
v.vType = typeFloat
|
||||
v.value = f
|
||||
}
|
||||
|
||||
func (v *Value) setString(s string) {
|
||||
v.vType = typeString
|
||||
v.value = s
|
||||
}
|
||||
|
||||
func (v *Value) setBool(b bool) {
|
||||
v.vType = typeBool
|
||||
v.value = b
|
||||
}
|
||||
|
||||
// CSVString - convert to string for CSV serialization
|
||||
func (v *Value) CSVString() string {
|
||||
switch v.vType {
|
||||
case typeNull:
|
||||
return ""
|
||||
case typeBool:
|
||||
return fmt.Sprintf("%v", v.value.(bool))
|
||||
case typeString:
|
||||
return fmt.Sprintf("%s", v.value.(string))
|
||||
case typeInt:
|
||||
return fmt.Sprintf("%v", v.value.(int64))
|
||||
case typeFloat:
|
||||
return fmt.Sprintf("%v", v.value.(float64))
|
||||
case typeBytes:
|
||||
return fmt.Sprintf("%v", string(v.value.([]byte)))
|
||||
default:
|
||||
return "CSV serialization not implemented for this type"
|
||||
}
|
||||
}
|
||||
|
||||
// floatToValue converts a float into int representation if needed.
|
||||
func floatToValue(f float64) *Value {
|
||||
intPart, fracPart := math.Modf(f)
|
||||
if fracPart == 0 {
|
||||
return FromInt(int64(intPart))
|
||||
}
|
||||
return FromFloat(f)
|
||||
}
|
||||
|
||||
// Value comparison functions: we do not expose them outside the
|
||||
// module. Logical operators "<", ">", ">=", "<=" work on strings and
|
||||
// numbers. Equality operators "=", "!=" work on strings,
|
||||
// numbers and booleans.
|
||||
|
||||
// Supported comparison operators
|
||||
const (
|
||||
opLt = "<"
|
||||
opLte = "<="
|
||||
opGt = ">"
|
||||
opGte = ">="
|
||||
opEq = "="
|
||||
opIneq = "!="
|
||||
)
|
||||
|
||||
// When numeric types are compared, type promotions could happen. If
|
||||
// values do not have types (e.g. when reading from CSV), for
|
||||
// comparison operations, automatic type conversion happens by trying
|
||||
// to check if the value is a number (first an integer, then a float),
|
||||
// and falling back to string.
|
||||
func (v *Value) compareOp(op string, a *Value) (res bool, err error) {
|
||||
if !isValidComparisonOperator(op) {
|
||||
return false, errArithInvalidOperator
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%v", value.value)
|
||||
// Check if type conversion/inference is needed - it is needed
|
||||
// if the Value is a byte-slice.
|
||||
err = inferTypesForCmp(v, a)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
isNumeric := v.isNumeric() && a.isNumeric()
|
||||
if isNumeric {
|
||||
intV, ok1i := v.ToInt()
|
||||
intA, ok2i := a.ToInt()
|
||||
if ok1i && ok2i {
|
||||
return intCompare(op, intV, intA), nil
|
||||
}
|
||||
|
||||
// If both values are numeric, then at least one is
|
||||
// float since we got here, so we convert.
|
||||
flV, _ := v.ToFloat()
|
||||
flA, _ := a.ToFloat()
|
||||
return floatCompare(op, flV, flA), nil
|
||||
}
|
||||
|
||||
strV, ok1s := v.ToString()
|
||||
strA, ok2s := a.ToString()
|
||||
if ok1s && ok2s {
|
||||
return stringCompare(op, strV, strA), nil
|
||||
}
|
||||
|
||||
boolV, ok1b := v.ToBool()
|
||||
boolA, ok2b := v.ToBool()
|
||||
if ok1b && ok2b {
|
||||
return boolCompare(op, boolV, boolA)
|
||||
}
|
||||
|
||||
return false, errCmpMismatchedTypes
|
||||
}
|
||||
|
||||
// MarshalJSON - encodes to JSON data.
|
||||
func (value *Value) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(value.value)
|
||||
func inferTypesForCmp(a *Value, b *Value) error {
|
||||
_, okA := a.ToBytes()
|
||||
_, okB := b.ToBytes()
|
||||
switch {
|
||||
case !okA && !okB:
|
||||
// Both Values already have types
|
||||
return nil
|
||||
|
||||
case okA && okB:
|
||||
// Both Values are untyped so try the types in order:
|
||||
// int, float, bool, string
|
||||
|
||||
// Check for numeric inference
|
||||
iA, okAi := a.bytesToInt()
|
||||
iB, okBi := b.bytesToInt()
|
||||
if okAi && okBi {
|
||||
a.setInt(iA)
|
||||
b.setInt(iB)
|
||||
return nil
|
||||
}
|
||||
|
||||
fA, okAf := a.bytesToFloat()
|
||||
fB, okBf := b.bytesToFloat()
|
||||
if okAf && okBf {
|
||||
a.setFloat(fA)
|
||||
b.setFloat(fB)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if they int and float combination.
|
||||
if okAi && okBf {
|
||||
a.setInt(iA)
|
||||
b.setFloat(fA)
|
||||
return nil
|
||||
}
|
||||
if okBi && okAf {
|
||||
a.setFloat(fA)
|
||||
b.setInt(iB)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Not numeric types at this point.
|
||||
|
||||
// Check for bool inference
|
||||
bA, okAb := a.bytesToBool()
|
||||
bB, okBb := b.bytesToBool()
|
||||
if okAb && okBb {
|
||||
a.setBool(bA)
|
||||
b.setBool(bB)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Fallback to string
|
||||
sA := a.bytesToString()
|
||||
sB := b.bytesToString()
|
||||
a.setString(sA)
|
||||
b.setString(sB)
|
||||
return nil
|
||||
|
||||
case okA && !okB:
|
||||
// Here a has `a` is untyped, but `b` has a fixed
|
||||
// type.
|
||||
switch b.vType {
|
||||
case typeString:
|
||||
s := a.bytesToString()
|
||||
a.setString(s)
|
||||
|
||||
case typeInt, typeFloat:
|
||||
if iA, ok := a.bytesToInt(); ok {
|
||||
a.setInt(iA)
|
||||
} else if fA, ok := a.bytesToFloat(); ok {
|
||||
a.setFloat(fA)
|
||||
} else {
|
||||
return fmt.Errorf("Could not convert %s to a number", string(a.value.([]byte)))
|
||||
}
|
||||
|
||||
case typeBool:
|
||||
if bA, ok := a.bytesToBool(); ok {
|
||||
a.setBool(bA)
|
||||
} else {
|
||||
return fmt.Errorf("Could not convert %s to a boolean", string(a.value.([]byte)))
|
||||
}
|
||||
|
||||
default:
|
||||
return errCmpMismatchedTypes
|
||||
}
|
||||
return nil
|
||||
|
||||
case !okA && okB:
|
||||
// swap arguments to avoid repeating code
|
||||
return inferTypesForCmp(b, a)
|
||||
|
||||
default:
|
||||
// Does not happen
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// NullValue - returns underlying null value. It panics if value is not null type.
|
||||
func (value *Value) NullValue() *struct{} {
|
||||
if value.valueType == Null {
|
||||
// Value arithmetic functions: we do not expose them outside the
|
||||
// module. All arithmetic works only on numeric values with automatic
|
||||
// promotion to the "larger" type that can represent the value. TODO:
|
||||
// Add support for large number arithmetic.
|
||||
|
||||
// Supported arithmetic operators
|
||||
const (
|
||||
opPlus = "+"
|
||||
opMinus = "-"
|
||||
opDivide = "/"
|
||||
opMultiply = "*"
|
||||
opModulo = "%"
|
||||
)
|
||||
|
||||
// For arithmetic operations, if both values are numeric then the
|
||||
// operation shall succeed. If the types are unknown automatic type
|
||||
// conversion to a number is attempted.
|
||||
func (v *Value) arithOp(op string, a *Value) error {
|
||||
err := inferTypeForArithOp(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = inferTypeForArithOp(a)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !v.isNumeric() || !a.isNumeric() {
|
||||
return errInvalidDataType(errArithMismatchedTypes)
|
||||
}
|
||||
|
||||
if !isValidArithOperator(op) {
|
||||
return errInvalidDataType(errArithMismatchedTypes)
|
||||
}
|
||||
|
||||
intV, ok1i := v.ToInt()
|
||||
intA, ok2i := a.ToInt()
|
||||
switch {
|
||||
case ok1i && ok2i:
|
||||
res, err := intArithOp(op, intV, intA)
|
||||
v.setInt(res)
|
||||
return err
|
||||
|
||||
default:
|
||||
// Convert arguments to float
|
||||
flV, _ := v.ToFloat()
|
||||
flA, _ := a.ToFloat()
|
||||
res, err := floatArithOp(op, flV, flA)
|
||||
v.setFloat(res)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func inferTypeForArithOp(a *Value) error {
|
||||
if _, ok := a.ToBytes(); !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
panic(fmt.Sprintf("requested bool value but found %T type", value.value))
|
||||
}
|
||||
|
||||
// BoolValue - returns underlying bool value. It panics if value is not Bool type.
|
||||
func (value *Value) BoolValue() bool {
|
||||
if value.valueType == Bool {
|
||||
return value.value.(bool)
|
||||
if i, ok := a.bytesToInt(); ok {
|
||||
a.setInt(i)
|
||||
return nil
|
||||
}
|
||||
|
||||
panic(fmt.Sprintf("requested bool value but found %T type", value.value))
|
||||
}
|
||||
|
||||
// IntValue - returns underlying int value. It panics if value is not Int type.
|
||||
func (value *Value) IntValue() int64 {
|
||||
if value.valueType == Int {
|
||||
return value.value.(int64)
|
||||
if f, ok := a.bytesToFloat(); ok {
|
||||
a.setFloat(f)
|
||||
return nil
|
||||
}
|
||||
|
||||
panic(fmt.Sprintf("requested int value but found %T type", value.value))
|
||||
err := fmt.Errorf("Could not convert %s to a number", string(a.value.([]byte)))
|
||||
return errInvalidDataType(err)
|
||||
}
|
||||
|
||||
// FloatValue - returns underlying int/float value as float64. It panics if value is not Int/Float type.
|
||||
func (value *Value) FloatValue() float64 {
|
||||
switch value.valueType {
|
||||
case Int:
|
||||
return float64(value.value.(int64))
|
||||
case Float:
|
||||
return value.value.(float64)
|
||||
// All the bytesTo* functions defined below assume the value is a byte-slice.
|
||||
|
||||
// Converts untyped value into int. The bool return implies success -
|
||||
// it returns false only if there is a conversion failure.
|
||||
func (v *Value) bytesToInt() (int64, bool) {
|
||||
bytes, _ := v.ToBytes()
|
||||
i, err := strconv.ParseInt(string(bytes), 10, 64)
|
||||
return i, err == nil
|
||||
}
|
||||
|
||||
// Converts untyped value into float. The bool return implies success
|
||||
// - it returns false only if there is a conversion failure.
|
||||
func (v *Value) bytesToFloat() (float64, bool) {
|
||||
bytes, _ := v.ToBytes()
|
||||
i, err := strconv.ParseFloat(string(bytes), 64)
|
||||
return i, err == nil
|
||||
}
|
||||
|
||||
// Converts untyped value into bool. The second bool return implies
|
||||
// success - it returns false in case of a conversion failure.
|
||||
func (v *Value) bytesToBool() (val bool, ok bool) {
|
||||
bytes, _ := v.ToBytes()
|
||||
ok = true
|
||||
switch strings.ToLower(string(bytes)) {
|
||||
case "t", "true":
|
||||
val = true
|
||||
case "f", "false":
|
||||
val = false
|
||||
default:
|
||||
ok = false
|
||||
}
|
||||
return val, ok
|
||||
}
|
||||
|
||||
// bytesToString - never fails
|
||||
func (v *Value) bytesToString() string {
|
||||
bytes, _ := v.ToBytes()
|
||||
return string(bytes)
|
||||
}
|
||||
|
||||
// Calculates minimum or maximum of v and a and assigns the result to
|
||||
// v - it works only on numeric arguments, where `v` is already
|
||||
// assumed to be numeric. Attempts conversion to numeric type for `a`
|
||||
// (first int, then float) only if the underlying values do not have a
|
||||
// type.
|
||||
func (v *Value) minmax(a *Value, isMax, isFirstRow bool) error {
|
||||
err := inferTypeForArithOp(a)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
panic(fmt.Sprintf("requested float value but found %T type", value.value))
|
||||
}
|
||||
|
||||
// StringValue - returns underlying string value. It panics if value is not String type.
|
||||
func (value *Value) StringValue() string {
|
||||
if value.valueType == String {
|
||||
return value.value.(string)
|
||||
if !a.isNumeric() {
|
||||
return errArithMismatchedTypes
|
||||
}
|
||||
|
||||
panic(fmt.Sprintf("requested string value but found %T type", value.value))
|
||||
}
|
||||
|
||||
// TimeValue - returns underlying time value. It panics if value is not Timestamp type.
|
||||
func (value *Value) TimeValue() time.Time {
|
||||
if value.valueType == Timestamp {
|
||||
return value.value.(time.Time)
|
||||
}
|
||||
|
||||
panic(fmt.Sprintf("requested time value but found %T type", value.value))
|
||||
}
|
||||
|
||||
// ArrayValue - returns underlying value array. It panics if value is not Array type.
|
||||
func (value *Value) ArrayValue() []*Value {
|
||||
if value.valueType == Array {
|
||||
return value.value.([]*Value)
|
||||
}
|
||||
|
||||
panic(fmt.Sprintf("requested array value but found %T type", value.value))
|
||||
}
|
||||
|
||||
func (value *Value) recordValue() Record {
|
||||
if value.valueType == record {
|
||||
return value.value.(Record)
|
||||
}
|
||||
|
||||
panic(fmt.Sprintf("requested record value but found %T type", value.value))
|
||||
}
|
||||
|
||||
// Type - returns value type.
|
||||
func (value *Value) Type() Type {
|
||||
return value.valueType
|
||||
}
|
||||
|
||||
// Value - returns underneath value interface.
|
||||
func (value *Value) Value() interface{} {
|
||||
return value.value
|
||||
}
|
||||
|
||||
// NewNull - creates new null value.
|
||||
func NewNull() *Value {
|
||||
return &Value{nil, Null}
|
||||
}
|
||||
|
||||
// NewBool - creates new Bool value of b.
|
||||
func NewBool(b bool) *Value {
|
||||
return &Value{b, Bool}
|
||||
}
|
||||
|
||||
// NewInt - creates new Int value of i.
|
||||
func NewInt(i int64) *Value {
|
||||
return &Value{i, Int}
|
||||
}
|
||||
|
||||
// NewFloat - creates new Float value of f.
|
||||
func NewFloat(f float64) *Value {
|
||||
return &Value{f, Float}
|
||||
}
|
||||
|
||||
// NewString - creates new Sring value of s.
|
||||
func NewString(s string) *Value {
|
||||
return &Value{s, String}
|
||||
}
|
||||
|
||||
// NewTime - creates new Time value of t.
|
||||
func NewTime(t time.Time) *Value {
|
||||
return &Value{t, Timestamp}
|
||||
}
|
||||
|
||||
// NewArray - creates new Array value of values.
|
||||
func NewArray(values []*Value) *Value {
|
||||
return &Value{values, Array}
|
||||
}
|
||||
|
||||
func newRecordValue(r Record) *Value {
|
||||
return &Value{r, record}
|
||||
}
|
||||
|
||||
// NewValue - creates new Value from SQLVal v.
|
||||
func NewValue(v *sqlparser.SQLVal) (*Value, error) {
|
||||
switch v.Type {
|
||||
case sqlparser.StrVal:
|
||||
return NewString(string(v.Val)), nil
|
||||
case sqlparser.IntVal:
|
||||
i64, err := strconv.ParseInt(string(v.Val), 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// In case of first row, set v to a.
|
||||
if isFirstRow {
|
||||
intA, okI := a.ToInt()
|
||||
if okI {
|
||||
v.setInt(intA)
|
||||
return nil
|
||||
}
|
||||
return NewInt(i64), nil
|
||||
case sqlparser.FloatVal:
|
||||
f64, err := strconv.ParseFloat(string(v.Val), 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewFloat(f64), nil
|
||||
case sqlparser.HexNum: // represented as 0xDD
|
||||
i64, err := strconv.ParseInt(string(v.Val), 16, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewInt(i64), nil
|
||||
case sqlparser.HexVal: // represented as X'0DD'
|
||||
i64, err := strconv.ParseInt(string(v.Val), 16, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewInt(i64), nil
|
||||
case sqlparser.BitVal: // represented as B'00'
|
||||
i64, err := strconv.ParseInt(string(v.Val), 2, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewInt(i64), nil
|
||||
case sqlparser.ValArg:
|
||||
// FIXME: the format is unknown and not sure how to handle it.
|
||||
floatA, _ := a.ToFloat()
|
||||
v.setFloat(floatA)
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unknown SQL value %v; %v ", v, v.Type)
|
||||
intV, ok1i := v.ToInt()
|
||||
intA, ok2i := a.ToInt()
|
||||
if ok1i && ok2i {
|
||||
result := intV
|
||||
if !isMax {
|
||||
if intA < result {
|
||||
result = intA
|
||||
}
|
||||
} else {
|
||||
if intA > result {
|
||||
result = intA
|
||||
}
|
||||
}
|
||||
v.setInt(result)
|
||||
return nil
|
||||
}
|
||||
|
||||
floatV, _ := v.ToFloat()
|
||||
floatA, _ := a.ToFloat()
|
||||
var result float64
|
||||
if !isMax {
|
||||
result = math.Min(floatV, floatA)
|
||||
} else {
|
||||
result = math.Max(floatV, floatA)
|
||||
}
|
||||
v.setFloat(result)
|
||||
return nil
|
||||
}
|
||||
|
||||
// inferTypeAsString is used to convert untyped values to string - it
|
||||
// is called when the caller requires a string context to proceed.
|
||||
func inferTypeAsString(v *Value) {
|
||||
b, ok := v.ToBytes()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
v.setString(string(b))
|
||||
}
|
||||
|
||||
func isValidComparisonOperator(op string) bool {
|
||||
switch op {
|
||||
case opLt:
|
||||
case opLte:
|
||||
case opGt:
|
||||
case opGte:
|
||||
case opEq:
|
||||
case opIneq:
|
||||
default:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func intCompare(op string, left, right int64) bool {
|
||||
switch op {
|
||||
case opLt:
|
||||
return left < right
|
||||
case opLte:
|
||||
return left <= right
|
||||
case opGt:
|
||||
return left > right
|
||||
case opGte:
|
||||
return left >= right
|
||||
case opEq:
|
||||
return left == right
|
||||
case opIneq:
|
||||
return left != right
|
||||
}
|
||||
// This case does not happen
|
||||
return false
|
||||
}
|
||||
|
||||
func floatCompare(op string, left, right float64) bool {
|
||||
switch op {
|
||||
case opLt:
|
||||
return left < right
|
||||
case opLte:
|
||||
return left <= right
|
||||
case opGt:
|
||||
return left > right
|
||||
case opGte:
|
||||
return left >= right
|
||||
case opEq:
|
||||
return left == right
|
||||
case opIneq:
|
||||
return left != right
|
||||
}
|
||||
// This case does not happen
|
||||
return false
|
||||
}
|
||||
|
||||
func stringCompare(op string, left, right string) bool {
|
||||
switch op {
|
||||
case opLt:
|
||||
return left < right
|
||||
case opLte:
|
||||
return left <= right
|
||||
case opGt:
|
||||
return left > right
|
||||
case opGte:
|
||||
return left >= right
|
||||
case opEq:
|
||||
return left == right
|
||||
case opIneq:
|
||||
return left != right
|
||||
}
|
||||
// This case does not happen
|
||||
return false
|
||||
}
|
||||
|
||||
func boolCompare(op string, left, right bool) (bool, error) {
|
||||
switch op {
|
||||
case opEq:
|
||||
return left == right, nil
|
||||
case opIneq:
|
||||
return left != right, nil
|
||||
default:
|
||||
return false, errCmpInvalidBoolOperator
|
||||
}
|
||||
}
|
||||
|
||||
func isValidArithOperator(op string) bool {
|
||||
switch op {
|
||||
case opPlus:
|
||||
case opMinus:
|
||||
case opDivide:
|
||||
case opMultiply:
|
||||
case opModulo:
|
||||
default:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Overflow errors are ignored.
|
||||
func intArithOp(op string, left, right int64) (int64, error) {
|
||||
switch op {
|
||||
case opPlus:
|
||||
return left + right, nil
|
||||
case opMinus:
|
||||
return left - right, nil
|
||||
case opDivide:
|
||||
if right == 0 {
|
||||
return 0, errArithDivideByZero
|
||||
}
|
||||
return left / right, nil
|
||||
case opMultiply:
|
||||
return left * right, nil
|
||||
case opModulo:
|
||||
if right == 0 {
|
||||
return 0, errArithDivideByZero
|
||||
}
|
||||
return left % right, nil
|
||||
}
|
||||
// This does not happen
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Overflow errors are ignored.
|
||||
func floatArithOp(op string, left, right float64) (float64, error) {
|
||||
switch op {
|
||||
case opPlus:
|
||||
return left + right, nil
|
||||
case opMinus:
|
||||
return left - right, nil
|
||||
case opDivide:
|
||||
if right == 0 {
|
||||
return 0, errArithDivideByZero
|
||||
}
|
||||
return left / right, nil
|
||||
case opMultiply:
|
||||
return left * right, nil
|
||||
case opModulo:
|
||||
if right == 0 {
|
||||
return 0, errArithDivideByZero
|
||||
}
|
||||
return math.Mod(left, right), nil
|
||||
}
|
||||
// This does not happen
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user