// Copyright (c) 2015-2021 MinIO, Inc. // // This file is part of MinIO Object Storage stack // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // This program is distributed in the hope that it will be useful // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . package sql import ( "errors" "fmt" ) // Aggregation Function name constants const ( aggFnAvg FuncName = "AVG" aggFnCount FuncName = "COUNT" aggFnMax FuncName = "MAX" aggFnMin FuncName = "MIN" aggFnSum FuncName = "SUM" ) var ( errNonNumericArg = func(fnStr FuncName) error { return fmt.Errorf("%s() requires a numeric argument", fnStr) } errInvalidAggregation = errors.New("Invalid aggregation seen") ) type aggVal struct { runningSum *Value runningCount int64 runningMax, runningMin *Value // Stores if at least one record has been seen seen bool } func newAggVal(fn FuncName) *aggVal { switch fn { case aggFnAvg, aggFnSum: return &aggVal{runningSum: FromFloat(0)} case aggFnMin: return &aggVal{runningMin: FromInt(0)} case aggFnMax: return &aggVal{runningMax: FromInt(0)} default: return &aggVal{} } } // evalAggregationNode - performs partial computation using the // current row and stores the result. // // On success, it returns (nil, nil). func (e *FuncExpr) evalAggregationNode(r Record, tableAlias string) error { // It is assumed that this function is called only when // `e` is an aggregation function. var val *Value var err error funcName := e.getFunctionName() if aggFnCount == funcName { if e.Count.StarArg { // Handle COUNT(*) e.aggregate.runningCount++ return nil } val, err = e.Count.ExprArg.evalNode(r, tableAlias) if err != nil { return err } } else { // Evaluate the (only) argument val, err = e.SFunc.ArgsList[0].evalNode(r, tableAlias) if err != nil { return err } } if val.IsNull() { // E.g. the column or field does not exist in the // record - in all such cases the aggregation is not // updated. return nil } argVal := val if funcName != aggFnCount { // All aggregation functions, except COUNT require a // numeric argument. // Here, we diverge from Amazon S3 behavior by // inferring untyped values are numbers. if !argVal.isNumeric() { if i, ok := argVal.bytesToInt(); ok { argVal.setInt(i) } else if f, ok := argVal.bytesToFloat(); ok { argVal.setFloat(f) } else { return errNonNumericArg(funcName) } } } // Mark that we have seen one non-null value. isFirstRow := false if !e.aggregate.seen { e.aggregate.seen = true isFirstRow = true } switch funcName { case aggFnCount: // For all non-null values, the count is incremented. e.aggregate.runningCount++ case aggFnAvg, aggFnSum: e.aggregate.runningCount++ // Convert to float. f, ok := argVal.ToFloat() if !ok { return fmt.Errorf("Could not convert value %v (%s) to a number", argVal.value, argVal.GetTypeString()) } argVal.setFloat(f) err = e.aggregate.runningSum.arithOp(opPlus, argVal) case aggFnMin: err = e.aggregate.runningMin.minmax(argVal, false, isFirstRow) case aggFnMax: err = e.aggregate.runningMax.minmax(argVal, true, isFirstRow) default: err = errInvalidAggregation } return err } func (e *AliasedExpression) aggregateRow(r Record, tableAlias string) error { return e.Expression.aggregateRow(r, tableAlias) } func (e *Expression) aggregateRow(r Record, tableAlias string) error { for _, ex := range e.And { err := ex.aggregateRow(r, tableAlias) if err != nil { return err } } return nil } func (e *ListExpr) aggregateRow(r Record, tableAlias string) error { for _, ex := range e.Elements { err := ex.aggregateRow(r, tableAlias) if err != nil { return err } } return nil } func (e *AndCondition) aggregateRow(r Record, tableAlias string) error { for _, ex := range e.Condition { err := ex.aggregateRow(r, tableAlias) if err != nil { return err } } return nil } func (e *Condition) aggregateRow(r Record, tableAlias string) error { if e.Operand != nil { return e.Operand.aggregateRow(r, tableAlias) } return e.Not.aggregateRow(r, tableAlias) } func (e *ConditionOperand) aggregateRow(r Record, tableAlias string) error { err := e.Operand.aggregateRow(r, tableAlias) if err != nil { return err } if e.ConditionRHS == nil { return nil } switch { case e.ConditionRHS.Compare != nil: return e.ConditionRHS.Compare.Operand.aggregateRow(r, tableAlias) case e.ConditionRHS.Between != nil: err = e.ConditionRHS.Between.Start.aggregateRow(r, tableAlias) if err != nil { return err } return e.ConditionRHS.Between.End.aggregateRow(r, tableAlias) case e.ConditionRHS.In != nil: elt := e.ConditionRHS.In.ListExpression err = elt.aggregateRow(r, tableAlias) if err != nil { return err } return nil case e.ConditionRHS.Like != nil: err = e.ConditionRHS.Like.Pattern.aggregateRow(r, tableAlias) if err != nil { return err } return e.ConditionRHS.Like.EscapeChar.aggregateRow(r, tableAlias) default: return errInvalidASTNode } } func (e *Operand) aggregateRow(r Record, tableAlias string) error { err := e.Left.aggregateRow(r, tableAlias) if err != nil { return err } for _, rt := range e.Right { err = rt.Right.aggregateRow(r, tableAlias) if err != nil { return err } } return nil } func (e *MultOp) aggregateRow(r Record, tableAlias string) error { err := e.Left.aggregateRow(r, tableAlias) if err != nil { return err } for _, rt := range e.Right { err = rt.Right.aggregateRow(r, tableAlias) if err != nil { return err } } return nil } func (e *UnaryTerm) aggregateRow(r Record, tableAlias string) error { if e.Negated != nil { return e.Negated.Term.aggregateRow(r, tableAlias) } return e.Primary.aggregateRow(r, tableAlias) } func (e *PrimaryTerm) aggregateRow(r Record, tableAlias string) error { switch { case e.ListExpr != nil: return e.ListExpr.aggregateRow(r, tableAlias) case e.SubExpression != nil: return e.SubExpression.aggregateRow(r, tableAlias) case e.FuncCall != nil: return e.FuncCall.aggregateRow(r, tableAlias) } return nil } func (e *FuncExpr) aggregateRow(r Record, tableAlias string) error { switch e.getFunctionName() { case aggFnAvg, aggFnSum, aggFnMax, aggFnMin, aggFnCount: return e.evalAggregationNode(r, tableAlias) default: // TODO: traverse arguments and call aggregateRow on // them if they could be an ancestor of an // aggregation. } return nil } // getAggregate() implementation for each AST node follows. This is // called after calling aggregateRow() on each input row, to calculate // the final aggregate result. func (e *FuncExpr) getAggregate() (*Value, error) { switch e.getFunctionName() { case aggFnCount: return FromInt(e.aggregate.runningCount), nil case aggFnAvg: if e.aggregate.runningCount == 0 { // No rows were seen by AVG. return FromNull(), nil } err := e.aggregate.runningSum.arithOp(opDivide, FromInt(e.aggregate.runningCount)) return e.aggregate.runningSum, err case aggFnMin: if !e.aggregate.seen { // No rows were seen by MIN return FromNull(), nil } return e.aggregate.runningMin, nil case aggFnMax: if !e.aggregate.seen { // No rows were seen by MAX return FromNull(), nil } return e.aggregate.runningMax, nil case aggFnSum: // TODO: check if returning 0 when no rows were seen // by SUM is expected behavior. return e.aggregate.runningSum, nil default: // TODO: } return nil, errInvalidAggregation }