minio/pkg/s3select/sql/statement.go
Aditya Manthramurthy b2936243f9
Fix S3Select SQL column reference handling (#11957)
This change fixes handling of these types of queries:

- Double quoted column names with special characters:
    SELECT "column.name" FROM s3object
- Double quoted column names with reserved keywords:
    SELECT "CAST" FROM s3object
- Table name as prefix for column names:
    SELECT S3Object."CAST" FROM s3object
2021-04-06 08:49:04 -07:00

345 lines
8.6 KiB
Go

/*
* MinIO Cloud Storage, (C) 2019 MinIO, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import (
"errors"
"fmt"
"strings"
"github.com/bcicen/jstream"
"github.com/minio/simdjson-go"
)
var (
errBadLimitSpecified = errors.New("Limit value must be a positive integer")
)
const (
baseTableName = "s3object"
)
// SelectStatement is the top level parsed and analyzed structure
type SelectStatement struct {
selectAST *Select
// Analysis result of the statement
selectQProp qProp
// Result of parsing the limit clause if one is present
// (otherwise -1)
limitValue int64
// Count of rows that have been output.
outputCount int64
// Table alias
tableAlias string
}
// ParseSelectStatement - parses a select query from the given string
// and analyzes it.
func ParseSelectStatement(s string) (stmt SelectStatement, err error) {
var selectAST Select
err = SQLParser.ParseString(s, &selectAST)
if err != nil {
err = errQueryParseFailure(err)
return
}
// Check if select is "SELECT s.* from S3Object s"
if !selectAST.Expression.All &&
len(selectAST.Expression.Expressions) == 1 &&
len(selectAST.Expression.Expressions[0].Expression.And) == 1 &&
len(selectAST.Expression.Expressions[0].Expression.And[0].Condition) == 1 &&
selectAST.Expression.Expressions[0].Expression.And[0].Condition[0].Operand != nil &&
selectAST.Expression.Expressions[0].Expression.And[0].Condition[0].Operand.Operand.Left != nil &&
selectAST.Expression.Expressions[0].Expression.And[0].Condition[0].Operand.Operand.Left.Left != nil &&
selectAST.Expression.Expressions[0].Expression.And[0].Condition[0].Operand.Operand.Left.Left.Primary != nil &&
selectAST.Expression.Expressions[0].Expression.And[0].Condition[0].Operand.Operand.Left.Left.Primary.JPathExpr != nil {
if selectAST.Expression.Expressions[0].Expression.And[0].Condition[0].Operand.Operand.Left.Left.Primary.JPathExpr.String() == selectAST.From.As+".*" {
selectAST.Expression.All = true
}
}
stmt.selectAST = &selectAST
// Check the parsed limit value
stmt.limitValue, err = parseLimit(selectAST.Limit)
if err != nil {
err = errQueryAnalysisFailure(err)
return
}
// Analyze where clause
if selectAST.Where != nil {
whereQProp := selectAST.Where.analyze(&selectAST)
if whereQProp.err != nil {
err = errQueryAnalysisFailure(fmt.Errorf("Where clause error: %w", whereQProp.err))
return
}
if whereQProp.isAggregation {
err = errQueryAnalysisFailure(errors.New("WHERE clause cannot have an aggregation"))
return
}
}
// Validate table name
err = validateTableName(selectAST.From)
if err != nil {
return
}
// Analyze main select expression
stmt.selectQProp = selectAST.Expression.analyze(&selectAST)
err = stmt.selectQProp.err
if err != nil {
err = errQueryAnalysisFailure(err)
}
// Set table alias
stmt.tableAlias = selectAST.From.As
return
}
func validateTableName(from *TableExpression) error {
if strings.ToLower(from.Table.BaseKey.String()) != baseTableName {
return errBadTableName(errors.New("table name must be `s3object`"))
}
if len(from.Table.PathExpr) > 0 {
if !from.Table.PathExpr[0].ArrayWildcard {
return errBadTableName(errors.New("keypath table name is invalid - please check the service documentation"))
}
}
return nil
}
func parseLimit(v *LitValue) (int64, error) {
switch {
case v == nil:
return -1, nil
case v.Int == nil:
return -1, errBadLimitSpecified
default:
r := int64(*v.Int)
if r < 0 {
return -1, errBadLimitSpecified
}
return r, nil
}
}
// EvalFrom evaluates the From clause on the input record. It only
// applies to JSON input data format (currently).
func (e *SelectStatement) EvalFrom(format string, input Record) ([]*Record, error) {
if !e.selectAST.From.HasKeypath() {
return []*Record{&input}, nil
}
_, rawVal := input.Raw()
if format != "json" {
return nil, errDataSource(errors.New("path not supported"))
}
switch rec := rawVal.(type) {
case jstream.KVS:
txedRec, _, err := jsonpathEval(e.selectAST.From.Table.PathExpr[1:], rec)
if err != nil {
return nil, err
}
var kvs jstream.KVS
switch v := txedRec.(type) {
case jstream.KVS:
kvs = v
case []interface{}:
recs := make([]*Record, len(v))
for i, val := range v {
tmpRec := input.Clone(nil)
if err = tmpRec.Replace(val); err != nil {
return nil, err
}
recs[i] = &tmpRec
}
return recs, nil
default:
kvs = jstream.KVS{jstream.KV{Key: "_1", Value: v}}
}
if err = input.Replace(kvs); err != nil {
return nil, err
}
return []*Record{&input}, nil
case simdjson.Object:
txedRec, _, err := jsonpathEval(e.selectAST.From.Table.PathExpr[1:], rec)
if err != nil {
return nil, err
}
switch v := txedRec.(type) {
case simdjson.Object:
err := input.Replace(v)
if err != nil {
return nil, err
}
case []interface{}:
recs := make([]*Record, len(v))
for i, val := range v {
tmpRec := input.Clone(nil)
if err = tmpRec.Replace(val); err != nil {
return nil, err
}
recs[i] = &tmpRec
}
return recs, nil
default:
input.Reset()
input, err = input.Set("_1", &Value{value: v})
if err != nil {
return nil, err
}
}
return []*Record{&input}, nil
}
return nil, errDataSource(errors.New("unexpected non JSON input"))
}
// IsAggregated returns if the statement involves SQL aggregation
func (e *SelectStatement) IsAggregated() bool {
return e.selectQProp.isAggregation
}
// AggregateResult - returns the aggregated result after all input
// records have been processed. Applies only to aggregation queries.
func (e *SelectStatement) AggregateResult(output Record) error {
for i, expr := range e.selectAST.Expression.Expressions {
v, err := expr.evalNode(nil, e.tableAlias)
if err != nil {
return err
}
if expr.As != "" {
output, err = output.Set(expr.As, v)
} else {
output, err = output.Set(fmt.Sprintf("_%d", i+1), v)
}
if err != nil {
return err
}
}
return nil
}
func (e *SelectStatement) isPassingWhereClause(input Record) (bool, error) {
if e.selectAST.Where == nil {
return true, nil
}
value, err := e.selectAST.Where.evalNode(input, e.tableAlias)
if err != nil {
return false, err
}
b, ok := value.ToBool()
if !ok {
err = fmt.Errorf("WHERE expression did not return bool")
return false, err
}
return b, nil
}
// AggregateRow - aggregates the input record. Applies only to
// aggregation queries.
func (e *SelectStatement) AggregateRow(input Record) error {
ok, err := e.isPassingWhereClause(input)
if err != nil {
return err
}
if !ok {
return nil
}
for _, expr := range e.selectAST.Expression.Expressions {
err := expr.aggregateRow(input, e.tableAlias)
if err != nil {
return err
}
}
return nil
}
// Eval - evaluates the Select statement for the given record. It
// applies only to non-aggregation queries.
// The function returns whether the statement passed the WHERE clause and should be outputted.
func (e *SelectStatement) Eval(input, output Record) (Record, error) {
ok, err := e.isPassingWhereClause(input)
if err != nil || !ok {
// Either error or row did not pass where clause
return nil, err
}
if e.selectAST.Expression.All {
// Return the input record for `SELECT * FROM
// .. WHERE ..`
// Update count of records output.
if e.limitValue > -1 {
e.outputCount++
}
return input.Clone(output), nil
}
for i, expr := range e.selectAST.Expression.Expressions {
v, err := expr.evalNode(input, e.tableAlias)
if err != nil {
return nil, err
}
// Pick output column names
if expr.As != "" {
output, err = output.Set(expr.As, v)
} else if comp, ok := getLastKeypathComponent(expr.Expression); ok {
output, err = output.Set(comp, v)
} else {
output, err = output.Set(fmt.Sprintf("_%d", i+1), v)
}
if err != nil {
return nil, err
}
}
// Update count of records output.
if e.limitValue > -1 {
e.outputCount++
}
return output, nil
}
// LimitReached - returns true if the number of records output has
// reached the value of the `LIMIT` clause.
func (e *SelectStatement) LimitReached() bool {
if e.limitValue == -1 {
return false
}
return e.outputCount >= e.limitValue
}