Fix S3Select SQL column reference handling (#11957)

This change fixes handling of these types of queries:

- Double quoted column names with special characters:
    SELECT "column.name" FROM s3object
- Double quoted column names with reserved keywords:
    SELECT "CAST" FROM s3object
- Table name as prefix for column names:
    SELECT S3Object."CAST" FROM s3object
This commit is contained in:
Aditya Manthramurthy
2021-04-06 08:49:04 -07:00
committed by GitHub
parent d5d2fc9850
commit b2936243f9
8 changed files with 297 additions and 123 deletions

View File

@@ -21,7 +21,6 @@ import (
"errors"
"fmt"
"math"
"strings"
"github.com/bcicen/jstream"
"github.com/minio/simdjson-go"
@@ -47,21 +46,21 @@ var (
// of child nodes. The final result row is returned after all rows are
// processed, and the `getAggregate` function is called.
func (e *AliasedExpression) evalNode(r Record) (*Value, error) {
return e.Expression.evalNode(r)
func (e *AliasedExpression) evalNode(r Record, tableAlias string) (*Value, error) {
return e.Expression.evalNode(r, tableAlias)
}
func (e *Expression) evalNode(r Record) (*Value, error) {
func (e *Expression) evalNode(r Record, tableAlias string) (*Value, error) {
if len(e.And) == 1 {
// In this case, result is not required to be boolean
// type.
return e.And[0].evalNode(r)
return e.And[0].evalNode(r, tableAlias)
}
// Compute OR of conditions
result := false
for _, ex := range e.And {
res, err := ex.evalNode(r)
res, err := ex.evalNode(r, tableAlias)
if err != nil {
return nil, err
}
@@ -74,16 +73,16 @@ func (e *Expression) evalNode(r Record) (*Value, error) {
return FromBool(result), nil
}
func (e *AndCondition) evalNode(r Record) (*Value, error) {
func (e *AndCondition) evalNode(r Record, tableAlias string) (*Value, error) {
if len(e.Condition) == 1 {
// In this case, result does not have to be boolean
return e.Condition[0].evalNode(r)
return e.Condition[0].evalNode(r, tableAlias)
}
// Compute AND of conditions
result := true
for _, ex := range e.Condition {
res, err := ex.evalNode(r)
res, err := ex.evalNode(r, tableAlias)
if err != nil {
return nil, err
}
@@ -96,14 +95,14 @@ func (e *AndCondition) evalNode(r Record) (*Value, error) {
return FromBool(result), nil
}
func (e *Condition) evalNode(r Record) (*Value, error) {
func (e *Condition) evalNode(r Record, tableAlias string) (*Value, error) {
if e.Operand != nil {
// In this case, result does not have to be boolean
return e.Operand.evalNode(r)
return e.Operand.evalNode(r, tableAlias)
}
// Compute NOT of condition
res, err := e.Not.evalNode(r)
res, err := e.Not.evalNode(r, tableAlias)
if err != nil {
return nil, err
}
@@ -114,8 +113,8 @@ func (e *Condition) evalNode(r Record) (*Value, error) {
return FromBool(!b), nil
}
func (e *ConditionOperand) evalNode(r Record) (*Value, error) {
opVal, opErr := e.Operand.evalNode(r)
func (e *ConditionOperand) evalNode(r Record, tableAlias string) (*Value, error) {
opVal, opErr := e.Operand.evalNode(r, tableAlias)
if opErr != nil || e.ConditionRHS == nil {
return opVal, opErr
}
@@ -123,7 +122,7 @@ func (e *ConditionOperand) evalNode(r Record) (*Value, error) {
// Need to evaluate the ConditionRHS
switch {
case e.ConditionRHS.Compare != nil:
cmpRight, cmpRErr := e.ConditionRHS.Compare.Operand.evalNode(r)
cmpRight, cmpRErr := e.ConditionRHS.Compare.Operand.evalNode(r, tableAlias)
if cmpRErr != nil {
return nil, cmpRErr
}
@@ -132,26 +131,26 @@ func (e *ConditionOperand) evalNode(r Record) (*Value, error) {
return FromBool(b), err
case e.ConditionRHS.Between != nil:
return e.ConditionRHS.Between.evalBetweenNode(r, opVal)
return e.ConditionRHS.Between.evalBetweenNode(r, opVal, tableAlias)
case e.ConditionRHS.Like != nil:
return e.ConditionRHS.Like.evalLikeNode(r, opVal)
return e.ConditionRHS.Like.evalLikeNode(r, opVal, tableAlias)
case e.ConditionRHS.In != nil:
return e.ConditionRHS.In.evalInNode(r, opVal)
return e.ConditionRHS.In.evalInNode(r, opVal, tableAlias)
default:
return nil, errInvalidASTNode
}
}
func (e *Between) evalBetweenNode(r Record, arg *Value) (*Value, error) {
stVal, stErr := e.Start.evalNode(r)
func (e *Between) evalBetweenNode(r Record, arg *Value, tableAlias string) (*Value, error) {
stVal, stErr := e.Start.evalNode(r, tableAlias)
if stErr != nil {
return nil, stErr
}
endVal, endErr := e.End.evalNode(r)
endVal, endErr := e.End.evalNode(r, tableAlias)
if endErr != nil {
return nil, endErr
}
@@ -174,7 +173,7 @@ func (e *Between) evalBetweenNode(r Record, arg *Value) (*Value, error) {
return FromBool(result), nil
}
func (e *Like) evalLikeNode(r Record, arg *Value) (*Value, error) {
func (e *Like) evalLikeNode(r Record, arg *Value, tableAlias string) (*Value, error) {
inferTypeAsString(arg)
s, ok := arg.ToString()
@@ -183,7 +182,7 @@ func (e *Like) evalLikeNode(r Record, arg *Value) (*Value, error) {
return nil, errLikeInvalidInputs(err)
}
pattern, err1 := e.Pattern.evalNode(r)
pattern, err1 := e.Pattern.evalNode(r, tableAlias)
if err1 != nil {
return nil, err1
}
@@ -199,7 +198,7 @@ func (e *Like) evalLikeNode(r Record, arg *Value) (*Value, error) {
escape := runeZero
if e.EscapeChar != nil {
escapeVal, err2 := e.EscapeChar.evalNode(r)
escapeVal, err2 := e.EscapeChar.evalNode(r, tableAlias)
if err2 != nil {
return nil, err2
}
@@ -230,14 +229,14 @@ func (e *Like) evalLikeNode(r Record, arg *Value) (*Value, error) {
return FromBool(matchResult), nil
}
func (e *ListExpr) evalNode(r Record) (*Value, error) {
func (e *ListExpr) evalNode(r Record, tableAlias string) (*Value, error) {
res := make([]Value, len(e.Elements))
if len(e.Elements) == 1 {
// If length 1, treat as single value.
return e.Elements[0].evalNode(r)
return e.Elements[0].evalNode(r, tableAlias)
}
for i, elt := range e.Elements {
v, err := elt.evalNode(r)
v, err := elt.evalNode(r, tableAlias)
if err != nil {
return nil, err
}
@@ -248,7 +247,7 @@ func (e *ListExpr) evalNode(r Record) (*Value, error) {
const floatCmpTolerance = 0.000001
func (e *In) evalInNode(r Record, lhs *Value) (*Value, error) {
func (e *In) evalInNode(r Record, lhs *Value, tableAlias string) (*Value, error) {
// Compare two values in terms of in-ness.
var cmp func(a, b Value) bool
cmp = func(a, b Value) bool {
@@ -283,7 +282,7 @@ func (e *In) evalInNode(r Record, lhs *Value) (*Value, error) {
var rhs Value
if elt := e.ListExpression; elt != nil {
eltVal, err := elt.evalNode(r)
eltVal, err := elt.evalNode(r, tableAlias)
if err != nil {
return nil, err
}
@@ -304,8 +303,8 @@ func (e *In) evalInNode(r Record, lhs *Value) (*Value, error) {
return FromBool(cmp(rhs, *lhs)), nil
}
func (e *Operand) evalNode(r Record) (*Value, error) {
lval, lerr := e.Left.evalNode(r)
func (e *Operand) evalNode(r Record, tableAlias string) (*Value, error) {
lval, lerr := e.Left.evalNode(r, tableAlias)
if lerr != nil || len(e.Right) == 0 {
return lval, lerr
}
@@ -315,7 +314,7 @@ func (e *Operand) evalNode(r Record) (*Value, error) {
// symbols.
for _, rightTerm := range e.Right {
op := rightTerm.Op
rval, rerr := rightTerm.Right.evalNode(r)
rval, rerr := rightTerm.Right.evalNode(r, tableAlias)
if rerr != nil {
return nil, rerr
}
@@ -327,8 +326,8 @@ func (e *Operand) evalNode(r Record) (*Value, error) {
return lval, nil
}
func (e *MultOp) evalNode(r Record) (*Value, error) {
lval, lerr := e.Left.evalNode(r)
func (e *MultOp) evalNode(r Record, tableAlias string) (*Value, error) {
lval, lerr := e.Left.evalNode(r, tableAlias)
if lerr != nil || len(e.Right) == 0 {
return lval, lerr
}
@@ -337,7 +336,7 @@ func (e *MultOp) evalNode(r Record) (*Value, error) {
// AST node is for terms separated by *, / or % symbols.
for _, rightTerm := range e.Right {
op := rightTerm.Op
rval, rerr := rightTerm.Right.evalNode(r)
rval, rerr := rightTerm.Right.evalNode(r, tableAlias)
if rerr != nil {
return nil, rerr
}
@@ -350,12 +349,12 @@ func (e *MultOp) evalNode(r Record) (*Value, error) {
return lval, nil
}
func (e *UnaryTerm) evalNode(r Record) (*Value, error) {
func (e *UnaryTerm) evalNode(r Record, tableAlias string) (*Value, error) {
if e.Negated == nil {
return e.Primary.evalNode(r)
return e.Primary.evalNode(r, tableAlias)
}
v, err := e.Negated.Term.evalNode(r)
v, err := e.Negated.Term.evalNode(r, tableAlias)
if err != nil {
return nil, err
}
@@ -368,19 +367,15 @@ func (e *UnaryTerm) evalNode(r Record) (*Value, error) {
return nil, errArithMismatchedTypes
}
func (e *JSONPath) evalNode(r Record) (*Value, error) {
// Strip the table name from the keypath.
keypath := e.String()
if strings.Contains(keypath, ".") {
ps := strings.SplitN(keypath, ".", 2)
if len(ps) == 2 {
keypath = ps[1]
}
func (e *JSONPath) evalNode(r Record, tableAlias string) (*Value, error) {
alias := tableAlias
if tableAlias == "" {
alias = baseTableName
}
pathExpr := e.StripTableAlias(alias)
_, rawVal := r.Raw()
switch rowVal := rawVal.(type) {
case jstream.KVS, simdjson.Object:
pathExpr := e.PathExpr
if len(pathExpr) == 0 {
pathExpr = []*JSONPathElement{{Key: &ObjectKey{ID: e.BaseKey}}}
}
@@ -392,7 +387,10 @@ func (e *JSONPath) evalNode(r Record) (*Value, error) {
return jsonToValue(result)
default:
return r.Get(keypath)
if pathExpr[len(pathExpr)-1].Key == nil {
return nil, errInvalidKeypath
}
return r.Get(pathExpr[len(pathExpr)-1].Key.keyString())
}
}
@@ -447,28 +445,28 @@ func jsonToValue(result interface{}) (*Value, error) {
return nil, fmt.Errorf("Unhandled value type: %T", result)
}
func (e *PrimaryTerm) evalNode(r Record) (res *Value, err error) {
func (e *PrimaryTerm) evalNode(r Record, tableAlias string) (res *Value, err error) {
switch {
case e.Value != nil:
return e.Value.evalNode(r)
case e.JPathExpr != nil:
return e.JPathExpr.evalNode(r)
return e.JPathExpr.evalNode(r, tableAlias)
case e.ListExpr != nil:
return e.ListExpr.evalNode(r)
return e.ListExpr.evalNode(r, tableAlias)
case e.SubExpression != nil:
return e.SubExpression.evalNode(r)
return e.SubExpression.evalNode(r, tableAlias)
case e.FuncCall != nil:
return e.FuncCall.evalNode(r)
return e.FuncCall.evalNode(r, tableAlias)
}
return nil, errInvalidASTNode
}
func (e *FuncExpr) evalNode(r Record) (res *Value, err error) {
func (e *FuncExpr) evalNode(r Record, tableAlias string) (res *Value, err error) {
switch e.getFunctionName() {
case aggFnCount, aggFnAvg, aggFnMax, aggFnMin, aggFnSum:
return e.getAggregate()
default:
return e.evalSQLFnNode(r)
return e.evalSQLFnNode(r, tableAlias)
}
}