mirror of
https://github.com/minio/minio.git
synced 2025-01-12 15:33:22 -05:00
Fix S3Select SQL column reference handling (#11957)
This change fixes handling of these types of queries: - Double quoted column names with special characters: SELECT "column.name" FROM s3object - Double quoted column names with reserved keywords: SELECT "CAST" FROM s3object - Table name as prefix for column names: SELECT S3Object."CAST" FROM s3object
This commit is contained in:
parent
d5d2fc9850
commit
b2936243f9
@ -739,6 +739,152 @@ func TestCSVQueries2(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCSVQueries3(t *testing.T) {
|
||||||
|
input := `na.me,qty,CAST
|
||||||
|
apple,1,true
|
||||||
|
mango,3,false
|
||||||
|
`
|
||||||
|
var testTable = []struct {
|
||||||
|
name string
|
||||||
|
query string
|
||||||
|
requestXML []byte // override request XML
|
||||||
|
wantResult string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Select a column containing dot",
|
||||||
|
query: `select "na.me" from S3Object s`,
|
||||||
|
wantResult: `apple
|
||||||
|
mango`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Select column containing dot with table name prefix",
|
||||||
|
query: `select count(S3Object."na.me") from S3Object`,
|
||||||
|
wantResult: `2`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Select column containing dot with table alias prefix",
|
||||||
|
query: `select s."na.me" from S3Object as s`,
|
||||||
|
wantResult: `apple
|
||||||
|
mango`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Select column simplest",
|
||||||
|
query: `select qty from S3Object`,
|
||||||
|
wantResult: `1
|
||||||
|
3`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Select column with table name prefix",
|
||||||
|
query: `select S3Object.qty from S3Object`,
|
||||||
|
wantResult: `1
|
||||||
|
3`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Select column without table alias",
|
||||||
|
query: `select qty from S3Object s`,
|
||||||
|
wantResult: `1
|
||||||
|
3`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Select column with table alias",
|
||||||
|
query: `select s.qty from S3Object s`,
|
||||||
|
wantResult: `1
|
||||||
|
3`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Select reserved word column",
|
||||||
|
query: `select "CAST" from s3object`,
|
||||||
|
wantResult: `true
|
||||||
|
false`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Select reserved word column with table alias",
|
||||||
|
query: `select S3Object."CAST" from s3object`,
|
||||||
|
wantResult: `true
|
||||||
|
false`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Select reserved word column with unused table alias",
|
||||||
|
query: `select "CAST" from s3object s`,
|
||||||
|
wantResult: `true
|
||||||
|
false`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Select reserved word column with table alias",
|
||||||
|
query: `select s."CAST" from s3object s`,
|
||||||
|
wantResult: `true
|
||||||
|
false`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Select reserved word column with table alias",
|
||||||
|
query: `select NOT CAST(s."CAST" AS Bool) from s3object s`,
|
||||||
|
wantResult: `false
|
||||||
|
true`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
defRequest := `<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<SelectObjectContentRequest>
|
||||||
|
<Expression>%s</Expression>
|
||||||
|
<ExpressionType>SQL</ExpressionType>
|
||||||
|
<InputSerialization>
|
||||||
|
<CompressionType>NONE</CompressionType>
|
||||||
|
<CSV>
|
||||||
|
<FileHeaderInfo>USE</FileHeaderInfo>
|
||||||
|
<QuoteCharacter>"</QuoteCharacter>
|
||||||
|
</CSV>
|
||||||
|
</InputSerialization>
|
||||||
|
<OutputSerialization>
|
||||||
|
<CSV/>
|
||||||
|
</OutputSerialization>
|
||||||
|
<RequestProgress>
|
||||||
|
<Enabled>FALSE</Enabled>
|
||||||
|
</RequestProgress>
|
||||||
|
</SelectObjectContentRequest>`
|
||||||
|
|
||||||
|
for _, testCase := range testTable {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
testReq := testCase.requestXML
|
||||||
|
if len(testReq) == 0 {
|
||||||
|
testReq = []byte(fmt.Sprintf(defRequest, testCase.query))
|
||||||
|
}
|
||||||
|
s3Select, err := NewS3Select(bytes.NewReader(testReq))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = s3Select.Open(func(offset, length int64) (io.ReadCloser, error) {
|
||||||
|
return ioutil.NopCloser(bytes.NewBufferString(input)), nil
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w := &testResponseWriter{}
|
||||||
|
s3Select.Evaluate(w)
|
||||||
|
s3Select.Close()
|
||||||
|
resp := http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: ioutil.NopCloser(bytes.NewReader(w.response)),
|
||||||
|
ContentLength: int64(len(w.response)),
|
||||||
|
}
|
||||||
|
res, err := minio.NewSelectResults(&resp, "testbucket")
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
got, err := ioutil.ReadAll(res)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
gotS := strings.TrimSpace(string(got))
|
||||||
|
if gotS != testCase.wantResult {
|
||||||
|
t.Errorf("received response does not match with expected reply.\nQuery: %s\n=====\ngot: %s\n=====\nwant: %s\n=====\n", testCase.query, gotS, testCase.wantResult)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestCSVInput(t *testing.T) {
|
func TestCSVInput(t *testing.T) {
|
||||||
var testTable = []struct {
|
var testTable = []struct {
|
||||||
requestXML []byte
|
requestXML []byte
|
||||||
|
@ -63,7 +63,7 @@ func newAggVal(fn FuncName) *aggVal {
|
|||||||
// current row and stores the result.
|
// current row and stores the result.
|
||||||
//
|
//
|
||||||
// On success, it returns (nil, nil).
|
// On success, it returns (nil, nil).
|
||||||
func (e *FuncExpr) evalAggregationNode(r Record) error {
|
func (e *FuncExpr) evalAggregationNode(r Record, tableAlias string) error {
|
||||||
// It is assumed that this function is called only when
|
// It is assumed that this function is called only when
|
||||||
// `e` is an aggregation function.
|
// `e` is an aggregation function.
|
||||||
|
|
||||||
@ -77,13 +77,13 @@ func (e *FuncExpr) evalAggregationNode(r Record) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
val, err = e.Count.ExprArg.evalNode(r)
|
val, err = e.Count.ExprArg.evalNode(r, tableAlias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Evaluate the (only) argument
|
// Evaluate the (only) argument
|
||||||
val, err = e.SFunc.ArgsList[0].evalNode(r)
|
val, err = e.SFunc.ArgsList[0].evalNode(r, tableAlias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -149,13 +149,13 @@ func (e *FuncExpr) evalAggregationNode(r Record) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *AliasedExpression) aggregateRow(r Record) error {
|
func (e *AliasedExpression) aggregateRow(r Record, tableAlias string) error {
|
||||||
return e.Expression.aggregateRow(r)
|
return e.Expression.aggregateRow(r, tableAlias)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Expression) aggregateRow(r Record) error {
|
func (e *Expression) aggregateRow(r Record, tableAlias string) error {
|
||||||
for _, ex := range e.And {
|
for _, ex := range e.And {
|
||||||
err := ex.aggregateRow(r)
|
err := ex.aggregateRow(r, tableAlias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -163,9 +163,9 @@ func (e *Expression) aggregateRow(r Record) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *ListExpr) aggregateRow(r Record) error {
|
func (e *ListExpr) aggregateRow(r Record, tableAlias string) error {
|
||||||
for _, ex := range e.Elements {
|
for _, ex := range e.Elements {
|
||||||
err := ex.aggregateRow(r)
|
err := ex.aggregateRow(r, tableAlias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -173,9 +173,9 @@ func (e *ListExpr) aggregateRow(r Record) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *AndCondition) aggregateRow(r Record) error {
|
func (e *AndCondition) aggregateRow(r Record, tableAlias string) error {
|
||||||
for _, ex := range e.Condition {
|
for _, ex := range e.Condition {
|
||||||
err := ex.aggregateRow(r)
|
err := ex.aggregateRow(r, tableAlias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -183,15 +183,15 @@ func (e *AndCondition) aggregateRow(r Record) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Condition) aggregateRow(r Record) error {
|
func (e *Condition) aggregateRow(r Record, tableAlias string) error {
|
||||||
if e.Operand != nil {
|
if e.Operand != nil {
|
||||||
return e.Operand.aggregateRow(r)
|
return e.Operand.aggregateRow(r, tableAlias)
|
||||||
}
|
}
|
||||||
return e.Not.aggregateRow(r)
|
return e.Not.aggregateRow(r, tableAlias)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *ConditionOperand) aggregateRow(r Record) error {
|
func (e *ConditionOperand) aggregateRow(r Record, tableAlias string) error {
|
||||||
err := e.Operand.aggregateRow(r)
|
err := e.Operand.aggregateRow(r, tableAlias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -202,38 +202,38 @@ func (e *ConditionOperand) aggregateRow(r Record) error {
|
|||||||
|
|
||||||
switch {
|
switch {
|
||||||
case e.ConditionRHS.Compare != nil:
|
case e.ConditionRHS.Compare != nil:
|
||||||
return e.ConditionRHS.Compare.Operand.aggregateRow(r)
|
return e.ConditionRHS.Compare.Operand.aggregateRow(r, tableAlias)
|
||||||
case e.ConditionRHS.Between != nil:
|
case e.ConditionRHS.Between != nil:
|
||||||
err = e.ConditionRHS.Between.Start.aggregateRow(r)
|
err = e.ConditionRHS.Between.Start.aggregateRow(r, tableAlias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return e.ConditionRHS.Between.End.aggregateRow(r)
|
return e.ConditionRHS.Between.End.aggregateRow(r, tableAlias)
|
||||||
case e.ConditionRHS.In != nil:
|
case e.ConditionRHS.In != nil:
|
||||||
elt := e.ConditionRHS.In.ListExpression
|
elt := e.ConditionRHS.In.ListExpression
|
||||||
err = elt.aggregateRow(r)
|
err = elt.aggregateRow(r, tableAlias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
case e.ConditionRHS.Like != nil:
|
case e.ConditionRHS.Like != nil:
|
||||||
err = e.ConditionRHS.Like.Pattern.aggregateRow(r)
|
err = e.ConditionRHS.Like.Pattern.aggregateRow(r, tableAlias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return e.ConditionRHS.Like.EscapeChar.aggregateRow(r)
|
return e.ConditionRHS.Like.EscapeChar.aggregateRow(r, tableAlias)
|
||||||
default:
|
default:
|
||||||
return errInvalidASTNode
|
return errInvalidASTNode
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Operand) aggregateRow(r Record) error {
|
func (e *Operand) aggregateRow(r Record, tableAlias string) error {
|
||||||
err := e.Left.aggregateRow(r)
|
err := e.Left.aggregateRow(r, tableAlias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for _, rt := range e.Right {
|
for _, rt := range e.Right {
|
||||||
err = rt.Right.aggregateRow(r)
|
err = rt.Right.aggregateRow(r, tableAlias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -241,13 +241,13 @@ func (e *Operand) aggregateRow(r Record) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *MultOp) aggregateRow(r Record) error {
|
func (e *MultOp) aggregateRow(r Record, tableAlias string) error {
|
||||||
err := e.Left.aggregateRow(r)
|
err := e.Left.aggregateRow(r, tableAlias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for _, rt := range e.Right {
|
for _, rt := range e.Right {
|
||||||
err = rt.Right.aggregateRow(r)
|
err = rt.Right.aggregateRow(r, tableAlias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -255,29 +255,29 @@ func (e *MultOp) aggregateRow(r Record) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *UnaryTerm) aggregateRow(r Record) error {
|
func (e *UnaryTerm) aggregateRow(r Record, tableAlias string) error {
|
||||||
if e.Negated != nil {
|
if e.Negated != nil {
|
||||||
return e.Negated.Term.aggregateRow(r)
|
return e.Negated.Term.aggregateRow(r, tableAlias)
|
||||||
}
|
}
|
||||||
return e.Primary.aggregateRow(r)
|
return e.Primary.aggregateRow(r, tableAlias)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *PrimaryTerm) aggregateRow(r Record) error {
|
func (e *PrimaryTerm) aggregateRow(r Record, tableAlias string) error {
|
||||||
switch {
|
switch {
|
||||||
case e.ListExpr != nil:
|
case e.ListExpr != nil:
|
||||||
return e.ListExpr.aggregateRow(r)
|
return e.ListExpr.aggregateRow(r, tableAlias)
|
||||||
case e.SubExpression != nil:
|
case e.SubExpression != nil:
|
||||||
return e.SubExpression.aggregateRow(r)
|
return e.SubExpression.aggregateRow(r, tableAlias)
|
||||||
case e.FuncCall != nil:
|
case e.FuncCall != nil:
|
||||||
return e.FuncCall.aggregateRow(r)
|
return e.FuncCall.aggregateRow(r, tableAlias)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *FuncExpr) aggregateRow(r Record) error {
|
func (e *FuncExpr) aggregateRow(r Record, tableAlias string) error {
|
||||||
switch e.getFunctionName() {
|
switch e.getFunctionName() {
|
||||||
case aggFnAvg, aggFnSum, aggFnMax, aggFnMin, aggFnCount:
|
case aggFnAvg, aggFnSum, aggFnMax, aggFnMin, aggFnCount:
|
||||||
return e.evalAggregationNode(r)
|
return e.evalAggregationNode(r, tableAlias)
|
||||||
default:
|
default:
|
||||||
// TODO: traverse arguments and call aggregateRow on
|
// TODO: traverse arguments and call aggregateRow on
|
||||||
// them if they could be an ancestor of an
|
// them if they could be an ancestor of an
|
||||||
|
@ -19,6 +19,7 @@ package sql
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Query analysis - The query is analyzed to determine if it involves
|
// Query analysis - The query is analyzed to determine if it involves
|
||||||
@ -177,7 +178,7 @@ func (e *PrimaryTerm) analyze(s *Select) (result qProp) {
|
|||||||
case e.JPathExpr != nil:
|
case e.JPathExpr != nil:
|
||||||
// Check if the path expression is valid
|
// Check if the path expression is valid
|
||||||
if len(e.JPathExpr.PathExpr) > 0 {
|
if len(e.JPathExpr.PathExpr) > 0 {
|
||||||
if e.JPathExpr.BaseKey.String() != s.From.As {
|
if e.JPathExpr.BaseKey.String() != s.From.As && strings.ToLower(e.JPathExpr.BaseKey.String()) != baseTableName {
|
||||||
result = qProp{err: errInvalidKeypath}
|
result = qProp{err: errInvalidKeypath}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -21,7 +21,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/bcicen/jstream"
|
"github.com/bcicen/jstream"
|
||||||
"github.com/minio/simdjson-go"
|
"github.com/minio/simdjson-go"
|
||||||
@ -47,21 +46,21 @@ var (
|
|||||||
// of child nodes. The final result row is returned after all rows are
|
// of child nodes. The final result row is returned after all rows are
|
||||||
// processed, and the `getAggregate` function is called.
|
// processed, and the `getAggregate` function is called.
|
||||||
|
|
||||||
func (e *AliasedExpression) evalNode(r Record) (*Value, error) {
|
func (e *AliasedExpression) evalNode(r Record, tableAlias string) (*Value, error) {
|
||||||
return e.Expression.evalNode(r)
|
return e.Expression.evalNode(r, tableAlias)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Expression) evalNode(r Record) (*Value, error) {
|
func (e *Expression) evalNode(r Record, tableAlias string) (*Value, error) {
|
||||||
if len(e.And) == 1 {
|
if len(e.And) == 1 {
|
||||||
// In this case, result is not required to be boolean
|
// In this case, result is not required to be boolean
|
||||||
// type.
|
// type.
|
||||||
return e.And[0].evalNode(r)
|
return e.And[0].evalNode(r, tableAlias)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute OR of conditions
|
// Compute OR of conditions
|
||||||
result := false
|
result := false
|
||||||
for _, ex := range e.And {
|
for _, ex := range e.And {
|
||||||
res, err := ex.evalNode(r)
|
res, err := ex.evalNode(r, tableAlias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -74,16 +73,16 @@ func (e *Expression) evalNode(r Record) (*Value, error) {
|
|||||||
return FromBool(result), nil
|
return FromBool(result), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *AndCondition) evalNode(r Record) (*Value, error) {
|
func (e *AndCondition) evalNode(r Record, tableAlias string) (*Value, error) {
|
||||||
if len(e.Condition) == 1 {
|
if len(e.Condition) == 1 {
|
||||||
// In this case, result does not have to be boolean
|
// In this case, result does not have to be boolean
|
||||||
return e.Condition[0].evalNode(r)
|
return e.Condition[0].evalNode(r, tableAlias)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute AND of conditions
|
// Compute AND of conditions
|
||||||
result := true
|
result := true
|
||||||
for _, ex := range e.Condition {
|
for _, ex := range e.Condition {
|
||||||
res, err := ex.evalNode(r)
|
res, err := ex.evalNode(r, tableAlias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -96,14 +95,14 @@ func (e *AndCondition) evalNode(r Record) (*Value, error) {
|
|||||||
return FromBool(result), nil
|
return FromBool(result), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Condition) evalNode(r Record) (*Value, error) {
|
func (e *Condition) evalNode(r Record, tableAlias string) (*Value, error) {
|
||||||
if e.Operand != nil {
|
if e.Operand != nil {
|
||||||
// In this case, result does not have to be boolean
|
// In this case, result does not have to be boolean
|
||||||
return e.Operand.evalNode(r)
|
return e.Operand.evalNode(r, tableAlias)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute NOT of condition
|
// Compute NOT of condition
|
||||||
res, err := e.Not.evalNode(r)
|
res, err := e.Not.evalNode(r, tableAlias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -114,8 +113,8 @@ func (e *Condition) evalNode(r Record) (*Value, error) {
|
|||||||
return FromBool(!b), nil
|
return FromBool(!b), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *ConditionOperand) evalNode(r Record) (*Value, error) {
|
func (e *ConditionOperand) evalNode(r Record, tableAlias string) (*Value, error) {
|
||||||
opVal, opErr := e.Operand.evalNode(r)
|
opVal, opErr := e.Operand.evalNode(r, tableAlias)
|
||||||
if opErr != nil || e.ConditionRHS == nil {
|
if opErr != nil || e.ConditionRHS == nil {
|
||||||
return opVal, opErr
|
return opVal, opErr
|
||||||
}
|
}
|
||||||
@ -123,7 +122,7 @@ func (e *ConditionOperand) evalNode(r Record) (*Value, error) {
|
|||||||
// Need to evaluate the ConditionRHS
|
// Need to evaluate the ConditionRHS
|
||||||
switch {
|
switch {
|
||||||
case e.ConditionRHS.Compare != nil:
|
case e.ConditionRHS.Compare != nil:
|
||||||
cmpRight, cmpRErr := e.ConditionRHS.Compare.Operand.evalNode(r)
|
cmpRight, cmpRErr := e.ConditionRHS.Compare.Operand.evalNode(r, tableAlias)
|
||||||
if cmpRErr != nil {
|
if cmpRErr != nil {
|
||||||
return nil, cmpRErr
|
return nil, cmpRErr
|
||||||
}
|
}
|
||||||
@ -132,26 +131,26 @@ func (e *ConditionOperand) evalNode(r Record) (*Value, error) {
|
|||||||
return FromBool(b), err
|
return FromBool(b), err
|
||||||
|
|
||||||
case e.ConditionRHS.Between != nil:
|
case e.ConditionRHS.Between != nil:
|
||||||
return e.ConditionRHS.Between.evalBetweenNode(r, opVal)
|
return e.ConditionRHS.Between.evalBetweenNode(r, opVal, tableAlias)
|
||||||
|
|
||||||
case e.ConditionRHS.Like != nil:
|
case e.ConditionRHS.Like != nil:
|
||||||
return e.ConditionRHS.Like.evalLikeNode(r, opVal)
|
return e.ConditionRHS.Like.evalLikeNode(r, opVal, tableAlias)
|
||||||
|
|
||||||
case e.ConditionRHS.In != nil:
|
case e.ConditionRHS.In != nil:
|
||||||
return e.ConditionRHS.In.evalInNode(r, opVal)
|
return e.ConditionRHS.In.evalInNode(r, opVal, tableAlias)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return nil, errInvalidASTNode
|
return nil, errInvalidASTNode
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Between) evalBetweenNode(r Record, arg *Value) (*Value, error) {
|
func (e *Between) evalBetweenNode(r Record, arg *Value, tableAlias string) (*Value, error) {
|
||||||
stVal, stErr := e.Start.evalNode(r)
|
stVal, stErr := e.Start.evalNode(r, tableAlias)
|
||||||
if stErr != nil {
|
if stErr != nil {
|
||||||
return nil, stErr
|
return nil, stErr
|
||||||
}
|
}
|
||||||
|
|
||||||
endVal, endErr := e.End.evalNode(r)
|
endVal, endErr := e.End.evalNode(r, tableAlias)
|
||||||
if endErr != nil {
|
if endErr != nil {
|
||||||
return nil, endErr
|
return nil, endErr
|
||||||
}
|
}
|
||||||
@ -174,7 +173,7 @@ func (e *Between) evalBetweenNode(r Record, arg *Value) (*Value, error) {
|
|||||||
return FromBool(result), nil
|
return FromBool(result), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Like) evalLikeNode(r Record, arg *Value) (*Value, error) {
|
func (e *Like) evalLikeNode(r Record, arg *Value, tableAlias string) (*Value, error) {
|
||||||
inferTypeAsString(arg)
|
inferTypeAsString(arg)
|
||||||
|
|
||||||
s, ok := arg.ToString()
|
s, ok := arg.ToString()
|
||||||
@ -183,7 +182,7 @@ func (e *Like) evalLikeNode(r Record, arg *Value) (*Value, error) {
|
|||||||
return nil, errLikeInvalidInputs(err)
|
return nil, errLikeInvalidInputs(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
pattern, err1 := e.Pattern.evalNode(r)
|
pattern, err1 := e.Pattern.evalNode(r, tableAlias)
|
||||||
if err1 != nil {
|
if err1 != nil {
|
||||||
return nil, err1
|
return nil, err1
|
||||||
}
|
}
|
||||||
@ -199,7 +198,7 @@ func (e *Like) evalLikeNode(r Record, arg *Value) (*Value, error) {
|
|||||||
|
|
||||||
escape := runeZero
|
escape := runeZero
|
||||||
if e.EscapeChar != nil {
|
if e.EscapeChar != nil {
|
||||||
escapeVal, err2 := e.EscapeChar.evalNode(r)
|
escapeVal, err2 := e.EscapeChar.evalNode(r, tableAlias)
|
||||||
if err2 != nil {
|
if err2 != nil {
|
||||||
return nil, err2
|
return nil, err2
|
||||||
}
|
}
|
||||||
@ -230,14 +229,14 @@ func (e *Like) evalLikeNode(r Record, arg *Value) (*Value, error) {
|
|||||||
return FromBool(matchResult), nil
|
return FromBool(matchResult), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *ListExpr) evalNode(r Record) (*Value, error) {
|
func (e *ListExpr) evalNode(r Record, tableAlias string) (*Value, error) {
|
||||||
res := make([]Value, len(e.Elements))
|
res := make([]Value, len(e.Elements))
|
||||||
if len(e.Elements) == 1 {
|
if len(e.Elements) == 1 {
|
||||||
// If length 1, treat as single value.
|
// If length 1, treat as single value.
|
||||||
return e.Elements[0].evalNode(r)
|
return e.Elements[0].evalNode(r, tableAlias)
|
||||||
}
|
}
|
||||||
for i, elt := range e.Elements {
|
for i, elt := range e.Elements {
|
||||||
v, err := elt.evalNode(r)
|
v, err := elt.evalNode(r, tableAlias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -248,7 +247,7 @@ func (e *ListExpr) evalNode(r Record) (*Value, error) {
|
|||||||
|
|
||||||
const floatCmpTolerance = 0.000001
|
const floatCmpTolerance = 0.000001
|
||||||
|
|
||||||
func (e *In) evalInNode(r Record, lhs *Value) (*Value, error) {
|
func (e *In) evalInNode(r Record, lhs *Value, tableAlias string) (*Value, error) {
|
||||||
// Compare two values in terms of in-ness.
|
// Compare two values in terms of in-ness.
|
||||||
var cmp func(a, b Value) bool
|
var cmp func(a, b Value) bool
|
||||||
cmp = func(a, b Value) bool {
|
cmp = func(a, b Value) bool {
|
||||||
@ -283,7 +282,7 @@ func (e *In) evalInNode(r Record, lhs *Value) (*Value, error) {
|
|||||||
|
|
||||||
var rhs Value
|
var rhs Value
|
||||||
if elt := e.ListExpression; elt != nil {
|
if elt := e.ListExpression; elt != nil {
|
||||||
eltVal, err := elt.evalNode(r)
|
eltVal, err := elt.evalNode(r, tableAlias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -304,8 +303,8 @@ func (e *In) evalInNode(r Record, lhs *Value) (*Value, error) {
|
|||||||
return FromBool(cmp(rhs, *lhs)), nil
|
return FromBool(cmp(rhs, *lhs)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Operand) evalNode(r Record) (*Value, error) {
|
func (e *Operand) evalNode(r Record, tableAlias string) (*Value, error) {
|
||||||
lval, lerr := e.Left.evalNode(r)
|
lval, lerr := e.Left.evalNode(r, tableAlias)
|
||||||
if lerr != nil || len(e.Right) == 0 {
|
if lerr != nil || len(e.Right) == 0 {
|
||||||
return lval, lerr
|
return lval, lerr
|
||||||
}
|
}
|
||||||
@ -315,7 +314,7 @@ func (e *Operand) evalNode(r Record) (*Value, error) {
|
|||||||
// symbols.
|
// symbols.
|
||||||
for _, rightTerm := range e.Right {
|
for _, rightTerm := range e.Right {
|
||||||
op := rightTerm.Op
|
op := rightTerm.Op
|
||||||
rval, rerr := rightTerm.Right.evalNode(r)
|
rval, rerr := rightTerm.Right.evalNode(r, tableAlias)
|
||||||
if rerr != nil {
|
if rerr != nil {
|
||||||
return nil, rerr
|
return nil, rerr
|
||||||
}
|
}
|
||||||
@ -327,8 +326,8 @@ func (e *Operand) evalNode(r Record) (*Value, error) {
|
|||||||
return lval, nil
|
return lval, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *MultOp) evalNode(r Record) (*Value, error) {
|
func (e *MultOp) evalNode(r Record, tableAlias string) (*Value, error) {
|
||||||
lval, lerr := e.Left.evalNode(r)
|
lval, lerr := e.Left.evalNode(r, tableAlias)
|
||||||
if lerr != nil || len(e.Right) == 0 {
|
if lerr != nil || len(e.Right) == 0 {
|
||||||
return lval, lerr
|
return lval, lerr
|
||||||
}
|
}
|
||||||
@ -337,7 +336,7 @@ func (e *MultOp) evalNode(r Record) (*Value, error) {
|
|||||||
// AST node is for terms separated by *, / or % symbols.
|
// AST node is for terms separated by *, / or % symbols.
|
||||||
for _, rightTerm := range e.Right {
|
for _, rightTerm := range e.Right {
|
||||||
op := rightTerm.Op
|
op := rightTerm.Op
|
||||||
rval, rerr := rightTerm.Right.evalNode(r)
|
rval, rerr := rightTerm.Right.evalNode(r, tableAlias)
|
||||||
if rerr != nil {
|
if rerr != nil {
|
||||||
return nil, rerr
|
return nil, rerr
|
||||||
}
|
}
|
||||||
@ -350,12 +349,12 @@ func (e *MultOp) evalNode(r Record) (*Value, error) {
|
|||||||
return lval, nil
|
return lval, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *UnaryTerm) evalNode(r Record) (*Value, error) {
|
func (e *UnaryTerm) evalNode(r Record, tableAlias string) (*Value, error) {
|
||||||
if e.Negated == nil {
|
if e.Negated == nil {
|
||||||
return e.Primary.evalNode(r)
|
return e.Primary.evalNode(r, tableAlias)
|
||||||
}
|
}
|
||||||
|
|
||||||
v, err := e.Negated.Term.evalNode(r)
|
v, err := e.Negated.Term.evalNode(r, tableAlias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -368,19 +367,15 @@ func (e *UnaryTerm) evalNode(r Record) (*Value, error) {
|
|||||||
return nil, errArithMismatchedTypes
|
return nil, errArithMismatchedTypes
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *JSONPath) evalNode(r Record) (*Value, error) {
|
func (e *JSONPath) evalNode(r Record, tableAlias string) (*Value, error) {
|
||||||
// Strip the table name from the keypath.
|
alias := tableAlias
|
||||||
keypath := e.String()
|
if tableAlias == "" {
|
||||||
if strings.Contains(keypath, ".") {
|
alias = baseTableName
|
||||||
ps := strings.SplitN(keypath, ".", 2)
|
|
||||||
if len(ps) == 2 {
|
|
||||||
keypath = ps[1]
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
pathExpr := e.StripTableAlias(alias)
|
||||||
_, rawVal := r.Raw()
|
_, rawVal := r.Raw()
|
||||||
switch rowVal := rawVal.(type) {
|
switch rowVal := rawVal.(type) {
|
||||||
case jstream.KVS, simdjson.Object:
|
case jstream.KVS, simdjson.Object:
|
||||||
pathExpr := e.PathExpr
|
|
||||||
if len(pathExpr) == 0 {
|
if len(pathExpr) == 0 {
|
||||||
pathExpr = []*JSONPathElement{{Key: &ObjectKey{ID: e.BaseKey}}}
|
pathExpr = []*JSONPathElement{{Key: &ObjectKey{ID: e.BaseKey}}}
|
||||||
}
|
}
|
||||||
@ -392,7 +387,10 @@ func (e *JSONPath) evalNode(r Record) (*Value, error) {
|
|||||||
|
|
||||||
return jsonToValue(result)
|
return jsonToValue(result)
|
||||||
default:
|
default:
|
||||||
return r.Get(keypath)
|
if pathExpr[len(pathExpr)-1].Key == nil {
|
||||||
|
return nil, errInvalidKeypath
|
||||||
|
}
|
||||||
|
return r.Get(pathExpr[len(pathExpr)-1].Key.keyString())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -447,28 +445,28 @@ func jsonToValue(result interface{}) (*Value, error) {
|
|||||||
return nil, fmt.Errorf("Unhandled value type: %T", result)
|
return nil, fmt.Errorf("Unhandled value type: %T", result)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *PrimaryTerm) evalNode(r Record) (res *Value, err error) {
|
func (e *PrimaryTerm) evalNode(r Record, tableAlias string) (res *Value, err error) {
|
||||||
switch {
|
switch {
|
||||||
case e.Value != nil:
|
case e.Value != nil:
|
||||||
return e.Value.evalNode(r)
|
return e.Value.evalNode(r)
|
||||||
case e.JPathExpr != nil:
|
case e.JPathExpr != nil:
|
||||||
return e.JPathExpr.evalNode(r)
|
return e.JPathExpr.evalNode(r, tableAlias)
|
||||||
case e.ListExpr != nil:
|
case e.ListExpr != nil:
|
||||||
return e.ListExpr.evalNode(r)
|
return e.ListExpr.evalNode(r, tableAlias)
|
||||||
case e.SubExpression != nil:
|
case e.SubExpression != nil:
|
||||||
return e.SubExpression.evalNode(r)
|
return e.SubExpression.evalNode(r, tableAlias)
|
||||||
case e.FuncCall != nil:
|
case e.FuncCall != nil:
|
||||||
return e.FuncCall.evalNode(r)
|
return e.FuncCall.evalNode(r, tableAlias)
|
||||||
}
|
}
|
||||||
return nil, errInvalidASTNode
|
return nil, errInvalidASTNode
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *FuncExpr) evalNode(r Record) (res *Value, err error) {
|
func (e *FuncExpr) evalNode(r Record, tableAlias string) (res *Value, err error) {
|
||||||
switch e.getFunctionName() {
|
switch e.getFunctionName() {
|
||||||
case aggFnCount, aggFnAvg, aggFnMax, aggFnMin, aggFnSum:
|
case aggFnCount, aggFnAvg, aggFnMax, aggFnMin, aggFnSum:
|
||||||
return e.getAggregate()
|
return e.getAggregate()
|
||||||
default:
|
default:
|
||||||
return e.evalSQLFnNode(r)
|
return e.evalSQLFnNode(r, tableAlias)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -84,35 +84,35 @@ func (e *FuncExpr) getFunctionName() FuncName {
|
|||||||
|
|
||||||
// evalSQLFnNode assumes that the FuncExpr is not an aggregation
|
// evalSQLFnNode assumes that the FuncExpr is not an aggregation
|
||||||
// function.
|
// function.
|
||||||
func (e *FuncExpr) evalSQLFnNode(r Record) (res *Value, err error) {
|
func (e *FuncExpr) evalSQLFnNode(r Record, tableAlias string) (res *Value, err error) {
|
||||||
// Handle functions that have phrase arguments
|
// Handle functions that have phrase arguments
|
||||||
switch e.getFunctionName() {
|
switch e.getFunctionName() {
|
||||||
case sqlFnCast:
|
case sqlFnCast:
|
||||||
expr := e.Cast.Expr
|
expr := e.Cast.Expr
|
||||||
res, err = expr.castTo(r, strings.ToUpper(e.Cast.CastType))
|
res, err = expr.castTo(r, strings.ToUpper(e.Cast.CastType), tableAlias)
|
||||||
return
|
return
|
||||||
|
|
||||||
case sqlFnSubstring:
|
case sqlFnSubstring:
|
||||||
return handleSQLSubstring(r, e.Substring)
|
return handleSQLSubstring(r, e.Substring, tableAlias)
|
||||||
|
|
||||||
case sqlFnExtract:
|
case sqlFnExtract:
|
||||||
return handleSQLExtract(r, e.Extract)
|
return handleSQLExtract(r, e.Extract, tableAlias)
|
||||||
|
|
||||||
case sqlFnTrim:
|
case sqlFnTrim:
|
||||||
return handleSQLTrim(r, e.Trim)
|
return handleSQLTrim(r, e.Trim, tableAlias)
|
||||||
|
|
||||||
case sqlFnDateAdd:
|
case sqlFnDateAdd:
|
||||||
return handleDateAdd(r, e.DateAdd)
|
return handleDateAdd(r, e.DateAdd, tableAlias)
|
||||||
|
|
||||||
case sqlFnDateDiff:
|
case sqlFnDateDiff:
|
||||||
return handleDateDiff(r, e.DateDiff)
|
return handleDateDiff(r, e.DateDiff, tableAlias)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// For all simple argument functions, we evaluate the arguments here
|
// For all simple argument functions, we evaluate the arguments here
|
||||||
argVals := make([]*Value, len(e.SFunc.ArgsList))
|
argVals := make([]*Value, len(e.SFunc.ArgsList))
|
||||||
for i, arg := range e.SFunc.ArgsList {
|
for i, arg := range e.SFunc.ArgsList {
|
||||||
argVals[i], err = arg.evalNode(r)
|
argVals[i], err = arg.evalNode(r, tableAlias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -219,8 +219,8 @@ func upperCase(v *Value) (*Value, error) {
|
|||||||
return FromString(strings.ToUpper(s)), nil
|
return FromString(strings.ToUpper(s)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleDateAdd(r Record, d *DateAddFunc) (*Value, error) {
|
func handleDateAdd(r Record, d *DateAddFunc, tableAlias string) (*Value, error) {
|
||||||
q, err := d.Quantity.evalNode(r)
|
q, err := d.Quantity.evalNode(r, tableAlias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -230,7 +230,7 @@ func handleDateAdd(r Record, d *DateAddFunc) (*Value, error) {
|
|||||||
return nil, fmt.Errorf("QUANTITY must be a numeric argument to %s()", sqlFnDateAdd)
|
return nil, fmt.Errorf("QUANTITY must be a numeric argument to %s()", sqlFnDateAdd)
|
||||||
}
|
}
|
||||||
|
|
||||||
ts, err := d.Timestamp.evalNode(r)
|
ts, err := d.Timestamp.evalNode(r, tableAlias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -245,8 +245,8 @@ func handleDateAdd(r Record, d *DateAddFunc) (*Value, error) {
|
|||||||
return dateAdd(strings.ToUpper(d.DatePart), qty, t)
|
return dateAdd(strings.ToUpper(d.DatePart), qty, t)
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleDateDiff(r Record, d *DateDiffFunc) (*Value, error) {
|
func handleDateDiff(r Record, d *DateDiffFunc, tableAlias string) (*Value, error) {
|
||||||
tval1, err := d.Timestamp1.evalNode(r)
|
tval1, err := d.Timestamp1.evalNode(r, tableAlias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -258,7 +258,7 @@ func handleDateDiff(r Record, d *DateDiffFunc) (*Value, error) {
|
|||||||
return nil, fmt.Errorf("%s() expects two timestamp arguments", sqlFnDateDiff)
|
return nil, fmt.Errorf("%s() expects two timestamp arguments", sqlFnDateDiff)
|
||||||
}
|
}
|
||||||
|
|
||||||
tval2, err := d.Timestamp2.evalNode(r)
|
tval2, err := d.Timestamp2.evalNode(r, tableAlias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -277,12 +277,12 @@ func handleUTCNow() (*Value, error) {
|
|||||||
return FromTimestamp(time.Now().UTC()), nil
|
return FromTimestamp(time.Now().UTC()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleSQLSubstring(r Record, e *SubstringFunc) (val *Value, err error) {
|
func handleSQLSubstring(r Record, e *SubstringFunc, tableAlias string) (val *Value, err error) {
|
||||||
// Both forms `SUBSTRING('abc' FROM 2 FOR 1)` and
|
// Both forms `SUBSTRING('abc' FROM 2 FOR 1)` and
|
||||||
// SUBSTRING('abc', 2, 1) are supported.
|
// SUBSTRING('abc', 2, 1) are supported.
|
||||||
|
|
||||||
// Evaluate the string argument
|
// Evaluate the string argument
|
||||||
v1, err := e.Expr.evalNode(r)
|
v1, err := e.Expr.evalNode(r, tableAlias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -301,7 +301,7 @@ func handleSQLSubstring(r Record, e *SubstringFunc) (val *Value, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Evaluate the FROM argument
|
// Evaluate the FROM argument
|
||||||
v2, err := arg2.evalNode(r)
|
v2, err := arg2.evalNode(r, tableAlias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -315,7 +315,7 @@ func handleSQLSubstring(r Record, e *SubstringFunc) (val *Value, err error) {
|
|||||||
length := -1
|
length := -1
|
||||||
// Evaluate the optional FOR argument
|
// Evaluate the optional FOR argument
|
||||||
if arg3 != nil {
|
if arg3 != nil {
|
||||||
v3, err := arg3.evalNode(r)
|
v3, err := arg3.evalNode(r, tableAlias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -336,11 +336,11 @@ func handleSQLSubstring(r Record, e *SubstringFunc) (val *Value, err error) {
|
|||||||
return FromString(res), err
|
return FromString(res), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleSQLTrim(r Record, e *TrimFunc) (res *Value, err error) {
|
func handleSQLTrim(r Record, e *TrimFunc, tableAlias string) (res *Value, err error) {
|
||||||
chars := ""
|
chars := ""
|
||||||
ok := false
|
ok := false
|
||||||
if e.TrimChars != nil {
|
if e.TrimChars != nil {
|
||||||
charsV, cerr := e.TrimChars.evalNode(r)
|
charsV, cerr := e.TrimChars.evalNode(r, tableAlias)
|
||||||
if cerr != nil {
|
if cerr != nil {
|
||||||
return nil, cerr
|
return nil, cerr
|
||||||
}
|
}
|
||||||
@ -351,7 +351,7 @@ func handleSQLTrim(r Record, e *TrimFunc) (res *Value, err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fromV, ferr := e.TrimFrom.evalNode(r)
|
fromV, ferr := e.TrimFrom.evalNode(r, tableAlias)
|
||||||
if ferr != nil {
|
if ferr != nil {
|
||||||
return nil, ferr
|
return nil, ferr
|
||||||
}
|
}
|
||||||
@ -368,8 +368,8 @@ func handleSQLTrim(r Record, e *TrimFunc) (res *Value, err error) {
|
|||||||
return FromString(result), nil
|
return FromString(result), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleSQLExtract(r Record, e *ExtractFunc) (res *Value, err error) {
|
func handleSQLExtract(r Record, e *ExtractFunc, tableAlias string) (res *Value, err error) {
|
||||||
timeVal, verr := e.From.evalNode(r)
|
timeVal, verr := e.From.evalNode(r, tableAlias)
|
||||||
if verr != nil {
|
if verr != nil {
|
||||||
return nil, verr
|
return nil, verr
|
||||||
}
|
}
|
||||||
@ -406,8 +406,8 @@ const (
|
|||||||
castTimestamp = "TIMESTAMP"
|
castTimestamp = "TIMESTAMP"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (e *Expression) castTo(r Record, castType string) (res *Value, err error) {
|
func (e *Expression) castTo(r Record, castType string, tableAlias string) (res *Value, err error) {
|
||||||
v, err := e.evalNode(r)
|
v, err := e.evalNode(r, tableAlias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -120,6 +120,8 @@ type JSONPath struct {
|
|||||||
|
|
||||||
// Cached values:
|
// Cached values:
|
||||||
pathString string
|
pathString string
|
||||||
|
strippedTableAlias string
|
||||||
|
strippedPathExpr []*JSONPathElement
|
||||||
}
|
}
|
||||||
|
|
||||||
// AliasedExpression is an expression that can be optionally named
|
// AliasedExpression is an expression that can be optionally named
|
||||||
|
@ -46,6 +46,9 @@ type SelectStatement struct {
|
|||||||
|
|
||||||
// Count of rows that have been output.
|
// Count of rows that have been output.
|
||||||
outputCount int64
|
outputCount int64
|
||||||
|
|
||||||
|
// Table alias
|
||||||
|
tableAlias string
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseSelectStatement - parses a select query from the given string
|
// ParseSelectStatement - parses a select query from the given string
|
||||||
@ -107,6 +110,9 @@ func ParseSelectStatement(s string) (stmt SelectStatement, err error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
err = errQueryAnalysisFailure(err)
|
err = errQueryAnalysisFailure(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set table alias
|
||||||
|
stmt.tableAlias = selectAST.From.As
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -226,7 +232,7 @@ func (e *SelectStatement) IsAggregated() bool {
|
|||||||
// records have been processed. Applies only to aggregation queries.
|
// records have been processed. Applies only to aggregation queries.
|
||||||
func (e *SelectStatement) AggregateResult(output Record) error {
|
func (e *SelectStatement) AggregateResult(output Record) error {
|
||||||
for i, expr := range e.selectAST.Expression.Expressions {
|
for i, expr := range e.selectAST.Expression.Expressions {
|
||||||
v, err := expr.evalNode(nil)
|
v, err := expr.evalNode(nil, e.tableAlias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -246,7 +252,7 @@ func (e *SelectStatement) isPassingWhereClause(input Record) (bool, error) {
|
|||||||
if e.selectAST.Where == nil {
|
if e.selectAST.Where == nil {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
value, err := e.selectAST.Where.evalNode(input)
|
value, err := e.selectAST.Where.evalNode(input, e.tableAlias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
@ -272,7 +278,7 @@ func (e *SelectStatement) AggregateRow(input Record) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, expr := range e.selectAST.Expression.Expressions {
|
for _, expr := range e.selectAST.Expression.Expressions {
|
||||||
err := expr.aggregateRow(input)
|
err := expr.aggregateRow(input, e.tableAlias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -302,7 +308,7 @@ func (e *SelectStatement) Eval(input, output Record) (Record, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i, expr := range e.selectAST.Expression.Expressions {
|
for i, expr := range e.selectAST.Expression.Expressions {
|
||||||
v, err := expr.evalNode(input)
|
v, err := expr.evalNode(input, e.tableAlias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -36,6 +36,27 @@ func (e *JSONPath) String() string {
|
|||||||
return e.pathString
|
return e.pathString
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StripTableAlias removes a table alias from the path. The result is also
|
||||||
|
// cached for repeated lookups during SQL query evaluation.
|
||||||
|
func (e *JSONPath) StripTableAlias(tableAlias string) []*JSONPathElement {
|
||||||
|
if e.strippedTableAlias == tableAlias {
|
||||||
|
return e.strippedPathExpr
|
||||||
|
}
|
||||||
|
|
||||||
|
hasTableAlias := e.BaseKey.String() == tableAlias || strings.ToLower(e.BaseKey.String()) == baseTableName
|
||||||
|
var pathExpr []*JSONPathElement
|
||||||
|
if hasTableAlias {
|
||||||
|
pathExpr = e.PathExpr
|
||||||
|
} else {
|
||||||
|
pathExpr = make([]*JSONPathElement, len(e.PathExpr)+1)
|
||||||
|
pathExpr[0] = &JSONPathElement{Key: &ObjectKey{ID: e.BaseKey}}
|
||||||
|
copy(pathExpr[1:], e.PathExpr)
|
||||||
|
}
|
||||||
|
e.strippedTableAlias = tableAlias
|
||||||
|
e.strippedPathExpr = pathExpr
|
||||||
|
return e.strippedPathExpr
|
||||||
|
}
|
||||||
|
|
||||||
func (e *JSONPathElement) String() string {
|
func (e *JSONPathElement) String() string {
|
||||||
switch {
|
switch {
|
||||||
case e.Key != nil:
|
case e.Key != nil:
|
||||||
|
Loading…
Reference in New Issue
Block a user