minio/internal/s3select/sql/aggregation.go
Aditya Manthramurthy 496027b589
Fix precendence bug in S3Select SQL IN clauses (#18708)
Fixes a precendence issue in SQL Select where `a in b and c = 3` was parsed as `a
in (b and c = 3)`.

Fixes #18682
2023-12-22 23:19:11 -08:00

330 lines
7.9 KiB
Go

// Copyright (c) 2015-2021 MinIO, Inc.
//
// This file is part of MinIO Object Storage stack
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package sql
import (
"errors"
"fmt"
)
// Aggregation Function name constants
const (
aggFnAvg FuncName = "AVG"
aggFnCount FuncName = "COUNT"
aggFnMax FuncName = "MAX"
aggFnMin FuncName = "MIN"
aggFnSum FuncName = "SUM"
)
var (
errNonNumericArg = func(fnStr FuncName) error {
return fmt.Errorf("%s() requires a numeric argument", fnStr)
}
errInvalidAggregation = errors.New("Invalid aggregation seen")
)
type aggVal struct {
runningSum *Value
runningCount int64
runningMax, runningMin *Value
// Stores if at least one record has been seen
seen bool
}
func newAggVal(fn FuncName) *aggVal {
switch fn {
case aggFnAvg, aggFnSum:
return &aggVal{runningSum: FromFloat(0)}
case aggFnMin:
return &aggVal{runningMin: FromInt(0)}
case aggFnMax:
return &aggVal{runningMax: FromInt(0)}
default:
return &aggVal{}
}
}
// evalAggregationNode - performs partial computation using the
// current row and stores the result.
//
// On success, it returns (nil, nil).
func (e *FuncExpr) evalAggregationNode(r Record, tableAlias string) error {
// It is assumed that this function is called only when
// `e` is an aggregation function.
var val *Value
var err error
funcName := e.getFunctionName()
if aggFnCount == funcName {
if e.Count.StarArg {
// Handle COUNT(*)
e.aggregate.runningCount++
return nil
}
val, err = e.Count.ExprArg.evalNode(r, tableAlias)
if err != nil {
return err
}
} else {
// Evaluate the (only) argument
val, err = e.SFunc.ArgsList[0].evalNode(r, tableAlias)
if err != nil {
return err
}
}
if val.IsNull() {
// E.g. the column or field does not exist in the
// record - in all such cases the aggregation is not
// updated.
return nil
}
argVal := val
if funcName != aggFnCount {
// All aggregation functions, except COUNT require a
// numeric argument.
// Here, we diverge from Amazon S3 behavior by
// inferring untyped values are numbers.
if !argVal.isNumeric() {
if i, ok := argVal.bytesToInt(); ok {
argVal.setInt(i)
} else if f, ok := argVal.bytesToFloat(); ok {
argVal.setFloat(f)
} else {
return errNonNumericArg(funcName)
}
}
}
// Mark that we have seen one non-null value.
isFirstRow := false
if !e.aggregate.seen {
e.aggregate.seen = true
isFirstRow = true
}
switch funcName {
case aggFnCount:
// For all non-null values, the count is incremented.
e.aggregate.runningCount++
case aggFnAvg, aggFnSum:
e.aggregate.runningCount++
// Convert to float.
f, ok := argVal.ToFloat()
if !ok {
return fmt.Errorf("Could not convert value %v (%s) to a number", argVal.value, argVal.GetTypeString())
}
argVal.setFloat(f)
err = e.aggregate.runningSum.arithOp(opPlus, argVal)
case aggFnMin:
err = e.aggregate.runningMin.minmax(argVal, false, isFirstRow)
case aggFnMax:
err = e.aggregate.runningMax.minmax(argVal, true, isFirstRow)
default:
err = errInvalidAggregation
}
return err
}
func (e *AliasedExpression) aggregateRow(r Record, tableAlias string) error {
return e.Expression.aggregateRow(r, tableAlias)
}
func (e *Expression) aggregateRow(r Record, tableAlias string) error {
for _, ex := range e.And {
err := ex.aggregateRow(r, tableAlias)
if err != nil {
return err
}
}
return nil
}
func (e *ListExpr) aggregateRow(r Record, tableAlias string) error {
for _, ex := range e.Elements {
err := ex.aggregateRow(r, tableAlias)
if err != nil {
return err
}
}
return nil
}
func (e *AndCondition) aggregateRow(r Record, tableAlias string) error {
for _, ex := range e.Condition {
err := ex.aggregateRow(r, tableAlias)
if err != nil {
return err
}
}
return nil
}
func (e *Condition) aggregateRow(r Record, tableAlias string) error {
if e.Operand != nil {
return e.Operand.aggregateRow(r, tableAlias)
}
return e.Not.aggregateRow(r, tableAlias)
}
func (e *ConditionOperand) aggregateRow(r Record, tableAlias string) error {
err := e.Operand.aggregateRow(r, tableAlias)
if err != nil {
return err
}
if e.ConditionRHS == nil {
return nil
}
switch {
case e.ConditionRHS.Compare != nil:
return e.ConditionRHS.Compare.Operand.aggregateRow(r, tableAlias)
case e.ConditionRHS.Between != nil:
err = e.ConditionRHS.Between.Start.aggregateRow(r, tableAlias)
if err != nil {
return err
}
return e.ConditionRHS.Between.End.aggregateRow(r, tableAlias)
case e.ConditionRHS.In != nil:
if e.ConditionRHS.In.ListExpr != nil {
return e.ConditionRHS.In.ListExpr.aggregateRow(r, tableAlias)
}
return nil
case e.ConditionRHS.Like != nil:
err = e.ConditionRHS.Like.Pattern.aggregateRow(r, tableAlias)
if err != nil {
return err
}
return e.ConditionRHS.Like.EscapeChar.aggregateRow(r, tableAlias)
default:
return errInvalidASTNode
}
}
func (e *Operand) aggregateRow(r Record, tableAlias string) error {
err := e.Left.aggregateRow(r, tableAlias)
if err != nil {
return err
}
for _, rt := range e.Right {
err = rt.Right.aggregateRow(r, tableAlias)
if err != nil {
return err
}
}
return nil
}
func (e *MultOp) aggregateRow(r Record, tableAlias string) error {
err := e.Left.aggregateRow(r, tableAlias)
if err != nil {
return err
}
for _, rt := range e.Right {
err = rt.Right.aggregateRow(r, tableAlias)
if err != nil {
return err
}
}
return nil
}
func (e *UnaryTerm) aggregateRow(r Record, tableAlias string) error {
if e.Negated != nil {
return e.Negated.Term.aggregateRow(r, tableAlias)
}
return e.Primary.aggregateRow(r, tableAlias)
}
func (e *PrimaryTerm) aggregateRow(r Record, tableAlias string) error {
switch {
case e.ListExpr != nil:
return e.ListExpr.aggregateRow(r, tableAlias)
case e.SubExpression != nil:
return e.SubExpression.aggregateRow(r, tableAlias)
case e.FuncCall != nil:
return e.FuncCall.aggregateRow(r, tableAlias)
}
return nil
}
func (e *FuncExpr) aggregateRow(r Record, tableAlias string) error {
switch e.getFunctionName() {
case aggFnAvg, aggFnSum, aggFnMax, aggFnMin, aggFnCount:
return e.evalAggregationNode(r, tableAlias)
default:
// TODO: traverse arguments and call aggregateRow on
// them if they could be an ancestor of an
// aggregation.
}
return nil
}
// getAggregate() implementation for each AST node follows. This is
// called after calling aggregateRow() on each input row, to calculate
// the final aggregate result.
func (e *FuncExpr) getAggregate() (*Value, error) {
switch e.getFunctionName() {
case aggFnCount:
return FromInt(e.aggregate.runningCount), nil
case aggFnAvg:
if e.aggregate.runningCount == 0 {
// No rows were seen by AVG.
return FromNull(), nil
}
err := e.aggregate.runningSum.arithOp(opDivide, FromInt(e.aggregate.runningCount))
return e.aggregate.runningSum, err
case aggFnMin:
if !e.aggregate.seen {
// No rows were seen by MIN
return FromNull(), nil
}
return e.aggregate.runningMin, nil
case aggFnMax:
if !e.aggregate.seen {
// No rows were seen by MAX
return FromNull(), nil
}
return e.aggregate.runningMax, nil
case aggFnSum:
// TODO: check if returning 0 when no rows were seen
// by SUM is expected behavior.
return e.aggregate.runningSum, nil
default:
// TODO:
}
return nil, errInvalidAggregation
}