From b2936243f9341cc9453cdd27b3b614767693dab3 Mon Sep 17 00:00:00 2001 From: Aditya Manthramurthy Date: Tue, 6 Apr 2021 08:49:04 -0700 Subject: [PATCH] Fix S3Select SQL column reference handling (#11957) This change fixes handling of these types of queries: - Double quoted column names with special characters: SELECT "column.name" FROM s3object - Double quoted column names with reserved keywords: SELECT "CAST" FROM s3object - Table name as prefix for column names: SELECT S3Object."CAST" FROM s3object --- pkg/s3select/select_test.go | 146 ++++++++++++++++++++++++++++++++ pkg/s3select/sql/aggregation.go | 74 ++++++++-------- pkg/s3select/sql/analysis.go | 3 +- pkg/s3select/sql/evaluate.go | 108 ++++++++++++----------- pkg/s3select/sql/funceval.go | 50 +++++------ pkg/s3select/sql/parser.go | 4 +- pkg/s3select/sql/statement.go | 14 ++- pkg/s3select/sql/utils.go | 21 +++++ 8 files changed, 297 insertions(+), 123 deletions(-) diff --git a/pkg/s3select/select_test.go b/pkg/s3select/select_test.go index 407300f4c..cf784dc36 100644 --- a/pkg/s3select/select_test.go +++ b/pkg/s3select/select_test.go @@ -739,6 +739,152 @@ func TestCSVQueries2(t *testing.T) { } } +func TestCSVQueries3(t *testing.T) { + input := `na.me,qty,CAST +apple,1,true +mango,3,false +` + var testTable = []struct { + name string + query string + requestXML []byte // override request XML + wantResult string + }{ + { + name: "Select a column containing dot", + query: `select "na.me" from S3Object s`, + wantResult: `apple +mango`, + }, + { + name: "Select column containing dot with table name prefix", + query: `select count(S3Object."na.me") from S3Object`, + wantResult: `2`, + }, + { + name: "Select column containing dot with table alias prefix", + query: `select s."na.me" from S3Object as s`, + wantResult: `apple +mango`, + }, + { + name: "Select column simplest", + query: `select qty from S3Object`, + wantResult: `1 +3`, + }, + { + name: "Select column with table name prefix", + query: `select S3Object.qty from S3Object`, + wantResult: `1 +3`, + }, + { + name: "Select column without table alias", + query: `select qty from S3Object s`, + wantResult: `1 +3`, + }, + { + name: "Select column with table alias", + query: `select s.qty from S3Object s`, + wantResult: `1 +3`, + }, + { + name: "Select reserved word column", + query: `select "CAST" from s3object`, + wantResult: `true +false`, + }, + { + name: "Select reserved word column with table alias", + query: `select S3Object."CAST" from s3object`, + wantResult: `true +false`, + }, + { + name: "Select reserved word column with unused table alias", + query: `select "CAST" from s3object s`, + wantResult: `true +false`, + }, + { + name: "Select reserved word column with table alias", + query: `select s."CAST" from s3object s`, + wantResult: `true +false`, + }, + { + name: "Select reserved word column with table alias", + query: `select NOT CAST(s."CAST" AS Bool) from s3object s`, + wantResult: `false +true`, + }, + } + + defRequest := ` + + %s + SQL + + NONE + + USE + " + + + + + + + FALSE + +` + + for _, testCase := range testTable { + t.Run(testCase.name, func(t *testing.T) { + testReq := testCase.requestXML + if len(testReq) == 0 { + testReq = []byte(fmt.Sprintf(defRequest, testCase.query)) + } + s3Select, err := NewS3Select(bytes.NewReader(testReq)) + if err != nil { + t.Fatal(err) + } + + if err = s3Select.Open(func(offset, length int64) (io.ReadCloser, error) { + return ioutil.NopCloser(bytes.NewBufferString(input)), nil + }); err != nil { + t.Fatal(err) + } + + w := &testResponseWriter{} + s3Select.Evaluate(w) + s3Select.Close() + resp := http.Response{ + StatusCode: http.StatusOK, + Body: ioutil.NopCloser(bytes.NewReader(w.response)), + ContentLength: int64(len(w.response)), + } + res, err := minio.NewSelectResults(&resp, "testbucket") + if err != nil { + t.Error(err) + return + } + got, err := ioutil.ReadAll(res) + if err != nil { + t.Error(err) + return + } + gotS := strings.TrimSpace(string(got)) + if gotS != testCase.wantResult { + t.Errorf("received response does not match with expected reply.\nQuery: %s\n=====\ngot: %s\n=====\nwant: %s\n=====\n", testCase.query, gotS, testCase.wantResult) + } + }) + } +} + func TestCSVInput(t *testing.T) { var testTable = []struct { requestXML []byte diff --git a/pkg/s3select/sql/aggregation.go b/pkg/s3select/sql/aggregation.go index 2d8ed02ef..5baf2429b 100644 --- a/pkg/s3select/sql/aggregation.go +++ b/pkg/s3select/sql/aggregation.go @@ -63,7 +63,7 @@ func newAggVal(fn FuncName) *aggVal { // current row and stores the result. // // On success, it returns (nil, nil). -func (e *FuncExpr) evalAggregationNode(r Record) error { +func (e *FuncExpr) evalAggregationNode(r Record, tableAlias string) error { // It is assumed that this function is called only when // `e` is an aggregation function. @@ -77,13 +77,13 @@ func (e *FuncExpr) evalAggregationNode(r Record) error { return nil } - val, err = e.Count.ExprArg.evalNode(r) + val, err = e.Count.ExprArg.evalNode(r, tableAlias) if err != nil { return err } } else { // Evaluate the (only) argument - val, err = e.SFunc.ArgsList[0].evalNode(r) + val, err = e.SFunc.ArgsList[0].evalNode(r, tableAlias) if err != nil { return err } @@ -149,13 +149,13 @@ func (e *FuncExpr) evalAggregationNode(r Record) error { return err } -func (e *AliasedExpression) aggregateRow(r Record) error { - return e.Expression.aggregateRow(r) +func (e *AliasedExpression) aggregateRow(r Record, tableAlias string) error { + return e.Expression.aggregateRow(r, tableAlias) } -func (e *Expression) aggregateRow(r Record) error { +func (e *Expression) aggregateRow(r Record, tableAlias string) error { for _, ex := range e.And { - err := ex.aggregateRow(r) + err := ex.aggregateRow(r, tableAlias) if err != nil { return err } @@ -163,9 +163,9 @@ func (e *Expression) aggregateRow(r Record) error { return nil } -func (e *ListExpr) aggregateRow(r Record) error { +func (e *ListExpr) aggregateRow(r Record, tableAlias string) error { for _, ex := range e.Elements { - err := ex.aggregateRow(r) + err := ex.aggregateRow(r, tableAlias) if err != nil { return err } @@ -173,9 +173,9 @@ func (e *ListExpr) aggregateRow(r Record) error { return nil } -func (e *AndCondition) aggregateRow(r Record) error { +func (e *AndCondition) aggregateRow(r Record, tableAlias string) error { for _, ex := range e.Condition { - err := ex.aggregateRow(r) + err := ex.aggregateRow(r, tableAlias) if err != nil { return err } @@ -183,15 +183,15 @@ func (e *AndCondition) aggregateRow(r Record) error { return nil } -func (e *Condition) aggregateRow(r Record) error { +func (e *Condition) aggregateRow(r Record, tableAlias string) error { if e.Operand != nil { - return e.Operand.aggregateRow(r) + return e.Operand.aggregateRow(r, tableAlias) } - return e.Not.aggregateRow(r) + return e.Not.aggregateRow(r, tableAlias) } -func (e *ConditionOperand) aggregateRow(r Record) error { - err := e.Operand.aggregateRow(r) +func (e *ConditionOperand) aggregateRow(r Record, tableAlias string) error { + err := e.Operand.aggregateRow(r, tableAlias) if err != nil { return err } @@ -202,38 +202,38 @@ func (e *ConditionOperand) aggregateRow(r Record) error { switch { case e.ConditionRHS.Compare != nil: - return e.ConditionRHS.Compare.Operand.aggregateRow(r) + return e.ConditionRHS.Compare.Operand.aggregateRow(r, tableAlias) case e.ConditionRHS.Between != nil: - err = e.ConditionRHS.Between.Start.aggregateRow(r) + err = e.ConditionRHS.Between.Start.aggregateRow(r, tableAlias) if err != nil { return err } - return e.ConditionRHS.Between.End.aggregateRow(r) + return e.ConditionRHS.Between.End.aggregateRow(r, tableAlias) case e.ConditionRHS.In != nil: elt := e.ConditionRHS.In.ListExpression - err = elt.aggregateRow(r) + err = elt.aggregateRow(r, tableAlias) if err != nil { return err } return nil case e.ConditionRHS.Like != nil: - err = e.ConditionRHS.Like.Pattern.aggregateRow(r) + err = e.ConditionRHS.Like.Pattern.aggregateRow(r, tableAlias) if err != nil { return err } - return e.ConditionRHS.Like.EscapeChar.aggregateRow(r) + return e.ConditionRHS.Like.EscapeChar.aggregateRow(r, tableAlias) default: return errInvalidASTNode } } -func (e *Operand) aggregateRow(r Record) error { - err := e.Left.aggregateRow(r) +func (e *Operand) aggregateRow(r Record, tableAlias string) error { + err := e.Left.aggregateRow(r, tableAlias) if err != nil { return err } for _, rt := range e.Right { - err = rt.Right.aggregateRow(r) + err = rt.Right.aggregateRow(r, tableAlias) if err != nil { return err } @@ -241,13 +241,13 @@ func (e *Operand) aggregateRow(r Record) error { return nil } -func (e *MultOp) aggregateRow(r Record) error { - err := e.Left.aggregateRow(r) +func (e *MultOp) aggregateRow(r Record, tableAlias string) error { + err := e.Left.aggregateRow(r, tableAlias) if err != nil { return err } for _, rt := range e.Right { - err = rt.Right.aggregateRow(r) + err = rt.Right.aggregateRow(r, tableAlias) if err != nil { return err } @@ -255,29 +255,29 @@ func (e *MultOp) aggregateRow(r Record) error { return nil } -func (e *UnaryTerm) aggregateRow(r Record) error { +func (e *UnaryTerm) aggregateRow(r Record, tableAlias string) error { if e.Negated != nil { - return e.Negated.Term.aggregateRow(r) + return e.Negated.Term.aggregateRow(r, tableAlias) } - return e.Primary.aggregateRow(r) + return e.Primary.aggregateRow(r, tableAlias) } -func (e *PrimaryTerm) aggregateRow(r Record) error { +func (e *PrimaryTerm) aggregateRow(r Record, tableAlias string) error { switch { case e.ListExpr != nil: - return e.ListExpr.aggregateRow(r) + return e.ListExpr.aggregateRow(r, tableAlias) case e.SubExpression != nil: - return e.SubExpression.aggregateRow(r) + return e.SubExpression.aggregateRow(r, tableAlias) case e.FuncCall != nil: - return e.FuncCall.aggregateRow(r) + return e.FuncCall.aggregateRow(r, tableAlias) } return nil } -func (e *FuncExpr) aggregateRow(r Record) error { +func (e *FuncExpr) aggregateRow(r Record, tableAlias string) error { switch e.getFunctionName() { case aggFnAvg, aggFnSum, aggFnMax, aggFnMin, aggFnCount: - return e.evalAggregationNode(r) + return e.evalAggregationNode(r, tableAlias) default: // TODO: traverse arguments and call aggregateRow on // them if they could be an ancestor of an diff --git a/pkg/s3select/sql/analysis.go b/pkg/s3select/sql/analysis.go index 3e84cad48..eab71b141 100644 --- a/pkg/s3select/sql/analysis.go +++ b/pkg/s3select/sql/analysis.go @@ -19,6 +19,7 @@ package sql import ( "errors" "fmt" + "strings" ) // Query analysis - The query is analyzed to determine if it involves @@ -177,7 +178,7 @@ func (e *PrimaryTerm) analyze(s *Select) (result qProp) { case e.JPathExpr != nil: // Check if the path expression is valid if len(e.JPathExpr.PathExpr) > 0 { - if e.JPathExpr.BaseKey.String() != s.From.As { + if e.JPathExpr.BaseKey.String() != s.From.As && strings.ToLower(e.JPathExpr.BaseKey.String()) != baseTableName { result = qProp{err: errInvalidKeypath} return } diff --git a/pkg/s3select/sql/evaluate.go b/pkg/s3select/sql/evaluate.go index 50451c515..83bdf0ffb 100644 --- a/pkg/s3select/sql/evaluate.go +++ b/pkg/s3select/sql/evaluate.go @@ -21,7 +21,6 @@ import ( "errors" "fmt" "math" - "strings" "github.com/bcicen/jstream" "github.com/minio/simdjson-go" @@ -47,21 +46,21 @@ var ( // of child nodes. The final result row is returned after all rows are // processed, and the `getAggregate` function is called. -func (e *AliasedExpression) evalNode(r Record) (*Value, error) { - return e.Expression.evalNode(r) +func (e *AliasedExpression) evalNode(r Record, tableAlias string) (*Value, error) { + return e.Expression.evalNode(r, tableAlias) } -func (e *Expression) evalNode(r Record) (*Value, error) { +func (e *Expression) evalNode(r Record, tableAlias string) (*Value, error) { if len(e.And) == 1 { // In this case, result is not required to be boolean // type. - return e.And[0].evalNode(r) + return e.And[0].evalNode(r, tableAlias) } // Compute OR of conditions result := false for _, ex := range e.And { - res, err := ex.evalNode(r) + res, err := ex.evalNode(r, tableAlias) if err != nil { return nil, err } @@ -74,16 +73,16 @@ func (e *Expression) evalNode(r Record) (*Value, error) { return FromBool(result), nil } -func (e *AndCondition) evalNode(r Record) (*Value, error) { +func (e *AndCondition) evalNode(r Record, tableAlias string) (*Value, error) { if len(e.Condition) == 1 { // In this case, result does not have to be boolean - return e.Condition[0].evalNode(r) + return e.Condition[0].evalNode(r, tableAlias) } // Compute AND of conditions result := true for _, ex := range e.Condition { - res, err := ex.evalNode(r) + res, err := ex.evalNode(r, tableAlias) if err != nil { return nil, err } @@ -96,14 +95,14 @@ func (e *AndCondition) evalNode(r Record) (*Value, error) { return FromBool(result), nil } -func (e *Condition) evalNode(r Record) (*Value, error) { +func (e *Condition) evalNode(r Record, tableAlias string) (*Value, error) { if e.Operand != nil { // In this case, result does not have to be boolean - return e.Operand.evalNode(r) + return e.Operand.evalNode(r, tableAlias) } // Compute NOT of condition - res, err := e.Not.evalNode(r) + res, err := e.Not.evalNode(r, tableAlias) if err != nil { return nil, err } @@ -114,8 +113,8 @@ func (e *Condition) evalNode(r Record) (*Value, error) { return FromBool(!b), nil } -func (e *ConditionOperand) evalNode(r Record) (*Value, error) { - opVal, opErr := e.Operand.evalNode(r) +func (e *ConditionOperand) evalNode(r Record, tableAlias string) (*Value, error) { + opVal, opErr := e.Operand.evalNode(r, tableAlias) if opErr != nil || e.ConditionRHS == nil { return opVal, opErr } @@ -123,7 +122,7 @@ func (e *ConditionOperand) evalNode(r Record) (*Value, error) { // Need to evaluate the ConditionRHS switch { case e.ConditionRHS.Compare != nil: - cmpRight, cmpRErr := e.ConditionRHS.Compare.Operand.evalNode(r) + cmpRight, cmpRErr := e.ConditionRHS.Compare.Operand.evalNode(r, tableAlias) if cmpRErr != nil { return nil, cmpRErr } @@ -132,26 +131,26 @@ func (e *ConditionOperand) evalNode(r Record) (*Value, error) { return FromBool(b), err case e.ConditionRHS.Between != nil: - return e.ConditionRHS.Between.evalBetweenNode(r, opVal) + return e.ConditionRHS.Between.evalBetweenNode(r, opVal, tableAlias) case e.ConditionRHS.Like != nil: - return e.ConditionRHS.Like.evalLikeNode(r, opVal) + return e.ConditionRHS.Like.evalLikeNode(r, opVal, tableAlias) case e.ConditionRHS.In != nil: - return e.ConditionRHS.In.evalInNode(r, opVal) + return e.ConditionRHS.In.evalInNode(r, opVal, tableAlias) default: return nil, errInvalidASTNode } } -func (e *Between) evalBetweenNode(r Record, arg *Value) (*Value, error) { - stVal, stErr := e.Start.evalNode(r) +func (e *Between) evalBetweenNode(r Record, arg *Value, tableAlias string) (*Value, error) { + stVal, stErr := e.Start.evalNode(r, tableAlias) if stErr != nil { return nil, stErr } - endVal, endErr := e.End.evalNode(r) + endVal, endErr := e.End.evalNode(r, tableAlias) if endErr != nil { return nil, endErr } @@ -174,7 +173,7 @@ func (e *Between) evalBetweenNode(r Record, arg *Value) (*Value, error) { return FromBool(result), nil } -func (e *Like) evalLikeNode(r Record, arg *Value) (*Value, error) { +func (e *Like) evalLikeNode(r Record, arg *Value, tableAlias string) (*Value, error) { inferTypeAsString(arg) s, ok := arg.ToString() @@ -183,7 +182,7 @@ func (e *Like) evalLikeNode(r Record, arg *Value) (*Value, error) { return nil, errLikeInvalidInputs(err) } - pattern, err1 := e.Pattern.evalNode(r) + pattern, err1 := e.Pattern.evalNode(r, tableAlias) if err1 != nil { return nil, err1 } @@ -199,7 +198,7 @@ func (e *Like) evalLikeNode(r Record, arg *Value) (*Value, error) { escape := runeZero if e.EscapeChar != nil { - escapeVal, err2 := e.EscapeChar.evalNode(r) + escapeVal, err2 := e.EscapeChar.evalNode(r, tableAlias) if err2 != nil { return nil, err2 } @@ -230,14 +229,14 @@ func (e *Like) evalLikeNode(r Record, arg *Value) (*Value, error) { return FromBool(matchResult), nil } -func (e *ListExpr) evalNode(r Record) (*Value, error) { +func (e *ListExpr) evalNode(r Record, tableAlias string) (*Value, error) { res := make([]Value, len(e.Elements)) if len(e.Elements) == 1 { // If length 1, treat as single value. - return e.Elements[0].evalNode(r) + return e.Elements[0].evalNode(r, tableAlias) } for i, elt := range e.Elements { - v, err := elt.evalNode(r) + v, err := elt.evalNode(r, tableAlias) if err != nil { return nil, err } @@ -248,7 +247,7 @@ func (e *ListExpr) evalNode(r Record) (*Value, error) { const floatCmpTolerance = 0.000001 -func (e *In) evalInNode(r Record, lhs *Value) (*Value, error) { +func (e *In) evalInNode(r Record, lhs *Value, tableAlias string) (*Value, error) { // Compare two values in terms of in-ness. var cmp func(a, b Value) bool cmp = func(a, b Value) bool { @@ -283,7 +282,7 @@ func (e *In) evalInNode(r Record, lhs *Value) (*Value, error) { var rhs Value if elt := e.ListExpression; elt != nil { - eltVal, err := elt.evalNode(r) + eltVal, err := elt.evalNode(r, tableAlias) if err != nil { return nil, err } @@ -304,8 +303,8 @@ func (e *In) evalInNode(r Record, lhs *Value) (*Value, error) { return FromBool(cmp(rhs, *lhs)), nil } -func (e *Operand) evalNode(r Record) (*Value, error) { - lval, lerr := e.Left.evalNode(r) +func (e *Operand) evalNode(r Record, tableAlias string) (*Value, error) { + lval, lerr := e.Left.evalNode(r, tableAlias) if lerr != nil || len(e.Right) == 0 { return lval, lerr } @@ -315,7 +314,7 @@ func (e *Operand) evalNode(r Record) (*Value, error) { // symbols. for _, rightTerm := range e.Right { op := rightTerm.Op - rval, rerr := rightTerm.Right.evalNode(r) + rval, rerr := rightTerm.Right.evalNode(r, tableAlias) if rerr != nil { return nil, rerr } @@ -327,8 +326,8 @@ func (e *Operand) evalNode(r Record) (*Value, error) { return lval, nil } -func (e *MultOp) evalNode(r Record) (*Value, error) { - lval, lerr := e.Left.evalNode(r) +func (e *MultOp) evalNode(r Record, tableAlias string) (*Value, error) { + lval, lerr := e.Left.evalNode(r, tableAlias) if lerr != nil || len(e.Right) == 0 { return lval, lerr } @@ -337,7 +336,7 @@ func (e *MultOp) evalNode(r Record) (*Value, error) { // AST node is for terms separated by *, / or % symbols. for _, rightTerm := range e.Right { op := rightTerm.Op - rval, rerr := rightTerm.Right.evalNode(r) + rval, rerr := rightTerm.Right.evalNode(r, tableAlias) if rerr != nil { return nil, rerr } @@ -350,12 +349,12 @@ func (e *MultOp) evalNode(r Record) (*Value, error) { return lval, nil } -func (e *UnaryTerm) evalNode(r Record) (*Value, error) { +func (e *UnaryTerm) evalNode(r Record, tableAlias string) (*Value, error) { if e.Negated == nil { - return e.Primary.evalNode(r) + return e.Primary.evalNode(r, tableAlias) } - v, err := e.Negated.Term.evalNode(r) + v, err := e.Negated.Term.evalNode(r, tableAlias) if err != nil { return nil, err } @@ -368,19 +367,15 @@ func (e *UnaryTerm) evalNode(r Record) (*Value, error) { return nil, errArithMismatchedTypes } -func (e *JSONPath) evalNode(r Record) (*Value, error) { - // Strip the table name from the keypath. - keypath := e.String() - if strings.Contains(keypath, ".") { - ps := strings.SplitN(keypath, ".", 2) - if len(ps) == 2 { - keypath = ps[1] - } +func (e *JSONPath) evalNode(r Record, tableAlias string) (*Value, error) { + alias := tableAlias + if tableAlias == "" { + alias = baseTableName } + pathExpr := e.StripTableAlias(alias) _, rawVal := r.Raw() switch rowVal := rawVal.(type) { case jstream.KVS, simdjson.Object: - pathExpr := e.PathExpr if len(pathExpr) == 0 { pathExpr = []*JSONPathElement{{Key: &ObjectKey{ID: e.BaseKey}}} } @@ -392,7 +387,10 @@ func (e *JSONPath) evalNode(r Record) (*Value, error) { return jsonToValue(result) default: - return r.Get(keypath) + if pathExpr[len(pathExpr)-1].Key == nil { + return nil, errInvalidKeypath + } + return r.Get(pathExpr[len(pathExpr)-1].Key.keyString()) } } @@ -447,28 +445,28 @@ func jsonToValue(result interface{}) (*Value, error) { return nil, fmt.Errorf("Unhandled value type: %T", result) } -func (e *PrimaryTerm) evalNode(r Record) (res *Value, err error) { +func (e *PrimaryTerm) evalNode(r Record, tableAlias string) (res *Value, err error) { switch { case e.Value != nil: return e.Value.evalNode(r) case e.JPathExpr != nil: - return e.JPathExpr.evalNode(r) + return e.JPathExpr.evalNode(r, tableAlias) case e.ListExpr != nil: - return e.ListExpr.evalNode(r) + return e.ListExpr.evalNode(r, tableAlias) case e.SubExpression != nil: - return e.SubExpression.evalNode(r) + return e.SubExpression.evalNode(r, tableAlias) case e.FuncCall != nil: - return e.FuncCall.evalNode(r) + return e.FuncCall.evalNode(r, tableAlias) } return nil, errInvalidASTNode } -func (e *FuncExpr) evalNode(r Record) (res *Value, err error) { +func (e *FuncExpr) evalNode(r Record, tableAlias string) (res *Value, err error) { switch e.getFunctionName() { case aggFnCount, aggFnAvg, aggFnMax, aggFnMin, aggFnSum: return e.getAggregate() default: - return e.evalSQLFnNode(r) + return e.evalSQLFnNode(r, tableAlias) } } diff --git a/pkg/s3select/sql/funceval.go b/pkg/s3select/sql/funceval.go index 0490535b2..1ceaea9dd 100644 --- a/pkg/s3select/sql/funceval.go +++ b/pkg/s3select/sql/funceval.go @@ -84,35 +84,35 @@ func (e *FuncExpr) getFunctionName() FuncName { // evalSQLFnNode assumes that the FuncExpr is not an aggregation // function. -func (e *FuncExpr) evalSQLFnNode(r Record) (res *Value, err error) { +func (e *FuncExpr) evalSQLFnNode(r Record, tableAlias string) (res *Value, err error) { // Handle functions that have phrase arguments switch e.getFunctionName() { case sqlFnCast: expr := e.Cast.Expr - res, err = expr.castTo(r, strings.ToUpper(e.Cast.CastType)) + res, err = expr.castTo(r, strings.ToUpper(e.Cast.CastType), tableAlias) return case sqlFnSubstring: - return handleSQLSubstring(r, e.Substring) + return handleSQLSubstring(r, e.Substring, tableAlias) case sqlFnExtract: - return handleSQLExtract(r, e.Extract) + return handleSQLExtract(r, e.Extract, tableAlias) case sqlFnTrim: - return handleSQLTrim(r, e.Trim) + return handleSQLTrim(r, e.Trim, tableAlias) case sqlFnDateAdd: - return handleDateAdd(r, e.DateAdd) + return handleDateAdd(r, e.DateAdd, tableAlias) case sqlFnDateDiff: - return handleDateDiff(r, e.DateDiff) + return handleDateDiff(r, e.DateDiff, tableAlias) } // For all simple argument functions, we evaluate the arguments here argVals := make([]*Value, len(e.SFunc.ArgsList)) for i, arg := range e.SFunc.ArgsList { - argVals[i], err = arg.evalNode(r) + argVals[i], err = arg.evalNode(r, tableAlias) if err != nil { return nil, err } @@ -219,8 +219,8 @@ func upperCase(v *Value) (*Value, error) { return FromString(strings.ToUpper(s)), nil } -func handleDateAdd(r Record, d *DateAddFunc) (*Value, error) { - q, err := d.Quantity.evalNode(r) +func handleDateAdd(r Record, d *DateAddFunc, tableAlias string) (*Value, error) { + q, err := d.Quantity.evalNode(r, tableAlias) if err != nil { return nil, err } @@ -230,7 +230,7 @@ func handleDateAdd(r Record, d *DateAddFunc) (*Value, error) { return nil, fmt.Errorf("QUANTITY must be a numeric argument to %s()", sqlFnDateAdd) } - ts, err := d.Timestamp.evalNode(r) + ts, err := d.Timestamp.evalNode(r, tableAlias) if err != nil { return nil, err } @@ -245,8 +245,8 @@ func handleDateAdd(r Record, d *DateAddFunc) (*Value, error) { return dateAdd(strings.ToUpper(d.DatePart), qty, t) } -func handleDateDiff(r Record, d *DateDiffFunc) (*Value, error) { - tval1, err := d.Timestamp1.evalNode(r) +func handleDateDiff(r Record, d *DateDiffFunc, tableAlias string) (*Value, error) { + tval1, err := d.Timestamp1.evalNode(r, tableAlias) if err != nil { return nil, err } @@ -258,7 +258,7 @@ func handleDateDiff(r Record, d *DateDiffFunc) (*Value, error) { return nil, fmt.Errorf("%s() expects two timestamp arguments", sqlFnDateDiff) } - tval2, err := d.Timestamp2.evalNode(r) + tval2, err := d.Timestamp2.evalNode(r, tableAlias) if err != nil { return nil, err } @@ -277,12 +277,12 @@ func handleUTCNow() (*Value, error) { return FromTimestamp(time.Now().UTC()), nil } -func handleSQLSubstring(r Record, e *SubstringFunc) (val *Value, err error) { +func handleSQLSubstring(r Record, e *SubstringFunc, tableAlias string) (val *Value, err error) { // Both forms `SUBSTRING('abc' FROM 2 FOR 1)` and // SUBSTRING('abc', 2, 1) are supported. // Evaluate the string argument - v1, err := e.Expr.evalNode(r) + v1, err := e.Expr.evalNode(r, tableAlias) if err != nil { return nil, err } @@ -301,7 +301,7 @@ func handleSQLSubstring(r Record, e *SubstringFunc) (val *Value, err error) { } // Evaluate the FROM argument - v2, err := arg2.evalNode(r) + v2, err := arg2.evalNode(r, tableAlias) if err != nil { return nil, err } @@ -315,7 +315,7 @@ func handleSQLSubstring(r Record, e *SubstringFunc) (val *Value, err error) { length := -1 // Evaluate the optional FOR argument if arg3 != nil { - v3, err := arg3.evalNode(r) + v3, err := arg3.evalNode(r, tableAlias) if err != nil { return nil, err } @@ -336,11 +336,11 @@ func handleSQLSubstring(r Record, e *SubstringFunc) (val *Value, err error) { return FromString(res), err } -func handleSQLTrim(r Record, e *TrimFunc) (res *Value, err error) { +func handleSQLTrim(r Record, e *TrimFunc, tableAlias string) (res *Value, err error) { chars := "" ok := false if e.TrimChars != nil { - charsV, cerr := e.TrimChars.evalNode(r) + charsV, cerr := e.TrimChars.evalNode(r, tableAlias) if cerr != nil { return nil, cerr } @@ -351,7 +351,7 @@ func handleSQLTrim(r Record, e *TrimFunc) (res *Value, err error) { } } - fromV, ferr := e.TrimFrom.evalNode(r) + fromV, ferr := e.TrimFrom.evalNode(r, tableAlias) if ferr != nil { return nil, ferr } @@ -368,8 +368,8 @@ func handleSQLTrim(r Record, e *TrimFunc) (res *Value, err error) { return FromString(result), nil } -func handleSQLExtract(r Record, e *ExtractFunc) (res *Value, err error) { - timeVal, verr := e.From.evalNode(r) +func handleSQLExtract(r Record, e *ExtractFunc, tableAlias string) (res *Value, err error) { + timeVal, verr := e.From.evalNode(r, tableAlias) if verr != nil { return nil, verr } @@ -406,8 +406,8 @@ const ( castTimestamp = "TIMESTAMP" ) -func (e *Expression) castTo(r Record, castType string) (res *Value, err error) { - v, err := e.evalNode(r) +func (e *Expression) castTo(r Record, castType string, tableAlias string) (res *Value, err error) { + v, err := e.evalNode(r, tableAlias) if err != nil { return nil, err } diff --git a/pkg/s3select/sql/parser.go b/pkg/s3select/sql/parser.go index 9b9aae7cd..a8d4f94ea 100644 --- a/pkg/s3select/sql/parser.go +++ b/pkg/s3select/sql/parser.go @@ -119,7 +119,9 @@ type JSONPath struct { PathExpr []*JSONPathElement `parser:"(@@)*"` // Cached values: - pathString string + pathString string + strippedTableAlias string + strippedPathExpr []*JSONPathElement } // AliasedExpression is an expression that can be optionally named diff --git a/pkg/s3select/sql/statement.go b/pkg/s3select/sql/statement.go index e4fffc6b1..2d3d49e3a 100644 --- a/pkg/s3select/sql/statement.go +++ b/pkg/s3select/sql/statement.go @@ -46,6 +46,9 @@ type SelectStatement struct { // Count of rows that have been output. outputCount int64 + + // Table alias + tableAlias string } // ParseSelectStatement - parses a select query from the given string @@ -107,6 +110,9 @@ func ParseSelectStatement(s string) (stmt SelectStatement, err error) { if err != nil { err = errQueryAnalysisFailure(err) } + + // Set table alias + stmt.tableAlias = selectAST.From.As return } @@ -226,7 +232,7 @@ func (e *SelectStatement) IsAggregated() bool { // records have been processed. Applies only to aggregation queries. func (e *SelectStatement) AggregateResult(output Record) error { for i, expr := range e.selectAST.Expression.Expressions { - v, err := expr.evalNode(nil) + v, err := expr.evalNode(nil, e.tableAlias) if err != nil { return err } @@ -246,7 +252,7 @@ func (e *SelectStatement) isPassingWhereClause(input Record) (bool, error) { if e.selectAST.Where == nil { return true, nil } - value, err := e.selectAST.Where.evalNode(input) + value, err := e.selectAST.Where.evalNode(input, e.tableAlias) if err != nil { return false, err } @@ -272,7 +278,7 @@ func (e *SelectStatement) AggregateRow(input Record) error { } for _, expr := range e.selectAST.Expression.Expressions { - err := expr.aggregateRow(input) + err := expr.aggregateRow(input, e.tableAlias) if err != nil { return err } @@ -302,7 +308,7 @@ func (e *SelectStatement) Eval(input, output Record) (Record, error) { } for i, expr := range e.selectAST.Expression.Expressions { - v, err := expr.evalNode(input) + v, err := expr.evalNode(input, e.tableAlias) if err != nil { return nil, err } diff --git a/pkg/s3select/sql/utils.go b/pkg/s3select/sql/utils.go index 64ab96aa9..469e91c48 100644 --- a/pkg/s3select/sql/utils.go +++ b/pkg/s3select/sql/utils.go @@ -36,6 +36,27 @@ func (e *JSONPath) String() string { return e.pathString } +// StripTableAlias removes a table alias from the path. The result is also +// cached for repeated lookups during SQL query evaluation. +func (e *JSONPath) StripTableAlias(tableAlias string) []*JSONPathElement { + if e.strippedTableAlias == tableAlias { + return e.strippedPathExpr + } + + hasTableAlias := e.BaseKey.String() == tableAlias || strings.ToLower(e.BaseKey.String()) == baseTableName + var pathExpr []*JSONPathElement + if hasTableAlias { + pathExpr = e.PathExpr + } else { + pathExpr = make([]*JSONPathElement, len(e.PathExpr)+1) + pathExpr[0] = &JSONPathElement{Key: &ObjectKey{ID: e.BaseKey}} + copy(pathExpr[1:], e.PathExpr) + } + e.strippedTableAlias = tableAlias + e.strippedPathExpr = pathExpr + return e.strippedPathExpr +} + func (e *JSONPathElement) String() string { switch { case e.Key != nil: