Fix S3Select SQL column reference handling (#11957)

This change fixes handling of these types of queries:

- Double quoted column names with special characters:
    SELECT "column.name" FROM s3object
- Double quoted column names with reserved keywords:
    SELECT "CAST" FROM s3object
- Table name as prefix for column names:
    SELECT S3Object."CAST" FROM s3object
This commit is contained in:
Aditya Manthramurthy
2021-04-06 08:49:04 -07:00
committed by GitHub
parent d5d2fc9850
commit b2936243f9
8 changed files with 297 additions and 123 deletions

View File

@@ -63,7 +63,7 @@ func newAggVal(fn FuncName) *aggVal {
// current row and stores the result.
//
// On success, it returns (nil, nil).
func (e *FuncExpr) evalAggregationNode(r Record) error {
func (e *FuncExpr) evalAggregationNode(r Record, tableAlias string) error {
// It is assumed that this function is called only when
// `e` is an aggregation function.
@@ -77,13 +77,13 @@ func (e *FuncExpr) evalAggregationNode(r Record) error {
return nil
}
val, err = e.Count.ExprArg.evalNode(r)
val, err = e.Count.ExprArg.evalNode(r, tableAlias)
if err != nil {
return err
}
} else {
// Evaluate the (only) argument
val, err = e.SFunc.ArgsList[0].evalNode(r)
val, err = e.SFunc.ArgsList[0].evalNode(r, tableAlias)
if err != nil {
return err
}
@@ -149,13 +149,13 @@ func (e *FuncExpr) evalAggregationNode(r Record) error {
return err
}
func (e *AliasedExpression) aggregateRow(r Record) error {
return e.Expression.aggregateRow(r)
func (e *AliasedExpression) aggregateRow(r Record, tableAlias string) error {
return e.Expression.aggregateRow(r, tableAlias)
}
func (e *Expression) aggregateRow(r Record) error {
func (e *Expression) aggregateRow(r Record, tableAlias string) error {
for _, ex := range e.And {
err := ex.aggregateRow(r)
err := ex.aggregateRow(r, tableAlias)
if err != nil {
return err
}
@@ -163,9 +163,9 @@ func (e *Expression) aggregateRow(r Record) error {
return nil
}
func (e *ListExpr) aggregateRow(r Record) error {
func (e *ListExpr) aggregateRow(r Record, tableAlias string) error {
for _, ex := range e.Elements {
err := ex.aggregateRow(r)
err := ex.aggregateRow(r, tableAlias)
if err != nil {
return err
}
@@ -173,9 +173,9 @@ func (e *ListExpr) aggregateRow(r Record) error {
return nil
}
func (e *AndCondition) aggregateRow(r Record) error {
func (e *AndCondition) aggregateRow(r Record, tableAlias string) error {
for _, ex := range e.Condition {
err := ex.aggregateRow(r)
err := ex.aggregateRow(r, tableAlias)
if err != nil {
return err
}
@@ -183,15 +183,15 @@ func (e *AndCondition) aggregateRow(r Record) error {
return nil
}
func (e *Condition) aggregateRow(r Record) error {
func (e *Condition) aggregateRow(r Record, tableAlias string) error {
if e.Operand != nil {
return e.Operand.aggregateRow(r)
return e.Operand.aggregateRow(r, tableAlias)
}
return e.Not.aggregateRow(r)
return e.Not.aggregateRow(r, tableAlias)
}
func (e *ConditionOperand) aggregateRow(r Record) error {
err := e.Operand.aggregateRow(r)
func (e *ConditionOperand) aggregateRow(r Record, tableAlias string) error {
err := e.Operand.aggregateRow(r, tableAlias)
if err != nil {
return err
}
@@ -202,38 +202,38 @@ func (e *ConditionOperand) aggregateRow(r Record) error {
switch {
case e.ConditionRHS.Compare != nil:
return e.ConditionRHS.Compare.Operand.aggregateRow(r)
return e.ConditionRHS.Compare.Operand.aggregateRow(r, tableAlias)
case e.ConditionRHS.Between != nil:
err = e.ConditionRHS.Between.Start.aggregateRow(r)
err = e.ConditionRHS.Between.Start.aggregateRow(r, tableAlias)
if err != nil {
return err
}
return e.ConditionRHS.Between.End.aggregateRow(r)
return e.ConditionRHS.Between.End.aggregateRow(r, tableAlias)
case e.ConditionRHS.In != nil:
elt := e.ConditionRHS.In.ListExpression
err = elt.aggregateRow(r)
err = elt.aggregateRow(r, tableAlias)
if err != nil {
return err
}
return nil
case e.ConditionRHS.Like != nil:
err = e.ConditionRHS.Like.Pattern.aggregateRow(r)
err = e.ConditionRHS.Like.Pattern.aggregateRow(r, tableAlias)
if err != nil {
return err
}
return e.ConditionRHS.Like.EscapeChar.aggregateRow(r)
return e.ConditionRHS.Like.EscapeChar.aggregateRow(r, tableAlias)
default:
return errInvalidASTNode
}
}
func (e *Operand) aggregateRow(r Record) error {
err := e.Left.aggregateRow(r)
func (e *Operand) aggregateRow(r Record, tableAlias string) error {
err := e.Left.aggregateRow(r, tableAlias)
if err != nil {
return err
}
for _, rt := range e.Right {
err = rt.Right.aggregateRow(r)
err = rt.Right.aggregateRow(r, tableAlias)
if err != nil {
return err
}
@@ -241,13 +241,13 @@ func (e *Operand) aggregateRow(r Record) error {
return nil
}
func (e *MultOp) aggregateRow(r Record) error {
err := e.Left.aggregateRow(r)
func (e *MultOp) aggregateRow(r Record, tableAlias string) error {
err := e.Left.aggregateRow(r, tableAlias)
if err != nil {
return err
}
for _, rt := range e.Right {
err = rt.Right.aggregateRow(r)
err = rt.Right.aggregateRow(r, tableAlias)
if err != nil {
return err
}
@@ -255,29 +255,29 @@ func (e *MultOp) aggregateRow(r Record) error {
return nil
}
func (e *UnaryTerm) aggregateRow(r Record) error {
func (e *UnaryTerm) aggregateRow(r Record, tableAlias string) error {
if e.Negated != nil {
return e.Negated.Term.aggregateRow(r)
return e.Negated.Term.aggregateRow(r, tableAlias)
}
return e.Primary.aggregateRow(r)
return e.Primary.aggregateRow(r, tableAlias)
}
func (e *PrimaryTerm) aggregateRow(r Record) error {
func (e *PrimaryTerm) aggregateRow(r Record, tableAlias string) error {
switch {
case e.ListExpr != nil:
return e.ListExpr.aggregateRow(r)
return e.ListExpr.aggregateRow(r, tableAlias)
case e.SubExpression != nil:
return e.SubExpression.aggregateRow(r)
return e.SubExpression.aggregateRow(r, tableAlias)
case e.FuncCall != nil:
return e.FuncCall.aggregateRow(r)
return e.FuncCall.aggregateRow(r, tableAlias)
}
return nil
}
func (e *FuncExpr) aggregateRow(r Record) error {
func (e *FuncExpr) aggregateRow(r Record, tableAlias string) error {
switch e.getFunctionName() {
case aggFnAvg, aggFnSum, aggFnMax, aggFnMin, aggFnCount:
return e.evalAggregationNode(r)
return e.evalAggregationNode(r, tableAlias)
default:
// TODO: traverse arguments and call aggregateRow on
// them if they could be an ancestor of an