rename all remaining packages to internal/ (#12418)

This is to ensure that there are no projects
that try to import `minio/minio/pkg` into
their own repo. Any such common packages should
go to `https://github.com/minio/pkg`
This commit is contained in:
Harshavardhana
2021-06-01 14:59:40 -07:00
committed by GitHub
parent bf87c4b1e4
commit 1f262daf6f
540 changed files with 757 additions and 778 deletions

View File

@@ -0,0 +1,331 @@
// Copyright (c) 2015-2021 MinIO, Inc.
//
// This file is part of MinIO Object Storage stack
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package sql
import (
"errors"
"fmt"
)
// Aggregation Function name constants
const (
aggFnAvg FuncName = "AVG"
aggFnCount FuncName = "COUNT"
aggFnMax FuncName = "MAX"
aggFnMin FuncName = "MIN"
aggFnSum FuncName = "SUM"
)
var (
errNonNumericArg = func(fnStr FuncName) error {
return fmt.Errorf("%s() requires a numeric argument", fnStr)
}
errInvalidAggregation = errors.New("Invalid aggregation seen")
)
type aggVal struct {
runningSum *Value
runningCount int64
runningMax, runningMin *Value
// Stores if at least one record has been seen
seen bool
}
func newAggVal(fn FuncName) *aggVal {
switch fn {
case aggFnAvg, aggFnSum:
return &aggVal{runningSum: FromFloat(0)}
case aggFnMin:
return &aggVal{runningMin: FromInt(0)}
case aggFnMax:
return &aggVal{runningMax: FromInt(0)}
default:
return &aggVal{}
}
}
// evalAggregationNode - performs partial computation using the
// current row and stores the result.
//
// On success, it returns (nil, nil).
func (e *FuncExpr) evalAggregationNode(r Record, tableAlias string) error {
// It is assumed that this function is called only when
// `e` is an aggregation function.
var val *Value
var err error
funcName := e.getFunctionName()
if aggFnCount == funcName {
if e.Count.StarArg {
// Handle COUNT(*)
e.aggregate.runningCount++
return nil
}
val, err = e.Count.ExprArg.evalNode(r, tableAlias)
if err != nil {
return err
}
} else {
// Evaluate the (only) argument
val, err = e.SFunc.ArgsList[0].evalNode(r, tableAlias)
if err != nil {
return err
}
}
if val.IsNull() {
// E.g. the column or field does not exist in the
// record - in all such cases the aggregation is not
// updated.
return nil
}
argVal := val
if funcName != aggFnCount {
// All aggregation functions, except COUNT require a
// numeric argument.
// Here, we diverge from Amazon S3 behavior by
// inferring untyped values are numbers.
if !argVal.isNumeric() {
if i, ok := argVal.bytesToInt(); ok {
argVal.setInt(i)
} else if f, ok := argVal.bytesToFloat(); ok {
argVal.setFloat(f)
} else {
return errNonNumericArg(funcName)
}
}
}
// Mark that we have seen one non-null value.
isFirstRow := false
if !e.aggregate.seen {
e.aggregate.seen = true
isFirstRow = true
}
switch funcName {
case aggFnCount:
// For all non-null values, the count is incremented.
e.aggregate.runningCount++
case aggFnAvg, aggFnSum:
e.aggregate.runningCount++
// Convert to float.
f, ok := argVal.ToFloat()
if !ok {
return fmt.Errorf("Could not convert value %v (%s) to a number", argVal.value, argVal.GetTypeString())
}
argVal.setFloat(f)
err = e.aggregate.runningSum.arithOp(opPlus, argVal)
case aggFnMin:
err = e.aggregate.runningMin.minmax(argVal, false, isFirstRow)
case aggFnMax:
err = e.aggregate.runningMax.minmax(argVal, true, isFirstRow)
default:
err = errInvalidAggregation
}
return err
}
func (e *AliasedExpression) aggregateRow(r Record, tableAlias string) error {
return e.Expression.aggregateRow(r, tableAlias)
}
func (e *Expression) aggregateRow(r Record, tableAlias string) error {
for _, ex := range e.And {
err := ex.aggregateRow(r, tableAlias)
if err != nil {
return err
}
}
return nil
}
func (e *ListExpr) aggregateRow(r Record, tableAlias string) error {
for _, ex := range e.Elements {
err := ex.aggregateRow(r, tableAlias)
if err != nil {
return err
}
}
return nil
}
func (e *AndCondition) aggregateRow(r Record, tableAlias string) error {
for _, ex := range e.Condition {
err := ex.aggregateRow(r, tableAlias)
if err != nil {
return err
}
}
return nil
}
func (e *Condition) aggregateRow(r Record, tableAlias string) error {
if e.Operand != nil {
return e.Operand.aggregateRow(r, tableAlias)
}
return e.Not.aggregateRow(r, tableAlias)
}
func (e *ConditionOperand) aggregateRow(r Record, tableAlias string) error {
err := e.Operand.aggregateRow(r, tableAlias)
if err != nil {
return err
}
if e.ConditionRHS == nil {
return nil
}
switch {
case e.ConditionRHS.Compare != nil:
return e.ConditionRHS.Compare.Operand.aggregateRow(r, tableAlias)
case e.ConditionRHS.Between != nil:
err = e.ConditionRHS.Between.Start.aggregateRow(r, tableAlias)
if err != nil {
return err
}
return e.ConditionRHS.Between.End.aggregateRow(r, tableAlias)
case e.ConditionRHS.In != nil:
elt := e.ConditionRHS.In.ListExpression
err = elt.aggregateRow(r, tableAlias)
if err != nil {
return err
}
return nil
case e.ConditionRHS.Like != nil:
err = e.ConditionRHS.Like.Pattern.aggregateRow(r, tableAlias)
if err != nil {
return err
}
return e.ConditionRHS.Like.EscapeChar.aggregateRow(r, tableAlias)
default:
return errInvalidASTNode
}
}
func (e *Operand) aggregateRow(r Record, tableAlias string) error {
err := e.Left.aggregateRow(r, tableAlias)
if err != nil {
return err
}
for _, rt := range e.Right {
err = rt.Right.aggregateRow(r, tableAlias)
if err != nil {
return err
}
}
return nil
}
func (e *MultOp) aggregateRow(r Record, tableAlias string) error {
err := e.Left.aggregateRow(r, tableAlias)
if err != nil {
return err
}
for _, rt := range e.Right {
err = rt.Right.aggregateRow(r, tableAlias)
if err != nil {
return err
}
}
return nil
}
func (e *UnaryTerm) aggregateRow(r Record, tableAlias string) error {
if e.Negated != nil {
return e.Negated.Term.aggregateRow(r, tableAlias)
}
return e.Primary.aggregateRow(r, tableAlias)
}
func (e *PrimaryTerm) aggregateRow(r Record, tableAlias string) error {
switch {
case e.ListExpr != nil:
return e.ListExpr.aggregateRow(r, tableAlias)
case e.SubExpression != nil:
return e.SubExpression.aggregateRow(r, tableAlias)
case e.FuncCall != nil:
return e.FuncCall.aggregateRow(r, tableAlias)
}
return nil
}
func (e *FuncExpr) aggregateRow(r Record, tableAlias string) error {
switch e.getFunctionName() {
case aggFnAvg, aggFnSum, aggFnMax, aggFnMin, aggFnCount:
return e.evalAggregationNode(r, tableAlias)
default:
// TODO: traverse arguments and call aggregateRow on
// them if they could be an ancestor of an
// aggregation.
}
return nil
}
// getAggregate() implementation for each AST node follows. This is
// called after calling aggregateRow() on each input row, to calculate
// the final aggregate result.
func (e *FuncExpr) getAggregate() (*Value, error) {
switch e.getFunctionName() {
case aggFnCount:
return FromInt(e.aggregate.runningCount), nil
case aggFnAvg:
if e.aggregate.runningCount == 0 {
// No rows were seen by AVG.
return FromNull(), nil
}
err := e.aggregate.runningSum.arithOp(opDivide, FromInt(e.aggregate.runningCount))
return e.aggregate.runningSum, err
case aggFnMin:
if !e.aggregate.seen {
// No rows were seen by MIN
return FromNull(), nil
}
return e.aggregate.runningMin, nil
case aggFnMax:
if !e.aggregate.seen {
// No rows were seen by MAX
return FromNull(), nil
}
return e.aggregate.runningMax, nil
case aggFnSum:
// TODO: check if returning 0 when no rows were seen
// by SUM is expected behavior.
return e.aggregate.runningSum, nil
default:
// TODO:
}
return nil, errInvalidAggregation
}

View File

@@ -0,0 +1,324 @@
// Copyright (c) 2015-2021 MinIO, Inc.
//
// This file is part of MinIO Object Storage stack
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package sql
import (
"errors"
"fmt"
"strings"
)
// Query analysis - The query is analyzed to determine if it involves
// aggregation.
//
// Aggregation functions - An expression that involves aggregation of
// rows in some manner. Requires all input rows to be processed,
// before a result is returned.
//
// Row function - An expression that depends on a value in the
// row. They have an output for each input row.
//
// Some types of a queries are not valid. For example, an aggregation
// function combined with a row function is meaningless ("AVG(s.Age) +
// s.Salary"). Analysis determines if such a scenario exists so an
// error can be returned.
var (
// Fatal error for query processing.
errNestedAggregation = errors.New("Cannot nest aggregations")
errFunctionNotImplemented = errors.New("Function is not yet implemented")
errUnexpectedInvalidNode = errors.New("Unexpected node value")
errInvalidKeypath = errors.New("A provided keypath is invalid")
)
// qProp contains analysis info about an SQL term.
type qProp struct {
isAggregation, isRowFunc bool
err error
}
// `combine` combines a pair of `qProp`s, so that errors are
// propagated correctly, and checks that an aggregation is not being
// combined with a row-function term.
func (p *qProp) combine(q qProp) {
switch {
case p.err != nil:
// Do nothing
case q.err != nil:
p.err = q.err
default:
p.isAggregation = p.isAggregation || q.isAggregation
p.isRowFunc = p.isRowFunc || q.isRowFunc
if p.isAggregation && p.isRowFunc {
p.err = errNestedAggregation
}
}
}
func (e *SelectExpression) analyze(s *Select) (result qProp) {
if e.All {
return qProp{isRowFunc: true}
}
for _, ex := range e.Expressions {
result.combine(ex.analyze(s))
}
return
}
func (e *AliasedExpression) analyze(s *Select) qProp {
return e.Expression.analyze(s)
}
func (e *Expression) analyze(s *Select) (result qProp) {
for _, ac := range e.And {
result.combine(ac.analyze(s))
}
return
}
func (e *AndCondition) analyze(s *Select) (result qProp) {
for _, ac := range e.Condition {
result.combine(ac.analyze(s))
}
return
}
func (e *Condition) analyze(s *Select) (result qProp) {
if e.Operand != nil {
result = e.Operand.analyze(s)
} else {
result = e.Not.analyze(s)
}
return
}
func (e *ListExpr) analyze(s *Select) (result qProp) {
for _, ac := range e.Elements {
result.combine(ac.analyze(s))
}
return
}
func (e *ConditionOperand) analyze(s *Select) (result qProp) {
if e.ConditionRHS == nil {
result = e.Operand.analyze(s)
} else {
result.combine(e.Operand.analyze(s))
result.combine(e.ConditionRHS.analyze(s))
}
return
}
func (e *ConditionRHS) analyze(s *Select) (result qProp) {
switch {
case e.Compare != nil:
result = e.Compare.Operand.analyze(s)
case e.Between != nil:
result.combine(e.Between.Start.analyze(s))
result.combine(e.Between.End.analyze(s))
case e.In != nil:
result.combine(e.In.ListExpression.analyze(s))
case e.Like != nil:
result.combine(e.Like.Pattern.analyze(s))
if e.Like.EscapeChar != nil {
result.combine(e.Like.EscapeChar.analyze(s))
}
default:
result = qProp{err: errUnexpectedInvalidNode}
}
return
}
func (e *Operand) analyze(s *Select) (result qProp) {
result.combine(e.Left.analyze(s))
for _, r := range e.Right {
result.combine(r.Right.analyze(s))
}
return
}
func (e *MultOp) analyze(s *Select) (result qProp) {
result.combine(e.Left.analyze(s))
for _, r := range e.Right {
result.combine(r.Right.analyze(s))
}
return
}
func (e *UnaryTerm) analyze(s *Select) (result qProp) {
if e.Negated != nil {
result = e.Negated.Term.analyze(s)
} else {
result = e.Primary.analyze(s)
}
return
}
func (e *PrimaryTerm) analyze(s *Select) (result qProp) {
switch {
case e.Value != nil:
result = qProp{}
case e.JPathExpr != nil:
// Check if the path expression is valid
if len(e.JPathExpr.PathExpr) > 0 {
if e.JPathExpr.BaseKey.String() != s.From.As && strings.ToLower(e.JPathExpr.BaseKey.String()) != baseTableName {
result = qProp{err: errInvalidKeypath}
return
}
}
result = qProp{isRowFunc: true}
case e.ListExpr != nil:
result = e.ListExpr.analyze(s)
case e.SubExpression != nil:
result = e.SubExpression.analyze(s)
case e.FuncCall != nil:
result = e.FuncCall.analyze(s)
default:
result = qProp{err: errUnexpectedInvalidNode}
}
return
}
func (e *FuncExpr) analyze(s *Select) (result qProp) {
funcName := e.getFunctionName()
switch funcName {
case sqlFnCast:
return e.Cast.Expr.analyze(s)
case sqlFnExtract:
return e.Extract.From.analyze(s)
case sqlFnDateAdd:
result.combine(e.DateAdd.Quantity.analyze(s))
result.combine(e.DateAdd.Timestamp.analyze(s))
return result
case sqlFnDateDiff:
result.combine(e.DateDiff.Timestamp1.analyze(s))
result.combine(e.DateDiff.Timestamp2.analyze(s))
return result
// Handle aggregation function calls
case aggFnAvg, aggFnMax, aggFnMin, aggFnSum, aggFnCount:
// Initialize accumulator
e.aggregate = newAggVal(funcName)
var exprA qProp
if funcName == aggFnCount {
if e.Count.StarArg {
return qProp{isAggregation: true}
}
exprA = e.Count.ExprArg.analyze(s)
} else {
if len(e.SFunc.ArgsList) != 1 {
return qProp{err: fmt.Errorf("%s takes exactly one argument", funcName)}
}
exprA = e.SFunc.ArgsList[0].analyze(s)
}
if exprA.err != nil {
return exprA
}
if exprA.isAggregation {
return qProp{err: errNestedAggregation}
}
return qProp{isAggregation: true}
case sqlFnCoalesce:
if len(e.SFunc.ArgsList) == 0 {
return qProp{err: fmt.Errorf("%s needs at least one argument", string(funcName))}
}
for _, arg := range e.SFunc.ArgsList {
result.combine(arg.analyze(s))
}
return result
case sqlFnNullIf:
if len(e.SFunc.ArgsList) != 2 {
return qProp{err: fmt.Errorf("%s needs exactly 2 arguments", string(funcName))}
}
for _, arg := range e.SFunc.ArgsList {
result.combine(arg.analyze(s))
}
return result
case sqlFnCharLength, sqlFnCharacterLength:
if len(e.SFunc.ArgsList) != 1 {
return qProp{err: fmt.Errorf("%s needs exactly 2 arguments", string(funcName))}
}
for _, arg := range e.SFunc.ArgsList {
result.combine(arg.analyze(s))
}
return result
case sqlFnLower, sqlFnUpper:
if len(e.SFunc.ArgsList) != 1 {
return qProp{err: fmt.Errorf("%s needs exactly 2 arguments", string(funcName))}
}
for _, arg := range e.SFunc.ArgsList {
result.combine(arg.analyze(s))
}
return result
case sqlFnTrim:
if e.Trim.TrimChars != nil {
result.combine(e.Trim.TrimChars.analyze(s))
}
if e.Trim.TrimFrom != nil {
result.combine(e.Trim.TrimFrom.analyze(s))
}
return result
case sqlFnSubstring:
errVal := fmt.Errorf("Invalid argument(s) to %s", string(funcName))
result.combine(e.Substring.Expr.analyze(s))
switch {
case e.Substring.From != nil:
result.combine(e.Substring.From.analyze(s))
if e.Substring.For != nil {
result.combine(e.Substring.Expr.analyze(s))
}
case e.Substring.Arg2 != nil:
result.combine(e.Substring.Arg2.analyze(s))
if e.Substring.Arg3 != nil {
result.combine(e.Substring.Arg3.analyze(s))
}
default:
result.err = errVal
}
return result
case sqlFnUTCNow:
if len(e.SFunc.ArgsList) != 0 {
result.err = fmt.Errorf("%s() takes no arguments", string(funcName))
}
return result
}
// TODO: implement other functions
return qProp{err: errFunctionNotImplemented}
}

View File

@@ -0,0 +1,110 @@
// Copyright (c) 2015-2021 MinIO, Inc.
//
// This file is part of MinIO Object Storage stack
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package sql
import "fmt"
type s3Error struct {
code string
message string
statusCode int
cause error
}
func (err *s3Error) Cause() error {
return err.cause
}
func (err *s3Error) ErrorCode() string {
return err.code
}
func (err *s3Error) ErrorMessage() string {
return err.message
}
func (err *s3Error) HTTPStatusCode() int {
return err.statusCode
}
func (err *s3Error) Error() string {
return err.message
}
func errInvalidDataType(err error) *s3Error {
return &s3Error{
code: "InvalidDataType",
message: "The SQL expression contains an invalid data type.",
statusCode: 400,
cause: err,
}
}
func errIncorrectSQLFunctionArgumentType(err error) *s3Error {
return &s3Error{
code: "IncorrectSqlFunctionArgumentType",
message: "Incorrect type of arguments in function call.",
statusCode: 400,
cause: err,
}
}
func errLikeInvalidInputs(err error) *s3Error {
return &s3Error{
code: "LikeInvalidInputs",
message: "Invalid argument given to the LIKE clause in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errQueryParseFailure(err error) *s3Error {
return &s3Error{
code: "ParseSelectFailure",
message: err.Error(),
statusCode: 400,
cause: err,
}
}
func errQueryAnalysisFailure(err error) *s3Error {
return &s3Error{
code: "InvalidQuery",
message: err.Error(),
statusCode: 400,
cause: err,
}
}
func errBadTableName(err error) *s3Error {
return &s3Error{
code: "BadTableName",
message: fmt.Sprintf("The table name is not supported: %v", err),
statusCode: 400,
cause: err,
}
}
func errDataSource(err error) *s3Error {
return &s3Error{
code: "DataSourcePathUnsupported",
message: fmt.Sprintf("Data source: %v", err),
statusCode: 400,
cause: err,
}
}

View File

@@ -0,0 +1,491 @@
// Copyright (c) 2015-2021 MinIO, Inc.
//
// This file is part of MinIO Object Storage stack
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package sql
import (
"encoding/json"
"errors"
"fmt"
"math"
"github.com/bcicen/jstream"
"github.com/minio/simdjson-go"
)
var (
errInvalidASTNode = errors.New("invalid AST Node")
errExpectedBool = errors.New("expected bool")
errLikeNonStrArg = errors.New("LIKE clause requires string arguments")
errLikeInvalidEscape = errors.New("LIKE clause has invalid ESCAPE character")
errNotImplemented = errors.New("not implemented")
)
// AST Node Evaluation functions
//
// During evaluation, the query is known to be valid, as analysis is
// complete. The only errors possible are due to value type
// mismatches, etc.
//
// If an aggregation node is present as a descendant (when
// e.prop.isAggregation is true), we call evalNode on all child nodes,
// check for errors, but do not perform any combining of the results
// of child nodes. The final result row is returned after all rows are
// processed, and the `getAggregate` function is called.
func (e *AliasedExpression) evalNode(r Record, tableAlias string) (*Value, error) {
return e.Expression.evalNode(r, tableAlias)
}
func (e *Expression) evalNode(r Record, tableAlias string) (*Value, error) {
if len(e.And) == 1 {
// In this case, result is not required to be boolean
// type.
return e.And[0].evalNode(r, tableAlias)
}
// Compute OR of conditions
result := false
for _, ex := range e.And {
res, err := ex.evalNode(r, tableAlias)
if err != nil {
return nil, err
}
b, ok := res.ToBool()
if !ok {
return nil, errExpectedBool
}
result = result || b
}
return FromBool(result), nil
}
func (e *AndCondition) evalNode(r Record, tableAlias string) (*Value, error) {
if len(e.Condition) == 1 {
// In this case, result does not have to be boolean
return e.Condition[0].evalNode(r, tableAlias)
}
// Compute AND of conditions
result := true
for _, ex := range e.Condition {
res, err := ex.evalNode(r, tableAlias)
if err != nil {
return nil, err
}
b, ok := res.ToBool()
if !ok {
return nil, errExpectedBool
}
result = result && b
}
return FromBool(result), nil
}
func (e *Condition) evalNode(r Record, tableAlias string) (*Value, error) {
if e.Operand != nil {
// In this case, result does not have to be boolean
return e.Operand.evalNode(r, tableAlias)
}
// Compute NOT of condition
res, err := e.Not.evalNode(r, tableAlias)
if err != nil {
return nil, err
}
b, ok := res.ToBool()
if !ok {
return nil, errExpectedBool
}
return FromBool(!b), nil
}
func (e *ConditionOperand) evalNode(r Record, tableAlias string) (*Value, error) {
opVal, opErr := e.Operand.evalNode(r, tableAlias)
if opErr != nil || e.ConditionRHS == nil {
return opVal, opErr
}
// Need to evaluate the ConditionRHS
switch {
case e.ConditionRHS.Compare != nil:
cmpRight, cmpRErr := e.ConditionRHS.Compare.Operand.evalNode(r, tableAlias)
if cmpRErr != nil {
return nil, cmpRErr
}
b, err := opVal.compareOp(e.ConditionRHS.Compare.Operator, cmpRight)
return FromBool(b), err
case e.ConditionRHS.Between != nil:
return e.ConditionRHS.Between.evalBetweenNode(r, opVal, tableAlias)
case e.ConditionRHS.Like != nil:
return e.ConditionRHS.Like.evalLikeNode(r, opVal, tableAlias)
case e.ConditionRHS.In != nil:
return e.ConditionRHS.In.evalInNode(r, opVal, tableAlias)
default:
return nil, errInvalidASTNode
}
}
func (e *Between) evalBetweenNode(r Record, arg *Value, tableAlias string) (*Value, error) {
stVal, stErr := e.Start.evalNode(r, tableAlias)
if stErr != nil {
return nil, stErr
}
endVal, endErr := e.End.evalNode(r, tableAlias)
if endErr != nil {
return nil, endErr
}
part1, err1 := stVal.compareOp(opLte, arg)
if err1 != nil {
return nil, err1
}
part2, err2 := arg.compareOp(opLte, endVal)
if err2 != nil {
return nil, err2
}
result := part1 && part2
if e.Not {
result = !result
}
return FromBool(result), nil
}
func (e *Like) evalLikeNode(r Record, arg *Value, tableAlias string) (*Value, error) {
inferTypeAsString(arg)
s, ok := arg.ToString()
if !ok {
err := errLikeNonStrArg
return nil, errLikeInvalidInputs(err)
}
pattern, err1 := e.Pattern.evalNode(r, tableAlias)
if err1 != nil {
return nil, err1
}
// Infer pattern as string (in case it is untyped)
inferTypeAsString(pattern)
patternStr, ok := pattern.ToString()
if !ok {
err := errLikeNonStrArg
return nil, errLikeInvalidInputs(err)
}
escape := runeZero
if e.EscapeChar != nil {
escapeVal, err2 := e.EscapeChar.evalNode(r, tableAlias)
if err2 != nil {
return nil, err2
}
inferTypeAsString(escapeVal)
escapeStr, ok := escapeVal.ToString()
if !ok {
err := errLikeNonStrArg
return nil, errLikeInvalidInputs(err)
}
if len([]rune(escapeStr)) > 1 {
err := errLikeInvalidEscape
return nil, errLikeInvalidInputs(err)
}
}
matchResult, err := evalSQLLike(s, patternStr, escape)
if err != nil {
return nil, err
}
if e.Not {
matchResult = !matchResult
}
return FromBool(matchResult), nil
}
func (e *ListExpr) evalNode(r Record, tableAlias string) (*Value, error) {
res := make([]Value, len(e.Elements))
if len(e.Elements) == 1 {
// If length 1, treat as single value.
return e.Elements[0].evalNode(r, tableAlias)
}
for i, elt := range e.Elements {
v, err := elt.evalNode(r, tableAlias)
if err != nil {
return nil, err
}
res[i] = *v
}
return FromArray(res), nil
}
const floatCmpTolerance = 0.000001
func (e *In) evalInNode(r Record, lhs *Value, tableAlias string) (*Value, error) {
// Compare two values in terms of in-ness.
var cmp func(a, b Value) bool
cmp = func(a, b Value) bool {
// Convert if needed.
inferTypesForCmp(&a, &b)
if a.Equals(b) {
return true
}
// If elements, compare each.
aA, aOK := a.ToArray()
bA, bOK := b.ToArray()
if aOK && bOK {
if len(aA) != len(bA) {
return false
}
for i := range aA {
if !cmp(aA[i], bA[i]) {
return false
}
}
return true
}
// Try as numbers
aF, aOK := a.ToFloat()
bF, bOK := b.ToFloat()
diff := math.Abs(aF - bF)
return aOK && bOK && diff < floatCmpTolerance
}
var rhs Value
if elt := e.ListExpression; elt != nil {
eltVal, err := elt.evalNode(r, tableAlias)
if err != nil {
return nil, err
}
rhs = *eltVal
}
// If RHS is array compare each element.
if arr, ok := rhs.ToArray(); ok {
for _, element := range arr {
// If we have an array we are on the wrong level.
if cmp(element, *lhs) {
return FromBool(true), nil
}
}
return FromBool(false), nil
}
return FromBool(cmp(rhs, *lhs)), nil
}
func (e *Operand) evalNode(r Record, tableAlias string) (*Value, error) {
lval, lerr := e.Left.evalNode(r, tableAlias)
if lerr != nil || len(e.Right) == 0 {
return lval, lerr
}
// Process remaining child nodes - result must be
// numeric. This AST node is for terms separated by + or -
// symbols.
for _, rightTerm := range e.Right {
op := rightTerm.Op
rval, rerr := rightTerm.Right.evalNode(r, tableAlias)
if rerr != nil {
return nil, rerr
}
err := lval.arithOp(op, rval)
if err != nil {
return nil, err
}
}
return lval, nil
}
func (e *MultOp) evalNode(r Record, tableAlias string) (*Value, error) {
lval, lerr := e.Left.evalNode(r, tableAlias)
if lerr != nil || len(e.Right) == 0 {
return lval, lerr
}
// Process other child nodes - result must be numeric. This
// AST node is for terms separated by *, / or % symbols.
for _, rightTerm := range e.Right {
op := rightTerm.Op
rval, rerr := rightTerm.Right.evalNode(r, tableAlias)
if rerr != nil {
return nil, rerr
}
err := lval.arithOp(op, rval)
if err != nil {
return nil, err
}
}
return lval, nil
}
func (e *UnaryTerm) evalNode(r Record, tableAlias string) (*Value, error) {
if e.Negated == nil {
return e.Primary.evalNode(r, tableAlias)
}
v, err := e.Negated.Term.evalNode(r, tableAlias)
if err != nil {
return nil, err
}
inferTypeForArithOp(v)
v.negate()
if v.isNumeric() {
return v, nil
}
return nil, errArithMismatchedTypes
}
func (e *JSONPath) evalNode(r Record, tableAlias string) (*Value, error) {
alias := tableAlias
if tableAlias == "" {
alias = baseTableName
}
pathExpr := e.StripTableAlias(alias)
_, rawVal := r.Raw()
switch rowVal := rawVal.(type) {
case jstream.KVS, simdjson.Object:
if len(pathExpr) == 0 {
pathExpr = []*JSONPathElement{{Key: &ObjectKey{ID: e.BaseKey}}}
}
result, _, err := jsonpathEval(pathExpr, rowVal)
if err != nil {
return nil, err
}
return jsonToValue(result)
default:
if pathExpr[len(pathExpr)-1].Key == nil {
return nil, errInvalidKeypath
}
return r.Get(pathExpr[len(pathExpr)-1].Key.keyString())
}
}
// jsonToValue will convert the json value to an internal value.
func jsonToValue(result interface{}) (*Value, error) {
switch rval := result.(type) {
case string:
return FromString(rval), nil
case float64:
return FromFloat(rval), nil
case int64:
return FromInt(rval), nil
case uint64:
if rval <= math.MaxInt64 {
return FromInt(int64(rval)), nil
}
return FromFloat(float64(rval)), nil
case bool:
return FromBool(rval), nil
case jstream.KVS:
bs, err := json.Marshal(result)
if err != nil {
return nil, err
}
return FromBytes(bs), nil
case []interface{}:
dst := make([]Value, len(rval))
for i := range rval {
v, err := jsonToValue(rval[i])
if err != nil {
return nil, err
}
dst[i] = *v
}
return FromArray(dst), nil
case simdjson.Object:
o := rval
elems, err := o.Parse(nil)
if err != nil {
return nil, err
}
bs, err := elems.MarshalJSON()
if err != nil {
return nil, err
}
return FromBytes(bs), nil
case []Value:
return FromArray(rval), nil
case nil:
return FromNull(), nil
}
return nil, fmt.Errorf("Unhandled value type: %T", result)
}
func (e *PrimaryTerm) evalNode(r Record, tableAlias string) (res *Value, err error) {
switch {
case e.Value != nil:
return e.Value.evalNode(r)
case e.JPathExpr != nil:
return e.JPathExpr.evalNode(r, tableAlias)
case e.ListExpr != nil:
return e.ListExpr.evalNode(r, tableAlias)
case e.SubExpression != nil:
return e.SubExpression.evalNode(r, tableAlias)
case e.FuncCall != nil:
return e.FuncCall.evalNode(r, tableAlias)
}
return nil, errInvalidASTNode
}
func (e *FuncExpr) evalNode(r Record, tableAlias string) (res *Value, err error) {
switch e.getFunctionName() {
case aggFnCount, aggFnAvg, aggFnMax, aggFnMin, aggFnSum:
return e.getAggregate()
default:
return e.evalSQLFnNode(r, tableAlias)
}
}
// evalNode on a literal value is independent of the node being an
// aggregation or a row function - it always returns a value.
func (e *LitValue) evalNode(_ Record) (res *Value, err error) {
switch {
case e.Int != nil:
if *e.Int < math.MaxInt64 && *e.Int > math.MinInt64 {
return FromInt(int64(*e.Int)), nil
}
return FromFloat(*e.Int), nil
case e.Float != nil:
return FromFloat(*e.Float), nil
case e.String != nil:
return FromString(string(*e.String)), nil
case e.Boolean != nil:
return FromBool(bool(*e.Boolean)), nil
}
return FromNull(), nil
}

View File

@@ -0,0 +1,569 @@
// Copyright (c) 2015-2021 MinIO, Inc.
//
// This file is part of MinIO Object Storage stack
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package sql
import (
"errors"
"fmt"
"strconv"
"strings"
"time"
)
// FuncName - SQL function name.
type FuncName string
// SQL Function name constants
const (
// Conditionals
sqlFnCoalesce FuncName = "COALESCE"
sqlFnNullIf FuncName = "NULLIF"
// Conversion
sqlFnCast FuncName = "CAST"
// Date and time
sqlFnDateAdd FuncName = "DATE_ADD"
sqlFnDateDiff FuncName = "DATE_DIFF"
sqlFnExtract FuncName = "EXTRACT"
sqlFnToString FuncName = "TO_STRING"
sqlFnToTimestamp FuncName = "TO_TIMESTAMP"
sqlFnUTCNow FuncName = "UTCNOW"
// String
sqlFnCharLength FuncName = "CHAR_LENGTH"
sqlFnCharacterLength FuncName = "CHARACTER_LENGTH"
sqlFnLower FuncName = "LOWER"
sqlFnSubstring FuncName = "SUBSTRING"
sqlFnTrim FuncName = "TRIM"
sqlFnUpper FuncName = "UPPER"
)
var (
errUnimplementedCast = errors.New("This cast not yet implemented")
errNonStringTrimArg = errors.New("TRIM() received a non-string argument")
errNonTimestampArg = errors.New("Expected a timestamp argument")
)
func (e *FuncExpr) getFunctionName() FuncName {
switch {
case e.SFunc != nil:
return FuncName(strings.ToUpper(e.SFunc.FunctionName))
case e.Count != nil:
return aggFnCount
case e.Cast != nil:
return sqlFnCast
case e.Substring != nil:
return sqlFnSubstring
case e.Extract != nil:
return sqlFnExtract
case e.Trim != nil:
return sqlFnTrim
case e.DateAdd != nil:
return sqlFnDateAdd
case e.DateDiff != nil:
return sqlFnDateDiff
default:
return ""
}
}
// evalSQLFnNode assumes that the FuncExpr is not an aggregation
// function.
func (e *FuncExpr) evalSQLFnNode(r Record, tableAlias string) (res *Value, err error) {
// Handle functions that have phrase arguments
switch e.getFunctionName() {
case sqlFnCast:
expr := e.Cast.Expr
res, err = expr.castTo(r, strings.ToUpper(e.Cast.CastType), tableAlias)
return
case sqlFnSubstring:
return handleSQLSubstring(r, e.Substring, tableAlias)
case sqlFnExtract:
return handleSQLExtract(r, e.Extract, tableAlias)
case sqlFnTrim:
return handleSQLTrim(r, e.Trim, tableAlias)
case sqlFnDateAdd:
return handleDateAdd(r, e.DateAdd, tableAlias)
case sqlFnDateDiff:
return handleDateDiff(r, e.DateDiff, tableAlias)
}
// For all simple argument functions, we evaluate the arguments here
argVals := make([]*Value, len(e.SFunc.ArgsList))
for i, arg := range e.SFunc.ArgsList {
argVals[i], err = arg.evalNode(r, tableAlias)
if err != nil {
return nil, err
}
}
switch e.getFunctionName() {
case sqlFnCoalesce:
return coalesce(argVals)
case sqlFnNullIf:
return nullif(argVals[0], argVals[1])
case sqlFnCharLength, sqlFnCharacterLength:
return charlen(argVals[0])
case sqlFnLower:
return lowerCase(argVals[0])
case sqlFnUpper:
return upperCase(argVals[0])
case sqlFnUTCNow:
return handleUTCNow()
case sqlFnToString, sqlFnToTimestamp:
// TODO: implement
fallthrough
default:
return nil, errNotImplemented
}
}
func coalesce(args []*Value) (res *Value, err error) {
for _, arg := range args {
if arg.IsNull() {
continue
}
return arg, nil
}
return FromNull(), nil
}
func nullif(v1, v2 *Value) (res *Value, err error) {
// Handle Null cases
if v1.IsNull() || v2.IsNull() {
return v1, nil
}
err = inferTypesForCmp(v1, v2)
if err != nil {
return nil, err
}
atleastOneNumeric := v1.isNumeric() || v2.isNumeric()
bothNumeric := v1.isNumeric() && v2.isNumeric()
if atleastOneNumeric || !bothNumeric {
return v1, nil
}
if v1.SameTypeAs(*v2) {
return v1, nil
}
cmpResult, cmpErr := v1.compareOp(opEq, v2)
if cmpErr != nil {
return nil, cmpErr
}
if cmpResult {
return FromNull(), nil
}
return v1, nil
}
func charlen(v *Value) (*Value, error) {
inferTypeAsString(v)
s, ok := v.ToString()
if !ok {
err := fmt.Errorf("%s/%s expects a string argument", sqlFnCharLength, sqlFnCharacterLength)
return nil, errIncorrectSQLFunctionArgumentType(err)
}
return FromInt(int64(len([]rune(s)))), nil
}
func lowerCase(v *Value) (*Value, error) {
inferTypeAsString(v)
s, ok := v.ToString()
if !ok {
err := fmt.Errorf("%s expects a string argument", sqlFnLower)
return nil, errIncorrectSQLFunctionArgumentType(err)
}
return FromString(strings.ToLower(s)), nil
}
func upperCase(v *Value) (*Value, error) {
inferTypeAsString(v)
s, ok := v.ToString()
if !ok {
err := fmt.Errorf("%s expects a string argument", sqlFnUpper)
return nil, errIncorrectSQLFunctionArgumentType(err)
}
return FromString(strings.ToUpper(s)), nil
}
func handleDateAdd(r Record, d *DateAddFunc, tableAlias string) (*Value, error) {
q, err := d.Quantity.evalNode(r, tableAlias)
if err != nil {
return nil, err
}
inferTypeForArithOp(q)
qty, ok := q.ToFloat()
if !ok {
return nil, fmt.Errorf("QUANTITY must be a numeric argument to %s()", sqlFnDateAdd)
}
ts, err := d.Timestamp.evalNode(r, tableAlias)
if err != nil {
return nil, err
}
if err = inferTypeAsTimestamp(ts); err != nil {
return nil, err
}
t, ok := ts.ToTimestamp()
if !ok {
return nil, fmt.Errorf("%s() expects a timestamp argument", sqlFnDateAdd)
}
return dateAdd(strings.ToUpper(d.DatePart), qty, t)
}
func handleDateDiff(r Record, d *DateDiffFunc, tableAlias string) (*Value, error) {
tval1, err := d.Timestamp1.evalNode(r, tableAlias)
if err != nil {
return nil, err
}
if err = inferTypeAsTimestamp(tval1); err != nil {
return nil, err
}
ts1, ok := tval1.ToTimestamp()
if !ok {
return nil, fmt.Errorf("%s() expects two timestamp arguments", sqlFnDateDiff)
}
tval2, err := d.Timestamp2.evalNode(r, tableAlias)
if err != nil {
return nil, err
}
if err = inferTypeAsTimestamp(tval2); err != nil {
return nil, err
}
ts2, ok := tval2.ToTimestamp()
if !ok {
return nil, fmt.Errorf("%s() expects two timestamp arguments", sqlFnDateDiff)
}
return dateDiff(strings.ToUpper(d.DatePart), ts1, ts2)
}
func handleUTCNow() (*Value, error) {
return FromTimestamp(time.Now().UTC()), nil
}
func handleSQLSubstring(r Record, e *SubstringFunc, tableAlias string) (val *Value, err error) {
// Both forms `SUBSTRING('abc' FROM 2 FOR 1)` and
// SUBSTRING('abc', 2, 1) are supported.
// Evaluate the string argument
v1, err := e.Expr.evalNode(r, tableAlias)
if err != nil {
return nil, err
}
inferTypeAsString(v1)
s, ok := v1.ToString()
if !ok {
err := fmt.Errorf("Incorrect argument type passed to %s", sqlFnSubstring)
return nil, errIncorrectSQLFunctionArgumentType(err)
}
// Assemble other arguments
arg2, arg3 := e.From, e.For
// Check if the second form of substring is being used
if e.From == nil {
arg2, arg3 = e.Arg2, e.Arg3
}
// Evaluate the FROM argument
v2, err := arg2.evalNode(r, tableAlias)
if err != nil {
return nil, err
}
inferTypeForArithOp(v2)
startIdx, ok := v2.ToInt()
if !ok {
err := fmt.Errorf("Incorrect type for start index argument in %s", sqlFnSubstring)
return nil, errIncorrectSQLFunctionArgumentType(err)
}
length := -1
// Evaluate the optional FOR argument
if arg3 != nil {
v3, err := arg3.evalNode(r, tableAlias)
if err != nil {
return nil, err
}
inferTypeForArithOp(v3)
lenInt, ok := v3.ToInt()
if !ok {
err := fmt.Errorf("Incorrect type for length argument in %s", sqlFnSubstring)
return nil, errIncorrectSQLFunctionArgumentType(err)
}
length = int(lenInt)
if length < 0 {
err := fmt.Errorf("Negative length argument in %s", sqlFnSubstring)
return nil, errIncorrectSQLFunctionArgumentType(err)
}
}
res, err := evalSQLSubstring(s, int(startIdx), length)
return FromString(res), err
}
func handleSQLTrim(r Record, e *TrimFunc, tableAlias string) (res *Value, err error) {
chars := ""
ok := false
if e.TrimChars != nil {
charsV, cerr := e.TrimChars.evalNode(r, tableAlias)
if cerr != nil {
return nil, cerr
}
inferTypeAsString(charsV)
chars, ok = charsV.ToString()
if !ok {
return nil, errNonStringTrimArg
}
}
fromV, ferr := e.TrimFrom.evalNode(r, tableAlias)
if ferr != nil {
return nil, ferr
}
inferTypeAsString(fromV)
from, ok := fromV.ToString()
if !ok {
return nil, errNonStringTrimArg
}
result, terr := evalSQLTrim(e.TrimWhere, chars, from)
if terr != nil {
return nil, terr
}
return FromString(result), nil
}
func handleSQLExtract(r Record, e *ExtractFunc, tableAlias string) (res *Value, err error) {
timeVal, verr := e.From.evalNode(r, tableAlias)
if verr != nil {
return nil, verr
}
if err = inferTypeAsTimestamp(timeVal); err != nil {
return nil, err
}
t, ok := timeVal.ToTimestamp()
if !ok {
return nil, errNonTimestampArg
}
return extract(strings.ToUpper(e.Timeword), t)
}
func errUnsupportedCast(fromType, toType string) error {
return fmt.Errorf("Cannot cast from %v to %v", fromType, toType)
}
func errCastFailure(msg string) error {
return fmt.Errorf("Error casting: %s", msg)
}
// Allowed cast types
const (
castBool = "BOOL"
castInt = "INT"
castInteger = "INTEGER"
castString = "STRING"
castFloat = "FLOAT"
castDecimal = "DECIMAL"
castNumeric = "NUMERIC"
castTimestamp = "TIMESTAMP"
)
func (e *Expression) castTo(r Record, castType string, tableAlias string) (res *Value, err error) {
v, err := e.evalNode(r, tableAlias)
if err != nil {
return nil, err
}
switch castType {
case castInt, castInteger:
i, err := intCast(v)
return FromInt(i), err
case castFloat:
f, err := floatCast(v)
return FromFloat(f), err
case castString:
s, err := stringCast(v)
return FromString(s), err
case castTimestamp:
t, err := timestampCast(v)
return FromTimestamp(t), err
case castBool:
b, err := boolCast(v)
return FromBool(b), err
case castDecimal, castNumeric:
fallthrough
default:
return nil, errUnimplementedCast
}
}
func intCast(v *Value) (int64, error) {
// This conversion truncates floating point numbers to
// integer.
strToInt := func(s string) (int64, bool) {
i, errI := strconv.ParseInt(s, 10, 64)
if errI == nil {
return i, true
}
f, errF := strconv.ParseFloat(s, 64)
if errF == nil {
return int64(f), true
}
return 0, false
}
switch x := v.value.(type) {
case float64:
// Truncate fractional part
return int64(x), nil
case int64:
return x, nil
case string:
// Parse as number, truncate floating point if
// needed.
// String might contain trimming spaces, which
// needs to be trimmed.
res, ok := strToInt(strings.TrimSpace(x))
if !ok {
return 0, errCastFailure("could not parse as int")
}
return res, nil
case []byte:
// Parse as number, truncate floating point if
// needed.
// String might contain trimming spaces, which
// needs to be trimmed.
res, ok := strToInt(strings.TrimSpace(string(x)))
if !ok {
return 0, errCastFailure("could not parse as int")
}
return res, nil
default:
return 0, errUnsupportedCast(v.GetTypeString(), castInt)
}
}
func floatCast(v *Value) (float64, error) {
switch x := v.value.(type) {
case float64:
return x, nil
case int64:
return float64(x), nil
case string:
f, err := strconv.ParseFloat(strings.TrimSpace(x), 64)
if err != nil {
return 0, errCastFailure("could not parse as float")
}
return f, nil
case []byte:
f, err := strconv.ParseFloat(strings.TrimSpace(string(x)), 64)
if err != nil {
return 0, errCastFailure("could not parse as float")
}
return f, nil
default:
return 0, errUnsupportedCast(v.GetTypeString(), castFloat)
}
}
func stringCast(v *Value) (string, error) {
switch x := v.value.(type) {
case float64:
return fmt.Sprintf("%v", x), nil
case int64:
return fmt.Sprintf("%v", x), nil
case string:
return x, nil
case []byte:
return string(x), nil
case bool:
return fmt.Sprintf("%v", x), nil
case nil:
// FIXME: verify this case is correct
return "NULL", nil
}
// This does not happen
return "", errCastFailure(fmt.Sprintf("cannot cast %v to string type", v.GetTypeString()))
}
func timestampCast(v *Value) (t time.Time, _ error) {
switch x := v.value.(type) {
case string:
return parseSQLTimestamp(x)
case []byte:
return parseSQLTimestamp(string(x))
case time.Time:
return x, nil
default:
return t, errCastFailure(fmt.Sprintf("cannot cast %v to Timestamp type", v.GetTypeString()))
}
}
func boolCast(v *Value) (b bool, _ error) {
sToB := func(s string) (bool, error) {
switch s {
case "true":
return true, nil
case "false":
return false, nil
default:
return false, errCastFailure("cannot cast to Bool")
}
}
switch x := v.value.(type) {
case bool:
return x, nil
case string:
return sToB(strings.ToLower(x))
case []byte:
return sToB(strings.ToLower(string(x)))
default:
return false, errCastFailure("cannot cast %v to Bool")
}
}

View File

@@ -0,0 +1,84 @@
{
"title": "Murder on the Orient Express",
"authorInfo": {
"name": "Agatha Christie",
"yearRange": [1890, 1976],
"penName": "Mary Westmacott"
},
"genre": "Crime novel",
"publicationHistory": [
{
"year": 1934,
"publisher": "Collins Crime Club (London)",
"type": "Hardcover",
"pages": 256
},
{
"year": 1934,
"publisher": "Dodd Mead and Company (New York)",
"type": "Hardcover",
"pages": 302
},
{
"year": 2011,
"publisher": "Harper Collins",
"type": "Paperback",
"pages": 265
}
]
}
{
"title": "The Robots of Dawn",
"authorInfo": {
"name": "Isaac Asimov",
"yearRange": [1920, 1992],
"penName": "Paul French"
},
"genre": "Science fiction",
"publicationHistory": [
{
"year": 1983,
"publisher": "Phantasia Press",
"type": "Hardcover",
"pages": 336
},
{
"year": 1984,
"publisher": "Granada",
"type": "Hardcover",
"pages": 419
},
{
"year": 2018,
"publisher": "Harper Voyager",
"type": "Paperback",
"pages": 432
}
]
}
{
"title": "Pigs Have Wings",
"authorInfo": {
"name": "P. G. Wodehouse",
"yearRange": [1881, 1975]
},
"genre": "Comic novel",
"publicationHistory": [
{
"year": 1952,
"publisher": "Doubleday & Company",
"type": "Hardcover"
},
{
"year": 2000,
"publisher": "Harry N. Abrams",
"type": "Hardcover"
},
{
"year": 2019,
"publisher": "Ulverscroft Collections",
"type": "Paperback",
"pages": 294
}
]
}

View File

@@ -0,0 +1,129 @@
// Copyright (c) 2015-2021 MinIO, Inc.
//
// This file is part of MinIO Object Storage stack
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package sql
import (
"errors"
"github.com/bcicen/jstream"
"github.com/minio/simdjson-go"
)
var (
errKeyLookup = errors.New("Cannot look up key in non-object value")
errIndexLookup = errors.New("Cannot look up array index in non-array value")
errWildcardObjectLookup = errors.New("Object wildcard used on non-object value")
errWildcardArrayLookup = errors.New("Array wildcard used on non-array value")
errWilcardObjectUsageInvalid = errors.New("Invalid usage of object wildcard")
)
// jsonpathEval evaluates a JSON path and returns the value at the path.
// If the value should be considered flat (from wildcards) any array returned should be considered individual values.
func jsonpathEval(p []*JSONPathElement, v interface{}) (r interface{}, flat bool, err error) {
// fmt.Printf("JPATHexpr: %v jsonobj: %v\n\n", p, v)
if len(p) == 0 || v == nil {
return v, false, nil
}
switch {
case p[0].Key != nil:
key := p[0].Key.keyString()
switch kvs := v.(type) {
case jstream.KVS:
for _, kv := range kvs {
if kv.Key == key {
return jsonpathEval(p[1:], kv.Value)
}
}
// Key not found - return nil result
return nil, false, nil
case simdjson.Object:
elem := kvs.FindKey(key, nil)
if elem == nil {
// Key not found - return nil result
return nil, false, nil
}
val, err := IterToValue(elem.Iter)
if err != nil {
return nil, false, err
}
return jsonpathEval(p[1:], val)
default:
return nil, false, errKeyLookup
}
case p[0].Index != nil:
idx := *p[0].Index
arr, ok := v.([]interface{})
if !ok {
return nil, false, errIndexLookup
}
if idx >= len(arr) {
return nil, false, nil
}
return jsonpathEval(p[1:], arr[idx])
case p[0].ObjectWildcard:
switch kvs := v.(type) {
case jstream.KVS:
if len(p[1:]) > 0 {
return nil, false, errWilcardObjectUsageInvalid
}
return kvs, false, nil
case simdjson.Object:
if len(p[1:]) > 0 {
return nil, false, errWilcardObjectUsageInvalid
}
return kvs, false, nil
default:
return nil, false, errWildcardObjectLookup
}
case p[0].ArrayWildcard:
arr, ok := v.([]interface{})
if !ok {
return nil, false, errWildcardArrayLookup
}
// Lookup remainder of path in each array element and
// make result array.
var result []interface{}
for _, a := range arr {
rval, flatten, err := jsonpathEval(p[1:], a)
if err != nil {
return nil, false, err
}
if flatten {
// Flatten if array.
if arr, ok := rval.([]interface{}); ok {
result = append(result, arr...)
continue
}
}
result = append(result, rval)
}
return result, true, nil
}
panic("cannot reach here")
}

View File

@@ -0,0 +1,97 @@
// Copyright (c) 2015-2021 MinIO, Inc.
//
// This file is part of MinIO Object Storage stack
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package sql
import (
"bytes"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"reflect"
"testing"
"github.com/alecthomas/participle"
"github.com/bcicen/jstream"
)
func getJSONStructs(b []byte) ([]interface{}, error) {
dec := jstream.NewDecoder(bytes.NewBuffer(b), 0).ObjectAsKVS()
var result []interface{}
for parsedVal := range dec.Stream() {
result = append(result, parsedVal.Value)
}
if err := dec.Err(); err != nil {
return nil, err
}
return result, nil
}
func TestJsonpathEval(t *testing.T) {
f, err := os.Open(filepath.Join("jsondata", "books.json"))
if err != nil {
t.Fatal(err)
}
b, err := ioutil.ReadAll(f)
if err != nil {
t.Fatal(err)
}
p := participle.MustBuild(
&JSONPath{},
participle.Lexer(sqlLexer),
participle.CaseInsensitive("Keyword"),
)
cases := []struct {
str string
res []interface{}
}{
{"s.title", []interface{}{"Murder on the Orient Express", "The Robots of Dawn", "Pigs Have Wings"}},
{"s.authorInfo.yearRange", []interface{}{[]interface{}{1890.0, 1976.0}, []interface{}{1920.0, 1992.0}, []interface{}{1881.0, 1975.0}}},
{"s.authorInfo.name", []interface{}{"Agatha Christie", "Isaac Asimov", "P. G. Wodehouse"}},
{"s.authorInfo.yearRange[0]", []interface{}{1890.0, 1920.0, 1881.0}},
{"s.publicationHistory[0].pages", []interface{}{256.0, 336.0, nil}},
}
for i, tc := range cases {
jp := JSONPath{}
err := p.ParseString(tc.str, &jp)
// fmt.Println(jp)
if err != nil {
t.Fatalf("parse failed!: %d %v %s", i, err, tc)
}
// Read only the first json object from the file
recs, err := getJSONStructs(b)
if err != nil || len(recs) != 3 {
t.Fatalf("%v or length was not 3", err)
}
for j, rec := range recs {
// fmt.Println(rec)
r, _, err := jsonpathEval(jp.PathExpr, rec)
if err != nil {
t.Errorf("Error: %d %d %v", i, j, err)
}
if !reflect.DeepEqual(r, tc.res[j]) {
fmt.Printf("%#v (%v) != %v (%v)\n", r, reflect.TypeOf(r), tc.res[j], reflect.TypeOf(tc.res[j]))
t.Errorf("case: %d %d failed", i, j)
}
}
}
}

View File

@@ -0,0 +1,371 @@
// Copyright (c) 2015-2021 MinIO, Inc.
//
// This file is part of MinIO Object Storage stack
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package sql
import (
"strings"
"github.com/alecthomas/participle"
"github.com/alecthomas/participle/lexer"
)
// Types with custom Capture interface for parsing
// Boolean is a type for a parsed Boolean literal
type Boolean bool
// Capture interface used by participle
func (b *Boolean) Capture(values []string) error {
*b = strings.ToLower(values[0]) == "true"
return nil
}
// LiteralString is a type for parsed SQL string literals
type LiteralString string
// Capture interface used by participle
func (ls *LiteralString) Capture(values []string) error {
// Remove enclosing single quote
n := len(values[0])
r := values[0][1 : n-1]
// Translate doubled quotes
*ls = LiteralString(strings.Replace(r, "''", "'", -1))
return nil
}
// LiteralList is a type for parsed SQL lists literals
type LiteralList []string
// Capture interface used by participle
func (ls *LiteralList) Capture(values []string) error {
// Remove enclosing parenthesis.
n := len(values[0])
r := values[0][1 : n-1]
// Translate doubled quotes
*ls = LiteralList(strings.Split(r, ","))
return nil
}
// ObjectKey is a type for parsed strings occurring in key paths
type ObjectKey struct {
Lit *LiteralString `parser:" \"[\" @LitString \"]\""`
ID *Identifier `parser:"| \".\" @@"`
}
// QuotedIdentifier is a type for parsed strings that are double
// quoted.
type QuotedIdentifier string
// Capture interface used by participle
func (qi *QuotedIdentifier) Capture(values []string) error {
// Remove enclosing quotes
n := len(values[0])
r := values[0][1 : n-1]
// Translate doubled quotes
*qi = QuotedIdentifier(strings.Replace(r, `""`, `"`, -1))
return nil
}
// Types representing AST of SQL statement. Only SELECT is supported.
// Select is the top level AST node type
type Select struct {
Expression *SelectExpression `parser:"\"SELECT\" @@"`
From *TableExpression `parser:"\"FROM\" @@"`
Where *Expression `parser:"( \"WHERE\" @@ )?"`
Limit *LitValue `parser:"( \"LIMIT\" @@ )?"`
}
// SelectExpression represents the items requested in the select
// statement
type SelectExpression struct {
All bool `parser:" @\"*\""`
Expressions []*AliasedExpression `parser:"| @@ { \",\" @@ }"`
}
// TableExpression represents the FROM clause
type TableExpression struct {
Table *JSONPath `parser:"@@"`
As string `parser:"( \"AS\"? @Ident )?"`
}
// JSONPathElement represents a keypath component
type JSONPathElement struct {
Key *ObjectKey `parser:" @@"` // ['name'] and .name forms
Index *int `parser:"| \"[\" @Int \"]\""` // [3] form
ObjectWildcard bool `parser:"| @\".*\""` // .* form
ArrayWildcard bool `parser:"| @\"[*]\""` // [*] form
}
// JSONPath represents a keypath.
// Instances should be treated idempotent and not change once created.
type JSONPath struct {
BaseKey *Identifier `parser:" @@"`
PathExpr []*JSONPathElement `parser:"(@@)*"`
// Cached values:
pathString string
strippedTableAlias string
strippedPathExpr []*JSONPathElement
}
// AliasedExpression is an expression that can be optionally named
type AliasedExpression struct {
Expression *Expression `parser:"@@"`
As string `parser:"[ \"AS\" @Ident ]"`
}
// Grammar for Expression
//
// Expression → AndCondition ("OR" AndCondition)*
// AndCondition → Condition ("AND" Condition)*
// Condition → "NOT" Condition | ConditionExpression
// ConditionExpression → ValueExpression ("=" | "<>" | "<=" | ">=" | "<" | ">") ValueExpression
// | ValueExpression "LIKE" ValueExpression ("ESCAPE" LitString)?
// | ValueExpression ("NOT"? "BETWEEN" ValueExpression "AND" ValueExpression)
// | ValueExpression "IN" "(" Expression ("," Expression)* ")"
// | ValueExpression
// ValueExpression → Operand
//
// Operand grammar follows below
// Expression represents a logical disjunction of clauses
type Expression struct {
And []*AndCondition `parser:"@@ ( \"OR\" @@ )*"`
}
// ListExpr represents a literal list with elements as expressions.
type ListExpr struct {
Elements []*Expression `parser:"\"(\" @@ ( \",\" @@ )* \")\" | \"[\" @@ ( \",\" @@ )* \"]\""`
}
// AndCondition represents logical conjunction of clauses
type AndCondition struct {
Condition []*Condition `parser:"@@ ( \"AND\" @@ )*"`
}
// Condition represents a negation or a condition operand
type Condition struct {
Operand *ConditionOperand `parser:" @@"`
Not *Condition `parser:"| \"NOT\" @@"`
}
// ConditionOperand is a operand followed by an an optional operation
// expression
type ConditionOperand struct {
Operand *Operand `parser:"@@"`
ConditionRHS *ConditionRHS `parser:"@@?"`
}
// ConditionRHS represents the right-hand-side of Compare, Between, In
// or Like expressions.
type ConditionRHS struct {
Compare *Compare `parser:" @@"`
Between *Between `parser:"| @@"`
In *In `parser:"| \"IN\" @@"`
Like *Like `parser:"| @@"`
}
// Compare represents the RHS of a comparison expression
type Compare struct {
Operator string `parser:"@( \"<>\" | \"<=\" | \">=\" | \"=\" | \"<\" | \">\" | \"!=\" )"`
Operand *Operand `parser:" @@"`
}
// Like represents the RHS of a LIKE expression
type Like struct {
Not bool `parser:" @\"NOT\"? "`
Pattern *Operand `parser:" \"LIKE\" @@ "`
EscapeChar *Operand `parser:" (\"ESCAPE\" @@)? "`
}
// Between represents the RHS of a BETWEEN expression
type Between struct {
Not bool `parser:" @\"NOT\"? "`
Start *Operand `parser:" \"BETWEEN\" @@ "`
End *Operand `parser:" \"AND\" @@ "`
}
// In represents the RHS of an IN expression
type In struct {
ListExpression *Expression `parser:"@@ "`
}
// Grammar for Operand:
//
// operand → multOp ( ("-" | "+") multOp )*
// multOp → unary ( ("/" | "*" | "%") unary )*
// unary → "-" unary | primary
// primary → Value | Variable | "(" expression ")"
//
// An Operand is a single term followed by an optional sequence of
// terms separated by +/-
type Operand struct {
Left *MultOp `parser:"@@"`
Right []*OpFactor `parser:"(@@)*"`
}
// OpFactor represents the right-side of a +/- operation.
type OpFactor struct {
Op string `parser:"@(\"+\" | \"-\")"`
Right *MultOp `parser:"@@"`
}
// MultOp represents a single term followed by an optional sequence of
// terms separated by *, / or % operators.
type MultOp struct {
Left *UnaryTerm `parser:"@@"`
Right []*OpUnaryTerm `parser:"(@@)*"`
}
// OpUnaryTerm represents the right side of *, / or % binary operations.
type OpUnaryTerm struct {
Op string `parser:"@(\"*\" | \"/\" | \"%\")"`
Right *UnaryTerm `parser:"@@"`
}
// UnaryTerm represents a single negated term or a primary term
type UnaryTerm struct {
Negated *NegatedTerm `parser:" @@"`
Primary *PrimaryTerm `parser:"| @@"`
}
// NegatedTerm has a leading minus sign.
type NegatedTerm struct {
Term *PrimaryTerm `parser:"\"-\" @@"`
}
// PrimaryTerm represents a Value, Path expression, a Sub-expression
// or a function call.
type PrimaryTerm struct {
Value *LitValue `parser:" @@"`
JPathExpr *JSONPath `parser:"| @@"`
ListExpr *ListExpr `parser:"| @@"`
SubExpression *Expression `parser:"| \"(\" @@ \")\""`
// Include function expressions here.
FuncCall *FuncExpr `parser:"| @@"`
}
// FuncExpr represents a function call
type FuncExpr struct {
SFunc *SimpleArgFunc `parser:" @@"`
Count *CountFunc `parser:"| @@"`
Cast *CastFunc `parser:"| @@"`
Substring *SubstringFunc `parser:"| @@"`
Extract *ExtractFunc `parser:"| @@"`
Trim *TrimFunc `parser:"| @@"`
DateAdd *DateAddFunc `parser:"| @@"`
DateDiff *DateDiffFunc `parser:"| @@"`
// Used during evaluation for aggregation funcs
aggregate *aggVal
}
// SimpleArgFunc represents functions with simple expression
// arguments.
type SimpleArgFunc struct {
FunctionName string `parser:" @(\"AVG\" | \"MAX\" | \"MIN\" | \"SUM\" | \"COALESCE\" | \"NULLIF\" | \"TO_STRING\" | \"TO_TIMESTAMP\" | \"UTCNOW\" | \"CHAR_LENGTH\" | \"CHARACTER_LENGTH\" | \"LOWER\" | \"UPPER\") "`
ArgsList []*Expression `parser:"\"(\" (@@ (\",\" @@)*)?\")\""`
}
// CountFunc represents the COUNT sql function
type CountFunc struct {
StarArg bool `parser:" \"COUNT\" \"(\" ( @\"*\"?"`
ExprArg *Expression `parser:" @@? )! \")\""`
}
// CastFunc represents CAST sql function
type CastFunc struct {
Expr *Expression `parser:" \"CAST\" \"(\" @@ "`
CastType string `parser:" \"AS\" @(\"BOOL\" | \"INT\" | \"INTEGER\" | \"STRING\" | \"FLOAT\" | \"DECIMAL\" | \"NUMERIC\" | \"TIMESTAMP\") \")\" "`
}
// SubstringFunc represents SUBSTRING sql function
type SubstringFunc struct {
Expr *PrimaryTerm `parser:" \"SUBSTRING\" \"(\" @@ "`
From *Operand `parser:" ( \"FROM\" @@ "`
For *Operand `parser:" (\"FOR\" @@)? \")\" "`
Arg2 *Operand `parser:" | \",\" @@ "`
Arg3 *Operand `parser:" (\",\" @@)? \")\" )"`
}
// ExtractFunc represents EXTRACT sql function
type ExtractFunc struct {
Timeword string `parser:" \"EXTRACT\" \"(\" @( \"YEAR\":Timeword | \"MONTH\":Timeword | \"DAY\":Timeword | \"HOUR\":Timeword | \"MINUTE\":Timeword | \"SECOND\":Timeword | \"TIMEZONE_HOUR\":Timeword | \"TIMEZONE_MINUTE\":Timeword ) "`
From *PrimaryTerm `parser:" \"FROM\" @@ \")\" "`
}
// TrimFunc represents TRIM sql function
type TrimFunc struct {
TrimWhere *string `parser:" \"TRIM\" \"(\" ( @( \"LEADING\" | \"TRAILING\" | \"BOTH\" ) "`
TrimChars *PrimaryTerm `parser:" @@? "`
TrimFrom *PrimaryTerm `parser:" \"FROM\" )? @@ \")\" "`
}
// DateAddFunc represents the DATE_ADD function
type DateAddFunc struct {
DatePart string `parser:" \"DATE_ADD\" \"(\" @( \"YEAR\":Timeword | \"MONTH\":Timeword | \"DAY\":Timeword | \"HOUR\":Timeword | \"MINUTE\":Timeword | \"SECOND\":Timeword ) \",\""`
Quantity *Operand `parser:" @@ \",\""`
Timestamp *PrimaryTerm `parser:" @@ \")\""`
}
// DateDiffFunc represents the DATE_DIFF function
type DateDiffFunc struct {
DatePart string `parser:" \"DATE_DIFF\" \"(\" @( \"YEAR\":Timeword | \"MONTH\":Timeword | \"DAY\":Timeword | \"HOUR\":Timeword | \"MINUTE\":Timeword | \"SECOND\":Timeword ) \",\" "`
Timestamp1 *PrimaryTerm `parser:" @@ \",\" "`
Timestamp2 *PrimaryTerm `parser:" @@ \")\" "`
}
// LitValue represents a literal value parsed from the sql
type LitValue struct {
Float *float64 `parser:"( @Float"`
Int *float64 `parser:" | @Int"` // To avoid value out of range, use float64 instead
String *LiteralString `parser:" | @LitString"`
Boolean *Boolean `parser:" | @(\"TRUE\" | \"FALSE\")"`
Null bool `parser:" | @\"NULL\")"`
}
// Identifier represents a parsed identifier
type Identifier struct {
Unquoted *string `parser:" @Ident"`
Quoted *QuotedIdentifier `parser:"| @QuotIdent"`
}
var (
sqlLexer = lexer.Must(lexer.Regexp(`(\s+)` +
`|(?P<Timeword>(?i)\b(?:YEAR|MONTH|DAY|HOUR|MINUTE|SECOND|TIMEZONE_HOUR|TIMEZONE_MINUTE)\b)` +
`|(?P<Keyword>(?i)\b(?:SELECT|FROM|TOP|DISTINCT|ALL|WHERE|GROUP|BY|HAVING|UNION|MINUS|EXCEPT|INTERSECT|ORDER|LIMIT|OFFSET|TRUE|FALSE|NULL|IS|NOT|ANY|SOME|BETWEEN|AND|OR|LIKE|ESCAPE|AS|IN|BOOL|INT|INTEGER|STRING|FLOAT|DECIMAL|NUMERIC|TIMESTAMP|AVG|COUNT|MAX|MIN|SUM|COALESCE|NULLIF|CAST|DATE_ADD|DATE_DIFF|EXTRACT|TO_STRING|TO_TIMESTAMP|UTCNOW|CHAR_LENGTH|CHARACTER_LENGTH|LOWER|SUBSTRING|TRIM|UPPER|LEADING|TRAILING|BOTH|FOR)\b)` +
`|(?P<Ident>[a-zA-Z_][a-zA-Z0-9_]*)` +
`|(?P<QuotIdent>"([^"]*("")?)*")` +
`|(?P<Float>\d*\.\d+([eE][-+]?\d+)?)` +
`|(?P<Int>\d+)` +
`|(?P<LitString>'([^']*('')?)*')` +
`|(?P<Operators><>|!=|<=|>=|\.\*|\[\*\]|[-+*/%,.()=<>\[\]])`,
))
// SQLParser is used to parse SQL statements
SQLParser = participle.MustBuild(
&Select{},
participle.Lexer(sqlLexer),
participle.CaseInsensitive("Keyword"),
participle.CaseInsensitive("Timeword"),
)
)

View File

@@ -0,0 +1,386 @@
// Copyright (c) 2015-2021 MinIO, Inc.
//
// This file is part of MinIO Object Storage stack
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package sql
import (
"bytes"
"testing"
"github.com/alecthomas/participle"
"github.com/alecthomas/participle/lexer"
)
func TestJSONPathElement(t *testing.T) {
p := participle.MustBuild(
&JSONPathElement{},
participle.Lexer(sqlLexer),
participle.CaseInsensitive("Keyword"),
participle.CaseInsensitive("Timeword"),
)
j := JSONPathElement{}
cases := []string{
// Key
"['name']", ".name", `."name"`,
// Index
"[2]", "[0]", "[100]",
// Object wilcard
".*",
// array wildcard
"[*]",
}
for i, tc := range cases {
err := p.ParseString(tc, &j)
if err != nil {
t.Fatalf("%d: %v", i, err)
}
// repr.Println(j, repr.Indent(" "), repr.OmitEmpty(true))
}
}
func TestJSONPath(t *testing.T) {
p := participle.MustBuild(
&JSONPath{},
participle.Lexer(sqlLexer),
participle.CaseInsensitive("Keyword"),
participle.CaseInsensitive("Timeword"),
)
j := JSONPath{}
cases := []string{
"S3Object",
"S3Object.id",
"S3Object.book.title",
"S3Object.id[1]",
"S3Object.id['abc']",
"S3Object.id['ab']",
"S3Object.words.*.id",
"S3Object.words.name[*].val",
"S3Object.words.name[*].val[*]",
"S3Object.words.name[*].val.*",
}
for i, tc := range cases {
err := p.ParseString(tc, &j)
if err != nil {
t.Fatalf("%d: %v", i, err)
}
// repr.Println(j, repr.Indent(" "), repr.OmitEmpty(true))
}
}
func TestIdentifierParsing(t *testing.T) {
p := participle.MustBuild(
&Identifier{},
participle.Lexer(sqlLexer),
participle.CaseInsensitive("Keyword"),
)
id := Identifier{}
validCases := []string{
"a",
"_a",
"abc_a",
"a2",
`"abc"`,
`"abc\a""ac"`,
}
for i, tc := range validCases {
err := p.ParseString(tc, &id)
if err != nil {
t.Fatalf("%d: %v", i, err)
}
// repr.Println(id, repr.Indent(" "), repr.OmitEmpty(true))
}
invalidCases := []string{
"+a",
"-a",
"1a",
`"ab`,
`abc"`,
`aa""a`,
`"a"a"`,
}
for i, tc := range invalidCases {
err := p.ParseString(tc, &id)
if err == nil {
t.Fatalf("%d: %v", i, err)
}
// fmt.Println(tc, err)
}
}
func TestLiteralStringParsing(t *testing.T) {
var k ObjectKey
p := participle.MustBuild(
&ObjectKey{},
participle.Lexer(sqlLexer),
participle.CaseInsensitive("Keyword"),
)
validCases := []string{
"['abc']",
"['ab''c']",
"['a''b''c']",
"['abc-x_1##@(*&(#*))/\\']",
}
for i, tc := range validCases {
err := p.ParseString(tc, &k)
if err != nil {
t.Fatalf("%d: %v", i, err)
}
if string(*k.Lit) == "" {
t.Fatalf("Incorrect parse %#v", k)
}
// repr.Println(k, repr.Indent(" "), repr.OmitEmpty(true))
}
invalidCases := []string{
"['abc'']",
"['-abc'sc']",
"[abc']",
"['ac]",
}
for i, tc := range invalidCases {
err := p.ParseString(tc, &k)
if err == nil {
t.Fatalf("%d: %v", i, err)
}
// fmt.Println(tc, err)
}
}
func TestFunctionParsing(t *testing.T) {
var fex FuncExpr
p := participle.MustBuild(
&FuncExpr{},
participle.Lexer(sqlLexer),
participle.CaseInsensitive("Keyword"),
participle.CaseInsensitive("Timeword"),
)
validCases := []string{
"count(*)",
"sum(2 + s.id)",
"sum(t)",
"avg(s.id[1])",
"coalesce(s.id[1], 2, 2 + 3)",
"cast(s as string)",
"cast(s AS INT)",
"cast(s as DECIMAL)",
"extract(YEAR from '2018-01-09')",
"extract(month from '2018-01-09')",
"extract(hour from '2018-01-09')",
"extract(day from '2018-01-09')",
"substring('abcd' from 2 for 2)",
"substring('abcd' from 2)",
"substring('abcd' , 2 , 2)",
"substring('abcd' , 22 )",
"trim(' aab ')",
"trim(leading from ' aab ')",
"trim(trailing from ' aab ')",
"trim(both from ' aab ')",
"trim(both '12' from ' aab ')",
"trim(leading '12' from ' aab ')",
"trim(trailing '12' from ' aab ')",
"count(23)",
}
for i, tc := range validCases {
err := p.ParseString(tc, &fex)
if err != nil {
t.Fatalf("%d: %v", i, err)
}
// repr.Println(fex, repr.Indent(" "), repr.OmitEmpty(true))
}
}
func TestSqlLexer(t *testing.T) {
// s := bytes.NewBuffer([]byte("s.['name'].*.[*].abc.[\"abc\"]"))
s := bytes.NewBuffer([]byte("S3Object.words.*.id"))
// s := bytes.NewBuffer([]byte("COUNT(Id)"))
lex, err := sqlLexer.Lex(s)
if err != nil {
t.Fatal(err)
}
tokens, err := lexer.ConsumeAll(lex)
if err != nil {
t.Fatal(err)
}
// for i, t := range tokens {
// fmt.Printf("%d: %#v\n", i, t)
// }
if len(tokens) != 7 {
t.Fatalf("Expected 7 got %d", len(tokens))
}
}
func TestSelectWhere(t *testing.T) {
p := participle.MustBuild(
&Select{},
participle.Lexer(sqlLexer),
participle.CaseInsensitive("Keyword"),
)
s := Select{}
cases := []string{
"select * from s3object",
"select a, b from s3object s",
"select a, b from s3object as s",
"select a, b from s3object as s where a = 1",
"select a, b from s3object s where a = 1",
"select a, b from s3object where a = 1",
}
for i, tc := range cases {
err := p.ParseString(tc, &s)
if err != nil {
t.Fatalf("%d: %v", i, err)
}
// repr.Println(s, repr.Indent(" "), repr.OmitEmpty(true))
}
}
func TestLikeClause(t *testing.T) {
p := participle.MustBuild(
&Select{},
participle.Lexer(sqlLexer),
participle.CaseInsensitive("Keyword"),
)
s := Select{}
cases := []string{
`select * from s3object where Name like 'abcd'`,
`select Name like 'abc' from s3object`,
`select * from s3object where Name not like 'abc'`,
`select * from s3object where Name like 'abc' escape 't'`,
`select * from s3object where Name like 'a\%' escape '?'`,
`select * from s3object where Name not like 'abc\' escape '?'`,
`select * from s3object where Name like 'a\%' escape LOWER('?')`,
`select * from s3object where Name not like LOWER('Bc\') escape '?'`,
}
for i, tc := range cases {
err := p.ParseString(tc, &s)
if err != nil {
t.Errorf("%d: %v", i, err)
}
}
}
func TestBetweenClause(t *testing.T) {
p := participle.MustBuild(
&Select{},
participle.Lexer(sqlLexer),
participle.CaseInsensitive("Keyword"),
)
s := Select{}
cases := []string{
`select * from s3object where Id between 1 and 2`,
`select * from s3object where Id between 1 and 2 and name = 'Ab'`,
`select * from s3object where Id not between 1 and 2`,
`select * from s3object where Id not between 1 and 2 and name = 'Bc'`,
}
for i, tc := range cases {
err := p.ParseString(tc, &s)
if err != nil {
t.Errorf("%d: %v", i, err)
}
}
}
func TestFromClauseJSONPath(t *testing.T) {
p := participle.MustBuild(
&Select{},
participle.Lexer(sqlLexer),
participle.CaseInsensitive("Keyword"),
)
s := Select{}
cases := []string{
"select * from s3object",
"select * from s3object[*].name",
"select * from s3object[*].books[*]",
"select * from s3object[*].books[*].name",
"select * from s3object where name > 2",
"select * from s3object[*].name where name > 2",
"select * from s3object[*].books[*] where name > 2",
"select * from s3object[*].books[*].name where name > 2",
"select * from s3object[*].books[*] s",
"select * from s3object[*].books[*].name as s",
"select * from s3object s where name > 2",
"select * from s3object[*].name as s where name > 2",
}
for i, tc := range cases {
err := p.ParseString(tc, &s)
if err != nil {
t.Fatalf("%d: %v", i, err)
}
// repr.Println(s, repr.Indent(" "), repr.OmitEmpty(true))
}
}
func TestSelectParsing(t *testing.T) {
p := participle.MustBuild(
&Select{},
participle.Lexer(sqlLexer),
participle.CaseInsensitive("Keyword"),
)
s := Select{}
cases := []string{
"select * from s3object where name > 2 or value > 1 or word > 2",
"select s.word.id + 2 from s3object s",
"select 1-2-3 from s3object s limit 1",
}
for i, tc := range cases {
err := p.ParseString(tc, &s)
if err != nil {
t.Fatalf("%d: %v", i, err)
}
// repr.Println(s, repr.Indent(" "), repr.OmitEmpty(true))
}
}
func TestSqlLexerArithOps(t *testing.T) {
s := bytes.NewBuffer([]byte("year from select month hour distinct"))
lex, err := sqlLexer.Lex(s)
if err != nil {
t.Fatal(err)
}
tokens, err := lexer.ConsumeAll(lex)
if err != nil {
t.Fatal(err)
}
if len(tokens) != 7 {
t.Errorf("Expected 7 got %d", len(tokens))
}
// for i, t := range tokens {
// fmt.Printf("%d: %#v\n", i, t)
// }
}

View File

@@ -0,0 +1,142 @@
// Copyright (c) 2015-2021 MinIO, Inc.
//
// This file is part of MinIO Object Storage stack
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package sql
import (
"fmt"
"io"
"github.com/minio/simdjson-go"
)
// SelectObjectFormat specifies the format of the underlying data
type SelectObjectFormat int
const (
// SelectFmtUnknown - unknown format (default value)
SelectFmtUnknown SelectObjectFormat = iota
// SelectFmtCSV - CSV format
SelectFmtCSV
// SelectFmtJSON - JSON format
SelectFmtJSON
// SelectFmtSIMDJSON - SIMD JSON format
SelectFmtSIMDJSON
// SelectFmtParquet - Parquet format
SelectFmtParquet
)
// WriteCSVOpts - encapsulates options for Select CSV output
type WriteCSVOpts struct {
FieldDelimiter rune
Quote rune
QuoteEscape rune
AlwaysQuote bool
}
// Record - is a type containing columns and their values.
type Record interface {
Get(name string) (*Value, error)
// Set a value.
// Can return a different record type.
Set(name string, value *Value) (Record, error)
WriteCSV(writer io.Writer, opts WriteCSVOpts) error
WriteJSON(writer io.Writer) error
// Clone the record and if possible use the destination provided.
Clone(dst Record) Record
Reset()
// Returns underlying representation
Raw() (SelectObjectFormat, interface{})
// Replaces the underlying data
Replace(k interface{}) error
}
// IterToValue converts a simdjson Iter to its underlying value.
// Objects are returned as simdjson.Object
// Arrays are returned as []interface{} with parsed values.
func IterToValue(iter simdjson.Iter) (interface{}, error) {
switch iter.Type() {
case simdjson.TypeString:
v, err := iter.String()
if err != nil {
return nil, err
}
return v, nil
case simdjson.TypeFloat:
v, err := iter.Float()
if err != nil {
return nil, err
}
return v, nil
case simdjson.TypeInt:
v, err := iter.Int()
if err != nil {
return nil, err
}
return v, nil
case simdjson.TypeUint:
v, err := iter.Int()
if err != nil {
// Can't fit into int, convert to float.
v, err := iter.Float()
return v, err
}
return v, nil
case simdjson.TypeBool:
v, err := iter.Bool()
if err != nil {
return nil, err
}
return v, nil
case simdjson.TypeObject:
obj, err := iter.Object(nil)
if err != nil {
return nil, err
}
return *obj, err
case simdjson.TypeArray:
arr, err := iter.Array(nil)
if err != nil {
return nil, err
}
iter := arr.Iter()
var dst []interface{}
var next simdjson.Iter
for {
typ, err := iter.AdvanceIter(&next)
if err != nil {
return nil, err
}
if typ == simdjson.TypeNone {
break
}
v, err := IterToValue(next)
if err != nil {
return nil, err
}
dst = append(dst, v)
}
return dst, err
case simdjson.TypeNull:
return nil, nil
}
return nil, fmt.Errorf("IterToValue: unknown JSON type: %s", iter.Type().String())
}

View File

@@ -0,0 +1,345 @@
// Copyright (c) 2015-2021 MinIO, Inc.
//
// This file is part of MinIO Object Storage stack
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package sql
import (
"errors"
"fmt"
"strings"
"github.com/bcicen/jstream"
"github.com/minio/simdjson-go"
)
var (
errBadLimitSpecified = errors.New("Limit value must be a positive integer")
)
const (
baseTableName = "s3object"
)
// SelectStatement is the top level parsed and analyzed structure
type SelectStatement struct {
selectAST *Select
// Analysis result of the statement
selectQProp qProp
// Result of parsing the limit clause if one is present
// (otherwise -1)
limitValue int64
// Count of rows that have been output.
outputCount int64
// Table alias
tableAlias string
}
// ParseSelectStatement - parses a select query from the given string
// and analyzes it.
func ParseSelectStatement(s string) (stmt SelectStatement, err error) {
var selectAST Select
err = SQLParser.ParseString(s, &selectAST)
if err != nil {
err = errQueryParseFailure(err)
return
}
// Check if select is "SELECT s.* from S3Object s"
if !selectAST.Expression.All &&
len(selectAST.Expression.Expressions) == 1 &&
len(selectAST.Expression.Expressions[0].Expression.And) == 1 &&
len(selectAST.Expression.Expressions[0].Expression.And[0].Condition) == 1 &&
selectAST.Expression.Expressions[0].Expression.And[0].Condition[0].Operand != nil &&
selectAST.Expression.Expressions[0].Expression.And[0].Condition[0].Operand.Operand.Left != nil &&
selectAST.Expression.Expressions[0].Expression.And[0].Condition[0].Operand.Operand.Left.Left != nil &&
selectAST.Expression.Expressions[0].Expression.And[0].Condition[0].Operand.Operand.Left.Left.Primary != nil &&
selectAST.Expression.Expressions[0].Expression.And[0].Condition[0].Operand.Operand.Left.Left.Primary.JPathExpr != nil {
if selectAST.Expression.Expressions[0].Expression.And[0].Condition[0].Operand.Operand.Left.Left.Primary.JPathExpr.String() == selectAST.From.As+".*" {
selectAST.Expression.All = true
}
}
stmt.selectAST = &selectAST
// Check the parsed limit value
stmt.limitValue, err = parseLimit(selectAST.Limit)
if err != nil {
err = errQueryAnalysisFailure(err)
return
}
// Analyze where clause
if selectAST.Where != nil {
whereQProp := selectAST.Where.analyze(&selectAST)
if whereQProp.err != nil {
err = errQueryAnalysisFailure(fmt.Errorf("Where clause error: %w", whereQProp.err))
return
}
if whereQProp.isAggregation {
err = errQueryAnalysisFailure(errors.New("WHERE clause cannot have an aggregation"))
return
}
}
// Validate table name
err = validateTableName(selectAST.From)
if err != nil {
return
}
// Analyze main select expression
stmt.selectQProp = selectAST.Expression.analyze(&selectAST)
err = stmt.selectQProp.err
if err != nil {
err = errQueryAnalysisFailure(err)
}
// Set table alias
stmt.tableAlias = selectAST.From.As
return
}
func validateTableName(from *TableExpression) error {
if strings.ToLower(from.Table.BaseKey.String()) != baseTableName {
return errBadTableName(errors.New("table name must be `s3object`"))
}
if len(from.Table.PathExpr) > 0 {
if !from.Table.PathExpr[0].ArrayWildcard {
return errBadTableName(errors.New("keypath table name is invalid - please check the service documentation"))
}
}
return nil
}
func parseLimit(v *LitValue) (int64, error) {
switch {
case v == nil:
return -1, nil
case v.Int == nil:
return -1, errBadLimitSpecified
default:
r := int64(*v.Int)
if r < 0 {
return -1, errBadLimitSpecified
}
return r, nil
}
}
// EvalFrom evaluates the From clause on the input record. It only
// applies to JSON input data format (currently).
func (e *SelectStatement) EvalFrom(format string, input Record) ([]*Record, error) {
if !e.selectAST.From.HasKeypath() {
return []*Record{&input}, nil
}
_, rawVal := input.Raw()
if format != "json" {
return nil, errDataSource(errors.New("path not supported"))
}
switch rec := rawVal.(type) {
case jstream.KVS:
txedRec, _, err := jsonpathEval(e.selectAST.From.Table.PathExpr[1:], rec)
if err != nil {
return nil, err
}
var kvs jstream.KVS
switch v := txedRec.(type) {
case jstream.KVS:
kvs = v
case []interface{}:
recs := make([]*Record, len(v))
for i, val := range v {
tmpRec := input.Clone(nil)
if err = tmpRec.Replace(val); err != nil {
return nil, err
}
recs[i] = &tmpRec
}
return recs, nil
default:
kvs = jstream.KVS{jstream.KV{Key: "_1", Value: v}}
}
if err = input.Replace(kvs); err != nil {
return nil, err
}
return []*Record{&input}, nil
case simdjson.Object:
txedRec, _, err := jsonpathEval(e.selectAST.From.Table.PathExpr[1:], rec)
if err != nil {
return nil, err
}
switch v := txedRec.(type) {
case simdjson.Object:
err := input.Replace(v)
if err != nil {
return nil, err
}
case []interface{}:
recs := make([]*Record, len(v))
for i, val := range v {
tmpRec := input.Clone(nil)
if err = tmpRec.Replace(val); err != nil {
return nil, err
}
recs[i] = &tmpRec
}
return recs, nil
default:
input.Reset()
input, err = input.Set("_1", &Value{value: v})
if err != nil {
return nil, err
}
}
return []*Record{&input}, nil
}
return nil, errDataSource(errors.New("unexpected non JSON input"))
}
// IsAggregated returns if the statement involves SQL aggregation
func (e *SelectStatement) IsAggregated() bool {
return e.selectQProp.isAggregation
}
// AggregateResult - returns the aggregated result after all input
// records have been processed. Applies only to aggregation queries.
func (e *SelectStatement) AggregateResult(output Record) error {
for i, expr := range e.selectAST.Expression.Expressions {
v, err := expr.evalNode(nil, e.tableAlias)
if err != nil {
return err
}
if expr.As != "" {
output, err = output.Set(expr.As, v)
} else {
output, err = output.Set(fmt.Sprintf("_%d", i+1), v)
}
if err != nil {
return err
}
}
return nil
}
func (e *SelectStatement) isPassingWhereClause(input Record) (bool, error) {
if e.selectAST.Where == nil {
return true, nil
}
value, err := e.selectAST.Where.evalNode(input, e.tableAlias)
if err != nil {
return false, err
}
b, ok := value.ToBool()
if !ok {
err = fmt.Errorf("WHERE expression did not return bool")
return false, err
}
return b, nil
}
// AggregateRow - aggregates the input record. Applies only to
// aggregation queries.
func (e *SelectStatement) AggregateRow(input Record) error {
ok, err := e.isPassingWhereClause(input)
if err != nil {
return err
}
if !ok {
return nil
}
for _, expr := range e.selectAST.Expression.Expressions {
err := expr.aggregateRow(input, e.tableAlias)
if err != nil {
return err
}
}
return nil
}
// Eval - evaluates the Select statement for the given record. It
// applies only to non-aggregation queries.
// The function returns whether the statement passed the WHERE clause and should be outputted.
func (e *SelectStatement) Eval(input, output Record) (Record, error) {
ok, err := e.isPassingWhereClause(input)
if err != nil || !ok {
// Either error or row did not pass where clause
return nil, err
}
if e.selectAST.Expression.All {
// Return the input record for `SELECT * FROM
// .. WHERE ..`
// Update count of records output.
if e.limitValue > -1 {
e.outputCount++
}
return input.Clone(output), nil
}
for i, expr := range e.selectAST.Expression.Expressions {
v, err := expr.evalNode(input, e.tableAlias)
if err != nil {
return nil, err
}
// Pick output column names
if expr.As != "" {
output, err = output.Set(expr.As, v)
} else if comp, ok := getLastKeypathComponent(expr.Expression); ok {
output, err = output.Set(comp, v)
} else {
output, err = output.Set(fmt.Sprintf("_%d", i+1), v)
}
if err != nil {
return nil, err
}
}
// Update count of records output.
if e.limitValue > -1 {
e.outputCount++
}
return output, nil
}
// LimitReached - returns true if the number of records output has
// reached the value of the `LIMIT` clause.
func (e *SelectStatement) LimitReached() bool {
if e.limitValue == -1 {
return false
}
return e.outputCount >= e.limitValue
}

View File

@@ -0,0 +1,200 @@
// Copyright (c) 2015-2021 MinIO, Inc.
//
// This file is part of MinIO Object Storage stack
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package sql
import (
"errors"
"strings"
)
var (
errMalformedEscapeSequence = errors.New("Malformed escape sequence in LIKE clause")
errInvalidTrimArg = errors.New("Trim argument is invalid - this should not happen")
errInvalidSubstringIndexLen = errors.New("Substring start index or length falls outside the string")
)
const (
percent rune = '%'
underscore rune = '_'
runeZero rune = 0
)
func evalSQLLike(text, pattern string, escape rune) (match bool, err error) {
s := []rune{}
prev := runeZero
hasLeadingPercent := false
patLen := len([]rune(pattern))
for i, r := range pattern {
if i > 0 && prev == escape {
switch r {
case percent, escape, underscore:
s = append(s, r)
prev = r
if r == escape {
prev = runeZero
}
default:
return false, errMalformedEscapeSequence
}
continue
}
prev = r
var ok bool
switch r {
case percent:
if len(s) == 0 {
hasLeadingPercent = true
continue
}
text, ok = matcher(text, string(s), hasLeadingPercent)
if !ok {
return false, nil
}
hasLeadingPercent = true
s = []rune{}
if i == patLen-1 {
// Last pattern character is a %, so
// we are done.
return true, nil
}
case underscore:
if len(s) == 0 {
text, ok = dropRune(text)
if !ok {
return false, nil
}
continue
}
text, ok = matcher(text, string(s), hasLeadingPercent)
if !ok {
return false, nil
}
hasLeadingPercent = false
text, ok = dropRune(text)
if !ok {
return false, nil
}
s = []rune{}
case escape:
if i == patLen-1 {
return false, errMalformedEscapeSequence
}
// Otherwise do nothing.
default:
s = append(s, r)
}
}
if hasLeadingPercent {
return strings.HasSuffix(text, string(s)), nil
}
return string(s) == text, nil
}
// matcher - Finds `pat` in `text`, and returns the part remainder of
// `text`, after the match. If leadingPercent is false, `pat` must be
// the prefix of `text`, otherwise it must be a substring.
func matcher(text, pat string, leadingPercent bool) (res string, match bool) {
if !leadingPercent {
res = strings.TrimPrefix(text, pat)
if len(text) == len(res) {
return "", false
}
} else {
parts := strings.SplitN(text, pat, 2)
if len(parts) == 1 {
return "", false
}
res = parts[1]
}
return res, true
}
func dropRune(text string) (res string, ok bool) {
r := []rune(text)
if len(r) == 0 {
return "", false
}
return string(r[1:]), true
}
func evalSQLSubstring(s string, startIdx, length int) (res string, err error) {
rs := []rune(s)
// According to s3 document, if startIdx < 1, it is set to 1.
if startIdx < 1 {
startIdx = 1
}
if startIdx > len(rs) {
startIdx = len(rs) + 1
}
// StartIdx is 1-based in the input
startIdx--
endIdx := len(rs)
if length != -1 {
if length < 0 {
return "", errInvalidSubstringIndexLen
}
if length > (endIdx - startIdx) {
length = endIdx - startIdx
}
endIdx = startIdx + length
}
return string(rs[startIdx:endIdx]), nil
}
const (
trimLeading = "LEADING"
trimTrailing = "TRAILING"
trimBoth = "BOTH"
)
func evalSQLTrim(where *string, trimChars, text string) (result string, err error) {
cutSet := " "
if trimChars != "" {
cutSet = trimChars
}
trimFunc := strings.Trim
switch {
case where == nil:
case *where == trimBoth:
case *where == trimLeading:
trimFunc = strings.TrimLeft
case *where == trimTrailing:
trimFunc = strings.TrimRight
default:
return "", errInvalidTrimArg
}
return trimFunc(text, cutSet), nil
}

View File

@@ -0,0 +1,43 @@
/*
* MinIO Object Storage (c) 2021 MinIO, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import "testing"
func TestEvalSQLSubstring(t *testing.T) {
evalCases := []struct {
s string
startIdx int
length int
resExpected string
errExpected error
}{
{"abcd", 1, 1, "a", nil},
{"abcd", -1, 1, "a", nil},
{"abcd", 999, 999, "", nil},
{"", 999, 999, "", nil},
{"测试abc", 1, 1, "测", nil},
{"测试abc", 5, 5, "c", nil},
}
for i, tc := range evalCases {
res, err := evalSQLSubstring(tc.s, tc.startIdx, tc.length)
if res != tc.resExpected || err != tc.errExpected {
t.Errorf("Eval Case %d failed: %v %v", i, res, err)
}
}
}

View File

@@ -0,0 +1,108 @@
// Copyright (c) 2015-2021 MinIO, Inc.
//
// This file is part of MinIO Object Storage stack
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package sql
import (
"testing"
)
func TestEvalSQLLike(t *testing.T) {
dropCases := []struct {
input, resultExpected string
matchExpected bool
}{
{"", "", false},
{"a", "", true},
{"ab", "b", true},
{"தமிழ்", "மிழ்", true},
}
for i, tc := range dropCases {
res, ok := dropRune(tc.input)
if res != tc.resultExpected || ok != tc.matchExpected {
t.Errorf("DropRune Case %d failed", i)
}
}
matcherCases := []struct {
iText, iPat string
iHasLeadingPercent bool
resultExpected string
matchExpected bool
}{
{"abcd", "bcd", false, "", false},
{"abcd", "bcd", true, "", true},
{"abcd", "abcd", false, "", true},
{"abcd", "abcd", true, "", true},
{"abcd", "ab", false, "cd", true},
{"abcd", "ab", true, "cd", true},
{"abcd", "bc", false, "", false},
{"abcd", "bc", true, "d", true},
}
for i, tc := range matcherCases {
res, ok := matcher(tc.iText, tc.iPat, tc.iHasLeadingPercent)
if res != tc.resultExpected || ok != tc.matchExpected {
t.Errorf("Matcher Case %d failed", i)
}
}
evalCases := []struct {
iText, iPat string
iEsc rune
matchExpected bool
errExpected error
}{
{"abcd", "abc", runeZero, false, nil},
{"abcd", "abcd", runeZero, true, nil},
{"abcd", "abc_", runeZero, true, nil},
{"abcd", "_bdd", runeZero, false, nil},
{"abcd", "_b_d", runeZero, true, nil},
{"abcd", "____", runeZero, true, nil},
{"abcd", "____%", runeZero, true, nil},
{"abcd", "%____", runeZero, true, nil},
{"abcd", "%__", runeZero, true, nil},
{"abcd", "____", runeZero, true, nil},
{"", "_", runeZero, false, nil},
{"", "%", runeZero, true, nil},
{"abcd", "%%%%%", runeZero, true, nil},
{"abcd", "_____", runeZero, false, nil},
{"abcd", "%%%%%", runeZero, true, nil},
{"a%%d", `a\%\%d`, '\\', true, nil},
{"a%%d", `a\%d`, '\\', false, nil},
{`a%%\d`, `a\%\%\\d`, '\\', true, nil},
{`a%%\`, `a\%\%\\`, '\\', true, nil},
{`a%__%\`, `a\%\_\_\%\\`, '\\', true, nil},
{`a%__%\`, `a\%\_\_\%_`, '\\', true, nil},
{`a%__%\`, `a\%\_\__`, '\\', false, nil},
{`a%__%\`, `a\%\_\_%`, '\\', true, nil},
{`a%__%\`, `a?%?_?_?%\`, '?', true, nil},
}
for i, tc := range evalCases {
// fmt.Println("Case:", i)
res, err := evalSQLLike(tc.iText, tc.iPat, tc.iEsc)
if res != tc.matchExpected || err != tc.errExpected {
t.Errorf("Eval Case %d failed: %v %v", i, res, err)
}
}
}

View File

@@ -0,0 +1,183 @@
// Copyright (c) 2015-2021 MinIO, Inc.
//
// This file is part of MinIO Object Storage stack
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package sql
import (
"time"
)
const (
layoutYear = "2006T"
layoutMonth = "2006-01T"
layoutDay = "2006-01-02T"
layoutMinute = "2006-01-02T15:04Z07:00"
layoutSecond = "2006-01-02T15:04:05Z07:00"
layoutNanosecond = "2006-01-02T15:04:05.999999999Z07:00"
)
var (
tformats = []string{
layoutYear,
layoutMonth,
layoutDay,
layoutMinute,
layoutSecond,
layoutNanosecond,
}
)
func parseSQLTimestamp(s string) (t time.Time, err error) {
for _, f := range tformats {
t, err = time.Parse(f, s)
if err == nil {
break
}
}
return
}
// FormatSQLTimestamp - returns the a string representation of the
// timestamp as used in S3 Select
func FormatSQLTimestamp(t time.Time) string {
_, zoneOffset := t.Zone()
hasZone := zoneOffset != 0
hasFracSecond := t.Nanosecond() != 0
hasSecond := t.Second() != 0
hasTime := t.Hour() != 0 || t.Minute() != 0
hasDay := t.Day() != 1
hasMonth := t.Month() != 1
switch {
case hasFracSecond:
return t.Format(layoutNanosecond)
case hasSecond:
return t.Format(layoutSecond)
case hasTime || hasZone:
return t.Format(layoutMinute)
case hasDay:
return t.Format(layoutDay)
case hasMonth:
return t.Format(layoutMonth)
default:
return t.Format(layoutYear)
}
}
const (
timePartYear = "YEAR"
timePartMonth = "MONTH"
timePartDay = "DAY"
timePartHour = "HOUR"
timePartMinute = "MINUTE"
timePartSecond = "SECOND"
timePartTimezoneHour = "TIMEZONE_HOUR"
timePartTimezoneMinute = "TIMEZONE_MINUTE"
)
func extract(what string, t time.Time) (v *Value, err error) {
switch what {
case timePartYear:
return FromInt(int64(t.Year())), nil
case timePartMonth:
return FromInt(int64(t.Month())), nil
case timePartDay:
return FromInt(int64(t.Day())), nil
case timePartHour:
return FromInt(int64(t.Hour())), nil
case timePartMinute:
return FromInt(int64(t.Minute())), nil
case timePartSecond:
return FromInt(int64(t.Second())), nil
case timePartTimezoneHour:
_, zoneOffset := t.Zone()
return FromInt(int64(zoneOffset / 3600)), nil
case timePartTimezoneMinute:
_, zoneOffset := t.Zone()
return FromInt(int64((zoneOffset % 3600) / 60)), nil
default:
// This does not happen
return nil, errNotImplemented
}
}
func dateAdd(timePart string, qty float64, t time.Time) (*Value, error) {
var duration time.Duration
switch timePart {
case timePartYear:
return FromTimestamp(t.AddDate(int(qty), 0, 0)), nil
case timePartMonth:
return FromTimestamp(t.AddDate(0, int(qty), 0)), nil
case timePartDay:
return FromTimestamp(t.AddDate(0, 0, int(qty))), nil
case timePartHour:
duration = time.Duration(qty) * time.Hour
case timePartMinute:
duration = time.Duration(qty) * time.Minute
case timePartSecond:
duration = time.Duration(qty) * time.Second
default:
return nil, errNotImplemented
}
return FromTimestamp(t.Add(duration)), nil
}
// dateDiff computes the difference between two times in terms of the
// `timePart` which can be years, months, days, hours, minutes or
// seconds. For difference in years, months or days, the time part,
// including timezone is ignored.
func dateDiff(timePart string, ts1, ts2 time.Time) (*Value, error) {
if ts2.Before(ts1) {
v, err := dateDiff(timePart, ts2, ts1)
if err == nil {
v.negate()
}
return v, err
}
duration := ts2.Sub(ts1)
y1, m1, d1 := ts1.Date()
y2, m2, d2 := ts2.Date()
switch timePart {
case timePartYear:
dy := int64(y2 - y1)
if m2 > m1 || (m2 == m1 && d2 >= d1) {
return FromInt(dy), nil
}
return FromInt(dy - 1), nil
case timePartMonth:
m1 += time.Month(12 * y1)
m2 += time.Month(12 * y2)
return FromInt(int64(m2 - m1)), nil
case timePartDay:
return FromInt(int64(duration / (24 * time.Hour))), nil
case timePartHour:
hours := duration / time.Hour
return FromInt(int64(hours)), nil
case timePartMinute:
minutes := duration / time.Minute
return FromInt(int64(minutes)), nil
case timePartSecond:
seconds := duration / time.Second
return FromInt(int64(seconds)), nil
default:
}
return nil, errNotImplemented
}

View File

@@ -0,0 +1,61 @@
// Copyright (c) 2015-2021 MinIO, Inc.
//
// This file is part of MinIO Object Storage stack
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package sql
import (
"testing"
"time"
)
func TestParseAndDisplaySQLTimestamp(t *testing.T) {
beijing := time.FixedZone("", int((8 * time.Hour).Seconds()))
fakeLosAngeles := time.FixedZone("", -int((8 * time.Hour).Seconds()))
cases := []struct {
s string
t time.Time
}{
{"2010T", time.Date(2010, 1, 1, 0, 0, 0, 0, time.UTC)},
{"2010-02T", time.Date(2010, 2, 1, 0, 0, 0, 0, time.UTC)},
{"2010-02-03T", time.Date(2010, 2, 3, 0, 0, 0, 0, time.UTC)},
{"2010-02-03T04:11Z", time.Date(2010, 2, 3, 4, 11, 0, 0, time.UTC)},
{"2010-02-03T04:11:30Z", time.Date(2010, 2, 3, 4, 11, 30, 0, time.UTC)},
{"2010-02-03T04:11:30.23Z", time.Date(2010, 2, 3, 4, 11, 30, 230000000, time.UTC)},
{"2010-02-03T04:11+08:00", time.Date(2010, 2, 3, 4, 11, 0, 0, beijing)},
{"2010-02-03T04:11:30+08:00", time.Date(2010, 2, 3, 4, 11, 30, 0, beijing)},
{"2010-02-03T04:11:30.23+08:00", time.Date(2010, 2, 3, 4, 11, 30, 230000000, beijing)},
{"2010-02-03T04:11:30-08:00", time.Date(2010, 2, 3, 4, 11, 30, 0, fakeLosAngeles)},
{"2010-02-03T04:11:30.23-08:00", time.Date(2010, 2, 3, 4, 11, 30, 230000000, fakeLosAngeles)},
}
for i, tc := range cases {
tval, err := parseSQLTimestamp(tc.s)
if err != nil {
t.Errorf("Case %d: Unexpected error: %v", i+1, err)
continue
}
if !tval.Equal(tc.t) {
t.Errorf("Case %d: Expected %v got %v", i+1, tc.t, tval)
continue
}
tstr := FormatSQLTimestamp(tc.t)
if tstr != tc.s {
t.Errorf("Case %d: Expected %s got %s", i+1, tc.s, tstr)
continue
}
}
}

View File

@@ -0,0 +1,134 @@
// Copyright (c) 2015-2021 MinIO, Inc.
//
// This file is part of MinIO Object Storage stack
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package sql
import (
"fmt"
"strings"
)
// String functions
// String - returns the JSONPath representation
func (e *JSONPath) String() string {
if len(e.pathString) == 0 {
parts := make([]string, len(e.PathExpr)+1)
parts[0] = e.BaseKey.String()
for i, pe := range e.PathExpr {
parts[i+1] = pe.String()
}
e.pathString = strings.Join(parts, "")
}
return e.pathString
}
// StripTableAlias removes a table alias from the path. The result is also
// cached for repeated lookups during SQL query evaluation.
func (e *JSONPath) StripTableAlias(tableAlias string) []*JSONPathElement {
if e.strippedTableAlias == tableAlias {
return e.strippedPathExpr
}
hasTableAlias := e.BaseKey.String() == tableAlias || strings.ToLower(e.BaseKey.String()) == baseTableName
var pathExpr []*JSONPathElement
if hasTableAlias {
pathExpr = e.PathExpr
} else {
pathExpr = make([]*JSONPathElement, len(e.PathExpr)+1)
pathExpr[0] = &JSONPathElement{Key: &ObjectKey{ID: e.BaseKey}}
copy(pathExpr[1:], e.PathExpr)
}
e.strippedTableAlias = tableAlias
e.strippedPathExpr = pathExpr
return e.strippedPathExpr
}
func (e *JSONPathElement) String() string {
switch {
case e.Key != nil:
return e.Key.String()
case e.Index != nil:
return fmt.Sprintf("[%d]", *e.Index)
case e.ObjectWildcard:
return ".*"
case e.ArrayWildcard:
return "[*]"
}
return ""
}
// String removes double quotes in quoted identifiers
func (i *Identifier) String() string {
if i.Unquoted != nil {
return *i.Unquoted
}
return string(*i.Quoted)
}
func (o *ObjectKey) String() string {
if o.Lit != nil {
return fmt.Sprintf("['%s']", string(*o.Lit))
}
return fmt.Sprintf(".%s", o.ID.String())
}
func (o *ObjectKey) keyString() string {
if o.Lit != nil {
return string(*o.Lit)
}
return o.ID.String()
}
// getLastKeypathComponent checks if the given expression is a path
// expression, and if so extracts the last dot separated component of
// the path. Otherwise it returns false.
func getLastKeypathComponent(e *Expression) (string, bool) {
if len(e.And) > 1 ||
len(e.And[0].Condition) > 1 ||
e.And[0].Condition[0].Not != nil ||
e.And[0].Condition[0].Operand.ConditionRHS != nil {
return "", false
}
operand := e.And[0].Condition[0].Operand.Operand
if operand.Right != nil ||
operand.Left.Right != nil ||
operand.Left.Left.Negated != nil ||
operand.Left.Left.Primary.JPathExpr == nil {
return "", false
}
// Check if path expression ends in a key
jpath := operand.Left.Left.Primary.JPathExpr
n := len(jpath.PathExpr)
if n > 0 && jpath.PathExpr[n-1].Key == nil {
return "", false
}
ps := jpath.String()
if idx := strings.LastIndex(ps, "."); idx >= 0 {
// Get last part of path string.
ps = ps[idx+1:]
}
return ps, true
}
// HasKeypath returns if the from clause has a key path -
// e.g. S3object[*].id
func (from *TableExpression) HasKeypath() bool {
return len(from.Table.PathExpr) > 1
}

View File

@@ -0,0 +1,914 @@
// Copyright (c) 2015-2021 MinIO, Inc.
//
// This file is part of MinIO Object Storage stack
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package sql
import (
"encoding/json"
"errors"
"fmt"
"math"
"reflect"
"strconv"
"strings"
"time"
"unicode/utf8"
)
var (
errArithMismatchedTypes = errors.New("cannot perform arithmetic on mismatched types")
errArithInvalidOperator = errors.New("invalid arithmetic operator")
errArithDivideByZero = errors.New("cannot divide by 0")
errCmpMismatchedTypes = errors.New("cannot compare values of different types")
errCmpInvalidBoolOperator = errors.New("invalid comparison operator for boolean arguments")
)
// Value represents a value of restricted type reduced from an
// expression represented by an ASTNode. Only one of the fields is
// non-nil.
//
// In cases where we are fetching data from a data source (like csv),
// the type may not be determined yet. In these cases, a byte-slice is
// used.
type Value struct {
value interface{}
}
// MarshalJSON provides json marshaling of values.
func (v Value) MarshalJSON() ([]byte, error) {
if b, ok := v.ToBytes(); ok {
return b, nil
}
return json.Marshal(v.value)
}
// GetTypeString returns a string representation for vType
func (v Value) GetTypeString() string {
switch v.value.(type) {
case nil:
return "NULL"
case bool:
return "BOOL"
case string:
return "STRING"
case int64:
return "INT"
case float64:
return "FLOAT"
case time.Time:
return "TIMESTAMP"
case []byte:
return "BYTES"
case []Value:
return "ARRAY"
}
return "--"
}
// Repr returns a string representation of value.
func (v Value) Repr() string {
switch x := v.value.(type) {
case nil:
return ":NULL"
case bool, int64, float64:
return fmt.Sprintf("%v:%s", v.value, v.GetTypeString())
case time.Time:
return fmt.Sprintf("%s:TIMESTAMP", x)
case string:
return fmt.Sprintf("\"%s\":%s", x, v.GetTypeString())
case []byte:
return fmt.Sprintf("\"%s\":BYTES", string(x))
case []Value:
var s strings.Builder
s.WriteByte('[')
for i, v := range x {
s.WriteString(v.Repr())
if i < len(x)-1 {
s.WriteByte(',')
}
}
s.WriteString("]:ARRAY")
return s.String()
default:
return fmt.Sprintf("%v:INVALID", v.value)
}
}
// FromFloat creates a Value from a number
func FromFloat(f float64) *Value {
return &Value{value: f}
}
// FromInt creates a Value from an int
func FromInt(f int64) *Value {
return &Value{value: f}
}
// FromString creates a Value from a string
func FromString(str string) *Value {
return &Value{value: str}
}
// FromBool creates a Value from a bool
func FromBool(b bool) *Value {
return &Value{value: b}
}
// FromTimestamp creates a Value from a timestamp
func FromTimestamp(t time.Time) *Value {
return &Value{value: t}
}
// FromNull creates a Value with Null value
func FromNull() *Value {
return &Value{value: nil}
}
// FromBytes creates a Value from a []byte
func FromBytes(b []byte) *Value {
return &Value{value: b}
}
// FromArray creates a Value from an array of values.
func FromArray(a []Value) *Value {
return &Value{value: a}
}
// ToFloat works for int and float values
func (v Value) ToFloat() (val float64, ok bool) {
switch x := v.value.(type) {
case float64:
return x, true
case int64:
return float64(x), true
}
return 0, false
}
// ToInt returns the value if int.
func (v Value) ToInt() (val int64, ok bool) {
val, ok = v.value.(int64)
return
}
// ToString returns the value if string.
func (v Value) ToString() (val string, ok bool) {
val, ok = v.value.(string)
return
}
// Equals returns whether the values strictly match.
// Both type and value must match.
func (v Value) Equals(b Value) (ok bool) {
if !v.SameTypeAs(b) {
return false
}
return reflect.DeepEqual(v.value, b.value)
}
// SameTypeAs return whether the two types are strictly the same.
func (v Value) SameTypeAs(b Value) (ok bool) {
switch v.value.(type) {
case bool:
_, ok = b.value.(bool)
case string:
_, ok = b.value.(string)
case int64:
_, ok = b.value.(int64)
case float64:
_, ok = b.value.(float64)
case time.Time:
_, ok = b.value.(time.Time)
case []byte:
_, ok = b.value.([]byte)
case []Value:
_, ok = b.value.([]Value)
default:
ok = reflect.TypeOf(v.value) == reflect.TypeOf(b.value)
}
return ok
}
// ToBool returns the bool value; second return value refers to if the bool
// conversion succeeded.
func (v Value) ToBool() (val bool, ok bool) {
val, ok = v.value.(bool)
return
}
// ToTimestamp returns the timestamp value if present.
func (v Value) ToTimestamp() (t time.Time, ok bool) {
t, ok = v.value.(time.Time)
return
}
// ToBytes returns the value if byte-slice.
func (v Value) ToBytes() (val []byte, ok bool) {
val, ok = v.value.([]byte)
return
}
// ToArray returns the value if it is a slice of values.
func (v Value) ToArray() (val []Value, ok bool) {
val, ok = v.value.([]Value)
return
}
// IsNull - checks if value is missing.
func (v Value) IsNull() bool {
switch v.value.(type) {
case nil:
return true
}
return false
}
// IsArray returns whether the value is an array.
func (v Value) IsArray() (ok bool) {
_, ok = v.value.([]Value)
return ok
}
func (v Value) isNumeric() bool {
switch v.value.(type) {
case int64, float64:
return true
}
return false
}
// setters used internally to mutate values
func (v *Value) setInt(i int64) {
v.value = i
}
func (v *Value) setFloat(f float64) {
v.value = f
}
func (v *Value) setString(s string) {
v.value = s
}
func (v *Value) setBool(b bool) {
v.value = b
}
func (v *Value) setTimestamp(t time.Time) {
v.value = t
}
func (v Value) String() string {
return fmt.Sprintf("%#v", v.value)
}
// CSVString - convert to string for CSV serialization
func (v Value) CSVString() string {
switch x := v.value.(type) {
case nil:
return ""
case bool:
if x {
return "true"
}
return "false"
case string:
return x
case int64:
return strconv.FormatInt(x, 10)
case float64:
return strconv.FormatFloat(x, 'g', -1, 64)
case time.Time:
return FormatSQLTimestamp(x)
case []byte:
return string(x)
case []Value:
b, _ := json.Marshal(x)
return string(b)
default:
return "CSV serialization not implemented for this type"
}
}
// negate negates a numeric value
func (v *Value) negate() {
switch x := v.value.(type) {
case float64:
v.value = -x
case int64:
v.value = -x
}
}
// Value comparison functions: we do not expose them outside the
// module. Logical operators "<", ">", ">=", "<=" work on strings and
// numbers. Equality operators "=", "!=" work on strings,
// numbers and booleans.
// Supported comparison operators
const (
opLt = "<"
opLte = "<="
opGt = ">"
opGte = ">="
opEq = "="
opIneq = "!="
)
// InferBytesType will attempt to infer the data type of bytes.
// Will fail if value type is not bytes or it would result in invalid utf8.
// ORDER: int, float, bool, JSON (object or array), timestamp, string
// If the content is valid JSON, the type will still be bytes.
func (v *Value) InferBytesType() (err error) {
b, ok := v.ToBytes()
if !ok {
return fmt.Errorf("InferByteType: Input is not bytes, but %v", v.GetTypeString())
}
// Check for numeric inference
if x, ok := v.bytesToInt(); ok {
v.setInt(x)
return nil
}
if x, ok := v.bytesToFloat(); ok {
v.setFloat(x)
return nil
}
if x, ok := v.bytesToBool(); ok {
v.setBool(x)
return nil
}
asString := strings.TrimSpace(v.bytesToString())
if len(b) > 0 &&
(strings.HasPrefix(asString, "{") || strings.HasPrefix(asString, "[")) {
return nil
}
if t, err := parseSQLTimestamp(asString); err == nil {
v.setTimestamp(t)
return nil
}
if !utf8.Valid(b) {
return errors.New("value is not valid utf-8")
}
// Fallback to string
v.setString(asString)
return
}
// When numeric types are compared, type promotions could happen. If
// values do not have types (e.g. when reading from CSV), for
// comparison operations, automatic type conversion happens by trying
// to check if the value is a number (first an integer, then a float),
// and falling back to string.
func (v *Value) compareOp(op string, a *Value) (res bool, err error) {
if !isValidComparisonOperator(op) {
return false, errArithInvalidOperator
}
// Check if type conversion/inference is needed - it is needed
// if the Value is a byte-slice.
err = inferTypesForCmp(v, a)
if err != nil {
return false, err
}
// Check if either is nil
if v.IsNull() || a.IsNull() {
// If one is, both must be.
return boolCompare(op, v.IsNull(), a.IsNull())
}
// Check array values
aArr, aOK := a.ToArray()
vArr, vOK := v.ToArray()
if aOK && vOK {
return arrayCompare(op, aArr, vArr)
}
isNumeric := v.isNumeric() && a.isNumeric()
if isNumeric {
intV, ok1i := v.ToInt()
intA, ok2i := a.ToInt()
if ok1i && ok2i {
return intCompare(op, intV, intA), nil
}
// If both values are numeric, then at least one is
// float since we got here, so we convert.
flV, _ := v.ToFloat()
flA, _ := a.ToFloat()
return floatCompare(op, flV, flA), nil
}
strV, ok1s := v.ToString()
strA, ok2s := a.ToString()
if ok1s && ok2s {
return stringCompare(op, strV, strA), nil
}
boolV, ok1b := v.ToBool()
boolA, ok2b := a.ToBool()
if ok1b && ok2b {
return boolCompare(op, boolV, boolA)
}
timestampV, ok1t := v.ToTimestamp()
timestampA, ok2t := a.ToTimestamp()
if ok1t && ok2t {
return timestampCompare(op, timestampV, timestampA), nil
}
// Types cannot be compared, they do not match.
switch op {
case opEq:
return false, nil
case opIneq:
return true, nil
}
return false, errCmpInvalidBoolOperator
}
func inferTypesForCmp(a *Value, b *Value) error {
_, okA := a.ToBytes()
_, okB := b.ToBytes()
switch {
case !okA && !okB:
// Both Values already have types
return nil
case okA && okB:
// Both Values are untyped so try the types in order:
// int, float, bool, string
// Check for numeric inference
iA, okAi := a.bytesToInt()
iB, okBi := b.bytesToInt()
if okAi && okBi {
a.setInt(iA)
b.setInt(iB)
return nil
}
fA, okAf := a.bytesToFloat()
fB, okBf := b.bytesToFloat()
if okAf && okBf {
a.setFloat(fA)
b.setFloat(fB)
return nil
}
// Check if they int and float combination.
if okAi && okBf {
a.setInt(iA)
b.setFloat(fA)
return nil
}
if okBi && okAf {
a.setFloat(fA)
b.setInt(iB)
return nil
}
// Not numeric types at this point.
// Check for bool inference
bA, okAb := a.bytesToBool()
bB, okBb := b.bytesToBool()
if okAb && okBb {
a.setBool(bA)
b.setBool(bB)
return nil
}
// Fallback to string
sA := a.bytesToString()
sB := b.bytesToString()
a.setString(sA)
b.setString(sB)
return nil
case okA && !okB:
// Here a has `a` is untyped, but `b` has a fixed
// type.
switch b.value.(type) {
case string:
s := a.bytesToString()
a.setString(s)
case int64, float64:
if iA, ok := a.bytesToInt(); ok {
a.setInt(iA)
} else if fA, ok := a.bytesToFloat(); ok {
a.setFloat(fA)
} else {
return fmt.Errorf("Could not convert %s to a number", a.String())
}
case bool:
if bA, ok := a.bytesToBool(); ok {
a.setBool(bA)
} else {
return fmt.Errorf("Could not convert %s to a boolean", a.String())
}
default:
return errCmpMismatchedTypes
}
return nil
case !okA && okB:
// swap arguments to avoid repeating code
return inferTypesForCmp(b, a)
default:
// Does not happen
return nil
}
}
// Value arithmetic functions: we do not expose them outside the
// module. All arithmetic works only on numeric values with automatic
// promotion to the "larger" type that can represent the value. TODO:
// Add support for large number arithmetic.
// Supported arithmetic operators
const (
opPlus = "+"
opMinus = "-"
opDivide = "/"
opMultiply = "*"
opModulo = "%"
)
// For arithmetic operations, if both values are numeric then the
// operation shall succeed. If the types are unknown automatic type
// conversion to a number is attempted.
func (v *Value) arithOp(op string, a *Value) error {
err := inferTypeForArithOp(v)
if err != nil {
return err
}
err = inferTypeForArithOp(a)
if err != nil {
return err
}
if !v.isNumeric() || !a.isNumeric() {
return errInvalidDataType(errArithMismatchedTypes)
}
if !isValidArithOperator(op) {
return errInvalidDataType(errArithMismatchedTypes)
}
intV, ok1i := v.ToInt()
intA, ok2i := a.ToInt()
switch {
case ok1i && ok2i:
res, err := intArithOp(op, intV, intA)
v.setInt(res)
return err
default:
// Convert arguments to float
flV, _ := v.ToFloat()
flA, _ := a.ToFloat()
res, err := floatArithOp(op, flV, flA)
v.setFloat(res)
return err
}
}
func inferTypeForArithOp(a *Value) error {
if _, ok := a.ToBytes(); !ok {
return nil
}
if i, ok := a.bytesToInt(); ok {
a.setInt(i)
return nil
}
if f, ok := a.bytesToFloat(); ok {
a.setFloat(f)
return nil
}
err := fmt.Errorf("Could not convert %q to a number", string(a.value.([]byte)))
return errInvalidDataType(err)
}
// All the bytesTo* functions defined below assume the value is a byte-slice.
// Converts untyped value into int. The bool return implies success -
// it returns false only if there is a conversion failure.
func (v Value) bytesToInt() (int64, bool) {
bytes, _ := v.ToBytes()
i, err := strconv.ParseInt(strings.TrimSpace(string(bytes)), 10, 64)
return i, err == nil
}
// Converts untyped value into float. The bool return implies success
// - it returns false only if there is a conversion failure.
func (v Value) bytesToFloat() (float64, bool) {
bytes, _ := v.ToBytes()
i, err := strconv.ParseFloat(strings.TrimSpace(string(bytes)), 64)
return i, err == nil
}
// Converts untyped value into bool. The second bool return implies
// success - it returns false in case of a conversion failure.
func (v Value) bytesToBool() (val bool, ok bool) {
bytes, _ := v.ToBytes()
ok = true
switch strings.ToLower(strings.TrimSpace(string(bytes))) {
case "t", "true", "1":
val = true
case "f", "false", "0":
val = false
default:
ok = false
}
return val, ok
}
// bytesToString - never fails, but returns empty string if value is not bytes.
func (v Value) bytesToString() string {
bytes, _ := v.ToBytes()
return string(bytes)
}
// Calculates minimum or maximum of v and a and assigns the result to
// v - it works only on numeric arguments, where `v` is already
// assumed to be numeric. Attempts conversion to numeric type for `a`
// (first int, then float) only if the underlying values do not have a
// type.
func (v *Value) minmax(a *Value, isMax, isFirstRow bool) error {
err := inferTypeForArithOp(a)
if err != nil {
return err
}
if !a.isNumeric() {
return errArithMismatchedTypes
}
// In case of first row, set v to a.
if isFirstRow {
intA, okI := a.ToInt()
if okI {
v.setInt(intA)
return nil
}
floatA, _ := a.ToFloat()
v.setFloat(floatA)
return nil
}
intV, ok1i := v.ToInt()
intA, ok2i := a.ToInt()
if ok1i && ok2i {
result := intV
if !isMax {
if intA < result {
result = intA
}
} else {
if intA > result {
result = intA
}
}
v.setInt(result)
return nil
}
floatV, _ := v.ToFloat()
floatA, _ := a.ToFloat()
var result float64
if !isMax {
result = math.Min(floatV, floatA)
} else {
result = math.Max(floatV, floatA)
}
v.setFloat(result)
return nil
}
func inferTypeAsTimestamp(v *Value) error {
if s, ok := v.ToString(); ok {
t, err := parseSQLTimestamp(s)
if err != nil {
return err
}
v.setTimestamp(t)
} else if b, ok := v.ToBytes(); ok {
s := string(b)
t, err := parseSQLTimestamp(s)
if err != nil {
return err
}
v.setTimestamp(t)
}
return nil
}
// inferTypeAsString is used to convert untyped values to string - it
// is called when the caller requires a string context to proceed.
func inferTypeAsString(v *Value) {
b, ok := v.ToBytes()
if !ok {
return
}
v.setString(string(b))
}
func isValidComparisonOperator(op string) bool {
switch op {
case opLt:
case opLte:
case opGt:
case opGte:
case opEq:
case opIneq:
default:
return false
}
return true
}
func intCompare(op string, left, right int64) bool {
switch op {
case opLt:
return left < right
case opLte:
return left <= right
case opGt:
return left > right
case opGte:
return left >= right
case opEq:
return left == right
case opIneq:
return left != right
}
// This case does not happen
return false
}
func floatCompare(op string, left, right float64) bool {
diff := math.Abs(left - right)
switch op {
case opLt:
return left < right
case opLte:
return left <= right
case opGt:
return left > right
case opGte:
return left >= right
case opEq:
return diff < floatCmpTolerance
case opIneq:
return diff > floatCmpTolerance
}
// This case does not happen
return false
}
func stringCompare(op string, left, right string) bool {
switch op {
case opLt:
return left < right
case opLte:
return left <= right
case opGt:
return left > right
case opGte:
return left >= right
case opEq:
return left == right
case opIneq:
return left != right
}
// This case does not happen
return false
}
func boolCompare(op string, left, right bool) (bool, error) {
switch op {
case opEq:
return left == right, nil
case opIneq:
return left != right, nil
default:
return false, errCmpInvalidBoolOperator
}
}
func arrayCompare(op string, left, right []Value) (bool, error) {
switch op {
case opEq:
if len(left) != len(right) {
return false, nil
}
for i, l := range left {
eq, err := l.compareOp(op, &right[i])
if !eq || err != nil {
return eq, err
}
}
return true, nil
case opIneq:
for i, l := range left {
eq, err := l.compareOp(op, &right[i])
if eq || err != nil {
return eq, err
}
}
return false, nil
default:
return false, errCmpInvalidBoolOperator
}
}
func isValidArithOperator(op string) bool {
switch op {
case opPlus:
case opMinus:
case opDivide:
case opMultiply:
case opModulo:
default:
return false
}
return true
}
// Overflow errors are ignored.
func intArithOp(op string, left, right int64) (int64, error) {
switch op {
case opPlus:
return left + right, nil
case opMinus:
return left - right, nil
case opDivide:
if right == 0 {
return 0, errArithDivideByZero
}
return left / right, nil
case opMultiply:
return left * right, nil
case opModulo:
if right == 0 {
return 0, errArithDivideByZero
}
return left % right, nil
}
// This does not happen
return 0, nil
}
// Overflow errors are ignored.
func floatArithOp(op string, left, right float64) (float64, error) {
switch op {
case opPlus:
return left + right, nil
case opMinus:
return left - right, nil
case opDivide:
if right == 0 {
return 0, errArithDivideByZero
}
return left / right, nil
case opMultiply:
return left * right, nil
case opModulo:
if right == 0 {
return 0, errArithDivideByZero
}
return math.Mod(left, right), nil
}
// This does not happen
return 0, nil
}

View File

@@ -0,0 +1,38 @@
/*
* MinIO Object Storage (c) 2021 MinIO, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import "time"
func timestampCompare(op string, left, right time.Time) bool {
switch op {
case opLt:
return left.Before(right)
case opLte:
return left.Before(right) || left.Equal(right)
case opGt:
return left.After(right)
case opGte:
return left.After(right) || left.Equal(right)
case opEq:
return left.Equal(right)
case opIneq:
return !left.Equal(right)
}
// This case does not happen
return false
}

View File

@@ -0,0 +1,687 @@
// Copyright (c) 2015-2021 MinIO, Inc.
//
// This file is part of MinIO Object Storage stack
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package sql
import (
"fmt"
"math"
"strconv"
"testing"
"time"
)
// valueBuilders contains one constructor for each value type.
// Values should match if type is the same.
var valueBuilders = []func() *Value{
func() *Value {
return FromNull()
},
func() *Value {
return FromBool(true)
},
func() *Value {
return FromBytes([]byte("byte contents"))
},
func() *Value {
return FromFloat(math.Pi)
},
func() *Value {
return FromInt(0x1337)
},
func() *Value {
t, err := time.Parse(time.RFC3339, "2006-01-02T15:04:05Z")
if err != nil {
panic(err)
}
return FromTimestamp(t)
},
func() *Value {
return FromString("string contents")
},
}
// altValueBuilders contains one constructor for each value type.
// Values are zero values and should NOT match the values in valueBuilders, except Null type.
var altValueBuilders = []func() *Value{
func() *Value {
return FromNull()
},
func() *Value {
return FromBool(false)
},
func() *Value {
return FromBytes(nil)
},
func() *Value {
return FromFloat(0)
},
func() *Value {
return FromInt(0)
},
func() *Value {
return FromTimestamp(time.Time{})
},
func() *Value {
return FromString("")
},
}
func TestValue_SameTypeAs(t *testing.T) {
type fields struct {
a, b Value
}
type test struct {
name string
fields fields
wantOk bool
}
var tests []test
for i := range valueBuilders {
a := valueBuilders[i]()
for j := range valueBuilders {
b := valueBuilders[j]()
tests = append(tests, test{
name: fmt.Sprint(a.GetTypeString(), "==", b.GetTypeString()),
fields: fields{
a: *a, b: *b,
},
wantOk: i == j,
})
}
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if gotOk := tt.fields.a.SameTypeAs(tt.fields.b); gotOk != tt.wantOk {
t.Errorf("SameTypeAs() = %v, want %v", gotOk, tt.wantOk)
}
})
}
}
func TestValue_Equals(t *testing.T) {
type fields struct {
a, b Value
}
type test struct {
name string
fields fields
wantOk bool
}
var tests []test
for i := range valueBuilders {
a := valueBuilders[i]()
for j := range valueBuilders {
b := valueBuilders[j]()
tests = append(tests, test{
name: fmt.Sprint(a.GetTypeString(), "==", b.GetTypeString()),
fields: fields{
a: *a, b: *b,
},
wantOk: i == j,
})
}
}
for i := range valueBuilders {
a := valueBuilders[i]()
for j := range altValueBuilders {
b := altValueBuilders[j]()
tests = append(tests, test{
name: fmt.Sprint(a.GetTypeString(), "!=", b.GetTypeString()),
fields: fields{
a: *a, b: *b,
},
// Only Null == Null
wantOk: a.IsNull() && b.IsNull() && i == 0 && j == 0,
})
}
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if gotOk := tt.fields.a.Equals(tt.fields.b); gotOk != tt.wantOk {
t.Errorf("Equals() = %v, want %v", gotOk, tt.wantOk)
}
})
}
}
func TestValue_CSVString(t *testing.T) {
type test struct {
name string
want string
wantAlt string
}
tests := []test{
{
name: valueBuilders[0]().String(),
want: "",
wantAlt: "",
},
{
name: valueBuilders[1]().String(),
want: "true",
wantAlt: "false",
},
{
name: valueBuilders[2]().String(),
want: "byte contents",
wantAlt: "",
},
{
name: valueBuilders[3]().String(),
want: "3.141592653589793",
wantAlt: "0",
},
{
name: valueBuilders[4]().String(),
want: "4919",
wantAlt: "0",
},
{
name: valueBuilders[5]().String(),
want: "2006-01-02T15:04:05Z",
wantAlt: "0001T",
},
{
name: valueBuilders[6]().String(),
want: "string contents",
wantAlt: "",
},
}
for i, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v := valueBuilders[i]()
vAlt := altValueBuilders[i]()
if got := v.CSVString(); got != tt.want {
t.Errorf("CSVString() = %v, want %v", got, tt.want)
}
if got := vAlt.CSVString(); got != tt.wantAlt {
t.Errorf("CSVString() = %v, want %v", got, tt.wantAlt)
}
})
}
}
func TestValue_bytesToInt(t *testing.T) {
type fields struct {
value interface{}
}
tests := []struct {
name string
fields fields
want int64
wantOK bool
}{
{
name: "zero",
fields: fields{
value: []byte("0"),
},
want: 0,
wantOK: true,
},
{
name: "minuszero",
fields: fields{
value: []byte("-0"),
},
want: 0,
wantOK: true,
},
{
name: "one",
fields: fields{
value: []byte("1"),
},
want: 1,
wantOK: true,
},
{
name: "minusone",
fields: fields{
value: []byte("-1"),
},
want: -1,
wantOK: true,
},
{
name: "plusone",
fields: fields{
value: []byte("+1"),
},
want: 1,
wantOK: true,
},
{
name: "max",
fields: fields{
value: []byte(strconv.FormatInt(math.MaxInt64, 10)),
},
want: math.MaxInt64,
wantOK: true,
},
{
name: "min",
fields: fields{
value: []byte(strconv.FormatInt(math.MinInt64, 10)),
},
want: math.MinInt64,
wantOK: true,
},
{
name: "max-overflow",
fields: fields{
value: []byte("9223372036854775808"),
},
// Seems to be what strconv.ParseInt returns
want: math.MaxInt64,
wantOK: false,
},
{
name: "min-underflow",
fields: fields{
value: []byte("-9223372036854775809"),
},
// Seems to be what strconv.ParseInt returns
want: math.MinInt64,
wantOK: false,
},
{
name: "zerospace",
fields: fields{
value: []byte(" 0"),
},
want: 0,
wantOK: true,
},
{
name: "onespace",
fields: fields{
value: []byte("1 "),
},
want: 1,
wantOK: true,
},
{
name: "minusonespace",
fields: fields{
value: []byte(" -1 "),
},
want: -1,
wantOK: true,
},
{
name: "plusonespace",
fields: fields{
value: []byte("\t+1\t"),
},
want: 1,
wantOK: true,
},
{
name: "scientific",
fields: fields{
value: []byte("3e5"),
},
want: 0,
wantOK: false,
},
{
// No support for prefixes
name: "hex",
fields: fields{
value: []byte("0xff"),
},
want: 0,
wantOK: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v := &Value{
value: tt.fields.value,
}
got, got1 := v.bytesToInt()
if got != tt.want {
t.Errorf("bytesToInt() got = %v, want %v", got, tt.want)
}
if got1 != tt.wantOK {
t.Errorf("bytesToInt() got1 = %v, want %v", got1, tt.wantOK)
}
})
}
}
func TestValue_bytesToFloat(t *testing.T) {
type fields struct {
value interface{}
}
tests := []struct {
name string
fields fields
want float64
wantOK bool
}{
// Copied from TestValue_bytesToInt.
{
name: "zero",
fields: fields{
value: []byte("0"),
},
want: 0,
wantOK: true,
},
{
name: "minuszero",
fields: fields{
value: []byte("-0"),
},
want: 0,
wantOK: true,
},
{
name: "one",
fields: fields{
value: []byte("1"),
},
want: 1,
wantOK: true,
},
{
name: "minusone",
fields: fields{
value: []byte("-1"),
},
want: -1,
wantOK: true,
},
{
name: "plusone",
fields: fields{
value: []byte("+1"),
},
want: 1,
wantOK: true,
},
{
name: "maxint",
fields: fields{
value: []byte(strconv.FormatInt(math.MaxInt64, 10)),
},
want: math.MaxInt64,
wantOK: true,
},
{
name: "minint",
fields: fields{
value: []byte(strconv.FormatInt(math.MinInt64, 10)),
},
want: math.MinInt64,
wantOK: true,
},
{
name: "max-overflow-int",
fields: fields{
value: []byte("9223372036854775808"),
},
// Seems to be what strconv.ParseInt returns
want: math.MaxInt64,
wantOK: true,
},
{
name: "min-underflow-int",
fields: fields{
value: []byte("-9223372036854775809"),
},
// Seems to be what strconv.ParseInt returns
want: math.MinInt64,
wantOK: true,
},
{
name: "max",
fields: fields{
value: []byte(strconv.FormatFloat(math.MaxFloat64, 'g', -1, 64)),
},
want: math.MaxFloat64,
wantOK: true,
},
{
name: "min",
fields: fields{
value: []byte(strconv.FormatFloat(-math.MaxFloat64, 'g', -1, 64)),
},
want: -math.MaxFloat64,
wantOK: true,
},
{
name: "max-overflow",
fields: fields{
value: []byte("1.797693134862315708145274237317043567981e+309"),
},
// Seems to be what strconv.ParseInt returns
want: math.Inf(1),
wantOK: false,
},
{
name: "min-underflow",
fields: fields{
value: []byte("-1.797693134862315708145274237317043567981e+309"),
},
// Seems to be what strconv.ParseInt returns
want: math.Inf(-1),
wantOK: false,
},
{
name: "smallest-pos",
fields: fields{
value: []byte(strconv.FormatFloat(math.SmallestNonzeroFloat64, 'g', -1, 64)),
},
want: math.SmallestNonzeroFloat64,
wantOK: true,
},
{
name: "smallest-pos",
fields: fields{
value: []byte(strconv.FormatFloat(-math.SmallestNonzeroFloat64, 'g', -1, 64)),
},
want: -math.SmallestNonzeroFloat64,
wantOK: true,
},
{
name: "zerospace",
fields: fields{
value: []byte(" 0"),
},
want: 0,
wantOK: true,
},
{
name: "onespace",
fields: fields{
value: []byte("1 "),
},
want: 1,
wantOK: true,
},
{
name: "minusonespace",
fields: fields{
value: []byte(" -1 "),
},
want: -1,
wantOK: true,
},
{
name: "plusonespace",
fields: fields{
value: []byte("\t+1\t"),
},
want: 1,
wantOK: true,
},
{
name: "scientific",
fields: fields{
value: []byte("3e5"),
},
want: 300000,
wantOK: true,
},
{
// No support for prefixes
name: "hex",
fields: fields{
value: []byte("0xff"),
},
want: 0,
wantOK: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v := Value{
value: tt.fields.value,
}
got, got1 := v.bytesToFloat()
diff := math.Abs(got - tt.want)
if diff > floatCmpTolerance {
t.Errorf("bytesToFloat() got = %v, want %v", got, tt.want)
}
if got1 != tt.wantOK {
t.Errorf("bytesToFloat() got1 = %v, want %v", got1, tt.wantOK)
}
})
}
}
func TestValue_bytesToBool(t *testing.T) {
type fields struct {
value interface{}
}
tests := []struct {
name string
fields fields
wantVal bool
wantOk bool
}{
{
name: "true",
fields: fields{
value: []byte("true"),
},
wantVal: true,
wantOk: true,
},
{
name: "false",
fields: fields{
value: []byte("false"),
},
wantVal: false,
wantOk: true,
},
{
name: "t",
fields: fields{
value: []byte("t"),
},
wantVal: true,
wantOk: true,
},
{
name: "f",
fields: fields{
value: []byte("f"),
},
wantVal: false,
wantOk: true,
},
{
name: "1",
fields: fields{
value: []byte("1"),
},
wantVal: true,
wantOk: true,
},
{
name: "0",
fields: fields{
value: []byte("0"),
},
wantVal: false,
wantOk: true,
},
{
name: "truespace",
fields: fields{
value: []byte(" true "),
},
wantVal: true,
wantOk: true,
},
{
name: "truetabs",
fields: fields{
value: []byte("\ttrue\t"),
},
wantVal: true,
wantOk: true,
},
{
name: "TRUE",
fields: fields{
value: []byte("TRUE"),
},
wantVal: true,
wantOk: true,
},
{
name: "FALSE",
fields: fields{
value: []byte("FALSE"),
},
wantVal: false,
wantOk: true,
},
{
name: "invalid",
fields: fields{
value: []byte("no"),
},
wantVal: false,
wantOk: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v := Value{
value: tt.fields.value,
}
gotVal, gotOk := v.bytesToBool()
if gotVal != tt.wantVal {
t.Errorf("bytesToBool() gotVal = %v, want %v", gotVal, tt.wantVal)
}
if gotOk != tt.wantOk {
t.Errorf("bytesToBool() gotOk = %v, want %v", gotOk, tt.wantOk)
}
})
}
}