mirror of
				https://github.com/minio/minio.git
				synced 2025-10-29 15:55:00 -04:00 
			
		
		
		
	Fix S3Select SQL column reference handling (#11957)
This change fixes handling of these types of queries:
- Double quoted column names with special characters:
    SELECT "column.name" FROM s3object
- Double quoted column names with reserved keywords:
    SELECT "CAST" FROM s3object
- Table name as prefix for column names:
    SELECT S3Object."CAST" FROM s3object
			
			
This commit is contained in:
		
							parent
							
								
									d5d2fc9850
								
							
						
					
					
						commit
						b2936243f9
					
				| @ -739,6 +739,152 @@ func TestCSVQueries2(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestCSVQueries3(t *testing.T) { | ||||
| 	input := `na.me,qty,CAST | ||||
| apple,1,true | ||||
| mango,3,false | ||||
| ` | ||||
| 	var testTable = []struct { | ||||
| 		name       string | ||||
| 		query      string | ||||
| 		requestXML []byte // override request XML | ||||
| 		wantResult string | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name:  "Select a column containing dot", | ||||
| 			query: `select "na.me" from S3Object s`, | ||||
| 			wantResult: `apple | ||||
| mango`, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:       "Select column containing dot with table name prefix", | ||||
| 			query:      `select count(S3Object."na.me") from S3Object`, | ||||
| 			wantResult: `2`, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:  "Select column containing dot with table alias prefix", | ||||
| 			query: `select s."na.me" from S3Object as s`, | ||||
| 			wantResult: `apple | ||||
| mango`, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:  "Select column simplest", | ||||
| 			query: `select qty from S3Object`, | ||||
| 			wantResult: `1 | ||||
| 3`, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:  "Select column with table name prefix", | ||||
| 			query: `select S3Object.qty from S3Object`, | ||||
| 			wantResult: `1 | ||||
| 3`, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:  "Select column without table alias", | ||||
| 			query: `select qty from S3Object s`, | ||||
| 			wantResult: `1 | ||||
| 3`, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:  "Select column with table alias", | ||||
| 			query: `select s.qty from S3Object s`, | ||||
| 			wantResult: `1 | ||||
| 3`, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:  "Select reserved word column", | ||||
| 			query: `select "CAST"  from s3object`, | ||||
| 			wantResult: `true | ||||
| false`, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:  "Select reserved word column with table alias", | ||||
| 			query: `select S3Object."CAST" from s3object`, | ||||
| 			wantResult: `true | ||||
| false`, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:  "Select reserved word column with unused table alias", | ||||
| 			query: `select "CAST"  from s3object s`, | ||||
| 			wantResult: `true | ||||
| false`, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:  "Select reserved word column with table alias", | ||||
| 			query: `select s."CAST"  from s3object s`, | ||||
| 			wantResult: `true | ||||
| false`, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:  "Select reserved word column with table alias", | ||||
| 			query: `select NOT CAST(s."CAST" AS Bool)  from s3object s`, | ||||
| 			wantResult: `false | ||||
| true`, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	defRequest := `<?xml version="1.0" encoding="UTF-8"?> | ||||
| <SelectObjectContentRequest> | ||||
|     <Expression>%s</Expression> | ||||
|     <ExpressionType>SQL</ExpressionType> | ||||
|     <InputSerialization> | ||||
|         <CompressionType>NONE</CompressionType> | ||||
|         <CSV> | ||||
|             <FileHeaderInfo>USE</FileHeaderInfo> | ||||
| 	    <QuoteCharacter>"</QuoteCharacter> | ||||
|         </CSV> | ||||
|     </InputSerialization> | ||||
|     <OutputSerialization> | ||||
|         <CSV/> | ||||
|     </OutputSerialization> | ||||
|     <RequestProgress> | ||||
|         <Enabled>FALSE</Enabled> | ||||
|     </RequestProgress> | ||||
| </SelectObjectContentRequest>` | ||||
| 
 | ||||
| 	for _, testCase := range testTable { | ||||
| 		t.Run(testCase.name, func(t *testing.T) { | ||||
| 			testReq := testCase.requestXML | ||||
| 			if len(testReq) == 0 { | ||||
| 				testReq = []byte(fmt.Sprintf(defRequest, testCase.query)) | ||||
| 			} | ||||
| 			s3Select, err := NewS3Select(bytes.NewReader(testReq)) | ||||
| 			if err != nil { | ||||
| 				t.Fatal(err) | ||||
| 			} | ||||
| 
 | ||||
| 			if err = s3Select.Open(func(offset, length int64) (io.ReadCloser, error) { | ||||
| 				return ioutil.NopCloser(bytes.NewBufferString(input)), nil | ||||
| 			}); err != nil { | ||||
| 				t.Fatal(err) | ||||
| 			} | ||||
| 
 | ||||
| 			w := &testResponseWriter{} | ||||
| 			s3Select.Evaluate(w) | ||||
| 			s3Select.Close() | ||||
| 			resp := http.Response{ | ||||
| 				StatusCode:    http.StatusOK, | ||||
| 				Body:          ioutil.NopCloser(bytes.NewReader(w.response)), | ||||
| 				ContentLength: int64(len(w.response)), | ||||
| 			} | ||||
| 			res, err := minio.NewSelectResults(&resp, "testbucket") | ||||
| 			if err != nil { | ||||
| 				t.Error(err) | ||||
| 				return | ||||
| 			} | ||||
| 			got, err := ioutil.ReadAll(res) | ||||
| 			if err != nil { | ||||
| 				t.Error(err) | ||||
| 				return | ||||
| 			} | ||||
| 			gotS := strings.TrimSpace(string(got)) | ||||
| 			if gotS != testCase.wantResult { | ||||
| 				t.Errorf("received response does not match with expected reply.\nQuery: %s\n=====\ngot: %s\n=====\nwant: %s\n=====\n", testCase.query, gotS, testCase.wantResult) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestCSVInput(t *testing.T) { | ||||
| 	var testTable = []struct { | ||||
| 		requestXML     []byte | ||||
|  | ||||
| @ -63,7 +63,7 @@ func newAggVal(fn FuncName) *aggVal { | ||||
| // current row and stores the result. | ||||
| // | ||||
| // On success, it returns (nil, nil). | ||||
| func (e *FuncExpr) evalAggregationNode(r Record) error { | ||||
| func (e *FuncExpr) evalAggregationNode(r Record, tableAlias string) error { | ||||
| 	// It is assumed that this function is called only when | ||||
| 	// `e` is an aggregation function. | ||||
| 
 | ||||
| @ -77,13 +77,13 @@ func (e *FuncExpr) evalAggregationNode(r Record) error { | ||||
| 			return nil | ||||
| 		} | ||||
| 
 | ||||
| 		val, err = e.Count.ExprArg.evalNode(r) | ||||
| 		val, err = e.Count.ExprArg.evalNode(r, tableAlias) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} else { | ||||
| 		// Evaluate the (only) argument | ||||
| 		val, err = e.SFunc.ArgsList[0].evalNode(r) | ||||
| 		val, err = e.SFunc.ArgsList[0].evalNode(r, tableAlias) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| @ -149,13 +149,13 @@ func (e *FuncExpr) evalAggregationNode(r Record) error { | ||||
| 	return err | ||||
| } | ||||
| 
 | ||||
| func (e *AliasedExpression) aggregateRow(r Record) error { | ||||
| 	return e.Expression.aggregateRow(r) | ||||
| func (e *AliasedExpression) aggregateRow(r Record, tableAlias string) error { | ||||
| 	return e.Expression.aggregateRow(r, tableAlias) | ||||
| } | ||||
| 
 | ||||
| func (e *Expression) aggregateRow(r Record) error { | ||||
| func (e *Expression) aggregateRow(r Record, tableAlias string) error { | ||||
| 	for _, ex := range e.And { | ||||
| 		err := ex.aggregateRow(r) | ||||
| 		err := ex.aggregateRow(r, tableAlias) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| @ -163,9 +163,9 @@ func (e *Expression) aggregateRow(r Record) error { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (e *ListExpr) aggregateRow(r Record) error { | ||||
| func (e *ListExpr) aggregateRow(r Record, tableAlias string) error { | ||||
| 	for _, ex := range e.Elements { | ||||
| 		err := ex.aggregateRow(r) | ||||
| 		err := ex.aggregateRow(r, tableAlias) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| @ -173,9 +173,9 @@ func (e *ListExpr) aggregateRow(r Record) error { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (e *AndCondition) aggregateRow(r Record) error { | ||||
| func (e *AndCondition) aggregateRow(r Record, tableAlias string) error { | ||||
| 	for _, ex := range e.Condition { | ||||
| 		err := ex.aggregateRow(r) | ||||
| 		err := ex.aggregateRow(r, tableAlias) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| @ -183,15 +183,15 @@ func (e *AndCondition) aggregateRow(r Record) error { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (e *Condition) aggregateRow(r Record) error { | ||||
| func (e *Condition) aggregateRow(r Record, tableAlias string) error { | ||||
| 	if e.Operand != nil { | ||||
| 		return e.Operand.aggregateRow(r) | ||||
| 		return e.Operand.aggregateRow(r, tableAlias) | ||||
| 	} | ||||
| 	return e.Not.aggregateRow(r) | ||||
| 	return e.Not.aggregateRow(r, tableAlias) | ||||
| } | ||||
| 
 | ||||
| func (e *ConditionOperand) aggregateRow(r Record) error { | ||||
| 	err := e.Operand.aggregateRow(r) | ||||
| func (e *ConditionOperand) aggregateRow(r Record, tableAlias string) error { | ||||
| 	err := e.Operand.aggregateRow(r, tableAlias) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @ -202,38 +202,38 @@ func (e *ConditionOperand) aggregateRow(r Record) error { | ||||
| 
 | ||||
| 	switch { | ||||
| 	case e.ConditionRHS.Compare != nil: | ||||
| 		return e.ConditionRHS.Compare.Operand.aggregateRow(r) | ||||
| 		return e.ConditionRHS.Compare.Operand.aggregateRow(r, tableAlias) | ||||
| 	case e.ConditionRHS.Between != nil: | ||||
| 		err = e.ConditionRHS.Between.Start.aggregateRow(r) | ||||
| 		err = e.ConditionRHS.Between.Start.aggregateRow(r, tableAlias) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		return e.ConditionRHS.Between.End.aggregateRow(r) | ||||
| 		return e.ConditionRHS.Between.End.aggregateRow(r, tableAlias) | ||||
| 	case e.ConditionRHS.In != nil: | ||||
| 		elt := e.ConditionRHS.In.ListExpression | ||||
| 		err = elt.aggregateRow(r) | ||||
| 		err = elt.aggregateRow(r, tableAlias) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		return nil | ||||
| 	case e.ConditionRHS.Like != nil: | ||||
| 		err = e.ConditionRHS.Like.Pattern.aggregateRow(r) | ||||
| 		err = e.ConditionRHS.Like.Pattern.aggregateRow(r, tableAlias) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		return e.ConditionRHS.Like.EscapeChar.aggregateRow(r) | ||||
| 		return e.ConditionRHS.Like.EscapeChar.aggregateRow(r, tableAlias) | ||||
| 	default: | ||||
| 		return errInvalidASTNode | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (e *Operand) aggregateRow(r Record) error { | ||||
| 	err := e.Left.aggregateRow(r) | ||||
| func (e *Operand) aggregateRow(r Record, tableAlias string) error { | ||||
| 	err := e.Left.aggregateRow(r, tableAlias) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	for _, rt := range e.Right { | ||||
| 		err = rt.Right.aggregateRow(r) | ||||
| 		err = rt.Right.aggregateRow(r, tableAlias) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| @ -241,13 +241,13 @@ func (e *Operand) aggregateRow(r Record) error { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (e *MultOp) aggregateRow(r Record) error { | ||||
| 	err := e.Left.aggregateRow(r) | ||||
| func (e *MultOp) aggregateRow(r Record, tableAlias string) error { | ||||
| 	err := e.Left.aggregateRow(r, tableAlias) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	for _, rt := range e.Right { | ||||
| 		err = rt.Right.aggregateRow(r) | ||||
| 		err = rt.Right.aggregateRow(r, tableAlias) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| @ -255,29 +255,29 @@ func (e *MultOp) aggregateRow(r Record) error { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (e *UnaryTerm) aggregateRow(r Record) error { | ||||
| func (e *UnaryTerm) aggregateRow(r Record, tableAlias string) error { | ||||
| 	if e.Negated != nil { | ||||
| 		return e.Negated.Term.aggregateRow(r) | ||||
| 		return e.Negated.Term.aggregateRow(r, tableAlias) | ||||
| 	} | ||||
| 	return e.Primary.aggregateRow(r) | ||||
| 	return e.Primary.aggregateRow(r, tableAlias) | ||||
| } | ||||
| 
 | ||||
| func (e *PrimaryTerm) aggregateRow(r Record) error { | ||||
| func (e *PrimaryTerm) aggregateRow(r Record, tableAlias string) error { | ||||
| 	switch { | ||||
| 	case e.ListExpr != nil: | ||||
| 		return e.ListExpr.aggregateRow(r) | ||||
| 		return e.ListExpr.aggregateRow(r, tableAlias) | ||||
| 	case e.SubExpression != nil: | ||||
| 		return e.SubExpression.aggregateRow(r) | ||||
| 		return e.SubExpression.aggregateRow(r, tableAlias) | ||||
| 	case e.FuncCall != nil: | ||||
| 		return e.FuncCall.aggregateRow(r) | ||||
| 		return e.FuncCall.aggregateRow(r, tableAlias) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (e *FuncExpr) aggregateRow(r Record) error { | ||||
| func (e *FuncExpr) aggregateRow(r Record, tableAlias string) error { | ||||
| 	switch e.getFunctionName() { | ||||
| 	case aggFnAvg, aggFnSum, aggFnMax, aggFnMin, aggFnCount: | ||||
| 		return e.evalAggregationNode(r) | ||||
| 		return e.evalAggregationNode(r, tableAlias) | ||||
| 	default: | ||||
| 		// TODO: traverse arguments and call aggregateRow on | ||||
| 		// them if they could be an ancestor of an | ||||
|  | ||||
| @ -19,6 +19,7 @@ package sql | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"strings" | ||||
| ) | ||||
| 
 | ||||
| // Query analysis - The query is analyzed to determine if it involves | ||||
| @ -177,7 +178,7 @@ func (e *PrimaryTerm) analyze(s *Select) (result qProp) { | ||||
| 	case e.JPathExpr != nil: | ||||
| 		// Check if the path expression is valid | ||||
| 		if len(e.JPathExpr.PathExpr) > 0 { | ||||
| 			if e.JPathExpr.BaseKey.String() != s.From.As { | ||||
| 			if e.JPathExpr.BaseKey.String() != s.From.As && strings.ToLower(e.JPathExpr.BaseKey.String()) != baseTableName { | ||||
| 				result = qProp{err: errInvalidKeypath} | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| @ -21,7 +21,6 @@ import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"math" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"github.com/bcicen/jstream" | ||||
| 	"github.com/minio/simdjson-go" | ||||
| @ -47,21 +46,21 @@ var ( | ||||
| // of child nodes. The final result row is returned after all rows are | ||||
| // processed, and the `getAggregate` function is called. | ||||
| 
 | ||||
| func (e *AliasedExpression) evalNode(r Record) (*Value, error) { | ||||
| 	return e.Expression.evalNode(r) | ||||
| func (e *AliasedExpression) evalNode(r Record, tableAlias string) (*Value, error) { | ||||
| 	return e.Expression.evalNode(r, tableAlias) | ||||
| } | ||||
| 
 | ||||
| func (e *Expression) evalNode(r Record) (*Value, error) { | ||||
| func (e *Expression) evalNode(r Record, tableAlias string) (*Value, error) { | ||||
| 	if len(e.And) == 1 { | ||||
| 		// In this case, result is not required to be boolean | ||||
| 		// type. | ||||
| 		return e.And[0].evalNode(r) | ||||
| 		return e.And[0].evalNode(r, tableAlias) | ||||
| 	} | ||||
| 
 | ||||
| 	// Compute OR of conditions | ||||
| 	result := false | ||||
| 	for _, ex := range e.And { | ||||
| 		res, err := ex.evalNode(r) | ||||
| 		res, err := ex.evalNode(r, tableAlias) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| @ -74,16 +73,16 @@ func (e *Expression) evalNode(r Record) (*Value, error) { | ||||
| 	return FromBool(result), nil | ||||
| } | ||||
| 
 | ||||
| func (e *AndCondition) evalNode(r Record) (*Value, error) { | ||||
| func (e *AndCondition) evalNode(r Record, tableAlias string) (*Value, error) { | ||||
| 	if len(e.Condition) == 1 { | ||||
| 		// In this case, result does not have to be boolean | ||||
| 		return e.Condition[0].evalNode(r) | ||||
| 		return e.Condition[0].evalNode(r, tableAlias) | ||||
| 	} | ||||
| 
 | ||||
| 	// Compute AND of conditions | ||||
| 	result := true | ||||
| 	for _, ex := range e.Condition { | ||||
| 		res, err := ex.evalNode(r) | ||||
| 		res, err := ex.evalNode(r, tableAlias) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| @ -96,14 +95,14 @@ func (e *AndCondition) evalNode(r Record) (*Value, error) { | ||||
| 	return FromBool(result), nil | ||||
| } | ||||
| 
 | ||||
| func (e *Condition) evalNode(r Record) (*Value, error) { | ||||
| func (e *Condition) evalNode(r Record, tableAlias string) (*Value, error) { | ||||
| 	if e.Operand != nil { | ||||
| 		// In this case, result does not have to be boolean | ||||
| 		return e.Operand.evalNode(r) | ||||
| 		return e.Operand.evalNode(r, tableAlias) | ||||
| 	} | ||||
| 
 | ||||
| 	// Compute NOT of condition | ||||
| 	res, err := e.Not.evalNode(r) | ||||
| 	res, err := e.Not.evalNode(r, tableAlias) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @ -114,8 +113,8 @@ func (e *Condition) evalNode(r Record) (*Value, error) { | ||||
| 	return FromBool(!b), nil | ||||
| } | ||||
| 
 | ||||
| func (e *ConditionOperand) evalNode(r Record) (*Value, error) { | ||||
| 	opVal, opErr := e.Operand.evalNode(r) | ||||
| func (e *ConditionOperand) evalNode(r Record, tableAlias string) (*Value, error) { | ||||
| 	opVal, opErr := e.Operand.evalNode(r, tableAlias) | ||||
| 	if opErr != nil || e.ConditionRHS == nil { | ||||
| 		return opVal, opErr | ||||
| 	} | ||||
| @ -123,7 +122,7 @@ func (e *ConditionOperand) evalNode(r Record) (*Value, error) { | ||||
| 	// Need to evaluate the ConditionRHS | ||||
| 	switch { | ||||
| 	case e.ConditionRHS.Compare != nil: | ||||
| 		cmpRight, cmpRErr := e.ConditionRHS.Compare.Operand.evalNode(r) | ||||
| 		cmpRight, cmpRErr := e.ConditionRHS.Compare.Operand.evalNode(r, tableAlias) | ||||
| 		if cmpRErr != nil { | ||||
| 			return nil, cmpRErr | ||||
| 		} | ||||
| @ -132,26 +131,26 @@ func (e *ConditionOperand) evalNode(r Record) (*Value, error) { | ||||
| 		return FromBool(b), err | ||||
| 
 | ||||
| 	case e.ConditionRHS.Between != nil: | ||||
| 		return e.ConditionRHS.Between.evalBetweenNode(r, opVal) | ||||
| 		return e.ConditionRHS.Between.evalBetweenNode(r, opVal, tableAlias) | ||||
| 
 | ||||
| 	case e.ConditionRHS.Like != nil: | ||||
| 		return e.ConditionRHS.Like.evalLikeNode(r, opVal) | ||||
| 		return e.ConditionRHS.Like.evalLikeNode(r, opVal, tableAlias) | ||||
| 
 | ||||
| 	case e.ConditionRHS.In != nil: | ||||
| 		return e.ConditionRHS.In.evalInNode(r, opVal) | ||||
| 		return e.ConditionRHS.In.evalInNode(r, opVal, tableAlias) | ||||
| 
 | ||||
| 	default: | ||||
| 		return nil, errInvalidASTNode | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (e *Between) evalBetweenNode(r Record, arg *Value) (*Value, error) { | ||||
| 	stVal, stErr := e.Start.evalNode(r) | ||||
| func (e *Between) evalBetweenNode(r Record, arg *Value, tableAlias string) (*Value, error) { | ||||
| 	stVal, stErr := e.Start.evalNode(r, tableAlias) | ||||
| 	if stErr != nil { | ||||
| 		return nil, stErr | ||||
| 	} | ||||
| 
 | ||||
| 	endVal, endErr := e.End.evalNode(r) | ||||
| 	endVal, endErr := e.End.evalNode(r, tableAlias) | ||||
| 	if endErr != nil { | ||||
| 		return nil, endErr | ||||
| 	} | ||||
| @ -174,7 +173,7 @@ func (e *Between) evalBetweenNode(r Record, arg *Value) (*Value, error) { | ||||
| 	return FromBool(result), nil | ||||
| } | ||||
| 
 | ||||
| func (e *Like) evalLikeNode(r Record, arg *Value) (*Value, error) { | ||||
| func (e *Like) evalLikeNode(r Record, arg *Value, tableAlias string) (*Value, error) { | ||||
| 	inferTypeAsString(arg) | ||||
| 
 | ||||
| 	s, ok := arg.ToString() | ||||
| @ -183,7 +182,7 @@ func (e *Like) evalLikeNode(r Record, arg *Value) (*Value, error) { | ||||
| 		return nil, errLikeInvalidInputs(err) | ||||
| 	} | ||||
| 
 | ||||
| 	pattern, err1 := e.Pattern.evalNode(r) | ||||
| 	pattern, err1 := e.Pattern.evalNode(r, tableAlias) | ||||
| 	if err1 != nil { | ||||
| 		return nil, err1 | ||||
| 	} | ||||
| @ -199,7 +198,7 @@ func (e *Like) evalLikeNode(r Record, arg *Value) (*Value, error) { | ||||
| 
 | ||||
| 	escape := runeZero | ||||
| 	if e.EscapeChar != nil { | ||||
| 		escapeVal, err2 := e.EscapeChar.evalNode(r) | ||||
| 		escapeVal, err2 := e.EscapeChar.evalNode(r, tableAlias) | ||||
| 		if err2 != nil { | ||||
| 			return nil, err2 | ||||
| 		} | ||||
| @ -230,14 +229,14 @@ func (e *Like) evalLikeNode(r Record, arg *Value) (*Value, error) { | ||||
| 	return FromBool(matchResult), nil | ||||
| } | ||||
| 
 | ||||
| func (e *ListExpr) evalNode(r Record) (*Value, error) { | ||||
| func (e *ListExpr) evalNode(r Record, tableAlias string) (*Value, error) { | ||||
| 	res := make([]Value, len(e.Elements)) | ||||
| 	if len(e.Elements) == 1 { | ||||
| 		// If length 1, treat as single value. | ||||
| 		return e.Elements[0].evalNode(r) | ||||
| 		return e.Elements[0].evalNode(r, tableAlias) | ||||
| 	} | ||||
| 	for i, elt := range e.Elements { | ||||
| 		v, err := elt.evalNode(r) | ||||
| 		v, err := elt.evalNode(r, tableAlias) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| @ -248,7 +247,7 @@ func (e *ListExpr) evalNode(r Record) (*Value, error) { | ||||
| 
 | ||||
| const floatCmpTolerance = 0.000001 | ||||
| 
 | ||||
| func (e *In) evalInNode(r Record, lhs *Value) (*Value, error) { | ||||
| func (e *In) evalInNode(r Record, lhs *Value, tableAlias string) (*Value, error) { | ||||
| 	// Compare two values in terms of in-ness. | ||||
| 	var cmp func(a, b Value) bool | ||||
| 	cmp = func(a, b Value) bool { | ||||
| @ -283,7 +282,7 @@ func (e *In) evalInNode(r Record, lhs *Value) (*Value, error) { | ||||
| 
 | ||||
| 	var rhs Value | ||||
| 	if elt := e.ListExpression; elt != nil { | ||||
| 		eltVal, err := elt.evalNode(r) | ||||
| 		eltVal, err := elt.evalNode(r, tableAlias) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| @ -304,8 +303,8 @@ func (e *In) evalInNode(r Record, lhs *Value) (*Value, error) { | ||||
| 	return FromBool(cmp(rhs, *lhs)), nil | ||||
| } | ||||
| 
 | ||||
| func (e *Operand) evalNode(r Record) (*Value, error) { | ||||
| 	lval, lerr := e.Left.evalNode(r) | ||||
| func (e *Operand) evalNode(r Record, tableAlias string) (*Value, error) { | ||||
| 	lval, lerr := e.Left.evalNode(r, tableAlias) | ||||
| 	if lerr != nil || len(e.Right) == 0 { | ||||
| 		return lval, lerr | ||||
| 	} | ||||
| @ -315,7 +314,7 @@ func (e *Operand) evalNode(r Record) (*Value, error) { | ||||
| 	// symbols. | ||||
| 	for _, rightTerm := range e.Right { | ||||
| 		op := rightTerm.Op | ||||
| 		rval, rerr := rightTerm.Right.evalNode(r) | ||||
| 		rval, rerr := rightTerm.Right.evalNode(r, tableAlias) | ||||
| 		if rerr != nil { | ||||
| 			return nil, rerr | ||||
| 		} | ||||
| @ -327,8 +326,8 @@ func (e *Operand) evalNode(r Record) (*Value, error) { | ||||
| 	return lval, nil | ||||
| } | ||||
| 
 | ||||
| func (e *MultOp) evalNode(r Record) (*Value, error) { | ||||
| 	lval, lerr := e.Left.evalNode(r) | ||||
| func (e *MultOp) evalNode(r Record, tableAlias string) (*Value, error) { | ||||
| 	lval, lerr := e.Left.evalNode(r, tableAlias) | ||||
| 	if lerr != nil || len(e.Right) == 0 { | ||||
| 		return lval, lerr | ||||
| 	} | ||||
| @ -337,7 +336,7 @@ func (e *MultOp) evalNode(r Record) (*Value, error) { | ||||
| 	// AST node is for terms separated by *, / or % symbols. | ||||
| 	for _, rightTerm := range e.Right { | ||||
| 		op := rightTerm.Op | ||||
| 		rval, rerr := rightTerm.Right.evalNode(r) | ||||
| 		rval, rerr := rightTerm.Right.evalNode(r, tableAlias) | ||||
| 		if rerr != nil { | ||||
| 			return nil, rerr | ||||
| 		} | ||||
| @ -350,12 +349,12 @@ func (e *MultOp) evalNode(r Record) (*Value, error) { | ||||
| 	return lval, nil | ||||
| } | ||||
| 
 | ||||
| func (e *UnaryTerm) evalNode(r Record) (*Value, error) { | ||||
| func (e *UnaryTerm) evalNode(r Record, tableAlias string) (*Value, error) { | ||||
| 	if e.Negated == nil { | ||||
| 		return e.Primary.evalNode(r) | ||||
| 		return e.Primary.evalNode(r, tableAlias) | ||||
| 	} | ||||
| 
 | ||||
| 	v, err := e.Negated.Term.evalNode(r) | ||||
| 	v, err := e.Negated.Term.evalNode(r, tableAlias) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @ -368,19 +367,15 @@ func (e *UnaryTerm) evalNode(r Record) (*Value, error) { | ||||
| 	return nil, errArithMismatchedTypes | ||||
| } | ||||
| 
 | ||||
| func (e *JSONPath) evalNode(r Record) (*Value, error) { | ||||
| 	// Strip the table name from the keypath. | ||||
| 	keypath := e.String() | ||||
| 	if strings.Contains(keypath, ".") { | ||||
| 		ps := strings.SplitN(keypath, ".", 2) | ||||
| 		if len(ps) == 2 { | ||||
| 			keypath = ps[1] | ||||
| 		} | ||||
| func (e *JSONPath) evalNode(r Record, tableAlias string) (*Value, error) { | ||||
| 	alias := tableAlias | ||||
| 	if tableAlias == "" { | ||||
| 		alias = baseTableName | ||||
| 	} | ||||
| 	pathExpr := e.StripTableAlias(alias) | ||||
| 	_, rawVal := r.Raw() | ||||
| 	switch rowVal := rawVal.(type) { | ||||
| 	case jstream.KVS, simdjson.Object: | ||||
| 		pathExpr := e.PathExpr | ||||
| 		if len(pathExpr) == 0 { | ||||
| 			pathExpr = []*JSONPathElement{{Key: &ObjectKey{ID: e.BaseKey}}} | ||||
| 		} | ||||
| @ -392,7 +387,10 @@ func (e *JSONPath) evalNode(r Record) (*Value, error) { | ||||
| 
 | ||||
| 		return jsonToValue(result) | ||||
| 	default: | ||||
| 		return r.Get(keypath) | ||||
| 		if pathExpr[len(pathExpr)-1].Key == nil { | ||||
| 			return nil, errInvalidKeypath | ||||
| 		} | ||||
| 		return r.Get(pathExpr[len(pathExpr)-1].Key.keyString()) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| @ -447,28 +445,28 @@ func jsonToValue(result interface{}) (*Value, error) { | ||||
| 	return nil, fmt.Errorf("Unhandled value type: %T", result) | ||||
| } | ||||
| 
 | ||||
| func (e *PrimaryTerm) evalNode(r Record) (res *Value, err error) { | ||||
| func (e *PrimaryTerm) evalNode(r Record, tableAlias string) (res *Value, err error) { | ||||
| 	switch { | ||||
| 	case e.Value != nil: | ||||
| 		return e.Value.evalNode(r) | ||||
| 	case e.JPathExpr != nil: | ||||
| 		return e.JPathExpr.evalNode(r) | ||||
| 		return e.JPathExpr.evalNode(r, tableAlias) | ||||
| 	case e.ListExpr != nil: | ||||
| 		return e.ListExpr.evalNode(r) | ||||
| 		return e.ListExpr.evalNode(r, tableAlias) | ||||
| 	case e.SubExpression != nil: | ||||
| 		return e.SubExpression.evalNode(r) | ||||
| 		return e.SubExpression.evalNode(r, tableAlias) | ||||
| 	case e.FuncCall != nil: | ||||
| 		return e.FuncCall.evalNode(r) | ||||
| 		return e.FuncCall.evalNode(r, tableAlias) | ||||
| 	} | ||||
| 	return nil, errInvalidASTNode | ||||
| } | ||||
| 
 | ||||
| func (e *FuncExpr) evalNode(r Record) (res *Value, err error) { | ||||
| func (e *FuncExpr) evalNode(r Record, tableAlias string) (res *Value, err error) { | ||||
| 	switch e.getFunctionName() { | ||||
| 	case aggFnCount, aggFnAvg, aggFnMax, aggFnMin, aggFnSum: | ||||
| 		return e.getAggregate() | ||||
| 	default: | ||||
| 		return e.evalSQLFnNode(r) | ||||
| 		return e.evalSQLFnNode(r, tableAlias) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
|  | ||||
| @ -84,35 +84,35 @@ func (e *FuncExpr) getFunctionName() FuncName { | ||||
| 
 | ||||
| // evalSQLFnNode assumes that the FuncExpr is not an aggregation | ||||
| // function. | ||||
| func (e *FuncExpr) evalSQLFnNode(r Record) (res *Value, err error) { | ||||
| func (e *FuncExpr) evalSQLFnNode(r Record, tableAlias string) (res *Value, err error) { | ||||
| 	// Handle functions that have phrase arguments | ||||
| 	switch e.getFunctionName() { | ||||
| 	case sqlFnCast: | ||||
| 		expr := e.Cast.Expr | ||||
| 		res, err = expr.castTo(r, strings.ToUpper(e.Cast.CastType)) | ||||
| 		res, err = expr.castTo(r, strings.ToUpper(e.Cast.CastType), tableAlias) | ||||
| 		return | ||||
| 
 | ||||
| 	case sqlFnSubstring: | ||||
| 		return handleSQLSubstring(r, e.Substring) | ||||
| 		return handleSQLSubstring(r, e.Substring, tableAlias) | ||||
| 
 | ||||
| 	case sqlFnExtract: | ||||
| 		return handleSQLExtract(r, e.Extract) | ||||
| 		return handleSQLExtract(r, e.Extract, tableAlias) | ||||
| 
 | ||||
| 	case sqlFnTrim: | ||||
| 		return handleSQLTrim(r, e.Trim) | ||||
| 		return handleSQLTrim(r, e.Trim, tableAlias) | ||||
| 
 | ||||
| 	case sqlFnDateAdd: | ||||
| 		return handleDateAdd(r, e.DateAdd) | ||||
| 		return handleDateAdd(r, e.DateAdd, tableAlias) | ||||
| 
 | ||||
| 	case sqlFnDateDiff: | ||||
| 		return handleDateDiff(r, e.DateDiff) | ||||
| 		return handleDateDiff(r, e.DateDiff, tableAlias) | ||||
| 
 | ||||
| 	} | ||||
| 
 | ||||
| 	// For all simple argument functions, we evaluate the arguments here | ||||
| 	argVals := make([]*Value, len(e.SFunc.ArgsList)) | ||||
| 	for i, arg := range e.SFunc.ArgsList { | ||||
| 		argVals[i], err = arg.evalNode(r) | ||||
| 		argVals[i], err = arg.evalNode(r, tableAlias) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| @ -219,8 +219,8 @@ func upperCase(v *Value) (*Value, error) { | ||||
| 	return FromString(strings.ToUpper(s)), nil | ||||
| } | ||||
| 
 | ||||
| func handleDateAdd(r Record, d *DateAddFunc) (*Value, error) { | ||||
| 	q, err := d.Quantity.evalNode(r) | ||||
| func handleDateAdd(r Record, d *DateAddFunc, tableAlias string) (*Value, error) { | ||||
| 	q, err := d.Quantity.evalNode(r, tableAlias) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @ -230,7 +230,7 @@ func handleDateAdd(r Record, d *DateAddFunc) (*Value, error) { | ||||
| 		return nil, fmt.Errorf("QUANTITY must be a numeric argument to %s()", sqlFnDateAdd) | ||||
| 	} | ||||
| 
 | ||||
| 	ts, err := d.Timestamp.evalNode(r) | ||||
| 	ts, err := d.Timestamp.evalNode(r, tableAlias) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @ -245,8 +245,8 @@ func handleDateAdd(r Record, d *DateAddFunc) (*Value, error) { | ||||
| 	return dateAdd(strings.ToUpper(d.DatePart), qty, t) | ||||
| } | ||||
| 
 | ||||
| func handleDateDiff(r Record, d *DateDiffFunc) (*Value, error) { | ||||
| 	tval1, err := d.Timestamp1.evalNode(r) | ||||
| func handleDateDiff(r Record, d *DateDiffFunc, tableAlias string) (*Value, error) { | ||||
| 	tval1, err := d.Timestamp1.evalNode(r, tableAlias) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @ -258,7 +258,7 @@ func handleDateDiff(r Record, d *DateDiffFunc) (*Value, error) { | ||||
| 		return nil, fmt.Errorf("%s() expects two timestamp arguments", sqlFnDateDiff) | ||||
| 	} | ||||
| 
 | ||||
| 	tval2, err := d.Timestamp2.evalNode(r) | ||||
| 	tval2, err := d.Timestamp2.evalNode(r, tableAlias) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @ -277,12 +277,12 @@ func handleUTCNow() (*Value, error) { | ||||
| 	return FromTimestamp(time.Now().UTC()), nil | ||||
| } | ||||
| 
 | ||||
| func handleSQLSubstring(r Record, e *SubstringFunc) (val *Value, err error) { | ||||
| func handleSQLSubstring(r Record, e *SubstringFunc, tableAlias string) (val *Value, err error) { | ||||
| 	// Both forms `SUBSTRING('abc' FROM 2 FOR 1)` and | ||||
| 	// SUBSTRING('abc', 2, 1) are supported. | ||||
| 
 | ||||
| 	// Evaluate the string argument | ||||
| 	v1, err := e.Expr.evalNode(r) | ||||
| 	v1, err := e.Expr.evalNode(r, tableAlias) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @ -301,7 +301,7 @@ func handleSQLSubstring(r Record, e *SubstringFunc) (val *Value, err error) { | ||||
| 	} | ||||
| 
 | ||||
| 	// Evaluate the FROM argument | ||||
| 	v2, err := arg2.evalNode(r) | ||||
| 	v2, err := arg2.evalNode(r, tableAlias) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @ -315,7 +315,7 @@ func handleSQLSubstring(r Record, e *SubstringFunc) (val *Value, err error) { | ||||
| 	length := -1 | ||||
| 	// Evaluate the optional FOR argument | ||||
| 	if arg3 != nil { | ||||
| 		v3, err := arg3.evalNode(r) | ||||
| 		v3, err := arg3.evalNode(r, tableAlias) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| @ -336,11 +336,11 @@ func handleSQLSubstring(r Record, e *SubstringFunc) (val *Value, err error) { | ||||
| 	return FromString(res), err | ||||
| } | ||||
| 
 | ||||
| func handleSQLTrim(r Record, e *TrimFunc) (res *Value, err error) { | ||||
| func handleSQLTrim(r Record, e *TrimFunc, tableAlias string) (res *Value, err error) { | ||||
| 	chars := "" | ||||
| 	ok := false | ||||
| 	if e.TrimChars != nil { | ||||
| 		charsV, cerr := e.TrimChars.evalNode(r) | ||||
| 		charsV, cerr := e.TrimChars.evalNode(r, tableAlias) | ||||
| 		if cerr != nil { | ||||
| 			return nil, cerr | ||||
| 		} | ||||
| @ -351,7 +351,7 @@ func handleSQLTrim(r Record, e *TrimFunc) (res *Value, err error) { | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	fromV, ferr := e.TrimFrom.evalNode(r) | ||||
| 	fromV, ferr := e.TrimFrom.evalNode(r, tableAlias) | ||||
| 	if ferr != nil { | ||||
| 		return nil, ferr | ||||
| 	} | ||||
| @ -368,8 +368,8 @@ func handleSQLTrim(r Record, e *TrimFunc) (res *Value, err error) { | ||||
| 	return FromString(result), nil | ||||
| } | ||||
| 
 | ||||
| func handleSQLExtract(r Record, e *ExtractFunc) (res *Value, err error) { | ||||
| 	timeVal, verr := e.From.evalNode(r) | ||||
| func handleSQLExtract(r Record, e *ExtractFunc, tableAlias string) (res *Value, err error) { | ||||
| 	timeVal, verr := e.From.evalNode(r, tableAlias) | ||||
| 	if verr != nil { | ||||
| 		return nil, verr | ||||
| 	} | ||||
| @ -406,8 +406,8 @@ const ( | ||||
| 	castTimestamp = "TIMESTAMP" | ||||
| ) | ||||
| 
 | ||||
| func (e *Expression) castTo(r Record, castType string) (res *Value, err error) { | ||||
| 	v, err := e.evalNode(r) | ||||
| func (e *Expression) castTo(r Record, castType string, tableAlias string) (res *Value, err error) { | ||||
| 	v, err := e.evalNode(r, tableAlias) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| @ -119,7 +119,9 @@ type JSONPath struct { | ||||
| 	PathExpr []*JSONPathElement `parser:"(@@)*"` | ||||
| 
 | ||||
| 	// Cached values: | ||||
| 	pathString string | ||||
| 	pathString         string | ||||
| 	strippedTableAlias string | ||||
| 	strippedPathExpr   []*JSONPathElement | ||||
| } | ||||
| 
 | ||||
| // AliasedExpression is an expression that can be optionally named | ||||
|  | ||||
| @ -46,6 +46,9 @@ type SelectStatement struct { | ||||
| 
 | ||||
| 	// Count of rows that have been output. | ||||
| 	outputCount int64 | ||||
| 
 | ||||
| 	// Table alias | ||||
| 	tableAlias string | ||||
| } | ||||
| 
 | ||||
| // ParseSelectStatement - parses a select query from the given string | ||||
| @ -107,6 +110,9 @@ func ParseSelectStatement(s string) (stmt SelectStatement, err error) { | ||||
| 	if err != nil { | ||||
| 		err = errQueryAnalysisFailure(err) | ||||
| 	} | ||||
| 
 | ||||
| 	// Set table alias | ||||
| 	stmt.tableAlias = selectAST.From.As | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| @ -226,7 +232,7 @@ func (e *SelectStatement) IsAggregated() bool { | ||||
| // records have been processed. Applies only to aggregation queries. | ||||
| func (e *SelectStatement) AggregateResult(output Record) error { | ||||
| 	for i, expr := range e.selectAST.Expression.Expressions { | ||||
| 		v, err := expr.evalNode(nil) | ||||
| 		v, err := expr.evalNode(nil, e.tableAlias) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| @ -246,7 +252,7 @@ func (e *SelectStatement) isPassingWhereClause(input Record) (bool, error) { | ||||
| 	if e.selectAST.Where == nil { | ||||
| 		return true, nil | ||||
| 	} | ||||
| 	value, err := e.selectAST.Where.evalNode(input) | ||||
| 	value, err := e.selectAST.Where.evalNode(input, e.tableAlias) | ||||
| 	if err != nil { | ||||
| 		return false, err | ||||
| 	} | ||||
| @ -272,7 +278,7 @@ func (e *SelectStatement) AggregateRow(input Record) error { | ||||
| 	} | ||||
| 
 | ||||
| 	for _, expr := range e.selectAST.Expression.Expressions { | ||||
| 		err := expr.aggregateRow(input) | ||||
| 		err := expr.aggregateRow(input, e.tableAlias) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| @ -302,7 +308,7 @@ func (e *SelectStatement) Eval(input, output Record) (Record, error) { | ||||
| 	} | ||||
| 
 | ||||
| 	for i, expr := range e.selectAST.Expression.Expressions { | ||||
| 		v, err := expr.evalNode(input) | ||||
| 		v, err := expr.evalNode(input, e.tableAlias) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
|  | ||||
| @ -36,6 +36,27 @@ func (e *JSONPath) String() string { | ||||
| 	return e.pathString | ||||
| } | ||||
| 
 | ||||
| // StripTableAlias removes a table alias from the path. The result is also | ||||
| // cached for repeated lookups during SQL query evaluation. | ||||
| func (e *JSONPath) StripTableAlias(tableAlias string) []*JSONPathElement { | ||||
| 	if e.strippedTableAlias == tableAlias { | ||||
| 		return e.strippedPathExpr | ||||
| 	} | ||||
| 
 | ||||
| 	hasTableAlias := e.BaseKey.String() == tableAlias || strings.ToLower(e.BaseKey.String()) == baseTableName | ||||
| 	var pathExpr []*JSONPathElement | ||||
| 	if hasTableAlias { | ||||
| 		pathExpr = e.PathExpr | ||||
| 	} else { | ||||
| 		pathExpr = make([]*JSONPathElement, len(e.PathExpr)+1) | ||||
| 		pathExpr[0] = &JSONPathElement{Key: &ObjectKey{ID: e.BaseKey}} | ||||
| 		copy(pathExpr[1:], e.PathExpr) | ||||
| 	} | ||||
| 	e.strippedTableAlias = tableAlias | ||||
| 	e.strippedPathExpr = pathExpr | ||||
| 	return e.strippedPathExpr | ||||
| } | ||||
| 
 | ||||
| func (e *JSONPathElement) String() string { | ||||
| 	switch { | ||||
| 	case e.Key != nil: | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user