minio/pkg/s3select/sql/statement.go
Ryan Tam bd56f80250 Fix ignored alias for aggregate result in S3 Select (#7849)
The SQL parser as it stands right now ignores alias for aggregate
result, e.g. `SELECT COUNT(*) AS thing FROM s3object` doesn't actually
return record like `{"thing": 42}`, it returns a record like `{"_1": 42}`.
Column alias for aggregate result is supported in AWS's S3 Select, so
this commit fixes that by respecting the `expr.As` in the expression.

Also improve test for S3 select

On top of testing a simple `SELECT` query, we want to test a few more
"advanced" queries (e.g. aggregation).

Convert existing tests into table driven tests[1], and add the new test
cases with "advanced" queries into them.

[1] - https://github.com/golang/go/wiki/TableDrivenTests
2019-07-03 16:34:54 -07:00

273 lines
6.3 KiB
Go

/*
* MinIO Cloud Storage, (C) 2019 MinIO, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import (
"errors"
"fmt"
"strings"
"github.com/bcicen/jstream"
)
var (
errBadLimitSpecified = errors.New("Limit value must be a positive integer")
)
const (
baseTableName = "s3object"
)
// SelectStatement is the top level parsed and analyzed structure
type SelectStatement struct {
selectAST *Select
// Analysis result of the statement
selectQProp qProp
// Result of parsing the limit clause if one is present
// (otherwise -1)
limitValue int64
// Count of rows that have been output.
outputCount int64
}
// ParseSelectStatement - parses a select query from the given string
// and analyzes it.
func ParseSelectStatement(s string) (stmt SelectStatement, err error) {
var selectAST Select
err = SQLParser.ParseString(s, &selectAST)
if err != nil {
err = errQueryParseFailure(err)
return
}
stmt.selectAST = &selectAST
// Check the parsed limit value
stmt.limitValue, err = parseLimit(selectAST.Limit)
if err != nil {
err = errQueryAnalysisFailure(err)
return
}
// Analyze where clause
if selectAST.Where != nil {
whereQProp := selectAST.Where.analyze(&selectAST)
if whereQProp.err != nil {
err = errQueryAnalysisFailure(fmt.Errorf("Where clause error: %v", whereQProp.err))
return
}
if whereQProp.isAggregation {
err = errQueryAnalysisFailure(errors.New("WHERE clause cannot have an aggregation"))
return
}
}
// Validate table name
err = validateTableName(selectAST.From)
if err != nil {
return
}
// Analyze main select expression
stmt.selectQProp = selectAST.Expression.analyze(&selectAST)
err = stmt.selectQProp.err
if err != nil {
err = errQueryAnalysisFailure(err)
}
return
}
func validateTableName(from *TableExpression) error {
if strings.ToLower(from.Table.BaseKey.String()) != baseTableName {
return errBadTableName(errors.New("table name must be `s3object`"))
}
if len(from.Table.PathExpr) > 0 {
if !from.Table.PathExpr[0].ArrayWildcard {
return errBadTableName(errors.New("keypath table name is invalid - please check the service documentation"))
}
}
return nil
}
func parseLimit(v *LitValue) (int64, error) {
switch {
case v == nil:
return -1, nil
case v.Number == nil:
return -1, errBadLimitSpecified
default:
r := int64(*v.Number)
if r < 0 {
return -1, errBadLimitSpecified
}
return r, nil
}
}
// EvalFrom evaluates the From clause on the input record. It only
// applies to JSON input data format (currently).
func (e *SelectStatement) EvalFrom(format string, input Record) (Record, error) {
if e.selectAST.From.HasKeypath() {
if format == "json" {
objFmt, rawVal := input.Raw()
if objFmt != SelectFmtJSON {
return nil, errDataSource(errors.New("unexpected non JSON input"))
}
jsonRec := rawVal.(jstream.KVS)
txedRec, err := jsonpathEval(e.selectAST.From.Table.PathExpr[1:], jsonRec)
if err != nil {
return nil, err
}
var kvs jstream.KVS
switch v := txedRec.(type) {
case jstream.KVS:
kvs = v
default:
kvs = jstream.KVS{jstream.KV{Key: "_1", Value: v}}
}
if err = input.Replace(kvs); err != nil {
return nil, err
}
return input, nil
}
return nil, errDataSource(errors.New("path not supported"))
}
return input, nil
}
// IsAggregated returns if the statement involves SQL aggregation
func (e *SelectStatement) IsAggregated() bool {
return e.selectQProp.isAggregation
}
// AggregateResult - returns the aggregated result after all input
// records have been processed. Applies only to aggregation queries.
func (e *SelectStatement) AggregateResult(output Record) error {
for i, expr := range e.selectAST.Expression.Expressions {
v, err := expr.evalNode(nil)
if err != nil {
return err
}
if expr.As != "" {
output.Set(expr.As, v)
} else {
output.Set(fmt.Sprintf("_%d", i+1), v)
}
}
return nil
}
func (e *SelectStatement) isPassingWhereClause(input Record) (bool, error) {
if e.selectAST.Where == nil {
return true, nil
}
value, err := e.selectAST.Where.evalNode(input)
if err != nil {
return false, err
}
b, ok := value.ToBool()
if !ok {
err = fmt.Errorf("WHERE expression did not return bool")
return false, err
}
return b, nil
}
// AggregateRow - aggregates the input record. Applies only to
// aggregation queries.
func (e *SelectStatement) AggregateRow(input Record) error {
ok, err := e.isPassingWhereClause(input)
if err != nil {
return err
}
if !ok {
return nil
}
for _, expr := range e.selectAST.Expression.Expressions {
err := expr.aggregateRow(input)
if err != nil {
return err
}
}
return nil
}
// Eval - evaluates the Select statement for the given record. It
// applies only to non-aggregation queries.
func (e *SelectStatement) Eval(input, output Record) (Record, error) {
ok, err := e.isPassingWhereClause(input)
if err != nil || !ok {
// Either error or row did not pass where clause
return nil, err
}
if e.selectAST.Expression.All {
// Return the input record for `SELECT * FROM
// .. WHERE ..`
// Update count of records output.
if e.limitValue > -1 {
e.outputCount++
}
return input, nil
}
for i, expr := range e.selectAST.Expression.Expressions {
v, err := expr.evalNode(input)
if err != nil {
return nil, err
}
// Pick output column names
if expr.As != "" {
output.Set(expr.As, v)
} else if comp, ok := getLastKeypathComponent(expr.Expression); ok {
output.Set(comp, v)
} else {
output.Set(fmt.Sprintf("_%d", i+1), v)
}
}
// Update count of records output.
if e.limitValue > -1 {
e.outputCount++
}
return output, nil
}
// LimitReached - returns true if the number of records output has
// reached the value of the `LIMIT` clause.
func (e *SelectStatement) LimitReached() bool {
if e.limitValue == -1 {
return false
}
return e.outputCount >= e.limitValue
}