Add new SQL parser to support S3 Select syntax (#7102)

- New parser written from scratch, allows easier and complete parsing
  of the full S3 Select SQL syntax. Parser definition is directly
  provided by the AST defined for the SQL grammar.

- Bring support to parse and interpret SQL involving JSON path
  expressions; evaluation of JSON path expressions will be
  subsequently added.

- Bring automatic type inference and conversion for untyped
  values (e.g. CSV data).
This commit is contained in:
Aditya Manthramurthy 2019-01-28 17:59:48 -08:00 committed by Harshavardhana
parent 0a28c28a8c
commit 2786055df4
65 changed files with 6405 additions and 18231 deletions

View File

@ -185,6 +185,12 @@ func (api objectAPIHandlers) SelectObjectContentHandler(w http.ResponseWriter, r
return getObjectNInfo(ctx, bucket, object, rs, r.Header, readLock, ObjectOptions{})
}
objInfo, err := getObjectInfo(ctx, bucket, object, opts)
if err != nil {
writeErrorResponse(w, toAPIErrorCode(ctx, err), r.URL, guessIsBrowserReq(r))
return
}
if err = s3Select.Open(getObject); err != nil {
if serr, ok := err.(s3select.SelectError); ok {
w.WriteHeader(serr.HTTPStatusCode())
@ -198,12 +204,6 @@ func (api objectAPIHandlers) SelectObjectContentHandler(w http.ResponseWriter, r
s3Select.Evaluate(w)
s3Select.Close()
objInfo, err := getObjectInfo(ctx, bucket, object, opts)
if err != nil {
logger.LogIf(ctx, err)
return
}
// Get host and port from Request.RemoteAddr.
host, port, err := net.SplitHostPort(handlers.GetSourceIP(r))
if err != nil {

View File

@ -3,6 +3,15 @@ Traditional retrieval of objects is always as whole entities, i.e GetObject for
> This implementation is compatible with AWS S3 Select API
### Implemention status:
- Full S3 SQL syntax is supported
- All aggregation, conditional, type-conversion and strings SQL functions are supported
- JSONPath expressions are not yet evaluated
- Large numbers (more than 64-bit) are not yet supported
- Date related functions are not yet supported (EXTRACT, DATE_DIFF, etc)
- S3's reserved keywords list is not yet respected
## 1. Prerequisites
- Install Minio Server from [here](http://docs.minio.io/docs/minio-quickstart-guide).
- Familiarity with AWS S3 API

View File

@ -32,7 +32,11 @@ type Record struct {
nameIndexMap map[string]int64
}
// Get - gets the value for a column name.
// Get - gets the value for a column name. CSV fields do not have any
// defined type (other than the default string). So this function
// always returns fields using sql.FromBytes so that the type
// specified/implied by the query can be used, or can be automatically
// converted based on the query.
func (r *Record) Get(name string) (*sql.Value, error) {
index, found := r.nameIndexMap[name]
if !found {
@ -40,11 +44,12 @@ func (r *Record) Get(name string) (*sql.Value, error) {
}
if index >= int64(len(r.csvRecord)) {
// No value found for column 'name', hence return empty string for compatibility.
return sql.NewString(""), nil
// No value found for column 'name', hence return null
// value
return sql.FromNull(), nil
}
return sql.NewString(r.csvRecord[index]), nil
return sql.FromBytes([]byte(r.csvRecord[index])), nil
}
// Set - sets the value for a column name.

View File

@ -37,15 +37,15 @@ func (r *Record) Get(name string) (*sql.Value, error) {
result := gjson.GetBytes(r.data, name)
switch result.Type {
case gjson.Null:
return sql.NewNull(), nil
return sql.FromNull(), nil
case gjson.False:
return sql.NewBool(false), nil
return sql.FromBool(false), nil
case gjson.Number:
return sql.NewFloat(result.Float()), nil
return sql.FromFloat(result.Float()), nil
case gjson.String:
return sql.NewString(result.String()), nil
return sql.FromString(result.String()), nil
case gjson.True:
return sql.NewBool(true), nil
return sql.FromBool(true), nil
}
return nil, fmt.Errorf("unsupported gjson value %v; %v", result, result.Type)
@ -54,19 +54,20 @@ func (r *Record) Get(name string) (*sql.Value, error) {
// Set - sets the value for a column name.
func (r *Record) Set(name string, value *sql.Value) (err error) {
var v interface{}
switch value.Type() {
case sql.Null:
v = value.NullValue()
case sql.Bool:
v = value.BoolValue()
case sql.Int:
v = value.IntValue()
case sql.Float:
v = value.FloatValue()
case sql.String:
v = value.StringValue()
default:
return fmt.Errorf("unsupported sql value %v and type %v", value, value.Type())
if b, ok := value.ToBool(); ok {
v = b
} else if f, ok := value.ToFloat(); ok {
v = f
} else if i, ok := value.ToInt(); ok {
v = i
} else if s, ok := value.ToString(); ok {
v = s
} else if value.IsNull() {
v = nil
} else if b, ok := value.ToBytes(); ok {
v = string(b)
} else {
return fmt.Errorf("unsupported sql value %v and type %v", value, value.GetTypeString())
}
name = strings.Replace(name, "*", "__ALL__", -1)

View File

@ -32,7 +32,7 @@ type Reader struct {
}
// Read - reads single record.
func (r *Reader) Read() (sql.Record, error) {
func (r *Reader) Read() (rec sql.Record, rerr error) {
parquetRecord, err := r.file.Read()
if err != nil {
if err != io.EOF {
@ -43,39 +43,41 @@ func (r *Reader) Read() (sql.Record, error) {
}
record := json.NewRecord()
for name, v := range parquetRecord {
f := func(name string, v parquetgo.Value) bool {
if v.Value == nil {
if err = record.Set(name, sql.NewNull()); err != nil {
return nil, errParquetParsingError(err)
if err := record.Set(name, sql.FromNull()); err != nil {
rerr = errParquetParsingError(err)
}
continue
return rerr == nil
}
var value *sql.Value
switch v.Type {
case parquetgen.Type_BOOLEAN:
value = sql.NewBool(v.Value.(bool))
value = sql.FromBool(v.Value.(bool))
case parquetgen.Type_INT32:
value = sql.NewInt(int64(v.Value.(int32)))
value = sql.FromInt(int64(v.Value.(int32)))
case parquetgen.Type_INT64:
value = sql.NewInt(v.Value.(int64))
value = sql.FromInt(int64(v.Value.(int64)))
case parquetgen.Type_FLOAT:
value = sql.NewFloat(float64(v.Value.(float32)))
value = sql.FromFloat(float64(v.Value.(float32)))
case parquetgen.Type_DOUBLE:
value = sql.NewFloat(v.Value.(float64))
value = sql.FromFloat(v.Value.(float64))
case parquetgen.Type_INT96, parquetgen.Type_BYTE_ARRAY, parquetgen.Type_FIXED_LEN_BYTE_ARRAY:
value = sql.NewString(string(v.Value.([]byte)))
value = sql.FromString(string(v.Value.([]byte)))
default:
return nil, errParquetParsingError(nil)
rerr = errParquetParsingError(nil)
return false
}
if err = record.Set(name, value); err != nil {
return nil, errParquetParsingError(err)
rerr = errParquetParsingError(err)
}
return rerr == nil
}
return record, nil
parquetRecord.Range(f)
return record, rerr
}
// Close - closes underlaying readers.

View File

@ -105,7 +105,7 @@ func (input *InputSerialization) UnmarshalXML(d *xml.Decoder, start xml.StartEle
found++
}
if !parsedInput.ParquetArgs.IsEmpty() {
if parsedInput.CompressionType != noneType {
if parsedInput.CompressionType != "" && parsedInput.CompressionType != noneType {
return errInvalidRequestParameter(fmt.Errorf("CompressionType must be NONE for Parquet format"))
}
@ -178,7 +178,7 @@ type S3Select struct {
Output OutputSerialization `xml:"OutputSerialization"`
Progress RequestProgress `xml:"RequestProgress"`
statement *sql.Select
statement *sql.SelectStatement
progressReader *progressReader
recordReader recordReader
}
@ -209,12 +209,12 @@ func (s3Select *S3Select) UnmarshalXML(d *xml.Decoder, start xml.StartElement) e
return errMissingRequiredParameter(fmt.Errorf("OutputSerialization must be provided"))
}
statement, err := sql.NewSelect(parsedS3Select.Expression)
statement, err := sql.ParseSelectStatement(parsedS3Select.Expression)
if err != nil {
return err
}
parsedS3Select.statement = statement
parsedS3Select.statement = &statement
*s3Select = S3Select(parsedS3Select)
return nil
@ -334,6 +334,14 @@ func (s3Select *S3Select) Evaluate(w http.ResponseWriter) {
}
for {
if s3Select.statement.LimitReached() {
if err = writer.SendStats(s3Select.getProgress()); err != nil {
// FIXME: log this error.
err = nil
}
break
}
if inputRecord, err = s3Select.recordReader.Read(); err != nil {
if err != io.EOF {
break
@ -358,19 +366,25 @@ func (s3Select *S3Select) Evaluate(w http.ResponseWriter) {
break
}
outputRecord = s3Select.outputRecord()
if outputRecord, err = s3Select.statement.Eval(inputRecord, outputRecord); err != nil {
break
}
if s3Select.statement.IsAggregated() {
if err = s3Select.statement.AggregateRow(inputRecord); err != nil {
break
}
} else {
outputRecord = s3Select.outputRecord()
if outputRecord, err = s3Select.statement.Eval(inputRecord, outputRecord); err != nil {
break
}
if !s3Select.statement.IsAggregated() {
if !sendRecord() {
break
}
}
}
if err != nil {
fmt.Printf("SQL Err: %#v\n", err)
if serr := writer.SendError("InternalError", err.Error()); serr != nil {
// FIXME: log errors.
}

View File

@ -0,0 +1,318 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import (
"errors"
"fmt"
)
// Aggregation Function name constants
const (
aggFnAvg FuncName = "AVG"
aggFnCount FuncName = "COUNT"
aggFnMax FuncName = "MAX"
aggFnMin FuncName = "MIN"
aggFnSum FuncName = "SUM"
)
var (
errNonNumericArg = func(fnStr FuncName) error {
return fmt.Errorf("%s() requires a numeric argument", fnStr)
}
errInvalidAggregation = errors.New("Invalid aggregation seen")
)
type aggVal struct {
runningSum *Value
runningCount int64
runningMax, runningMin *Value
// Stores if at least one record has been seen
seen bool
}
func newAggVal(fn FuncName) *aggVal {
switch fn {
case aggFnAvg, aggFnSum:
return &aggVal{runningSum: FromInt(0)}
case aggFnMin:
return &aggVal{runningMin: FromInt(0)}
case aggFnMax:
return &aggVal{runningMax: FromInt(0)}
default:
return &aggVal{}
}
}
// evalAggregationNode - performs partial computation using the
// current row and stores the result.
//
// On success, it returns (nil, nil).
func (e *FuncExpr) evalAggregationNode(r Record) error {
// It is assumed that this function is called only when
// `e` is an aggregation function.
var val *Value
var err error
funcName := e.getFunctionName()
if aggFnCount == funcName {
if e.Count.StarArg {
// Handle COUNT(*)
e.aggregate.runningCount++
return nil
}
val, err = e.Count.ExprArg.evalNode(r)
if err != nil {
return err
}
} else {
// Evaluate the (only) argument
val, err = e.SFunc.ArgsList[0].evalNode(r)
if err != nil {
return err
}
}
if val.IsNull() {
// E.g. the column or field does not exist in the
// record - in all such cases the aggregation is not
// updated.
return nil
}
argVal := val
if funcName != aggFnCount {
// All aggregation functions, except COUNT require a
// numeric argument.
// Here, we diverge from Amazon S3 behavior by
// inferring untyped values are numbers.
if i, ok := argVal.bytesToInt(); ok {
argVal.setInt(i)
} else if f, ok := argVal.bytesToFloat(); ok {
argVal.setFloat(f)
} else {
return errNonNumericArg(funcName)
}
}
// Mark that we have seen one non-null value.
isFirstRow := false
if !e.aggregate.seen {
e.aggregate.seen = true
isFirstRow = true
}
switch funcName {
case aggFnCount:
// For all non-null values, the count is incremented.
e.aggregate.runningCount++
case aggFnAvg:
e.aggregate.runningCount++
err = e.aggregate.runningSum.arithOp(opPlus, argVal)
case aggFnMin:
err = e.aggregate.runningMin.minmax(argVal, false, isFirstRow)
case aggFnMax:
err = e.aggregate.runningMax.minmax(argVal, true, isFirstRow)
case aggFnSum:
err = e.aggregate.runningSum.arithOp(opPlus, argVal)
default:
err = errInvalidAggregation
}
return err
}
func (e *AliasedExpression) aggregateRow(r Record) error {
return e.Expression.aggregateRow(r)
}
func (e *Expression) aggregateRow(r Record) error {
for _, ex := range e.And {
err := ex.aggregateRow(r)
if err != nil {
return err
}
}
return nil
}
func (e *AndCondition) aggregateRow(r Record) error {
for _, ex := range e.Condition {
err := ex.aggregateRow(r)
if err != nil {
return err
}
}
return nil
}
func (e *Condition) aggregateRow(r Record) error {
if e.Operand != nil {
return e.Operand.aggregateRow(r)
}
return e.Not.aggregateRow(r)
}
func (e *ConditionOperand) aggregateRow(r Record) error {
err := e.Operand.aggregateRow(r)
if err != nil {
return err
}
if e.ConditionRHS == nil {
return nil
}
switch {
case e.ConditionRHS.Compare != nil:
return e.ConditionRHS.Compare.Operand.aggregateRow(r)
case e.ConditionRHS.Between != nil:
err = e.ConditionRHS.Between.Start.aggregateRow(r)
if err != nil {
return err
}
return e.ConditionRHS.Between.End.aggregateRow(r)
case e.ConditionRHS.In != nil:
for _, elt := range e.ConditionRHS.In.Expressions {
err = elt.aggregateRow(r)
if err != nil {
return err
}
}
return nil
case e.ConditionRHS.Like != nil:
err = e.ConditionRHS.Like.Pattern.aggregateRow(r)
if err != nil {
return err
}
return e.ConditionRHS.Like.EscapeChar.aggregateRow(r)
default:
return errInvalidASTNode
}
}
func (e *Operand) aggregateRow(r Record) error {
err := e.Left.aggregateRow(r)
if err != nil {
return err
}
for _, rt := range e.Right {
err = rt.Right.aggregateRow(r)
if err != nil {
return err
}
}
return nil
}
func (e *MultOp) aggregateRow(r Record) error {
err := e.Left.aggregateRow(r)
if err != nil {
return err
}
for _, rt := range e.Right {
err = rt.Right.aggregateRow(r)
if err != nil {
return err
}
}
return nil
}
func (e *UnaryTerm) aggregateRow(r Record) error {
if e.Negated != nil {
return e.Negated.Term.aggregateRow(r)
}
return e.Primary.aggregateRow(r)
}
func (e *PrimaryTerm) aggregateRow(r Record) error {
switch {
case e.SubExpression != nil:
return e.SubExpression.aggregateRow(r)
case e.FuncCall != nil:
return e.FuncCall.aggregateRow(r)
}
return nil
}
func (e *FuncExpr) aggregateRow(r Record) error {
switch e.getFunctionName() {
case aggFnAvg, aggFnSum, aggFnMax, aggFnMin, aggFnCount:
return e.evalAggregationNode(r)
default:
// TODO: traverse arguments and call aggregateRow on
// them if they could be an ancestor of an
// aggregation.
}
return nil
}
// getAggregate() implementation for each AST node follows. This is
// called after calling aggregateRow() on each input row, to calculate
// the final aggregate result.
func (e *Expression) getAggregate() (*Value, error) {
return e.evalNode(nil)
}
func (e *FuncExpr) getAggregate() (*Value, error) {
switch e.getFunctionName() {
case aggFnCount:
return FromFloat(float64(e.aggregate.runningCount)), nil
case aggFnAvg:
if e.aggregate.runningCount == 0 {
// No rows were seen by AVG.
return FromNull(), nil
}
err := e.aggregate.runningSum.arithOp(opDivide, FromInt(e.aggregate.runningCount))
return e.aggregate.runningSum, err
case aggFnMin:
if !e.aggregate.seen {
// No rows were seen by MIN
return FromNull(), nil
}
return e.aggregate.runningMin, nil
case aggFnMax:
if !e.aggregate.seen {
// No rows were seen by MAX
return FromNull(), nil
}
return e.aggregate.runningMax, nil
case aggFnSum:
// TODO: check if returning 0 when no rows were seen
// by SUM is expected behavior.
return e.aggregate.runningSum, nil
default:
// TODO:
}
return nil, errInvalidAggregation
}

View File

@ -0,0 +1,290 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import (
"errors"
"fmt"
)
// Query analysis - The query is analyzed to determine if it involves
// aggregation.
//
// Aggregation functions - An expression that involves aggregation of
// rows in some manner. Requires all input rows to be processed,
// before a result is returned.
//
// Row function - An expression that depends on a value in the
// row. They have an output for each input row.
//
// Some types of a queries are not valid. For example, an aggregation
// function combined with a row function is meaningless ("AVG(s.Age) +
// s.Salary"). Analysis determines if such a scenario exists so an
// error can be returned.
var (
// Fatal error for query processing.
errNestedAggregation = errors.New("Cannot nest aggregations")
errFunctionNotImplemented = errors.New("Function is not yet implemented")
errUnexpectedInvalidNode = errors.New("Unexpected node value")
errInvalidKeypath = errors.New("A provided keypath is invalid")
)
// qProp contains analysis info about an SQL term.
type qProp struct {
isAggregation, isRowFunc bool
err error
}
// `combine` combines a pair of `qProp`s, so that errors are
// propagated correctly, and checks that an aggregation is not being
// combined with a row-function term.
func (p *qProp) combine(q qProp) {
switch {
case p.err != nil:
// Do nothing
case q.err != nil:
p.err = q.err
default:
p.isAggregation = p.isAggregation || q.isAggregation
p.isRowFunc = p.isRowFunc || q.isRowFunc
if p.isAggregation && p.isRowFunc {
p.err = errNestedAggregation
}
}
}
func (e *SelectExpression) analyze(s *Select) (result qProp) {
if e.All {
return qProp{isRowFunc: true}
}
for _, ex := range e.Expressions {
result.combine(ex.analyze(s))
}
return
}
func (e *AliasedExpression) analyze(s *Select) qProp {
return e.Expression.analyze(s)
}
func (e *Expression) analyze(s *Select) (result qProp) {
for _, ac := range e.And {
result.combine(ac.analyze(s))
}
return
}
func (e *AndCondition) analyze(s *Select) (result qProp) {
for _, ac := range e.Condition {
result.combine(ac.analyze(s))
}
return
}
func (e *Condition) analyze(s *Select) (result qProp) {
if e.Operand != nil {
result = e.Operand.analyze(s)
} else {
result = e.Not.analyze(s)
}
return
}
func (e *ConditionOperand) analyze(s *Select) (result qProp) {
if e.ConditionRHS == nil {
result = e.Operand.analyze(s)
} else {
result.combine(e.Operand.analyze(s))
result.combine(e.ConditionRHS.analyze(s))
}
return
}
func (e *ConditionRHS) analyze(s *Select) (result qProp) {
switch {
case e.Compare != nil:
result = e.Compare.Operand.analyze(s)
case e.Between != nil:
result.combine(e.Between.Start.analyze(s))
result.combine(e.Between.End.analyze(s))
case e.In != nil:
for _, elt := range e.In.Expressions {
result.combine(elt.analyze(s))
}
case e.Like != nil:
result.combine(e.Like.Pattern.analyze(s))
if e.Like.EscapeChar != nil {
result.combine(e.Like.EscapeChar.analyze(s))
}
default:
result = qProp{err: errUnexpectedInvalidNode}
}
return
}
func (e *Operand) analyze(s *Select) (result qProp) {
result.combine(e.Left.analyze(s))
for _, r := range e.Right {
result.combine(r.Right.analyze(s))
}
return
}
func (e *MultOp) analyze(s *Select) (result qProp) {
result.combine(e.Left.analyze(s))
for _, r := range e.Right {
result.combine(r.Right.analyze(s))
}
return
}
func (e *UnaryTerm) analyze(s *Select) (result qProp) {
if e.Negated != nil {
result = e.Negated.Term.analyze(s)
} else {
result = e.Primary.analyze(s)
}
return
}
func (e *PrimaryTerm) analyze(s *Select) (result qProp) {
switch {
case e.Value != nil:
result = qProp{}
case e.JPathExpr != nil:
// Check if the path expression is valid
if len(e.JPathExpr.PathExpr) > 0 {
if e.JPathExpr.BaseKey.String() != s.From.As {
result = qProp{err: errInvalidKeypath}
return
}
}
result = qProp{isRowFunc: true}
case e.SubExpression != nil:
result = e.SubExpression.analyze(s)
case e.FuncCall != nil:
result = e.FuncCall.analyze(s)
default:
result = qProp{err: errUnexpectedInvalidNode}
}
return
}
func (e *FuncExpr) analyze(s *Select) (result qProp) {
funcName := e.getFunctionName()
switch funcName {
case sqlFnCast:
return e.Cast.Expr.analyze(s)
case sqlFnExtract:
return e.Extract.From.analyze(s)
// Handle aggregation function calls
case aggFnAvg, aggFnMax, aggFnMin, aggFnSum, aggFnCount:
// Initialize accumulator
e.aggregate = newAggVal(funcName)
var exprA qProp
if funcName == aggFnCount {
if e.Count.StarArg {
return qProp{isAggregation: true}
}
exprA = e.Count.ExprArg.analyze(s)
} else {
if len(e.SFunc.ArgsList) != 1 {
return qProp{err: fmt.Errorf("%s takes exactly one argument", funcName)}
}
exprA = e.SFunc.ArgsList[0].analyze(s)
}
if exprA.err != nil {
return exprA
}
if exprA.isAggregation {
return qProp{err: errNestedAggregation}
}
return qProp{isAggregation: true}
case sqlFnCoalesce:
if len(e.SFunc.ArgsList) == 0 {
return qProp{err: fmt.Errorf("%s needs at least one argument", string(funcName))}
}
for _, arg := range e.SFunc.ArgsList {
result.combine(arg.analyze(s))
}
return result
case sqlFnNullIf:
if len(e.SFunc.ArgsList) != 2 {
return qProp{err: fmt.Errorf("%s needs exactly 2 arguments", string(funcName))}
}
for _, arg := range e.SFunc.ArgsList {
result.combine(arg.analyze(s))
}
return result
case sqlFnCharLength, sqlFnCharacterLength:
if len(e.SFunc.ArgsList) != 1 {
return qProp{err: fmt.Errorf("%s needs exactly 2 arguments", string(funcName))}
}
for _, arg := range e.SFunc.ArgsList {
result.combine(arg.analyze(s))
}
return result
case sqlFnLower, sqlFnUpper:
if len(e.SFunc.ArgsList) != 1 {
return qProp{err: fmt.Errorf("%s needs exactly 2 arguments", string(funcName))}
}
for _, arg := range e.SFunc.ArgsList {
result.combine(arg.analyze(s))
}
return result
case sqlFnSubstring:
errVal := fmt.Errorf("Invalid argument(s) to %s", string(funcName))
result.combine(e.Substring.Expr.analyze(s))
switch {
case e.Substring.From != nil:
result.combine(e.Substring.From.analyze(s))
if e.Substring.For != nil {
result.combine(e.Substring.Expr.analyze(s))
}
case e.Substring.Arg2 != nil:
result.combine(e.Substring.Arg2.analyze(s))
if e.Substring.Arg3 != nil {
result.combine(e.Substring.Arg3.analyze(s))
}
default:
result.err = errVal
}
return result
}
// TODO: implement other functions
return qProp{err: errFunctionNotImplemented}
}

View File

@ -1,175 +0,0 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import "fmt"
// ArithOperator - arithmetic operator.
type ArithOperator string
const (
// Add operator '+'.
Add ArithOperator = "+"
// Subtract operator '-'.
Subtract ArithOperator = "-"
// Multiply operator '*'.
Multiply ArithOperator = "*"
// Divide operator '/'.
Divide ArithOperator = "/"
// Modulo operator '%'.
Modulo ArithOperator = "%"
)
// arithExpr - arithmetic function.
type arithExpr struct {
left Expr
right Expr
operator ArithOperator
funcType Type
}
// String - returns string representation of this function.
func (f *arithExpr) String() string {
return fmt.Sprintf("(%v %v %v)", f.left, f.operator, f.right)
}
func (f *arithExpr) compute(lv, rv *Value) (*Value, error) {
leftValueType := lv.Type()
rightValueType := rv.Type()
if !leftValueType.isNumber() {
err := fmt.Errorf("%v: left side expression evaluated to %v; not to number", f, leftValueType)
return nil, errExternalEvalException(err)
}
if !rightValueType.isNumber() {
err := fmt.Errorf("%v: right side expression evaluated to %v; not to number", f, rightValueType)
return nil, errExternalEvalException(err)
}
leftValue := lv.FloatValue()
rightValue := rv.FloatValue()
var result float64
switch f.operator {
case Add:
result = leftValue + rightValue
case Subtract:
result = leftValue - rightValue
case Multiply:
result = leftValue * rightValue
case Divide:
result = leftValue / rightValue
case Modulo:
result = float64(int64(leftValue) % int64(rightValue))
}
if leftValueType == Float || rightValueType == Float {
return NewFloat(result), nil
}
return NewInt(int64(result)), nil
}
// Call - evaluates this function for given arg values and returns result as Value.
func (f *arithExpr) Eval(record Record) (*Value, error) {
leftValue, err := f.left.Eval(record)
if err != nil {
return nil, err
}
rightValue, err := f.right.Eval(record)
if err != nil {
return nil, err
}
if f.funcType == aggregateFunction {
return nil, nil
}
return f.compute(leftValue, rightValue)
}
// AggregateValue - returns aggregated value.
func (f *arithExpr) AggregateValue() (*Value, error) {
if f.funcType != aggregateFunction {
err := fmt.Errorf("%v is not aggreate expression", f)
return nil, errExternalEvalException(err)
}
lv, err := f.left.AggregateValue()
if err != nil {
return nil, err
}
rv, err := f.right.AggregateValue()
if err != nil {
return nil, err
}
return f.compute(lv, rv)
}
// Type - returns arithmeticFunction or aggregateFunction type.
func (f *arithExpr) Type() Type {
return f.funcType
}
// ReturnType - returns Float as return type.
func (f *arithExpr) ReturnType() Type {
return Float
}
// newArithExpr - creates new arithmetic function.
func newArithExpr(operator ArithOperator, left, right Expr) (*arithExpr, error) {
if !left.ReturnType().isNumberKind() {
err := fmt.Errorf("operator %v: left side expression %v evaluate to %v, not number", operator, left, left.ReturnType())
return nil, errInvalidDataType(err)
}
if !right.ReturnType().isNumberKind() {
err := fmt.Errorf("operator %v: right side expression %v evaluate to %v; not number", operator, right, right.ReturnType())
return nil, errInvalidDataType(err)
}
funcType := arithmeticFunction
if left.Type() == aggregateFunction || right.Type() == aggregateFunction {
funcType = aggregateFunction
switch left.Type() {
case Int, Float, aggregateFunction:
default:
err := fmt.Errorf("operator %v: left side expression %v return type %v is incompatible for aggregate evaluation", operator, left, left.Type())
return nil, errUnsupportedSQLOperation(err)
}
switch right.Type() {
case Int, Float, aggregateFunction:
default:
err := fmt.Errorf("operator %v: right side expression %v return type %v is incompatible for aggregate evaluation", operator, right, right.Type())
return nil, errUnsupportedSQLOperation(err)
}
}
return &arithExpr{
left: left,
right: right,
operator: operator,
funcType: funcType,
}, nil
}

View File

@ -1,636 +0,0 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import (
"fmt"
"regexp"
"strings"
)
// ComparisonOperator - comparison operator.
type ComparisonOperator string
const (
// Equal operator '='.
Equal ComparisonOperator = "="
// NotEqual operator '!=' or '<>'.
NotEqual ComparisonOperator = "!="
// LessThan operator '<'.
LessThan ComparisonOperator = "<"
// GreaterThan operator '>'.
GreaterThan ComparisonOperator = ">"
// LessThanEqual operator '<='.
LessThanEqual ComparisonOperator = "<="
// GreaterThanEqual operator '>='.
GreaterThanEqual ComparisonOperator = ">="
// Between operator 'BETWEEN'
Between ComparisonOperator = "between"
// In operator 'IN'
In ComparisonOperator = "in"
// Like operator 'LIKE'
Like ComparisonOperator = "like"
// NotBetween operator 'NOT BETWEEN'
NotBetween ComparisonOperator = "not between"
// NotIn operator 'NOT IN'
NotIn ComparisonOperator = "not in"
// NotLike operator 'NOT LIKE'
NotLike ComparisonOperator = "not like"
// IsNull operator 'IS NULL'
IsNull ComparisonOperator = "is null"
// IsNotNull operator 'IS NOT NULL'
IsNotNull ComparisonOperator = "is not null"
)
// String - returns string representation of this operator.
func (operator ComparisonOperator) String() string {
return strings.ToUpper((string(operator)))
}
func equal(leftValue, rightValue *Value) (bool, error) {
switch {
case leftValue.Type() == Null && rightValue.Type() == Null:
return true, nil
case leftValue.Type() == Bool && rightValue.Type() == Bool:
return leftValue.BoolValue() == rightValue.BoolValue(), nil
case (leftValue.Type() == Int || leftValue.Type() == Float) &&
(rightValue.Type() == Int || rightValue.Type() == Float):
return leftValue.FloatValue() == rightValue.FloatValue(), nil
case leftValue.Type() == String && rightValue.Type() == String:
return leftValue.StringValue() == rightValue.StringValue(), nil
case leftValue.Type() == Timestamp && rightValue.Type() == Timestamp:
return leftValue.TimeValue() == rightValue.TimeValue(), nil
}
return false, fmt.Errorf("left value type %v and right value type %v are incompatible for equality check", leftValue.Type(), rightValue.Type())
}
// comparisonExpr - comparison function.
type comparisonExpr struct {
left Expr
right Expr
to Expr
operator ComparisonOperator
funcType Type
}
// String - returns string representation of this function.
func (f *comparisonExpr) String() string {
switch f.operator {
case Equal, NotEqual, LessThan, GreaterThan, LessThanEqual, GreaterThanEqual, In, Like, NotIn, NotLike:
return fmt.Sprintf("(%v %v %v)", f.left, f.operator, f.right)
case Between, NotBetween:
return fmt.Sprintf("(%v %v %v AND %v)", f.left, f.operator, f.right, f.to)
}
return fmt.Sprintf("(%v %v %v %v)", f.left, f.right, f.to, f.operator)
}
func (f *comparisonExpr) equal(leftValue, rightValue *Value) (*Value, error) {
result, err := equal(leftValue, rightValue)
if err != nil {
err = fmt.Errorf("%v: %v", f, err)
return nil, errExternalEvalException(err)
}
return NewBool(result), nil
}
func (f *comparisonExpr) notEqual(leftValue, rightValue *Value) (*Value, error) {
result, err := equal(leftValue, rightValue)
if err != nil {
err = fmt.Errorf("%v: %v", f, err)
return nil, errExternalEvalException(err)
}
return NewBool(!result), nil
}
func (f *comparisonExpr) lessThan(leftValue, rightValue *Value) (*Value, error) {
if !leftValue.Type().isNumber() {
err := fmt.Errorf("%v: left side expression evaluated to %v; not to number", f, leftValue.Type())
return nil, errExternalEvalException(err)
}
if !rightValue.Type().isNumber() {
err := fmt.Errorf("%v: right side expression evaluated to %v; not to number", f, rightValue.Type())
return nil, errExternalEvalException(err)
}
return NewBool(leftValue.FloatValue() < rightValue.FloatValue()), nil
}
func (f *comparisonExpr) greaterThan(leftValue, rightValue *Value) (*Value, error) {
if !leftValue.Type().isNumber() {
err := fmt.Errorf("%v: left side expression evaluated to %v; not to number", f, leftValue.Type())
return nil, errExternalEvalException(err)
}
if !rightValue.Type().isNumber() {
err := fmt.Errorf("%v: right side expression evaluated to %v; not to number", f, rightValue.Type())
return nil, errExternalEvalException(err)
}
return NewBool(leftValue.FloatValue() > rightValue.FloatValue()), nil
}
func (f *comparisonExpr) lessThanEqual(leftValue, rightValue *Value) (*Value, error) {
if !leftValue.Type().isNumber() {
err := fmt.Errorf("%v: left side expression evaluated to %v; not to number", f, leftValue.Type())
return nil, errExternalEvalException(err)
}
if !rightValue.Type().isNumber() {
err := fmt.Errorf("%v: right side expression evaluated to %v; not to number", f, rightValue.Type())
return nil, errExternalEvalException(err)
}
return NewBool(leftValue.FloatValue() <= rightValue.FloatValue()), nil
}
func (f *comparisonExpr) greaterThanEqual(leftValue, rightValue *Value) (*Value, error) {
if !leftValue.Type().isNumber() {
err := fmt.Errorf("%v: left side expression evaluated to %v; not to number", f, leftValue.Type())
return nil, errExternalEvalException(err)
}
if !rightValue.Type().isNumber() {
err := fmt.Errorf("%v: right side expression evaluated to %v; not to number", f, rightValue.Type())
return nil, errExternalEvalException(err)
}
return NewBool(leftValue.FloatValue() >= rightValue.FloatValue()), nil
}
func (f *comparisonExpr) computeBetween(leftValue, fromValue, toValue *Value) (bool, error) {
if !leftValue.Type().isNumber() {
err := fmt.Errorf("%v: left side expression evaluated to %v; not to number", f, leftValue.Type())
return false, errExternalEvalException(err)
}
if !fromValue.Type().isNumber() {
err := fmt.Errorf("%v: from side expression evaluated to %v; not to number", f, fromValue.Type())
return false, errExternalEvalException(err)
}
if !toValue.Type().isNumber() {
err := fmt.Errorf("%v: to side expression evaluated to %v; not to number", f, toValue.Type())
return false, errExternalEvalException(err)
}
return leftValue.FloatValue() >= fromValue.FloatValue() &&
leftValue.FloatValue() <= toValue.FloatValue(), nil
}
func (f *comparisonExpr) between(leftValue, fromValue, toValue *Value) (*Value, error) {
result, err := f.computeBetween(leftValue, fromValue, toValue)
if err != nil {
return nil, err
}
return NewBool(result), nil
}
func (f *comparisonExpr) notBetween(leftValue, fromValue, toValue *Value) (*Value, error) {
result, err := f.computeBetween(leftValue, fromValue, toValue)
if err != nil {
return nil, err
}
return NewBool(!result), nil
}
func (f *comparisonExpr) computeIn(leftValue, rightValue *Value) (found bool, err error) {
if rightValue.Type() != Array {
err := fmt.Errorf("%v: right side expression evaluated to %v; not to Array", f, rightValue.Type())
return false, errExternalEvalException(err)
}
values := rightValue.ArrayValue()
for i := range values {
found, err = equal(leftValue, values[i])
if err != nil {
return false, err
}
if found {
return true, nil
}
}
return false, nil
}
func (f *comparisonExpr) in(leftValue, rightValue *Value) (*Value, error) {
result, err := f.computeIn(leftValue, rightValue)
if err != nil {
err = fmt.Errorf("%v: %v", f, err)
return nil, errExternalEvalException(err)
}
return NewBool(result), nil
}
func (f *comparisonExpr) notIn(leftValue, rightValue *Value) (*Value, error) {
result, err := f.computeIn(leftValue, rightValue)
if err != nil {
err = fmt.Errorf("%v: %v", f, err)
return nil, errExternalEvalException(err)
}
return NewBool(!result), nil
}
func (f *comparisonExpr) computeLike(leftValue, rightValue *Value) (matched bool, err error) {
if leftValue.Type() != String {
err := fmt.Errorf("%v: left side expression evaluated to %v; not to string", f, leftValue.Type())
return false, errExternalEvalException(err)
}
if rightValue.Type() != String {
err := fmt.Errorf("%v: right side expression evaluated to %v; not to string", f, rightValue.Type())
return false, errExternalEvalException(err)
}
matched, err = regexp.MatchString(rightValue.StringValue(), leftValue.StringValue())
if err != nil {
err = fmt.Errorf("%v: %v", f, err)
return false, errExternalEvalException(err)
}
return matched, nil
}
func (f *comparisonExpr) like(leftValue, rightValue *Value) (*Value, error) {
result, err := f.computeLike(leftValue, rightValue)
if err != nil {
return nil, err
}
return NewBool(result), nil
}
func (f *comparisonExpr) notLike(leftValue, rightValue *Value) (*Value, error) {
result, err := f.computeLike(leftValue, rightValue)
if err != nil {
return nil, err
}
return NewBool(!result), nil
}
func (f *comparisonExpr) compute(leftValue, rightValue, toValue *Value) (*Value, error) {
switch f.operator {
case Equal:
return f.equal(leftValue, rightValue)
case NotEqual:
return f.notEqual(leftValue, rightValue)
case LessThan:
return f.lessThan(leftValue, rightValue)
case GreaterThan:
return f.greaterThan(leftValue, rightValue)
case LessThanEqual:
return f.lessThanEqual(leftValue, rightValue)
case GreaterThanEqual:
return f.greaterThanEqual(leftValue, rightValue)
case Between:
return f.between(leftValue, rightValue, toValue)
case In:
return f.in(leftValue, rightValue)
case Like:
return f.like(leftValue, rightValue)
case NotBetween:
return f.notBetween(leftValue, rightValue, toValue)
case NotIn:
return f.notIn(leftValue, rightValue)
case NotLike:
return f.notLike(leftValue, rightValue)
}
panic(fmt.Errorf("unexpected expression %v", f))
}
// Call - evaluates this function for given arg values and returns result as Value.
func (f *comparisonExpr) Eval(record Record) (*Value, error) {
leftValue, err := f.left.Eval(record)
if err != nil {
return nil, err
}
rightValue, err := f.right.Eval(record)
if err != nil {
return nil, err
}
var toValue *Value
if f.to != nil {
toValue, err = f.to.Eval(record)
if err != nil {
return nil, err
}
}
if f.funcType == aggregateFunction {
return nil, nil
}
return f.compute(leftValue, rightValue, toValue)
}
// AggregateValue - returns aggregated value.
func (f *comparisonExpr) AggregateValue() (*Value, error) {
if f.funcType != aggregateFunction {
err := fmt.Errorf("%v is not aggreate expression", f)
return nil, errExternalEvalException(err)
}
leftValue, err := f.left.AggregateValue()
if err != nil {
return nil, err
}
rightValue, err := f.right.AggregateValue()
if err != nil {
return nil, err
}
var toValue *Value
if f.to != nil {
toValue, err = f.to.AggregateValue()
if err != nil {
return nil, err
}
}
return f.compute(leftValue, rightValue, toValue)
}
// Type - returns comparisonFunction or aggregateFunction type.
func (f *comparisonExpr) Type() Type {
return f.funcType
}
// ReturnType - returns Bool as return type.
func (f *comparisonExpr) ReturnType() Type {
return Bool
}
// newComparisonExpr - creates new comparison function.
func newComparisonExpr(operator ComparisonOperator, funcs ...Expr) (*comparisonExpr, error) {
funcType := comparisonFunction
switch operator {
case Equal, NotEqual:
if len(funcs) != 2 {
panic(fmt.Sprintf("exactly two arguments are expected, but found %v", len(funcs)))
}
left := funcs[0]
if !left.ReturnType().isBaseKind() {
err := fmt.Errorf("operator %v: left side expression %v evaluate to %v is incompatible for equality check", operator, left, left.ReturnType())
return nil, errInvalidDataType(err)
}
right := funcs[1]
if !right.ReturnType().isBaseKind() {
err := fmt.Errorf("operator %v: right side expression %v evaluate to %v is incompatible for equality check", operator, right, right.ReturnType())
return nil, errInvalidDataType(err)
}
if left.Type() == aggregateFunction || right.Type() == aggregateFunction {
funcType = aggregateFunction
switch left.Type() {
case column, Array, function, arithmeticFunction, comparisonFunction, logicalFunction, record:
err := fmt.Errorf("operator %v: left side expression %v return type %v is incompatible for equality check", operator, left, left.Type())
return nil, errUnsupportedSQLOperation(err)
}
switch right.Type() {
case column, Array, function, arithmeticFunction, comparisonFunction, logicalFunction, record:
err := fmt.Errorf("operator %v: right side expression %v return type %v is incompatible for equality check", operator, right, right.Type())
return nil, errUnsupportedSQLOperation(err)
}
}
return &comparisonExpr{
left: left,
right: right,
operator: operator,
funcType: funcType,
}, nil
case LessThan, GreaterThan, LessThanEqual, GreaterThanEqual:
if len(funcs) != 2 {
panic(fmt.Sprintf("exactly two arguments are expected, but found %v", len(funcs)))
}
left := funcs[0]
if !left.ReturnType().isNumberKind() {
err := fmt.Errorf("operator %v: left side expression %v evaluate to %v, not number", operator, left, left.ReturnType())
return nil, errInvalidDataType(err)
}
right := funcs[1]
if !right.ReturnType().isNumberKind() {
err := fmt.Errorf("operator %v: right side expression %v evaluate to %v; not number", operator, right, right.ReturnType())
return nil, errInvalidDataType(err)
}
if left.Type() == aggregateFunction || right.Type() == aggregateFunction {
funcType = aggregateFunction
switch left.Type() {
case Int, Float, aggregateFunction:
default:
err := fmt.Errorf("operator %v: left side expression %v return type %v is incompatible for aggregate evaluation", operator, left, left.Type())
return nil, errUnsupportedSQLOperation(err)
}
switch right.Type() {
case Int, Float, aggregateFunction:
default:
err := fmt.Errorf("operator %v: right side expression %v return type %v is incompatible for aggregate evaluation", operator, right, right.Type())
return nil, errUnsupportedSQLOperation(err)
}
}
return &comparisonExpr{
left: left,
right: right,
operator: operator,
funcType: funcType,
}, nil
case In, NotIn:
if len(funcs) != 2 {
panic(fmt.Sprintf("exactly two arguments are expected, but found %v", len(funcs)))
}
left := funcs[0]
if !left.ReturnType().isBaseKind() {
err := fmt.Errorf("operator %v: left side expression %v evaluate to %v is incompatible for equality check", operator, left, left.ReturnType())
return nil, errInvalidDataType(err)
}
right := funcs[1]
if right.ReturnType() != Array {
err := fmt.Errorf("operator %v: right side expression %v evaluate to %v is incompatible for equality check", operator, right, right.ReturnType())
return nil, errInvalidDataType(err)
}
if left.Type() == aggregateFunction || right.Type() == aggregateFunction {
funcType = aggregateFunction
switch left.Type() {
case column, Array, function, arithmeticFunction, comparisonFunction, logicalFunction, record:
err := fmt.Errorf("operator %v: left side expression %v return type %v is incompatible for aggregate evaluation", operator, left, left.Type())
return nil, errUnsupportedSQLOperation(err)
}
switch right.Type() {
case Array, aggregateFunction:
default:
err := fmt.Errorf("operator %v: right side expression %v return type %v is incompatible for aggregate evaluation", operator, right, right.Type())
return nil, errUnsupportedSQLOperation(err)
}
}
return &comparisonExpr{
left: left,
right: right,
operator: operator,
funcType: funcType,
}, nil
case Like, NotLike:
if len(funcs) != 2 {
panic(fmt.Sprintf("exactly two arguments are expected, but found %v", len(funcs)))
}
left := funcs[0]
if !left.ReturnType().isStringKind() {
err := fmt.Errorf("operator %v: left side expression %v evaluate to %v, not string", operator, left, left.ReturnType())
return nil, errLikeInvalidInputs(err)
}
right := funcs[1]
if !right.ReturnType().isStringKind() {
err := fmt.Errorf("operator %v: right side expression %v evaluate to %v, not string", operator, right, right.ReturnType())
return nil, errLikeInvalidInputs(err)
}
if left.Type() == aggregateFunction || right.Type() == aggregateFunction {
funcType = aggregateFunction
switch left.Type() {
case String, aggregateFunction:
default:
err := fmt.Errorf("operator %v: left side expression %v return type %v is incompatible for aggregate evaluation", operator, left, left.Type())
return nil, errUnsupportedSQLOperation(err)
}
switch right.Type() {
case String, aggregateFunction:
default:
err := fmt.Errorf("operator %v: right side expression %v return type %v is incompatible for aggregate evaluation", operator, right, right.Type())
return nil, errUnsupportedSQLOperation(err)
}
}
return &comparisonExpr{
left: left,
right: right,
operator: operator,
funcType: funcType,
}, nil
case Between, NotBetween:
if len(funcs) != 3 {
panic(fmt.Sprintf("too many values in funcs %v", funcs))
}
left := funcs[0]
if !left.ReturnType().isNumberKind() {
err := fmt.Errorf("operator %v: left side expression %v evaluate to %v, not number", operator, left, left.ReturnType())
return nil, errInvalidDataType(err)
}
from := funcs[1]
if !from.ReturnType().isNumberKind() {
err := fmt.Errorf("operator %v: from expression %v evaluate to %v, not number", operator, from, from.ReturnType())
return nil, errInvalidDataType(err)
}
to := funcs[2]
if !to.ReturnType().isNumberKind() {
err := fmt.Errorf("operator %v: to expression %v evaluate to %v, not number", operator, to, to.ReturnType())
return nil, errInvalidDataType(err)
}
if left.Type() == aggregateFunction || from.Type() == aggregateFunction || to.Type() == aggregateFunction {
funcType = aggregateFunction
switch left.Type() {
case Int, Float, aggregateFunction:
default:
err := fmt.Errorf("operator %v: left side expression %v return type %v is incompatible for aggregate evaluation", operator, left, left.Type())
return nil, errUnsupportedSQLOperation(err)
}
switch from.Type() {
case Int, Float, aggregateFunction:
default:
err := fmt.Errorf("operator %v: from expression %v return type %v is incompatible for aggregate evaluation", operator, from, from.Type())
return nil, errUnsupportedSQLOperation(err)
}
switch to.Type() {
case Int, Float, aggregateFunction:
default:
err := fmt.Errorf("operator %v: to expression %v return type %v is incompatible for aggregate evaluation", operator, to, to.Type())
return nil, errUnsupportedSQLOperation(err)
}
}
return &comparisonExpr{
left: left,
right: from,
to: to,
operator: operator,
funcType: funcType,
}, nil
case IsNull, IsNotNull:
if len(funcs) != 1 {
panic(fmt.Sprintf("too many values in funcs %v", funcs))
}
if funcs[0].Type() == aggregateFunction {
funcType = aggregateFunction
}
if operator == IsNull {
operator = Equal
} else {
operator = NotEqual
}
return &comparisonExpr{
left: funcs[0],
right: newValueExpr(NewNull()),
operator: operator,
funcType: funcType,
}, nil
}
return nil, errParseUnknownOperator(fmt.Errorf("unknown operator %v", operator))
}

View File

@ -43,42 +43,6 @@ func (err *s3Error) Error() string {
return err.message
}
func errUnsupportedSQLStructure(err error) *s3Error {
return &s3Error{
code: "UnsupportedSqlStructure",
message: "Encountered an unsupported SQL structure. Check the SQL Reference.",
statusCode: 400,
cause: err,
}
}
func errParseUnsupportedSelect(err error) *s3Error {
return &s3Error{
code: "ParseUnsupportedSelect",
message: "The SQL expression contains an unsupported use of SELECT.",
statusCode: 400,
cause: err,
}
}
func errParseAsteriskIsNotAloneInSelectList(err error) *s3Error {
return &s3Error{
code: "ParseAsteriskIsNotAloneInSelectList",
message: "Other expressions are not allowed in the SELECT list when '*' is used without dot notation in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errParseInvalidContextForWildcardInSelectList(err error) *s3Error {
return &s3Error{
code: "ParseInvalidContextForWildcardInSelectList",
message: "Invalid use of * in SELECT list in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errInvalidDataType(err error) *s3Error {
return &s3Error{
code: "InvalidDataType",
@ -88,55 +52,10 @@ func errInvalidDataType(err error) *s3Error {
}
}
func errUnsupportedFunction(err error) *s3Error {
return &s3Error{
code: "UnsupportedFunction",
message: "Encountered an unsupported SQL function.",
statusCode: 400,
cause: err,
}
}
func errParseNonUnaryAgregateFunctionCall(err error) *s3Error {
return &s3Error{
code: "ParseNonUnaryAgregateFunctionCall",
message: "Only one argument is supported for aggregate functions in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errIncorrectSQLFunctionArgumentType(err error) *s3Error {
return &s3Error{
code: "IncorrectSqlFunctionArgumentType",
message: "Incorrect type of arguments in function call in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errEvaluatorInvalidArguments(err error) *s3Error {
return &s3Error{
code: "EvaluatorInvalidArguments",
message: "Incorrect number of arguments in the function call in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errUnsupportedSQLOperation(err error) *s3Error {
return &s3Error{
code: "UnsupportedSqlOperation",
message: "Encountered an unsupported SQL operation.",
statusCode: 400,
cause: err,
}
}
func errParseUnknownOperator(err error) *s3Error {
return &s3Error{
code: "ParseUnknownOperator",
message: "The SQL expression contains an invalid operator.",
message: "Incorrect type of arguments in function call.",
statusCode: 400,
cause: err,
}
@ -151,64 +70,28 @@ func errLikeInvalidInputs(err error) *s3Error {
}
}
func errExternalEvalException(err error) *s3Error {
func errQueryParseFailure(err error) *s3Error {
return &s3Error{
code: "ExternalEvalException",
message: "The query cannot be evaluated. Check the file and try again.",
code: "ParseSelectFailure",
message: err.Error(),
statusCode: 400,
cause: err,
}
}
func errValueParseFailure(err error) *s3Error {
func errQueryAnalysisFailure(err error) *s3Error {
return &s3Error{
code: "ValueParseFailure",
message: "Time stamp parse failure in the SQL expression.",
code: "InvalidQuery",
message: err.Error(),
statusCode: 400,
cause: err,
}
}
func errEvaluatorBindingDoesNotExist(err error) *s3Error {
func errBadTableName(err error) *s3Error {
return &s3Error{
code: "EvaluatorBindingDoesNotExist",
message: "A column name or a path provided does not exist in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errInternalError(err error) *s3Error {
return &s3Error{
code: "InternalError",
message: "Encountered an internal error.",
statusCode: 500,
cause: err,
}
}
func errParseInvalidTypeParam(err error) *s3Error {
return &s3Error{
code: "ParseInvalidTypeParam",
message: "The SQL expression contains an invalid parameter value.",
statusCode: 400,
cause: err,
}
}
func errParseUnsupportedSyntax(err error) *s3Error {
return &s3Error{
code: "ParseUnsupportedSyntax",
message: "The SQL expression contains unsupported syntax.",
statusCode: 400,
cause: err,
}
}
func errInvalidKeyPath(err error) *s3Error {
return &s3Error{
code: "InvalidKeyPath",
message: "Key path in the SQL expression is invalid.",
code: "BadTableName",
message: "The table name is not supported",
statusCode: 400,
cause: err,
}

View File

@ -0,0 +1,361 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import (
"errors"
"strings"
)
var (
errInvalidASTNode = errors.New("invalid AST Node")
errExpectedBool = errors.New("expected bool")
errLikeNonStrArg = errors.New("LIKE clause requires string arguments")
errLikeInvalidEscape = errors.New("LIKE clause has invalid ESCAPE character")
errNotImplemented = errors.New("not implemented")
)
// AST Node Evaluation functions
//
// During evaluation, the query is known to be valid, as analysis is
// complete. The only errors possible are due to value type
// mismatches, etc.
//
// If an aggregation node is present as a descendant (when
// e.prop.isAggregation is true), we call evalNode on all child nodes,
// check for errors, but do not perform any combining of the results
// of child nodes. The final result row is returned after all rows are
// processed, and the `getAggregate` function is called.
func (e *AliasedExpression) evalNode(r Record) (*Value, error) {
return e.Expression.evalNode(r)
}
func (e *Expression) evalNode(r Record) (*Value, error) {
if len(e.And) == 1 {
// In this case, result is not required to be boolean
// type.
return e.And[0].evalNode(r)
}
// Compute OR of conditions
result := false
for _, ex := range e.And {
res, err := ex.evalNode(r)
if err != nil {
return nil, err
}
b, ok := res.ToBool()
if !ok {
return nil, errExpectedBool
}
result = result || b
}
return FromBool(result), nil
}
func (e *AndCondition) evalNode(r Record) (*Value, error) {
if len(e.Condition) == 1 {
// In this case, result does not have to be boolean
return e.Condition[0].evalNode(r)
}
// Compute AND of conditions
result := true
for _, ex := range e.Condition {
res, err := ex.evalNode(r)
if err != nil {
return nil, err
}
b, ok := res.ToBool()
if !ok {
return nil, errExpectedBool
}
result = result && b
}
return FromBool(result), nil
}
func (e *Condition) evalNode(r Record) (*Value, error) {
if e.Operand != nil {
// In this case, result does not have to be boolean
return e.Operand.evalNode(r)
}
// Compute NOT of condition
res, err := e.Not.evalNode(r)
if err != nil {
return nil, err
}
b, ok := res.ToBool()
if !ok {
return nil, errExpectedBool
}
return FromBool(!b), nil
}
func (e *ConditionOperand) evalNode(r Record) (*Value, error) {
opVal, opErr := e.Operand.evalNode(r)
if opErr != nil || e.ConditionRHS == nil {
return opVal, opErr
}
// Need to evaluate the ConditionRHS
switch {
case e.ConditionRHS.Compare != nil:
cmpRight, cmpRErr := e.ConditionRHS.Compare.Operand.evalNode(r)
if cmpRErr != nil {
return nil, cmpRErr
}
b, err := opVal.compareOp(e.ConditionRHS.Compare.Operator, cmpRight)
return FromBool(b), err
case e.ConditionRHS.Between != nil:
return e.ConditionRHS.Between.evalBetweenNode(r, opVal)
case e.ConditionRHS.Like != nil:
return e.ConditionRHS.Like.evalLikeNode(r, opVal)
case e.ConditionRHS.In != nil:
return e.ConditionRHS.In.evalInNode(r, opVal)
default:
return nil, errInvalidASTNode
}
}
func (e *Between) evalBetweenNode(r Record, arg *Value) (*Value, error) {
stVal, stErr := e.Start.evalNode(r)
if stErr != nil {
return nil, stErr
}
endVal, endErr := e.End.evalNode(r)
if endErr != nil {
return nil, endErr
}
part1, err1 := stVal.compareOp(opLte, arg)
if err1 != nil {
return nil, err1
}
part2, err2 := arg.compareOp(opLte, endVal)
if err2 != nil {
return nil, err2
}
result := part1 && part2
if e.Not {
result = !result
}
return FromBool(result), nil
}
func (e *Like) evalLikeNode(r Record, arg *Value) (*Value, error) {
inferTypeAsString(arg)
s, ok := arg.ToString()
if !ok {
err := errLikeNonStrArg
return nil, errLikeInvalidInputs(err)
}
pattern, err1 := e.Pattern.evalNode(r)
if err1 != nil {
return nil, err1
}
// Infer pattern as string (in case it is untyped)
inferTypeAsString(pattern)
patternStr, ok := pattern.ToString()
if !ok {
err := errLikeNonStrArg
return nil, errLikeInvalidInputs(err)
}
escape := runeZero
if e.EscapeChar != nil {
escapeVal, err2 := e.EscapeChar.evalNode(r)
if err2 != nil {
return nil, err2
}
inferTypeAsString(escapeVal)
escapeStr, ok := escapeVal.ToString()
if !ok {
err := errLikeNonStrArg
return nil, errLikeInvalidInputs(err)
}
if len([]rune(escapeStr)) > 1 {
err := errLikeInvalidEscape
return nil, errLikeInvalidInputs(err)
}
}
matchResult, err := evalSQLLike(s, patternStr, escape)
if err != nil {
return nil, err
}
if e.Not {
matchResult = !matchResult
}
return FromBool(matchResult), nil
}
func (e *In) evalInNode(r Record, arg *Value) (*Value, error) {
result := false
for _, elt := range e.Expressions {
eltVal, err := elt.evalNode(r)
if err != nil {
return nil, err
}
// FIXME: type inference?
// Types must match.
if arg.vType != eltVal.vType {
// match failed.
continue
}
if arg.value == eltVal.value {
result = true
break
}
}
return FromBool(result), nil
}
func (e *Operand) evalNode(r Record) (*Value, error) {
lval, lerr := e.Left.evalNode(r)
if lerr != nil || len(e.Right) == 0 {
return lval, lerr
}
// Process remaining child nodes - result must be
// numeric. This AST node is for terms separated by + or -
// symbols.
for _, rightTerm := range e.Right {
op := rightTerm.Op
rval, rerr := rightTerm.Right.evalNode(r)
if rerr != nil {
return nil, rerr
}
err := lval.arithOp(op, rval)
if err != nil {
return nil, err
}
}
return lval, nil
}
func (e *MultOp) evalNode(r Record) (*Value, error) {
lval, lerr := e.Left.evalNode(r)
if lerr != nil || len(e.Right) == 0 {
return lval, lerr
}
// Process other child nodes - result must be numeric. This
// AST node is for terms separated by *, / or % symbols.
for _, rightTerm := range e.Right {
op := rightTerm.Op
rval, rerr := rightTerm.Right.evalNode(r)
if rerr != nil {
return nil, rerr
}
err := lval.arithOp(op, rval)
if err != nil {
return nil, err
}
}
return lval, nil
}
func (e *UnaryTerm) evalNode(r Record) (*Value, error) {
if e.Negated == nil {
return e.Primary.evalNode(r)
}
v, err := e.Negated.Term.evalNode(r)
if err != nil {
return nil, err
}
inferTypeForArithOp(v)
if ival, ok := v.ToInt(); ok {
return FromInt(-ival), nil
} else if fval, ok := v.ToFloat(); ok {
return FromFloat(-fval), nil
}
return nil, errArithMismatchedTypes
}
func (e *JSONPath) evalNode(r Record) (*Value, error) {
// Strip the table name from the keypath.
keypath := e.String()
ps := strings.SplitN(keypath, ".", 2)
if len(ps) == 2 {
keypath = ps[1]
}
return r.Get(keypath)
}
func (e *PrimaryTerm) evalNode(r Record) (res *Value, err error) {
switch {
case e.Value != nil:
return e.Value.evalNode(r)
case e.JPathExpr != nil:
return e.JPathExpr.evalNode(r)
case e.SubExpression != nil:
return e.SubExpression.evalNode(r)
case e.FuncCall != nil:
return e.FuncCall.evalNode(r)
}
return nil, errInvalidASTNode
}
func (e *FuncExpr) evalNode(r Record) (res *Value, err error) {
switch e.getFunctionName() {
case aggFnCount, aggFnAvg, aggFnMax, aggFnMin, aggFnSum:
return e.getAggregate()
default:
return e.evalSQLFnNode(r)
}
}
// evalNode on a literal value is independent of the node being an
// aggregation or a row function - it always returns a value.
func (e *LitValue) evalNode(_ Record) (res *Value, err error) {
switch {
case e.Number != nil:
return floatToValue(*e.Number), nil
case e.String != nil:
return FromString(string(*e.String)), nil
case e.Boolean != nil:
return FromBool(bool(*e.Boolean)), nil
}
return FromNull(), nil
}

View File

@ -1,160 +0,0 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import (
"fmt"
)
// Expr - a SQL expression type.
type Expr interface {
AggregateValue() (*Value, error)
Eval(record Record) (*Value, error)
ReturnType() Type
Type() Type
}
// aliasExpr - aliases expression by alias.
type aliasExpr struct {
alias string
expr Expr
}
// String - returns string representation of this expression.
func (expr *aliasExpr) String() string {
return fmt.Sprintf("(%v AS %v)", expr.expr, expr.alias)
}
// Eval - evaluates underlaying expression for given record and returns evaluated result.
func (expr *aliasExpr) Eval(record Record) (*Value, error) {
return expr.expr.Eval(record)
}
// AggregateValue - returns aggregated value from underlaying expression.
func (expr *aliasExpr) AggregateValue() (*Value, error) {
return expr.expr.AggregateValue()
}
// Type - returns underlaying expression type.
func (expr *aliasExpr) Type() Type {
return expr.expr.Type()
}
// ReturnType - returns underlaying expression's return type.
func (expr *aliasExpr) ReturnType() Type {
return expr.expr.ReturnType()
}
// newAliasExpr - creates new alias expression.
func newAliasExpr(alias string, expr Expr) *aliasExpr {
return &aliasExpr{alias, expr}
}
// starExpr - asterisk (*) expression.
type starExpr struct {
}
// String - returns string representation of this expression.
func (expr *starExpr) String() string {
return "*"
}
// Eval - returns given args as map value.
func (expr *starExpr) Eval(record Record) (*Value, error) {
return newRecordValue(record), nil
}
// AggregateValue - returns nil value.
func (expr *starExpr) AggregateValue() (*Value, error) {
return nil, nil
}
// Type - returns record type.
func (expr *starExpr) Type() Type {
return record
}
// ReturnType - returns record as return type.
func (expr *starExpr) ReturnType() Type {
return record
}
// newStarExpr - returns new asterisk (*) expression.
func newStarExpr() *starExpr {
return &starExpr{}
}
type valueExpr struct {
value *Value
}
func (expr *valueExpr) String() string {
return expr.value.String()
}
func (expr *valueExpr) Eval(record Record) (*Value, error) {
return expr.value, nil
}
func (expr *valueExpr) AggregateValue() (*Value, error) {
return expr.value, nil
}
func (expr *valueExpr) Type() Type {
return expr.value.Type()
}
func (expr *valueExpr) ReturnType() Type {
return expr.value.Type()
}
func newValueExpr(value *Value) *valueExpr {
return &valueExpr{value: value}
}
type columnExpr struct {
name string
}
func (expr *columnExpr) String() string {
return expr.name
}
func (expr *columnExpr) Eval(record Record) (*Value, error) {
value, err := record.Get(expr.name)
if err != nil {
return nil, errEvaluatorBindingDoesNotExist(err)
}
return value, nil
}
func (expr *columnExpr) AggregateValue() (*Value, error) {
return nil, nil
}
func (expr *columnExpr) Type() Type {
return column
}
func (expr *columnExpr) ReturnType() Type {
return column
}
func newColumnExpr(columnName string) *columnExpr {
return &columnExpr{name: columnName}
}

View File

@ -0,0 +1,433 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import (
"errors"
"fmt"
"strconv"
"strings"
)
// FuncName - SQL function name.
type FuncName string
// SQL Function name constants
const (
// Conditionals
sqlFnCoalesce FuncName = "COALESCE"
sqlFnNullIf FuncName = "NULLIF"
// Conversion
sqlFnCast FuncName = "CAST"
// Date and time
sqlFnDateAdd FuncName = "DATE_ADD"
sqlFnDateDiff FuncName = "DATE_DIFF"
sqlFnExtract FuncName = "EXTRACT"
sqlFnToString FuncName = "TO_STRING"
sqlFnToTimestamp FuncName = "TO_TIMESTAMP"
sqlFnUTCNow FuncName = "UTCNOW"
// String
sqlFnCharLength FuncName = "CHAR_LENGTH"
sqlFnCharacterLength FuncName = "CHARACTER_LENGTH"
sqlFnLower FuncName = "LOWER"
sqlFnSubstring FuncName = "SUBSTRING"
sqlFnTrim FuncName = "TRIM"
sqlFnUpper FuncName = "UPPER"
)
// Allowed cast types
const (
castBool = "BOOL"
castInt = "INT"
castInteger = "INTEGER"
castString = "STRING"
castFloat = "FLOAT"
castDecimal = "DECIMAL"
castNumeric = "NUMERIC"
castTimestamp = "TIMESTAMP"
)
var (
errUnimplementedCast = errors.New("This cast not yet implemented")
errNonStringTrimArg = errors.New("TRIM() received a non-string argument")
)
func (e *FuncExpr) getFunctionName() FuncName {
switch {
case e.SFunc != nil:
return FuncName(strings.ToUpper(e.SFunc.FunctionName))
case e.Count != nil:
return FuncName(aggFnCount)
case e.Cast != nil:
return sqlFnCast
case e.Substring != nil:
return sqlFnSubstring
case e.Extract != nil:
return sqlFnExtract
case e.Trim != nil:
return sqlFnTrim
default:
return ""
}
}
// evalSQLFnNode assumes that the FuncExpr is not an aggregation
// function.
func (e *FuncExpr) evalSQLFnNode(r Record) (res *Value, err error) {
// Handle functions that have phrase arguments
switch e.getFunctionName() {
case sqlFnCast:
expr := e.Cast.Expr
res, err = expr.castTo(r, strings.ToUpper(e.Cast.CastType))
return
case sqlFnSubstring:
return handleSQLSubstring(r, e.Substring)
case sqlFnExtract:
return nil, errNotImplemented
case sqlFnTrim:
return handleSQLTrim(r, e.Trim)
}
// For all simple argument functions, we evaluate the arguments here
argVals := make([]*Value, len(e.SFunc.ArgsList))
for i, arg := range e.SFunc.ArgsList {
argVals[i], err = arg.evalNode(r)
if err != nil {
return nil, err
}
}
switch e.getFunctionName() {
case sqlFnCoalesce:
return coalesce(r, argVals)
case sqlFnNullIf:
return nullif(r, argVals[0], argVals[1])
case sqlFnCharLength, sqlFnCharacterLength:
return charlen(r, argVals[0])
case sqlFnLower:
return lowerCase(r, argVals[0])
case sqlFnUpper:
return upperCase(r, argVals[0])
case sqlFnDateAdd, sqlFnDateDiff, sqlFnToString, sqlFnToTimestamp, sqlFnUTCNow:
// TODO: implement
fallthrough
default:
return nil, errInvalidASTNode
}
}
func coalesce(r Record, args []*Value) (res *Value, err error) {
for _, arg := range args {
if arg.IsNull() {
continue
}
return arg, nil
}
return FromNull(), nil
}
func nullif(r Record, v1, v2 *Value) (res *Value, err error) {
// Handle Null cases
if v1.IsNull() || v2.IsNull() {
return v1, nil
}
err = inferTypesForCmp(v1, v2)
if err != nil {
return nil, err
}
atleastOneNumeric := v1.isNumeric() || v2.isNumeric()
bothNumeric := v1.isNumeric() && v2.isNumeric()
if atleastOneNumeric || !bothNumeric {
return v1, nil
}
if v1.vType != v2.vType {
return v1, nil
}
cmpResult, cmpErr := v1.compareOp(opEq, v2)
if cmpErr != nil {
return nil, cmpErr
}
if cmpResult {
return FromNull(), nil
}
return v1, nil
}
func charlen(r Record, v *Value) (*Value, error) {
inferTypeAsString(v)
s, ok := v.ToString()
if !ok {
err := fmt.Errorf("%s/%s expects a string argument", sqlFnCharLength, sqlFnCharacterLength)
return nil, errIncorrectSQLFunctionArgumentType(err)
}
return FromInt(int64(len(s))), nil
}
func lowerCase(r Record, v *Value) (*Value, error) {
inferTypeAsString(v)
s, ok := v.ToString()
if !ok {
err := fmt.Errorf("%s expects a string argument", sqlFnLower)
return nil, errIncorrectSQLFunctionArgumentType(err)
}
return FromString(strings.ToLower(s)), nil
}
func upperCase(r Record, v *Value) (*Value, error) {
inferTypeAsString(v)
s, ok := v.ToString()
if !ok {
err := fmt.Errorf("%s expects a string argument", sqlFnUpper)
return nil, errIncorrectSQLFunctionArgumentType(err)
}
return FromString(strings.ToUpper(s)), nil
}
func handleSQLSubstring(r Record, e *SubstringFunc) (val *Value, err error) {
// Both forms `SUBSTRING('abc' FROM 2 FOR 1)` and
// SUBSTRING('abc', 2, 1) are supported.
// Evaluate the string argument
v1, err := e.Expr.evalNode(r)
if err != nil {
return nil, err
}
inferTypeAsString(v1)
s, ok := v1.ToString()
if !ok {
err := fmt.Errorf("Incorrect argument type passed to %s", sqlFnSubstring)
return nil, errIncorrectSQLFunctionArgumentType(err)
}
// Assemble other arguments
arg2, arg3 := e.From, e.For
// Check if the second form of substring is being used
if e.From == nil {
arg2, arg3 = e.Arg2, e.Arg3
}
// Evaluate the FROM argument
v2, err := arg2.evalNode(r)
if err != nil {
return nil, err
}
inferTypeForArithOp(v2)
startIdx, ok := v2.ToInt()
if !ok {
err := fmt.Errorf("Incorrect type for start index argument in %s", sqlFnSubstring)
return nil, errIncorrectSQLFunctionArgumentType(err)
}
length := -1
// Evaluate the optional FOR argument
if arg3 != nil {
v3, err := arg3.evalNode(r)
if err != nil {
return nil, err
}
inferTypeForArithOp(v3)
lenInt, ok := v3.ToInt()
if !ok {
err := fmt.Errorf("Incorrect type for length argument in %s", sqlFnSubstring)
return nil, errIncorrectSQLFunctionArgumentType(err)
}
length = int(lenInt)
if length < 0 {
err := fmt.Errorf("Negative length argument in %s", sqlFnSubstring)
return nil, errIncorrectSQLFunctionArgumentType(err)
}
}
res, err := evalSQLSubstring(s, int(startIdx), length)
return FromString(res), err
}
func handleSQLTrim(r Record, e *TrimFunc) (res *Value, err error) {
charsV, cerr := e.TrimChars.evalNode(r)
if cerr != nil {
return nil, cerr
}
inferTypeAsString(charsV)
chars, ok := charsV.ToString()
if !ok {
return nil, errNonStringTrimArg
}
fromV, ferr := e.TrimFrom.evalNode(r)
if ferr != nil {
return nil, ferr
}
from, ok := fromV.ToString()
if !ok {
return nil, errNonStringTrimArg
}
result, terr := evalSQLTrim(e.TrimWhere, chars, from)
if terr != nil {
return nil, terr
}
return FromString(result), nil
}
func errUnsupportedCast(fromType, toType string) error {
return fmt.Errorf("Cannot cast from %v to %v", fromType, toType)
}
func errCastFailure(msg string) error {
return fmt.Errorf("Error casting: %s", msg)
}
func (e *Expression) castTo(r Record, castType string) (res *Value, err error) {
v, err := e.evalNode(r)
if err != nil {
return nil, err
}
fmt.Println("Cast to ", castType)
switch castType {
case castInt, castInteger:
i, err := intCast(v)
return FromInt(i), err
case castFloat:
f, err := floatCast(v)
return FromFloat(f), err
case castString:
s, err := stringCast(v)
return FromString(s), err
case castBool, castDecimal, castNumeric, castTimestamp:
fallthrough
default:
return nil, errUnimplementedCast
}
}
func intCast(v *Value) (int64, error) {
// This conversion truncates floating point numbers to
// integer.
strToInt := func(s string) (int64, bool) {
i, errI := strconv.ParseInt(s, 10, 64)
if errI == nil {
return i, true
}
f, errF := strconv.ParseFloat(s, 64)
if errF == nil {
return int64(f), true
}
return 0, false
}
switch v.vType {
case typeFloat:
// Truncate fractional part
return int64(v.value.(float64)), nil
case typeInt:
return v.value.(int64), nil
case typeString:
// Parse as number, truncate floating point if
// needed.
s, _ := v.ToString()
res, ok := strToInt(s)
if !ok {
return 0, errCastFailure("could not parse as int")
}
return res, nil
case typeBytes:
// Parse as number, truncate floating point if
// needed.
b, _ := v.ToBytes()
s := string(b)
res, ok := strToInt(s)
if !ok {
return 0, errCastFailure("could not parse as int")
}
return res, nil
default:
return 0, errUnsupportedCast(v.GetTypeString(), castInt)
}
}
func floatCast(v *Value) (float64, error) {
switch v.vType {
case typeFloat:
return v.value.(float64), nil
case typeInt:
return float64(v.value.(int64)), nil
case typeString:
f, err := strconv.ParseFloat(v.value.(string), 64)
if err != nil {
return 0, errCastFailure("could not parse as float")
}
return f, nil
case typeBytes:
b, _ := v.ToBytes()
f, err := strconv.ParseFloat(string(b), 64)
if err != nil {
return 0, errCastFailure("could not parse as float")
}
return f, nil
default:
return 0, errUnsupportedCast(v.GetTypeString(), castFloat)
}
}
func stringCast(v *Value) (string, error) {
switch v.vType {
case typeFloat:
f, _ := v.ToFloat()
return fmt.Sprintf("%v", f), nil
case typeInt:
i, _ := v.ToInt()
return fmt.Sprintf("%v", i), nil
case typeString:
s, _ := v.ToString()
return s, nil
case typeBytes:
b, _ := v.ToBytes()
return string(b), nil
case typeBool:
b, _ := v.ToBool()
return fmt.Sprintf("%v", b), nil
case typeNull:
// FIXME: verify this case is correct
return fmt.Sprintf("NULL"), nil
}
// This does not happen
return "", nil
}

View File

@ -1,550 +0,0 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import (
"fmt"
"strings"
"time"
)
// FuncName - SQL function name.
type FuncName string
const (
// Avg - aggregate SQL function AVG().
Avg FuncName = "AVG"
// Count - aggregate SQL function COUNT().
Count FuncName = "COUNT"
// Max - aggregate SQL function MAX().
Max FuncName = "MAX"
// Min - aggregate SQL function MIN().
Min FuncName = "MIN"
// Sum - aggregate SQL function SUM().
Sum FuncName = "SUM"
// Coalesce - conditional SQL function COALESCE().
Coalesce FuncName = "COALESCE"
// NullIf - conditional SQL function NULLIF().
NullIf FuncName = "NULLIF"
// ToTimestamp - conversion SQL function TO_TIMESTAMP().
ToTimestamp FuncName = "TO_TIMESTAMP"
// UTCNow - date SQL function UTCNOW().
UTCNow FuncName = "UTCNOW"
// CharLength - string SQL function CHAR_LENGTH().
CharLength FuncName = "CHAR_LENGTH"
// CharacterLength - string SQL function CHARACTER_LENGTH() same as CHAR_LENGTH().
CharacterLength FuncName = "CHARACTER_LENGTH"
// Lower - string SQL function LOWER().
Lower FuncName = "LOWER"
// Substring - string SQL function SUBSTRING().
Substring FuncName = "SUBSTRING"
// Trim - string SQL function TRIM().
Trim FuncName = "TRIM"
// Upper - string SQL function UPPER().
Upper FuncName = "UPPER"
// DateAdd FuncName = "DATE_ADD"
// DateDiff FuncName = "DATE_DIFF"
// Extract FuncName = "EXTRACT"
// ToString FuncName = "TO_STRING"
// Cast FuncName = "CAST" // CAST('2007-04-05T14:30Z' AS TIMESTAMP)
)
func isAggregateFuncName(s string) bool {
switch FuncName(s) {
case Avg, Count, Max, Min, Sum:
return true
}
return false
}
func callForNumber(f Expr, record Record) (*Value, error) {
value, err := f.Eval(record)
if err != nil {
return nil, err
}
if !value.Type().isNumber() {
err := fmt.Errorf("%v evaluated to %v; not to number", f, value.Type())
return nil, errExternalEvalException(err)
}
return value, nil
}
func callForInt(f Expr, record Record) (*Value, error) {
value, err := f.Eval(record)
if err != nil {
return nil, err
}
if value.Type() != Int {
err := fmt.Errorf("%v evaluated to %v; not to int", f, value.Type())
return nil, errExternalEvalException(err)
}
return value, nil
}
func callForString(f Expr, record Record) (*Value, error) {
value, err := f.Eval(record)
if err != nil {
return nil, err
}
if value.Type() != String {
err := fmt.Errorf("%v evaluated to %v; not to string", f, value.Type())
return nil, errExternalEvalException(err)
}
return value, nil
}
// funcExpr - SQL function.
type funcExpr struct {
args []Expr
name FuncName
sumValue float64
countValue int64
maxValue float64
minValue float64
}
// String - returns string representation of this function.
func (f *funcExpr) String() string {
var argStrings []string
for _, arg := range f.args {
argStrings = append(argStrings, fmt.Sprintf("%v", arg))
}
return fmt.Sprintf("%v(%v)", f.name, strings.Join(argStrings, ","))
}
func (f *funcExpr) sum(record Record) (*Value, error) {
value, err := callForNumber(f.args[0], record)
if err != nil {
return nil, err
}
f.sumValue += value.FloatValue()
f.countValue++
return nil, nil
}
func (f *funcExpr) count(record Record) (*Value, error) {
value, err := f.args[0].Eval(record)
if err != nil {
return nil, err
}
if value.valueType != Null {
f.countValue++
}
return nil, nil
}
func (f *funcExpr) max(record Record) (*Value, error) {
value, err := callForNumber(f.args[0], record)
if err != nil {
return nil, err
}
v := value.FloatValue()
if v > f.maxValue {
f.maxValue = v
}
return nil, nil
}
func (f *funcExpr) min(record Record) (*Value, error) {
value, err := callForNumber(f.args[0], record)
if err != nil {
return nil, err
}
v := value.FloatValue()
if v < f.minValue {
f.minValue = v
}
return nil, nil
}
func (f *funcExpr) charLength(record Record) (*Value, error) {
value, err := callForString(f.args[0], record)
if err != nil {
return nil, err
}
return NewInt(int64(len(value.StringValue()))), nil
}
func (f *funcExpr) trim(record Record) (*Value, error) {
value, err := callForString(f.args[0], record)
if err != nil {
return nil, err
}
return NewString(strings.TrimSpace(value.StringValue())), nil
}
func (f *funcExpr) lower(record Record) (*Value, error) {
value, err := callForString(f.args[0], record)
if err != nil {
return nil, err
}
return NewString(strings.ToLower(value.StringValue())), nil
}
func (f *funcExpr) upper(record Record) (*Value, error) {
value, err := callForString(f.args[0], record)
if err != nil {
return nil, err
}
return NewString(strings.ToUpper(value.StringValue())), nil
}
func (f *funcExpr) substring(record Record) (*Value, error) {
stringValue, err := callForString(f.args[0], record)
if err != nil {
return nil, err
}
offsetValue, err := callForInt(f.args[1], record)
if err != nil {
return nil, err
}
var lengthValue *Value
if len(f.args) == 3 {
lengthValue, err = callForInt(f.args[2], record)
if err != nil {
return nil, err
}
}
value := stringValue.StringValue()
offset := int(offsetValue.FloatValue())
if offset < 0 || offset > len(value) {
offset = 0
}
length := len(value)
if lengthValue != nil {
length = int(lengthValue.FloatValue())
if length < 0 || length > len(value) {
length = len(value)
}
}
return NewString(value[offset:length]), nil
}
func (f *funcExpr) coalesce(record Record) (*Value, error) {
values := make([]*Value, len(f.args))
var err error
for i := range f.args {
values[i], err = f.args[i].Eval(record)
if err != nil {
return nil, err
}
}
for i := range values {
if values[i].Type() != Null {
return values[i], nil
}
}
return values[0], nil
}
func (f *funcExpr) nullIf(record Record) (*Value, error) {
value1, err := f.args[0].Eval(record)
if err != nil {
return nil, err
}
value2, err := f.args[1].Eval(record)
if err != nil {
return nil, err
}
result, err := equal(value1, value2)
if err != nil {
return nil, err
}
if result {
return NewNull(), nil
}
return value1, nil
}
func (f *funcExpr) toTimeStamp(record Record) (*Value, error) {
value, err := callForString(f.args[0], record)
if err != nil {
return nil, err
}
t, err := time.Parse(time.RFC3339, value.StringValue())
if err != nil {
err := fmt.Errorf("%v: value '%v': %v", f, value, err)
return nil, errValueParseFailure(err)
}
return NewTime(t), nil
}
func (f *funcExpr) utcNow(record Record) (*Value, error) {
return NewTime(time.Now().UTC()), nil
}
// Call - evaluates this function for given arg values and returns result as Value.
func (f *funcExpr) Eval(record Record) (*Value, error) {
switch f.name {
case Avg, Sum:
return f.sum(record)
case Count:
return f.count(record)
case Max:
return f.max(record)
case Min:
return f.min(record)
case Coalesce:
return f.coalesce(record)
case NullIf:
return f.nullIf(record)
case ToTimestamp:
return f.toTimeStamp(record)
case UTCNow:
return f.utcNow(record)
case Substring:
return f.substring(record)
case CharLength, CharacterLength:
return f.charLength(record)
case Trim:
return f.trim(record)
case Lower:
return f.lower(record)
case Upper:
return f.upper(record)
}
panic(fmt.Sprintf("unsupported aggregate function %v", f.name))
}
// AggregateValue - returns aggregated value.
func (f *funcExpr) AggregateValue() (*Value, error) {
switch f.name {
case Avg:
return NewFloat(f.sumValue / float64(f.countValue)), nil
case Count:
return NewInt(f.countValue), nil
case Max:
return NewFloat(f.maxValue), nil
case Min:
return NewFloat(f.minValue), nil
case Sum:
return NewFloat(f.sumValue), nil
}
err := fmt.Errorf("%v is not aggreate function", f)
return nil, errExternalEvalException(err)
}
// Type - returns Function or aggregateFunction type.
func (f *funcExpr) Type() Type {
switch f.name {
case Avg, Count, Max, Min, Sum:
return aggregateFunction
}
return function
}
// ReturnType - returns respective primitive type depending on SQL function.
func (f *funcExpr) ReturnType() Type {
switch f.name {
case Avg, Max, Min, Sum:
return Float
case Count:
return Int
case CharLength, CharacterLength, Trim, Lower, Upper, Substring:
return String
case ToTimestamp, UTCNow:
return Timestamp
case Coalesce, NullIf:
return column
}
return function
}
// newFuncExpr - creates new SQL function.
func newFuncExpr(funcName FuncName, funcs ...Expr) (*funcExpr, error) {
switch funcName {
case Avg, Max, Min, Sum:
if len(funcs) != 1 {
err := fmt.Errorf("%v(): exactly one argument expected; got %v", funcName, len(funcs))
return nil, errParseNonUnaryAgregateFunctionCall(err)
}
if !funcs[0].ReturnType().isNumberKind() {
err := fmt.Errorf("%v(): argument %v evaluate to %v, not number", funcName, funcs[0], funcs[0].ReturnType())
return nil, errIncorrectSQLFunctionArgumentType(err)
}
return &funcExpr{
args: funcs,
name: funcName,
}, nil
case Count:
if len(funcs) != 1 {
err := fmt.Errorf("%v(): exactly one argument expected; got %v", funcName, len(funcs))
return nil, errParseNonUnaryAgregateFunctionCall(err)
}
switch funcs[0].ReturnType() {
case Null, Bool, Int, Float, String, Timestamp, column, record:
default:
err := fmt.Errorf("%v(): argument %v evaluate to %v is incompatible", funcName, funcs[0], funcs[0].ReturnType())
return nil, errIncorrectSQLFunctionArgumentType(err)
}
return &funcExpr{
args: funcs,
name: funcName,
}, nil
case CharLength, CharacterLength, Trim, Lower, Upper, ToTimestamp:
if len(funcs) != 1 {
err := fmt.Errorf("%v(): exactly one argument expected; got %v", funcName, len(funcs))
return nil, errEvaluatorInvalidArguments(err)
}
if !funcs[0].ReturnType().isStringKind() {
err := fmt.Errorf("%v(): argument %v evaluate to %v, not string", funcName, funcs[0], funcs[0].ReturnType())
return nil, errIncorrectSQLFunctionArgumentType(err)
}
return &funcExpr{
args: funcs,
name: funcName,
}, nil
case Coalesce:
if len(funcs) < 1 {
err := fmt.Errorf("%v(): one or more argument expected; got %v", funcName, len(funcs))
return nil, errEvaluatorInvalidArguments(err)
}
for i := range funcs {
if !funcs[i].ReturnType().isBaseKind() {
err := fmt.Errorf("%v(): argument-%v %v evaluate to %v is incompatible", funcName, i+1, funcs[i], funcs[i].ReturnType())
return nil, errIncorrectSQLFunctionArgumentType(err)
}
}
return &funcExpr{
args: funcs,
name: funcName,
}, nil
case NullIf:
if len(funcs) != 2 {
err := fmt.Errorf("%v(): exactly two arguments expected; got %v", funcName, len(funcs))
return nil, errEvaluatorInvalidArguments(err)
}
if !funcs[0].ReturnType().isBaseKind() {
err := fmt.Errorf("%v(): argument-1 %v evaluate to %v is incompatible", funcName, funcs[0], funcs[0].ReturnType())
return nil, errIncorrectSQLFunctionArgumentType(err)
}
if !funcs[1].ReturnType().isBaseKind() {
err := fmt.Errorf("%v(): argument-2 %v evaluate to %v is incompatible", funcName, funcs[1], funcs[1].ReturnType())
return nil, errIncorrectSQLFunctionArgumentType(err)
}
return &funcExpr{
args: funcs,
name: funcName,
}, nil
case UTCNow:
if len(funcs) != 0 {
err := fmt.Errorf("%v(): no argument expected; got %v", funcName, len(funcs))
return nil, errEvaluatorInvalidArguments(err)
}
return &funcExpr{
args: funcs,
name: funcName,
}, nil
case Substring:
if len(funcs) < 2 || len(funcs) > 3 {
err := fmt.Errorf("%v(): exactly two or three arguments expected; got %v", funcName, len(funcs))
return nil, errEvaluatorInvalidArguments(err)
}
if !funcs[0].ReturnType().isStringKind() {
err := fmt.Errorf("%v(): argument-1 %v evaluate to %v, not string", funcName, funcs[0], funcs[0].ReturnType())
return nil, errIncorrectSQLFunctionArgumentType(err)
}
if !funcs[1].ReturnType().isIntKind() {
err := fmt.Errorf("%v(): argument-2 %v evaluate to %v, not int", funcName, funcs[1], funcs[1].ReturnType())
return nil, errIncorrectSQLFunctionArgumentType(err)
}
if len(funcs) > 2 {
if !funcs[2].ReturnType().isIntKind() {
err := fmt.Errorf("%v(): argument-3 %v evaluate to %v, not int", funcName, funcs[2], funcs[2].ReturnType())
return nil, errIncorrectSQLFunctionArgumentType(err)
}
}
return &funcExpr{
args: funcs,
name: funcName,
}, nil
}
return nil, errUnsupportedFunction(fmt.Errorf("unknown function name %v", funcName))
}

View File

@ -1,336 +0,0 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import "fmt"
// andExpr - logical AND function.
type andExpr struct {
left Expr
right Expr
funcType Type
}
// String - returns string representation of this function.
func (f *andExpr) String() string {
return fmt.Sprintf("(%v AND %v)", f.left, f.right)
}
// Call - evaluates this function for given arg values and returns result as Value.
func (f *andExpr) Eval(record Record) (*Value, error) {
leftValue, err := f.left.Eval(record)
if err != nil {
return nil, err
}
if f.funcType == aggregateFunction {
_, err = f.right.Eval(record)
return nil, err
}
if leftValue.Type() != Bool {
err := fmt.Errorf("%v: left side expression evaluated to %v; not to bool", f, leftValue.Type())
return nil, errExternalEvalException(err)
}
if !leftValue.BoolValue() {
return leftValue, nil
}
rightValue, err := f.right.Eval(record)
if err != nil {
return nil, err
}
if rightValue.Type() != Bool {
err := fmt.Errorf("%v: right side expression evaluated to %v; not to bool", f, rightValue.Type())
return nil, errExternalEvalException(err)
}
return rightValue, nil
}
// AggregateValue - returns aggregated value.
func (f *andExpr) AggregateValue() (*Value, error) {
if f.funcType != aggregateFunction {
err := fmt.Errorf("%v is not aggreate expression", f)
return nil, errExternalEvalException(err)
}
leftValue, err := f.left.AggregateValue()
if err != nil {
return nil, err
}
if leftValue.Type() != Bool {
err := fmt.Errorf("%v: left side expression evaluated to %v; not to bool", f, leftValue.Type())
return nil, errExternalEvalException(err)
}
if !leftValue.BoolValue() {
return leftValue, nil
}
rightValue, err := f.right.AggregateValue()
if err != nil {
return nil, err
}
if rightValue.Type() != Bool {
err := fmt.Errorf("%v: right side expression evaluated to %v; not to bool", f, rightValue.Type())
return nil, errExternalEvalException(err)
}
return rightValue, nil
}
// Type - returns logicalFunction or aggregateFunction type.
func (f *andExpr) Type() Type {
return f.funcType
}
// ReturnType - returns Bool as return type.
func (f *andExpr) ReturnType() Type {
return Bool
}
// newAndExpr - creates new AND logical function.
func newAndExpr(left, right Expr) (*andExpr, error) {
if !left.ReturnType().isBoolKind() {
err := fmt.Errorf("operator AND: left side expression %v evaluate to %v, not bool", left, left.ReturnType())
return nil, errInvalidDataType(err)
}
if !right.ReturnType().isBoolKind() {
err := fmt.Errorf("operator AND: right side expression %v evaluate to %v; not bool", right, right.ReturnType())
return nil, errInvalidDataType(err)
}
funcType := logicalFunction
if left.Type() == aggregateFunction || right.Type() == aggregateFunction {
funcType = aggregateFunction
if left.Type() == column {
err := fmt.Errorf("operator AND: left side expression %v return type %v is incompatible for aggregate evaluation", left, left.Type())
return nil, errUnsupportedSQLOperation(err)
}
if right.Type() == column {
err := fmt.Errorf("operator AND: right side expression %v return type %v is incompatible for aggregate evaluation", right, right.Type())
return nil, errUnsupportedSQLOperation(err)
}
}
return &andExpr{
left: left,
right: right,
funcType: funcType,
}, nil
}
// orExpr - logical OR function.
type orExpr struct {
left Expr
right Expr
funcType Type
}
// String - returns string representation of this function.
func (f *orExpr) String() string {
return fmt.Sprintf("(%v OR %v)", f.left, f.right)
}
// Call - evaluates this function for given arg values and returns result as Value.
func (f *orExpr) Eval(record Record) (*Value, error) {
leftValue, err := f.left.Eval(record)
if err != nil {
return nil, err
}
if f.funcType == aggregateFunction {
_, err = f.right.Eval(record)
return nil, err
}
if leftValue.Type() != Bool {
err := fmt.Errorf("%v: left side expression evaluated to %v; not to bool", f, leftValue.Type())
return nil, errExternalEvalException(err)
}
if leftValue.BoolValue() {
return leftValue, nil
}
rightValue, err := f.right.Eval(record)
if err != nil {
return nil, err
}
if rightValue.Type() != Bool {
err := fmt.Errorf("%v: right side expression evaluated to %v; not to bool", f, rightValue.Type())
return nil, errExternalEvalException(err)
}
return rightValue, nil
}
// AggregateValue - returns aggregated value.
func (f *orExpr) AggregateValue() (*Value, error) {
if f.funcType != aggregateFunction {
err := fmt.Errorf("%v is not aggreate expression", f)
return nil, errExternalEvalException(err)
}
leftValue, err := f.left.AggregateValue()
if err != nil {
return nil, err
}
if leftValue.Type() != Bool {
err := fmt.Errorf("%v: left side expression evaluated to %v; not to bool", f, leftValue.Type())
return nil, errExternalEvalException(err)
}
if leftValue.BoolValue() {
return leftValue, nil
}
rightValue, err := f.right.AggregateValue()
if err != nil {
return nil, err
}
if rightValue.Type() != Bool {
err := fmt.Errorf("%v: right side expression evaluated to %v; not to bool", f, rightValue.Type())
return nil, errExternalEvalException(err)
}
return rightValue, nil
}
// Type - returns logicalFunction or aggregateFunction type.
func (f *orExpr) Type() Type {
return f.funcType
}
// ReturnType - returns Bool as return type.
func (f *orExpr) ReturnType() Type {
return Bool
}
// newOrExpr - creates new OR logical function.
func newOrExpr(left, right Expr) (*orExpr, error) {
if !left.ReturnType().isBoolKind() {
err := fmt.Errorf("operator OR: left side expression %v evaluate to %v, not bool", left, left.ReturnType())
return nil, errInvalidDataType(err)
}
if !right.ReturnType().isBoolKind() {
err := fmt.Errorf("operator OR: right side expression %v evaluate to %v; not bool", right, right.ReturnType())
return nil, errInvalidDataType(err)
}
funcType := logicalFunction
if left.Type() == aggregateFunction || right.Type() == aggregateFunction {
funcType = aggregateFunction
if left.Type() == column {
err := fmt.Errorf("operator OR: left side expression %v return type %v is incompatible for aggregate evaluation", left, left.Type())
return nil, errUnsupportedSQLOperation(err)
}
if right.Type() == column {
err := fmt.Errorf("operator OR: right side expression %v return type %v is incompatible for aggregate evaluation", right, right.Type())
return nil, errUnsupportedSQLOperation(err)
}
}
return &orExpr{
left: left,
right: right,
funcType: funcType,
}, nil
}
// notExpr - logical NOT function.
type notExpr struct {
right Expr
funcType Type
}
// String - returns string representation of this function.
func (f *notExpr) String() string {
return fmt.Sprintf("(%v)", f.right)
}
// Call - evaluates this function for given arg values and returns result as Value.
func (f *notExpr) Eval(record Record) (*Value, error) {
rightValue, err := f.right.Eval(record)
if err != nil {
return nil, err
}
if f.funcType == aggregateFunction {
return nil, nil
}
if rightValue.Type() != Bool {
err := fmt.Errorf("%v: right side expression evaluated to %v; not to bool", f, rightValue.Type())
return nil, errExternalEvalException(err)
}
return NewBool(!rightValue.BoolValue()), nil
}
// AggregateValue - returns aggregated value.
func (f *notExpr) AggregateValue() (*Value, error) {
if f.funcType != aggregateFunction {
err := fmt.Errorf("%v is not aggreate expression", f)
return nil, errExternalEvalException(err)
}
rightValue, err := f.right.AggregateValue()
if err != nil {
return nil, err
}
if rightValue.Type() != Bool {
err := fmt.Errorf("%v: right side expression evaluated to %v; not to bool", f, rightValue.Type())
return nil, errExternalEvalException(err)
}
return NewBool(!rightValue.BoolValue()), nil
}
// Type - returns logicalFunction or aggregateFunction type.
func (f *notExpr) Type() Type {
return f.funcType
}
// ReturnType - returns Bool as return type.
func (f *notExpr) ReturnType() Type {
return Bool
}
// newNotExpr - creates new NOT logical function.
func newNotExpr(right Expr) (*notExpr, error) {
if !right.ReturnType().isBoolKind() {
err := fmt.Errorf("operator NOT: right side expression %v evaluate to %v; not bool", right, right.ReturnType())
return nil, errInvalidDataType(err)
}
funcType := logicalFunction
if right.Type() == aggregateFunction {
funcType = aggregateFunction
}
return &notExpr{
right: right,
funcType: funcType,
}, nil
}

329
pkg/s3select/sql/parser.go Normal file
View File

@ -0,0 +1,329 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import (
"strings"
"github.com/alecthomas/participle"
"github.com/alecthomas/participle/lexer"
)
// Types with custom Capture interface for parsing
// Boolean is a type for a parsed Boolean literal
type Boolean bool
// Capture interface used by participle
func (b *Boolean) Capture(values []string) error {
*b = strings.ToLower(values[0]) == "true"
return nil
}
// LiteralString is a type for parsed SQL string literals
type LiteralString string
// Capture interface used by participle
func (ls *LiteralString) Capture(values []string) error {
// Remove enclosing single quote
n := len(values[0])
r := values[0][1 : n-1]
// Translate doubled quotes
*ls = LiteralString(strings.Replace(r, "''", "'", -1))
return nil
}
// ObjectKey is a type for parsed strings occurring in key paths
type ObjectKey struct {
Lit *LiteralString `parser:" \"[\" @LitString \"]\""`
ID *Identifier `parser:"| \".\" @@"`
}
// QuotedIdentifier is a type for parsed strings that are double
// quoted.
type QuotedIdentifier string
// Capture inferface used by participle
func (qi *QuotedIdentifier) Capture(values []string) error {
// Remove enclosing quotes
n := len(values[0])
r := values[0][1 : n-1]
// Translate doubled quotes
*qi = QuotedIdentifier(strings.Replace(r, `""`, `"`, -1))
return nil
}
// Types representing AST of SQL statement. Only SELECT is supported.
// Select is the top level AST node type
type Select struct {
Expression *SelectExpression `parser:"\"SELECT\" @@"`
From *TableExpression `parser:"\"FROM\" @@"`
Where *Expression `parser:"[ \"WHERE\" @@ ]"`
Limit *LitValue `parser:"[ \"LIMIT\" @@ ]"`
}
// SelectExpression represents the items requested in the select
// statement
type SelectExpression struct {
All bool `parser:" @\"*\""`
Expressions []*AliasedExpression `parser:"| @@ { \",\" @@ }"`
prop qProp
}
// TableExpression represents the FROM clause
type TableExpression struct {
Table *JSONPath `parser:"@@"`
As string `parser:"( \"AS\"? @Ident )?"`
}
// JSONPathElement represents a keypath component
type JSONPathElement struct {
Key *ObjectKey `parser:" @@"` // ['name'] and .name forms
Index *uint64 `parser:"| \"[\" @Number \"]\""` // [3] form
ObjectWildcard bool `parser:"| @\".*\""` // .* form
ArrayWildcard bool `parser:"| @\"[*]\""` // [*] form
}
// JSONPath represents a keypath
type JSONPath struct {
BaseKey *Identifier `parser:" @@"`
PathExpr []*JSONPathElement `parser:"(@@)*"`
}
// AliasedExpression is an expression that can be optionally named
type AliasedExpression struct {
Expression *Expression `parser:"@@"`
As string `parser:"[ \"AS\" @Ident ]"`
}
// Grammar for Expression
//
// Expression → AndCondition ("OR" AndCondition)*
// AndCondition → Condition ("AND" Condition)*
// Condition → "NOT" Condition | ConditionExpression
// ConditionExpression → ValueExpression ("=" | "<>" | "<=" | ">=" | "<" | ">") ValueExpression
// | ValueExpression "LIKE" ValueExpression ("ESCAPE" LitString)?
// | ValueExpression ("NOT"? "BETWEEN" ValueExpression "AND" ValueExpression)
// | ValueExpression "IN" "(" Expression ("," Expression)* ")"
// | ValueExpression
// ValueExpression → Operand
//
// Operand grammar follows below
// Expression represents a logical disjunction of clauses
type Expression struct {
And []*AndCondition `parser:"@@ ( \"OR\" @@ )*"`
}
// AndCondition represents logical conjunction of clauses
type AndCondition struct {
Condition []*Condition `parser:"@@ ( \"AND\" @@ )*"`
}
// Condition represents a negation or a condition operand
type Condition struct {
Operand *ConditionOperand `parser:" @@"`
Not *Condition `parser:"| \"NOT\" @@"`
}
// ConditionOperand is a operand followed by an an optional operation
// expression
type ConditionOperand struct {
Operand *Operand `parser:"@@"`
ConditionRHS *ConditionRHS `parser:"@@?"`
}
// ConditionRHS represents the right-hand-side of Compare, Between, In
// or Like expressions.
type ConditionRHS struct {
Compare *Compare `parser:" @@"`
Between *Between `parser:"| @@"`
In *In `parser:"| \"IN\" \"(\" @@ \")\""`
Like *Like `parser:"| @@"`
}
// Compare represents the RHS of a comparison expression
type Compare struct {
Operator string `parser:"@( \"<>\" | \"<=\" | \">=\" | \"=\" | \"<\" | \">\" | \"!=\" )"`
Operand *Operand `parser:" @@"`
}
// Like represents the RHS of a LIKE expression
type Like struct {
Not bool `parser:" @\"NOT\"? "`
Pattern *Operand `parser:" \"LIKE\" @@ "`
EscapeChar *Operand `parser:" (\"ESCAPE\" @@)? "`
}
// Between represents the RHS of a BETWEEN expression
type Between struct {
Not bool `parser:" @\"NOT\"? "`
Start *Operand `parser:" \"BETWEEN\" @@ "`
End *Operand `parser:" \"AND\" @@ "`
}
// In represents the RHS of an IN expression
type In struct {
Expressions []*Expression `parser:"@@ ( \",\" @@ )*"`
}
// Grammar for Operand:
//
// operand → multOp ( ("-" | "+") multOp )*
// multOp → unary ( ("/" | "*" | "%") unary )*
// unary → "-" unary | primary
// primary → Value | Variable | "(" expression ")"
//
// An Operand is a single term followed by an optional sequence of
// terms separated by +/-
type Operand struct {
Left *MultOp `parser:"@@"`
Right []*OpFactor `parser:"(@@)*"`
}
// OpFactor represents the right-side of a +/- operation.
type OpFactor struct {
Op string `parser:"@(\"+\" | \"-\")"`
Right *MultOp `parser:"@@"`
}
// MultOp represents a single term followed by an optional sequence of
// terms separated by *, / or % operators.
type MultOp struct {
Left *UnaryTerm `parser:"@@"`
Right []*OpUnaryTerm `parser:"(@@)*"`
}
// OpUnaryTerm represents the right side of *, / or % binary operations.
type OpUnaryTerm struct {
Op string `parser:"@(\"*\" | \"/\" | \"%\")"`
Right *UnaryTerm `parser:"@@"`
}
// UnaryTerm represents a single negated term or a primary term
type UnaryTerm struct {
Negated *NegatedTerm `parser:" @@"`
Primary *PrimaryTerm `parser:"| @@"`
}
// NegatedTerm has a leading minus sign.
type NegatedTerm struct {
Term *PrimaryTerm `parser:"\"-\" @@"`
}
// PrimaryTerm represents a Value, Path expression, a Sub-expression
// or a function call.
type PrimaryTerm struct {
Value *LitValue `parser:" @@"`
JPathExpr *JSONPath `parser:"| @@"`
SubExpression *Expression `parser:"| \"(\" @@ \")\""`
// Include function expressions here.
FuncCall *FuncExpr `parser:"| @@"`
}
// FuncExpr represents a function call
type FuncExpr struct {
SFunc *SimpleArgFunc `parser:" @@"`
Count *CountFunc `parser:"| @@"`
Cast *CastFunc `parser:"| @@"`
Substring *SubstringFunc `parser:"| @@"`
Extract *ExtractFunc `parser:"| @@"`
Trim *TrimFunc `parser:"| @@"`
// Used during evaluation for aggregation funcs
aggregate *aggVal
}
// SimpleArgFunc represents functions with simple expression
// arguments.
type SimpleArgFunc struct {
FunctionName string `parser:" @(\"AVG\" | \"MAX\" | \"MIN\" | \"SUM\" | \"COALESCE\" | \"NULLIF\" | \"DATE_ADD\" | \"DATE_DIFF\" | \"TO_STRING\" | \"TO_TIMESTAMP\" | \"UTCNOW\" | \"CHAR_LENGTH\" | \"CHARACTER_LENGTH\" | \"LOWER\" | \"UPPER\") "`
ArgsList []*Expression `parser:"\"(\" (@@ (\",\" @@)*)?\")\""`
}
// CountFunc represents the COUNT sql function
type CountFunc struct {
StarArg bool `parser:" \"COUNT\" \"(\" ( @\"*\"?"`
ExprArg *Expression `parser:" @@? )! \")\""`
}
// CastFunc represents CAST sql function
type CastFunc struct {
Expr *Expression `parser:" \"CAST\" \"(\" @@ "`
CastType string `parser:" \"AS\" @(\"BOOL\" | \"INT\" | \"INTEGER\" | \"STRING\" | \"FLOAT\" | \"DECIMAL\" | \"NUMERIC\" | \"TIMESTAMP\") \")\" "`
}
// SubstringFunc represents SUBSTRING sql function
type SubstringFunc struct {
Expr *PrimaryTerm `parser:" \"SUBSTRING\" \"(\" @@ "`
From *Operand `parser:" ( \"FROM\" @@ "`
For *Operand `parser:" (\"FOR\" @@)? \")\" "`
Arg2 *Operand `parser:" | \",\" @@ "`
Arg3 *Operand `parser:" (\",\" @@)? \")\" )"`
}
// ExtractFunc represents EXTRACT sql function
type ExtractFunc struct {
Timeword string `parser:" \"EXTRACT\" \"(\" @( \"YEAR\":Timeword | \"MONTH\":Timeword | \"DAY\":Timeword | \"HOUR\":Timeword | \"MINUTE\":Timeword | \"SECOND\":Timeword | \"TIMEZONE_HOUR\":Timeword | \"TIMEZONE_MINUTE\":Timeword ) "`
From *PrimaryTerm `parser:" \"FROM\" @@ \")\" "`
}
// TrimFunc represents TRIM sql function
type TrimFunc struct {
TrimWhere *string `parser:" \"TRIM\" \"(\" ( @( \"LEADING\" | \"TRAILING\" | \"BOTH\" ) "`
TrimChars *PrimaryTerm `parser:" @@? "`
TrimFrom *PrimaryTerm `parser:" \"FROM\" )? @@ \")\" "`
}
// LitValue represents a literal value parsed from the sql
type LitValue struct {
Number *float64 `parser:"( @Number"`
String *LiteralString `parser:" | @LitString"`
Boolean *Boolean `parser:" | @(\"TRUE\" | \"FALSE\")"`
Null bool `parser:" | @\"NULL\")"`
}
// Identifier represents a parsed identifier
type Identifier struct {
Unquoted *string `parser:" @Ident"`
Quoted *QuotedIdentifier `parser:"| @QuotIdent"`
}
var (
sqlLexer = lexer.Must(lexer.Regexp(`(\s+)` +
`|(?P<Timeword>(?i)\b(?:YEAR|MONTH|DAY|HOUR|MINUTE|SECOND|TIMEZONE_HOUR|TIMEZONE_MINUTE)\b)` +
`|(?P<Keyword>(?i)\b(?:SELECT|FROM|TOP|DISTINCT|ALL|WHERE|GROUP|BY|HAVING|UNION|MINUS|EXCEPT|INTERSECT|ORDER|LIMIT|OFFSET|TRUE|FALSE|NULL|IS|NOT|ANY|SOME|BETWEEN|AND|OR|LIKE|ESCAPE|AS|IN|BOOL|INT|INTEGER|STRING|FLOAT|DECIMAL|NUMERIC|TIMESTAMP|AVG|COUNT|MAX|MIN|SUM|COALESCE|NULLIF|CAST|DATE_ADD|DATE_DIFF|EXTRACT|TO_STRING|TO_TIMESTAMP|UTCNOW|CHAR_LENGTH|CHARACTER_LENGTH|LOWER|SUBSTRING|TRIM|UPPER|LEADING|TRAILING|BOTH|FOR)\b)` +
`|(?P<Ident>[a-zA-Z_][a-zA-Z0-9_]*)` +
`|(?P<QuotIdent>"([^"]*("")?)*")` +
`|(?P<Number>\d*\.?\d+([eE][-+]?\d+)?)` +
`|(?P<LitString>'([^']*('')?)*')` +
`|(?P<Operators><>|!=|<=|>=|\.\*|\[\*\]|[-+*/%,.()=<>\[\]])`,
))
// SQLParser is used to parse SQL statements
SQLParser = participle.MustBuild(
&Select{},
participle.Lexer(sqlLexer),
participle.CaseInsensitive("Keyword"),
participle.CaseInsensitive("Timeword"),
)
)

View File

@ -0,0 +1,383 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import (
"bytes"
"testing"
"github.com/alecthomas/participle"
"github.com/alecthomas/participle/lexer"
)
func TestJSONPathElement(t *testing.T) {
p := participle.MustBuild(
&JSONPathElement{},
participle.Lexer(sqlLexer),
participle.CaseInsensitive("Keyword"),
)
j := JSONPathElement{}
cases := []string{
// Key
"['name']", ".name", `."name"`,
// Index
"[2]", "[0]", "[100]",
// Object wilcard
".*",
// array wildcard
"[*]",
}
for i, tc := range cases {
err := p.ParseString(tc, &j)
if err != nil {
t.Fatalf("%d: %v", i, err)
}
// repr.Println(j, repr.Indent(" "), repr.OmitEmpty(true))
}
}
func TestJSONPath(t *testing.T) {
p := participle.MustBuild(
&JSONPath{},
participle.Lexer(sqlLexer),
participle.CaseInsensitive("Keyword"),
)
j := JSONPath{}
cases := []string{
"S3Object",
"S3Object.id",
"S3Object.book.title",
"S3Object.id[1]",
"S3Object.id['abc']",
"S3Object.id['ab']",
"S3Object.words.*.id",
"S3Object.words.name[*].val",
"S3Object.words.name[*].val[*]",
"S3Object.words.name[*].val.*",
}
for i, tc := range cases {
err := p.ParseString(tc, &j)
if err != nil {
t.Fatalf("%d: %v", i, err)
}
// repr.Println(j, repr.Indent(" "), repr.OmitEmpty(true))
}
}
func TestIdentifierParsing(t *testing.T) {
p := participle.MustBuild(
&Identifier{},
participle.Lexer(sqlLexer),
participle.CaseInsensitive("Keyword"),
)
id := Identifier{}
validCases := []string{
"a",
"_a",
"abc_a",
"a2",
`"abc"`,
`"abc\a""ac"`,
}
for i, tc := range validCases {
err := p.ParseString(tc, &id)
if err != nil {
t.Fatalf("%d: %v", i, err)
}
// repr.Println(id, repr.Indent(" "), repr.OmitEmpty(true))
}
invalidCases := []string{
"+a",
"-a",
"1a",
`"ab`,
`abc"`,
`aa""a`,
`"a"a"`,
}
for i, tc := range invalidCases {
err := p.ParseString(tc, &id)
if err == nil {
t.Fatalf("%d: %v", i, err)
}
// fmt.Println(tc, err)
}
}
func TestLiteralStringParsing(t *testing.T) {
var k ObjectKey
p := participle.MustBuild(
&ObjectKey{},
participle.Lexer(sqlLexer),
participle.CaseInsensitive("Keyword"),
)
validCases := []string{
"['abc']",
"['ab''c']",
"['a''b''c']",
"['abc-x_1##@(*&(#*))/\\']",
}
for i, tc := range validCases {
err := p.ParseString(tc, &k)
if err != nil {
t.Fatalf("%d: %v", i, err)
}
if string(*k.Lit) == "" {
t.Fatalf("Incorrect parse %#v", k)
}
// repr.Println(k, repr.Indent(" "), repr.OmitEmpty(true))
}
invalidCases := []string{
"['abc'']",
"['-abc'sc']",
"[abc']",
"['ac]",
}
for i, tc := range invalidCases {
err := p.ParseString(tc, &k)
if err == nil {
t.Fatalf("%d: %v", i, err)
}
// fmt.Println(tc, err)
}
}
func TestFunctionParsing(t *testing.T) {
var fex FuncExpr
p := participle.MustBuild(
&FuncExpr{},
participle.Lexer(sqlLexer),
participle.CaseInsensitive("Keyword"),
participle.CaseInsensitive("Timeword"),
)
validCases := []string{
"count(*)",
"sum(2 + s.id)",
"sum(t)",
"avg(s.id[1])",
"coalesce(s.id[1], 2, 2 + 3)",
"cast(s as string)",
"cast(s AS INT)",
"cast(s as DECIMAL)",
"extract(YEAR from '2018-01-09')",
"extract(month from '2018-01-09')",
"extract(hour from '2018-01-09')",
"extract(day from '2018-01-09')",
"substring('abcd' from 2 for 2)",
"substring('abcd' from 2)",
"substring('abcd' , 2 , 2)",
"substring('abcd' , 22 )",
"trim(' aab ')",
"trim(leading from ' aab ')",
"trim(trailing from ' aab ')",
"trim(both from ' aab ')",
"trim(both '12' from ' aab ')",
"trim(leading '12' from ' aab ')",
"trim(trailing '12' from ' aab ')",
"count(23)",
}
for i, tc := range validCases {
err := p.ParseString(tc, &fex)
if err != nil {
t.Fatalf("%d: %v", i, err)
}
// repr.Println(fex, repr.Indent(" "), repr.OmitEmpty(true))
}
}
func TestSqlLexer(t *testing.T) {
// s := bytes.NewBuffer([]byte("s.['name'].*.[*].abc.[\"abc\"]"))
s := bytes.NewBuffer([]byte("S3Object.words.*.id"))
// s := bytes.NewBuffer([]byte("COUNT(Id)"))
lex, err := sqlLexer.Lex(s)
if err != nil {
t.Fatal(err)
}
tokens, err := lexer.ConsumeAll(lex)
if err != nil {
t.Fatal(err)
}
// for i, t := range tokens {
// fmt.Printf("%d: %#v\n", i, t)
// }
if len(tokens) != 7 {
t.Fatalf("Expected 7 got %d", len(tokens))
}
}
func TestSelectWhere(t *testing.T) {
p := participle.MustBuild(
&Select{},
participle.Lexer(sqlLexer),
participle.CaseInsensitive("Keyword"),
)
s := Select{}
cases := []string{
"select * from s3object",
"select a, b from s3object s",
"select a, b from s3object as s",
"select a, b from s3object as s where a = 1",
"select a, b from s3object s where a = 1",
"select a, b from s3object where a = 1",
}
for i, tc := range cases {
err := p.ParseString(tc, &s)
if err != nil {
t.Fatalf("%d: %v", i, err)
}
// repr.Println(s, repr.Indent(" "), repr.OmitEmpty(true))
}
}
func TestLikeClause(t *testing.T) {
p := participle.MustBuild(
&Select{},
participle.Lexer(sqlLexer),
participle.CaseInsensitive("Keyword"),
)
s := Select{}
cases := []string{
`select * from s3object where Name like 'abcd'`,
`select Name like 'abc' from s3object`,
`select * from s3object where Name not like 'abc'`,
`select * from s3object where Name like 'abc' escape 't'`,
`select * from s3object where Name like 'a\%' escape '?'`,
`select * from s3object where Name not like 'abc\' escape '?'`,
`select * from s3object where Name like 'a\%' escape LOWER('?')`,
`select * from s3object where Name not like LOWER('Bc\') escape '?'`,
}
for i, tc := range cases {
err := p.ParseString(tc, &s)
if err != nil {
t.Errorf("%d: %v", i, err)
}
}
}
func TestBetweenClause(t *testing.T) {
p := participle.MustBuild(
&Select{},
participle.Lexer(sqlLexer),
participle.CaseInsensitive("Keyword"),
)
s := Select{}
cases := []string{
`select * from s3object where Id between 1 and 2`,
`select * from s3object where Id between 1 and 2 and name = 'Ab'`,
`select * from s3object where Id not between 1 and 2`,
`select * from s3object where Id not between 1 and 2 and name = 'Bc'`,
}
for i, tc := range cases {
err := p.ParseString(tc, &s)
if err != nil {
t.Errorf("%d: %v", i, err)
}
}
}
func TestFromClauseJSONPath(t *testing.T) {
p := participle.MustBuild(
&Select{},
participle.Lexer(sqlLexer),
participle.CaseInsensitive("Keyword"),
)
s := Select{}
cases := []string{
"select * from s3object",
"select * from s3object[*].name",
"select * from s3object[*].books[*]",
"select * from s3object[*].books[*].name",
"select * from s3object where name > 2",
"select * from s3object[*].name where name > 2",
"select * from s3object[*].books[*] where name > 2",
"select * from s3object[*].books[*].name where name > 2",
"select * from s3object[*].books[*] s",
"select * from s3object[*].books[*].name as s",
"select * from s3object s where name > 2",
"select * from s3object[*].name as s where name > 2",
}
for i, tc := range cases {
err := p.ParseString(tc, &s)
if err != nil {
t.Fatalf("%d: %v", i, err)
}
// repr.Println(s, repr.Indent(" "), repr.OmitEmpty(true))
}
}
func TestSelectParsing(t *testing.T) {
p := participle.MustBuild(
&Select{},
participle.Lexer(sqlLexer),
participle.CaseInsensitive("Keyword"),
)
s := Select{}
cases := []string{
"select * from s3object where name > 2 or value > 1 or word > 2",
"select s.word.id + 2 from s3object s",
"select 1-2-3 from s3object s limit 1",
}
for i, tc := range cases {
err := p.ParseString(tc, &s)
if err != nil {
t.Fatalf("%d: %v", i, err)
}
// repr.Println(s, repr.Indent(" "), repr.OmitEmpty(true))
}
}
func TestSqlLexerArithOps(t *testing.T) {
s := bytes.NewBuffer([]byte("year from select month hour distinct"))
lex, err := sqlLexer.Lex(s)
if err != nil {
t.Fatal(err)
}
tokens, err := lexer.ConsumeAll(lex)
if err != nil {
t.Fatal(err)
}
if len(tokens) != 7 {
t.Errorf("Expected 7 got %d", len(tokens))
}
// for i, t := range tokens {
// fmt.Printf("%d: %#v\n", i, t)
// }
}

View File

@ -1,529 +0,0 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import (
"fmt"
"strings"
"github.com/xwb1989/sqlparser"
)
func getColumnName(colName *sqlparser.ColName) string {
columnName := colName.Qualifier.Name.String()
if qualifier := colName.Qualifier.Qualifier.String(); qualifier != "" {
columnName = qualifier + "." + columnName
}
if columnName == "" {
columnName = colName.Name.String()
} else {
columnName = columnName + "." + colName.Name.String()
}
return columnName
}
func newLiteralExpr(parserExpr sqlparser.Expr, tableAlias string) (Expr, error) {
switch parserExpr.(type) {
case *sqlparser.NullVal:
return newValueExpr(NewNull()), nil
case sqlparser.BoolVal:
return newValueExpr(NewBool((bool(parserExpr.(sqlparser.BoolVal))))), nil
case *sqlparser.SQLVal:
sqlValue := parserExpr.(*sqlparser.SQLVal)
value, err := NewValue(sqlValue)
if err != nil {
return nil, err
}
return newValueExpr(value), nil
case *sqlparser.ColName:
columnName := getColumnName(parserExpr.(*sqlparser.ColName))
if tableAlias != "" {
if !strings.HasPrefix(columnName, tableAlias+".") {
err := fmt.Errorf("column name %v does not start with table alias %v", columnName, tableAlias)
return nil, errInvalidKeyPath(err)
}
columnName = strings.TrimPrefix(columnName, tableAlias+".")
}
return newColumnExpr(columnName), nil
case sqlparser.ValTuple:
var valueType Type
var values []*Value
for i, valExpr := range parserExpr.(sqlparser.ValTuple) {
sqlVal, ok := valExpr.(*sqlparser.SQLVal)
if !ok {
return nil, errParseInvalidTypeParam(fmt.Errorf("value %v in Tuple should be primitive value", i+1))
}
val, err := NewValue(sqlVal)
if err != nil {
return nil, err
}
if i == 0 {
valueType = val.Type()
} else if valueType != val.Type() {
return nil, errParseInvalidTypeParam(fmt.Errorf("mixed value type is not allowed in Tuple"))
}
values = append(values, val)
}
return newValueExpr(NewArray(values)), nil
}
return nil, nil
}
func isExprToComparisonExpr(parserExpr *sqlparser.IsExpr, tableAlias string, isSelectExpr bool) (Expr, error) {
leftExpr, err := newExpr(parserExpr.Expr, tableAlias, isSelectExpr)
if err != nil {
return nil, err
}
f, err := newComparisonExpr(ComparisonOperator(parserExpr.Operator), leftExpr)
if err != nil {
return nil, err
}
if !leftExpr.Type().isBase() {
return f, nil
}
value, err := f.Eval(nil)
if err != nil {
return nil, err
}
return newValueExpr(value), nil
}
func rangeCondToComparisonFunc(parserExpr *sqlparser.RangeCond, tableAlias string, isSelectExpr bool) (Expr, error) {
leftExpr, err := newExpr(parserExpr.Left, tableAlias, isSelectExpr)
if err != nil {
return nil, err
}
fromExpr, err := newExpr(parserExpr.From, tableAlias, isSelectExpr)
if err != nil {
return nil, err
}
toExpr, err := newExpr(parserExpr.To, tableAlias, isSelectExpr)
if err != nil {
return nil, err
}
f, err := newComparisonExpr(ComparisonOperator(parserExpr.Operator), leftExpr, fromExpr, toExpr)
if err != nil {
return nil, err
}
if !leftExpr.Type().isBase() || !fromExpr.Type().isBase() || !toExpr.Type().isBase() {
return f, nil
}
value, err := f.Eval(nil)
if err != nil {
return nil, err
}
return newValueExpr(value), nil
}
func toComparisonExpr(parserExpr *sqlparser.ComparisonExpr, tableAlias string, isSelectExpr bool) (Expr, error) {
leftExpr, err := newExpr(parserExpr.Left, tableAlias, isSelectExpr)
if err != nil {
return nil, err
}
rightExpr, err := newExpr(parserExpr.Right, tableAlias, isSelectExpr)
if err != nil {
return nil, err
}
f, err := newComparisonExpr(ComparisonOperator(parserExpr.Operator), leftExpr, rightExpr)
if err != nil {
return nil, err
}
if !leftExpr.Type().isBase() || !rightExpr.Type().isBase() {
return f, nil
}
value, err := f.Eval(nil)
if err != nil {
return nil, err
}
return newValueExpr(value), nil
}
func toArithExpr(parserExpr *sqlparser.BinaryExpr, tableAlias string, isSelectExpr bool) (Expr, error) {
leftExpr, err := newExpr(parserExpr.Left, tableAlias, isSelectExpr)
if err != nil {
return nil, err
}
rightExpr, err := newExpr(parserExpr.Right, tableAlias, isSelectExpr)
if err != nil {
return nil, err
}
f, err := newArithExpr(ArithOperator(parserExpr.Operator), leftExpr, rightExpr)
if err != nil {
return nil, err
}
if !leftExpr.Type().isBase() || !rightExpr.Type().isBase() {
return f, nil
}
value, err := f.Eval(nil)
if err != nil {
return nil, err
}
return newValueExpr(value), nil
}
func toFuncExpr(parserExpr *sqlparser.FuncExpr, tableAlias string, isSelectExpr bool) (Expr, error) {
funcName := strings.ToUpper(parserExpr.Name.String())
if !isSelectExpr && isAggregateFuncName(funcName) {
return nil, errUnsupportedSQLOperation(fmt.Errorf("%v() must be used in select expression", funcName))
}
funcs, aggregatedExprFound, err := newSelectExprs(parserExpr.Exprs, tableAlias)
if err != nil {
return nil, err
}
if aggregatedExprFound {
return nil, errIncorrectSQLFunctionArgumentType(fmt.Errorf("%v(): aggregated expression must not be used as argument", funcName))
}
return newFuncExpr(FuncName(funcName), funcs...)
}
func toAndExpr(parserExpr *sqlparser.AndExpr, tableAlias string, isSelectExpr bool) (Expr, error) {
leftExpr, err := newExpr(parserExpr.Left, tableAlias, isSelectExpr)
if err != nil {
return nil, err
}
rightExpr, err := newExpr(parserExpr.Right, tableAlias, isSelectExpr)
if err != nil {
return nil, err
}
f, err := newAndExpr(leftExpr, rightExpr)
if err != nil {
return nil, err
}
if leftExpr.Type() != Bool || rightExpr.Type() != Bool {
return f, nil
}
value, err := f.Eval(nil)
if err != nil {
return nil, err
}
return newValueExpr(value), nil
}
func toOrExpr(parserExpr *sqlparser.OrExpr, tableAlias string, isSelectExpr bool) (Expr, error) {
leftExpr, err := newExpr(parserExpr.Left, tableAlias, isSelectExpr)
if err != nil {
return nil, err
}
rightExpr, err := newExpr(parserExpr.Right, tableAlias, isSelectExpr)
if err != nil {
return nil, err
}
f, err := newOrExpr(leftExpr, rightExpr)
if err != nil {
return nil, err
}
if leftExpr.Type() != Bool || rightExpr.Type() != Bool {
return f, nil
}
value, err := f.Eval(nil)
if err != nil {
return nil, err
}
return newValueExpr(value), nil
}
func toNotExpr(parserExpr *sqlparser.NotExpr, tableAlias string, isSelectExpr bool) (Expr, error) {
rightExpr, err := newExpr(parserExpr.Expr, tableAlias, isSelectExpr)
if err != nil {
return nil, err
}
f, err := newNotExpr(rightExpr)
if err != nil {
return nil, err
}
if rightExpr.Type() != Bool {
return f, nil
}
value, err := f.Eval(nil)
if err != nil {
return nil, err
}
return newValueExpr(value), nil
}
func newExpr(parserExpr sqlparser.Expr, tableAlias string, isSelectExpr bool) (Expr, error) {
f, err := newLiteralExpr(parserExpr, tableAlias)
if err != nil {
return nil, err
}
if f != nil {
return f, nil
}
switch parserExpr.(type) {
case *sqlparser.ParenExpr:
return newExpr(parserExpr.(*sqlparser.ParenExpr).Expr, tableAlias, isSelectExpr)
case *sqlparser.IsExpr:
return isExprToComparisonExpr(parserExpr.(*sqlparser.IsExpr), tableAlias, isSelectExpr)
case *sqlparser.RangeCond:
return rangeCondToComparisonFunc(parserExpr.(*sqlparser.RangeCond), tableAlias, isSelectExpr)
case *sqlparser.ComparisonExpr:
return toComparisonExpr(parserExpr.(*sqlparser.ComparisonExpr), tableAlias, isSelectExpr)
case *sqlparser.BinaryExpr:
return toArithExpr(parserExpr.(*sqlparser.BinaryExpr), tableAlias, isSelectExpr)
case *sqlparser.FuncExpr:
return toFuncExpr(parserExpr.(*sqlparser.FuncExpr), tableAlias, isSelectExpr)
case *sqlparser.AndExpr:
return toAndExpr(parserExpr.(*sqlparser.AndExpr), tableAlias, isSelectExpr)
case *sqlparser.OrExpr:
return toOrExpr(parserExpr.(*sqlparser.OrExpr), tableAlias, isSelectExpr)
case *sqlparser.NotExpr:
return toNotExpr(parserExpr.(*sqlparser.NotExpr), tableAlias, isSelectExpr)
}
return nil, errParseUnsupportedSyntax(fmt.Errorf("unknown expression type %T; %v", parserExpr, parserExpr))
}
func newSelectExprs(parserSelectExprs []sqlparser.SelectExpr, tableAlias string) ([]Expr, bool, error) {
var funcs []Expr
starExprFound := false
aggregatedExprFound := false
for _, selectExpr := range parserSelectExprs {
switch selectExpr.(type) {
case *sqlparser.AliasedExpr:
if starExprFound {
return nil, false, errParseAsteriskIsNotAloneInSelectList(nil)
}
aliasedExpr := selectExpr.(*sqlparser.AliasedExpr)
f, err := newExpr(aliasedExpr.Expr, tableAlias, true)
if err != nil {
return nil, false, err
}
if f.Type() == aggregateFunction {
if !aggregatedExprFound {
aggregatedExprFound = true
if len(funcs) > 0 {
return nil, false, errParseUnsupportedSyntax(fmt.Errorf("expression must not mixed with aggregated expression"))
}
}
} else if aggregatedExprFound {
return nil, false, errParseUnsupportedSyntax(fmt.Errorf("expression must not mixed with aggregated expression"))
}
alias := aliasedExpr.As.String()
if alias != "" {
f = newAliasExpr(alias, f)
}
funcs = append(funcs, f)
case *sqlparser.StarExpr:
if starExprFound {
err := fmt.Errorf("only single star expression allowed")
return nil, false, errParseInvalidContextForWildcardInSelectList(err)
}
starExprFound = true
funcs = append(funcs, newStarExpr())
default:
return nil, false, errParseUnsupportedSyntax(fmt.Errorf("unknown select expression %v", selectExpr))
}
}
return funcs, aggregatedExprFound, nil
}
// Select - SQL Select statement.
type Select struct {
tableName string
tableAlias string
selectExprs []Expr
aggregatedExprFound bool
whereExpr Expr
}
// TableAlias - returns table alias name.
func (statement *Select) TableAlias() string {
return statement.tableAlias
}
// IsSelectAll - returns whether '*' is used in select expression or not.
func (statement *Select) IsSelectAll() bool {
if len(statement.selectExprs) == 1 {
_, ok := statement.selectExprs[0].(*starExpr)
return ok
}
return false
}
// IsAggregated - returns whether aggregated functions are used in select expression or not.
func (statement *Select) IsAggregated() bool {
return statement.aggregatedExprFound
}
// AggregateResult - returns aggregate result as record.
func (statement *Select) AggregateResult(output Record) error {
if !statement.aggregatedExprFound {
return nil
}
for i, expr := range statement.selectExprs {
value, err := expr.AggregateValue()
if err != nil {
return err
}
if value == nil {
return errInternalError(fmt.Errorf("%v returns <nil> for AggregateValue()", expr))
}
name := fmt.Sprintf("_%v", i+1)
if _, ok := expr.(*aliasExpr); ok {
name = expr.(*aliasExpr).alias
}
if err = output.Set(name, value); err != nil {
return errInternalError(fmt.Errorf("error occurred to store value %v for %v; %v", value, name, err))
}
}
return nil
}
// Eval - evaluates this Select expressions for given record.
func (statement *Select) Eval(input, output Record) (Record, error) {
if statement.whereExpr != nil {
value, err := statement.whereExpr.Eval(input)
if err != nil {
return nil, err
}
if value == nil || value.valueType != Bool {
err = fmt.Errorf("WHERE expression %v returns invalid bool value %v", statement.whereExpr, value)
return nil, errInternalError(err)
}
if !value.BoolValue() {
return nil, nil
}
}
// Call selectExprs
for i, expr := range statement.selectExprs {
value, err := expr.Eval(input)
if err != nil {
return nil, err
}
if statement.aggregatedExprFound {
continue
}
name := fmt.Sprintf("_%v", i+1)
switch expr.(type) {
case *starExpr:
return value.recordValue(), nil
case *aliasExpr:
name = expr.(*aliasExpr).alias
case *columnExpr:
name = expr.(*columnExpr).name
}
if err = output.Set(name, value); err != nil {
return nil, errInternalError(fmt.Errorf("error occurred to store value %v for %v; %v", value, name, err))
}
}
return output, nil
}
// NewSelect - creates new Select by parsing sql.
func NewSelect(sql string) (*Select, error) {
stmt, err := sqlparser.Parse(sql)
if err != nil {
return nil, errUnsupportedSQLStructure(err)
}
selectStmt, ok := stmt.(*sqlparser.Select)
if !ok {
return nil, errParseUnsupportedSelect(fmt.Errorf("unsupported SQL statement %v", sql))
}
var tableName, tableAlias string
for _, fromExpr := range selectStmt.From {
tableExpr := fromExpr.(*sqlparser.AliasedTableExpr)
tableName = tableExpr.Expr.(sqlparser.TableName).Name.String()
tableAlias = tableExpr.As.String()
}
selectExprs, aggregatedExprFound, err := newSelectExprs(selectStmt.SelectExprs, tableAlias)
if err != nil {
return nil, err
}
var whereExpr Expr
if selectStmt.Where != nil {
whereExpr, err = newExpr(selectStmt.Where.Expr, tableAlias, false)
if err != nil {
return nil, err
}
}
return &Select{
tableName: tableName,
tableAlias: tableAlias,
selectExprs: selectExprs,
aggregatedExprFound: aggregatedExprFound,
whereExpr: whereExpr,
}, nil
}

View File

@ -0,0 +1,202 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import (
"errors"
"fmt"
"strings"
)
var (
errBadLimitSpecified = errors.New("Limit value must be a positive integer")
)
// SelectStatement is the top level parsed and analyzed structure
type SelectStatement struct {
selectAST *Select
// Analysis result of the statement
selectQProp qProp
// Result of parsing the limit clause if one is present
// (otherwise -1)
limitValue int64
// Count of rows that have been output.
outputCount int64
}
// ParseSelectStatement - parses a select query from the given string
// and analyzes it.
func ParseSelectStatement(s string) (stmt SelectStatement, err error) {
var selectAST Select
err = SQLParser.ParseString(s, &selectAST)
if err != nil {
err = errQueryParseFailure(err)
return
}
stmt.selectAST = &selectAST
// Check the parsed limit value
stmt.limitValue, err = parseLimit(selectAST.Limit)
if err != nil {
err = errQueryAnalysisFailure(err)
return
}
// Analyze where clause
if selectAST.Where != nil {
whereQProp := selectAST.Where.analyze(&selectAST)
if whereQProp.err != nil {
err = errQueryAnalysisFailure(fmt.Errorf("Where clause error: %v", whereQProp.err))
return
}
if whereQProp.isAggregation {
err = errQueryAnalysisFailure(errors.New("WHERE clause cannot have an aggregation"))
return
}
}
// Validate table name
tableString := strings.ToLower(selectAST.From.Table.String())
if !strings.HasPrefix(tableString, "s3object.") && tableString != "s3object" {
err = errBadTableName(errors.New("Table name must be s3object"))
return
}
// Analyze main select expression
stmt.selectQProp = selectAST.Expression.analyze(&selectAST)
err = stmt.selectQProp.err
if err != nil {
fmt.Println("Got Analysis err:", err)
err = errQueryAnalysisFailure(err)
}
return
}
func parseLimit(v *LitValue) (int64, error) {
switch {
case v == nil:
return -1, nil
case v.Number == nil:
return -1, errBadLimitSpecified
default:
r := int64(*v.Number)
if r < 0 {
return -1, errBadLimitSpecified
}
return r, nil
}
}
// IsAggregated returns if the statement involves SQL aggregation
func (e *SelectStatement) IsAggregated() bool {
return e.selectQProp.isAggregation
}
// AggregateResult - returns the aggregated result after all input
// records have been processed. Applies only to aggregation queries.
func (e *SelectStatement) AggregateResult(output Record) error {
for i, expr := range e.selectAST.Expression.Expressions {
v, err := expr.evalNode(nil)
if err != nil {
return err
}
output.Set(fmt.Sprintf("_%d", i+1), v)
}
return nil
}
// AggregateRow - aggregates the input record. Applies only to
// aggregation queries.
func (e *SelectStatement) AggregateRow(input Record) error {
for _, expr := range e.selectAST.Expression.Expressions {
err := expr.aggregateRow(input)
if err != nil {
return err
}
}
return nil
}
// Eval - evaluates the Select statement for the given record. It
// applies only to non-aggregation queries.
func (e *SelectStatement) Eval(input, output Record) (Record, error) {
if whereExpr := e.selectAST.Where; whereExpr != nil {
value, err := whereExpr.evalNode(input)
if err != nil {
return nil, err
}
b, ok := value.ToBool()
if !ok {
err = fmt.Errorf("WHERE expression did not return bool")
return nil, err
}
if !b {
// Where clause is not satisfied by the row
return nil, nil
}
}
if e.selectAST.Expression.All {
// Return the input record for `SELECT * FROM
// .. WHERE ..`
// Update count of records output.
if e.limitValue > -1 {
e.outputCount++
}
return input, nil
}
for i, expr := range e.selectAST.Expression.Expressions {
v, err := expr.evalNode(input)
if err != nil {
return nil, err
}
// Pick output column names
if expr.As != "" {
output.Set(expr.As, v)
} else if comp, ok := getLastKeypathComponent(expr.Expression); ok {
output.Set(comp, v)
} else {
output.Set(fmt.Sprintf("_%d", i+1), v)
}
}
// Update count of records output.
if e.limitValue > -1 {
e.outputCount++
}
return output, nil
}
// LimitReached - returns true if the number of records output has
// reached the value of the `LIMIT` clause.
func (e *SelectStatement) LimitReached() bool {
if e.limitValue == -1 {
return false
}
return e.outputCount >= e.limitValue
}

View File

@ -0,0 +1,188 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import (
"errors"
"strings"
)
var (
errMalformedEscapeSequence = errors.New("Malformed escape sequence in LIKE clause")
errInvalidTrimArg = errors.New("Trim argument is invalid - this should not happen")
errInvalidSubstringIndexLen = errors.New("Substring start index or length falls outside the string")
)
const (
percent rune = '%'
underscore rune = '_'
runeZero rune = 0
)
func evalSQLLike(text, pattern string, escape rune) (match bool, err error) {
s := []rune{}
prev := runeZero
hasLeadingPercent := false
patLen := len([]rune(pattern))
for i, r := range pattern {
if i > 0 && prev == escape {
switch r {
case percent, escape, underscore:
s = append(s, r)
prev = r
if r == escape {
prev = runeZero
}
default:
return false, errMalformedEscapeSequence
}
continue
}
prev = r
var ok bool
switch r {
case percent:
if len(s) == 0 {
hasLeadingPercent = true
continue
}
text, ok = matcher(text, string(s), hasLeadingPercent)
if !ok {
return false, nil
}
hasLeadingPercent = true
s = []rune{}
if i == patLen-1 {
// Last pattern character is a %, so
// we are done.
return true, nil
}
case underscore:
if len(s) == 0 {
text, ok = dropRune(text)
if !ok {
return false, nil
}
continue
}
text, ok = matcher(text, string(s), hasLeadingPercent)
if !ok {
return false, nil
}
hasLeadingPercent = false
text, ok = dropRune(text)
if !ok {
return false, nil
}
s = []rune{}
case escape:
if i == patLen-1 {
return false, errMalformedEscapeSequence
}
// Otherwise do nothing.
default:
s = append(s, r)
}
}
if hasLeadingPercent {
return strings.HasSuffix(text, string(s)), nil
}
return string(s) == text, nil
}
// matcher - Finds `pat` in `text`, and returns the part remainder of
// `text`, after the match. If leadingPercent is false, `pat` must be
// the prefix of `text`, otherwise it must be a substring.
func matcher(text, pat string, leadingPercent bool) (res string, match bool) {
if !leadingPercent {
res = strings.TrimPrefix(text, pat)
if len(text) == len(res) {
return "", false
}
} else {
parts := strings.SplitN(text, pat, 2)
if len(parts) == 1 {
return "", false
}
res = parts[1]
}
return res, true
}
func dropRune(text string) (res string, ok bool) {
r := []rune(text)
if len(r) == 0 {
return "", false
}
return string(r[1:]), true
}
func evalSQLSubstring(s string, startIdx, length int) (res string, err error) {
if startIdx <= 0 || startIdx > len(s) {
return "", errInvalidSubstringIndexLen
}
// StartIdx is 1-based in the input
startIdx--
rs := []rune(s)
endIdx := len(rs)
if length != -1 {
if length < 0 || startIdx+length > len(s) {
return "", errInvalidSubstringIndexLen
}
endIdx = startIdx + length
}
return string(rs[startIdx:endIdx]), nil
}
const (
trimLeading = "LEADING"
trimTrailing = "TRAILING"
trimBoth = "BOTH"
)
func evalSQLTrim(where *string, trimChars, text string) (result string, err error) {
cutSet := " "
if trimChars != "" {
cutSet = trimChars
}
trimFunc := strings.Trim
switch {
case where == nil:
case *where == trimBoth:
case *where == trimLeading:
trimFunc = strings.TrimLeft
case *where == trimTrailing:
trimFunc = strings.TrimRight
default:
return "", errInvalidTrimArg
}
return trimFunc(text, cutSet), nil
}

View File

@ -0,0 +1,107 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import (
"testing"
)
func TestEvalSQLLike(t *testing.T) {
dropCases := []struct {
input, resultExpected string
matchExpected bool
}{
{"", "", false},
{"a", "", true},
{"ab", "b", true},
{"தமிழ்", "மிழ்", true},
}
for i, tc := range dropCases {
res, ok := dropRune(tc.input)
if res != tc.resultExpected || ok != tc.matchExpected {
t.Errorf("DropRune Case %d failed", i)
}
}
matcherCases := []struct {
iText, iPat string
iHasLeadingPercent bool
resultExpected string
matchExpected bool
}{
{"abcd", "bcd", false, "", false},
{"abcd", "bcd", true, "", true},
{"abcd", "abcd", false, "", true},
{"abcd", "abcd", true, "", true},
{"abcd", "ab", false, "cd", true},
{"abcd", "ab", true, "cd", true},
{"abcd", "bc", false, "", false},
{"abcd", "bc", true, "d", true},
}
for i, tc := range matcherCases {
res, ok := matcher(tc.iText, tc.iPat, tc.iHasLeadingPercent)
if res != tc.resultExpected || ok != tc.matchExpected {
t.Errorf("Matcher Case %d failed", i)
}
}
evalCases := []struct {
iText, iPat string
iEsc rune
matchExpected bool
errExpected error
}{
{"abcd", "abc", runeZero, false, nil},
{"abcd", "abcd", runeZero, true, nil},
{"abcd", "abc_", runeZero, true, nil},
{"abcd", "_bdd", runeZero, false, nil},
{"abcd", "_b_d", runeZero, true, nil},
{"abcd", "____", runeZero, true, nil},
{"abcd", "____%", runeZero, true, nil},
{"abcd", "%____", runeZero, true, nil},
{"abcd", "%__", runeZero, true, nil},
{"abcd", "____", runeZero, true, nil},
{"", "_", runeZero, false, nil},
{"", "%", runeZero, true, nil},
{"abcd", "%%%%%", runeZero, true, nil},
{"abcd", "_____", runeZero, false, nil},
{"abcd", "%%%%%", runeZero, true, nil},
{"a%%d", `a\%\%d`, '\\', true, nil},
{"a%%d", `a\%d`, '\\', false, nil},
{`a%%\d`, `a\%\%\\d`, '\\', true, nil},
{`a%%\`, `a\%\%\\`, '\\', true, nil},
{`a%__%\`, `a\%\_\_\%\\`, '\\', true, nil},
{`a%__%\`, `a\%\_\_\%_`, '\\', true, nil},
{`a%__%\`, `a\%\_\__`, '\\', false, nil},
{`a%__%\`, `a\%\_\_%`, '\\', true, nil},
{`a%__%\`, `a?%?_?_?%\`, '?', true, nil},
}
for i, tc := range evalCases {
// fmt.Println("Case:", i)
res, err := evalSQLLike(tc.iText, tc.iPat, tc.iEsc)
if res != tc.matchExpected || err != tc.errExpected {
t.Errorf("Eval Case %d failed: %v %v", i, res, err)
}
}
}

View File

@ -1,118 +0,0 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
// Type - value type.
type Type string
const (
// Null - represents NULL value type.
Null Type = "null"
// Bool - represents boolean value type.
Bool Type = "bool"
// Int - represents integer value type.
Int Type = "int"
// Float - represents floating point value type.
Float Type = "float"
// String - represents string value type.
String Type = "string"
// Timestamp - represents time value type.
Timestamp Type = "timestamp"
// Array - represents array of values where each value type is one of above.
Array Type = "array"
column Type = "column"
record Type = "record"
function Type = "function"
aggregateFunction Type = "aggregatefunction"
arithmeticFunction Type = "arithmeticfunction"
comparisonFunction Type = "comparisonfunction"
logicalFunction Type = "logicalfunction"
// Integer Type = "integer" // Same as Int
// Decimal Type = "decimal" // Same as Float
// Numeric Type = "numeric" // Same as Float
)
func (t Type) isBase() bool {
switch t {
case Null, Bool, Int, Float, String, Timestamp:
return true
}
return false
}
func (t Type) isBaseKind() bool {
switch t {
case Null, Bool, Int, Float, String, Timestamp, column:
return true
}
return false
}
func (t Type) isNumber() bool {
switch t {
case Int, Float:
return true
}
return false
}
func (t Type) isNumberKind() bool {
switch t {
case Int, Float, column:
return true
}
return false
}
func (t Type) isIntKind() bool {
switch t {
case Int, column:
return true
}
return false
}
func (t Type) isBoolKind() bool {
switch t {
case Bool, column:
return true
}
return false
}
func (t Type) isStringKind() bool {
switch t {
case String, column:
return true
}
return false
}

87
pkg/s3select/sql/utils.go Normal file
View File

@ -0,0 +1,87 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import (
"fmt"
"strings"
)
// String functions
// String - returns the JSONPath representation
func (e *JSONPath) String() string {
parts := make([]string, len(e.PathExpr)+1)
parts[0] = e.BaseKey.String()
for i, pe := range e.PathExpr {
parts[i+1] = pe.String()
}
return strings.Join(parts, "")
}
func (e *JSONPathElement) String() string {
switch {
case e.Key != nil:
return e.Key.String()
case e.Index != nil:
return fmt.Sprintf("[%d]", *e.Index)
case e.ObjectWildcard:
return ".*"
case e.ArrayWildcard:
return "[*]"
}
return ""
}
// Removes double quotes in quoted identifiers
func (i *Identifier) String() string {
if i.Unquoted != nil {
return *i.Unquoted
}
return string(*i.Quoted)
}
func (o *ObjectKey) String() string {
if o.Lit != nil {
return fmt.Sprintf("['%s']", string(*o.Lit))
}
return fmt.Sprintf(".%s", o.ID.String())
}
// getLastKeypathComponent checks if the given expression is a path
// expression, and if so extracts the last dot separated component of
// the path. Otherwise it returns false.
func getLastKeypathComponent(e *Expression) (string, bool) {
if len(e.And) > 1 ||
len(e.And[0].Condition) > 1 ||
e.And[0].Condition[0].Not != nil ||
e.And[0].Condition[0].Operand.ConditionRHS != nil {
return "", false
}
operand := e.And[0].Condition[0].Operand.Operand
if operand.Right != nil ||
operand.Left.Right != nil ||
operand.Left.Left.Negated != nil ||
operand.Left.Left.Primary.JPathExpr == nil {
return "", false
}
keypath := operand.Left.Left.Primary.JPathExpr.String()
ps := strings.Split(keypath, ".")
return ps[len(ps)-1], true
}

View File

@ -17,220 +17,710 @@
package sql
import (
"encoding/json"
"errors"
"fmt"
"math"
"strconv"
"strings"
"time"
"github.com/xwb1989/sqlparser"
)
// Value - represents any primitive value of bool, int, float, string and time.
var (
errArithMismatchedTypes = errors.New("cannot perform arithmetic on mismatched types")
errArithInvalidOperator = errors.New("invalid arithmetic operator")
errArithDivideByZero = errors.New("cannot divide by 0")
errCmpMismatchedTypes = errors.New("cannot compare values of different types")
errCmpInvalidBoolOperator = errors.New("invalid comparison operator for boolean arguments")
)
// vType represents the concrete type of a `Value`
type vType int
// Valid values for Type
const (
typeNull vType = iota + 1
typeBool
typeString
// 64-bit signed integer
typeInt
// 64-bit floating point
typeFloat
// This type refers to untyped values, e.g. as read from CSV
typeBytes
)
// Value represents a value of restricted type reduced from an
// expression represented by an ASTNode. Only one of the fields is
// non-nil.
//
// In cases where we are fetching data from a data source (like csv),
// the type may not be determined yet. In these cases, a byte-slice is
// used.
type Value struct {
value interface{}
valueType Type
value interface{}
vType vType
}
// String - represents value as string.
func (value *Value) String() string {
if value.value == nil {
if value.valueType == Null {
return "NULL"
}
return "<nil>"
// GetTypeString returns a string representation for vType
func (v *Value) GetTypeString() string {
switch v.vType {
case typeNull:
return "NULL"
case typeBool:
return "BOOL"
case typeString:
return "STRING"
case typeInt:
return "INT"
case typeFloat:
return "FLOAT"
case typeBytes:
return "BYTES"
}
switch value.valueType {
case String:
return fmt.Sprintf("'%v'", value.value)
case Array:
var valueStrings []string
for _, v := range value.value.([]*Value) {
valueStrings = append(valueStrings, fmt.Sprintf("%v", v))
}
return fmt.Sprintf("(%v)", strings.Join(valueStrings, ","))
}
return fmt.Sprintf("%v", value.value)
return "--"
}
// CSVString - encodes to CSV string.
func (value *Value) CSVString() string {
if value.valueType == Null {
// Repr returns a string representation of value.
func (v *Value) Repr() string {
switch v.vType {
case typeNull:
return ":NULL"
case typeBool, typeInt, typeFloat:
return fmt.Sprintf("%v:%s", v.value, v.GetTypeString())
case typeString:
return fmt.Sprintf("\"%s\":%s", v.value.(string), v.GetTypeString())
case typeBytes:
return fmt.Sprintf("\"%s\":BYTES", string(v.value.([]byte)))
default:
return fmt.Sprintf("%v:INVALID", v.value)
}
}
// FromFloat creates a Value from a number
func FromFloat(f float64) *Value {
return &Value{value: f, vType: typeFloat}
}
// FromInt creates a Value from an int
func FromInt(f int64) *Value {
return &Value{value: f, vType: typeInt}
}
// FromString creates a Value from a string
func FromString(str string) *Value {
return &Value{value: str, vType: typeString}
}
// FromBool creates a Value from a bool
func FromBool(b bool) *Value {
return &Value{value: b, vType: typeBool}
}
// FromNull creates a Value with Null value
func FromNull() *Value {
return &Value{vType: typeNull}
}
// FromBytes creates a Value from a []byte
func FromBytes(b []byte) *Value {
return &Value{value: b, vType: typeBytes}
}
// ToFloat works for int and float values
func (v *Value) ToFloat() (val float64, ok bool) {
switch v.vType {
case typeFloat:
val, ok = v.value.(float64)
case typeInt:
var i int64
i, ok = v.value.(int64)
val = float64(i)
default:
}
return
}
// ToInt converts value to int.
func (v *Value) ToInt() (val int64, ok bool) {
switch v.vType {
case typeInt:
val, ok = v.value.(int64)
default:
}
return
}
// ToString converts value to string.
func (v *Value) ToString() (val string, ok bool) {
switch v.vType {
case typeString:
val, ok = v.value.(string)
default:
}
return
}
// ToBool returns the bool value; second return value refers to if the bool
// conversion succeeded.
func (v *Value) ToBool() (val bool, ok bool) {
switch v.vType {
case typeBool:
return v.value.(bool), true
}
return false, false
}
// ToBytes converts Value to byte-slice.
func (v *Value) ToBytes() ([]byte, bool) {
switch v.vType {
case typeBytes:
return v.value.([]byte), true
}
return nil, false
}
// IsNull - checks if value is missing.
func (v *Value) IsNull() bool {
return v.vType == typeNull
}
func (v *Value) isNumeric() bool {
return v.vType == typeInt || v.vType == typeFloat
}
// setters used internally to mutate values
func (v *Value) setInt(i int64) {
v.vType = typeInt
v.value = i
}
func (v *Value) setFloat(f float64) {
v.vType = typeFloat
v.value = f
}
func (v *Value) setString(s string) {
v.vType = typeString
v.value = s
}
func (v *Value) setBool(b bool) {
v.vType = typeBool
v.value = b
}
// CSVString - convert to string for CSV serialization
func (v *Value) CSVString() string {
switch v.vType {
case typeNull:
return ""
case typeBool:
return fmt.Sprintf("%v", v.value.(bool))
case typeString:
return fmt.Sprintf("%s", v.value.(string))
case typeInt:
return fmt.Sprintf("%v", v.value.(int64))
case typeFloat:
return fmt.Sprintf("%v", v.value.(float64))
case typeBytes:
return fmt.Sprintf("%v", string(v.value.([]byte)))
default:
return "CSV serialization not implemented for this type"
}
}
// floatToValue converts a float into int representation if needed.
func floatToValue(f float64) *Value {
intPart, fracPart := math.Modf(f)
if fracPart == 0 {
return FromInt(int64(intPart))
}
return FromFloat(f)
}
// Value comparison functions: we do not expose them outside the
// module. Logical operators "<", ">", ">=", "<=" work on strings and
// numbers. Equality operators "=", "!=" work on strings,
// numbers and booleans.
// Supported comparison operators
const (
opLt = "<"
opLte = "<="
opGt = ">"
opGte = ">="
opEq = "="
opIneq = "!="
)
// When numeric types are compared, type promotions could happen. If
// values do not have types (e.g. when reading from CSV), for
// comparison operations, automatic type conversion happens by trying
// to check if the value is a number (first an integer, then a float),
// and falling back to string.
func (v *Value) compareOp(op string, a *Value) (res bool, err error) {
if !isValidComparisonOperator(op) {
return false, errArithInvalidOperator
}
return fmt.Sprintf("%v", value.value)
// Check if type conversion/inference is needed - it is needed
// if the Value is a byte-slice.
err = inferTypesForCmp(v, a)
if err != nil {
return false, err
}
isNumeric := v.isNumeric() && a.isNumeric()
if isNumeric {
intV, ok1i := v.ToInt()
intA, ok2i := a.ToInt()
if ok1i && ok2i {
return intCompare(op, intV, intA), nil
}
// If both values are numeric, then at least one is
// float since we got here, so we convert.
flV, _ := v.ToFloat()
flA, _ := a.ToFloat()
return floatCompare(op, flV, flA), nil
}
strV, ok1s := v.ToString()
strA, ok2s := a.ToString()
if ok1s && ok2s {
return stringCompare(op, strV, strA), nil
}
boolV, ok1b := v.ToBool()
boolA, ok2b := v.ToBool()
if ok1b && ok2b {
return boolCompare(op, boolV, boolA)
}
return false, errCmpMismatchedTypes
}
// MarshalJSON - encodes to JSON data.
func (value *Value) MarshalJSON() ([]byte, error) {
return json.Marshal(value.value)
func inferTypesForCmp(a *Value, b *Value) error {
_, okA := a.ToBytes()
_, okB := b.ToBytes()
switch {
case !okA && !okB:
// Both Values already have types
return nil
case okA && okB:
// Both Values are untyped so try the types in order:
// int, float, bool, string
// Check for numeric inference
iA, okAi := a.bytesToInt()
iB, okBi := b.bytesToInt()
if okAi && okBi {
a.setInt(iA)
b.setInt(iB)
return nil
}
fA, okAf := a.bytesToFloat()
fB, okBf := b.bytesToFloat()
if okAf && okBf {
a.setFloat(fA)
b.setFloat(fB)
return nil
}
// Check if they int and float combination.
if okAi && okBf {
a.setInt(iA)
b.setFloat(fA)
return nil
}
if okBi && okAf {
a.setFloat(fA)
b.setInt(iB)
return nil
}
// Not numeric types at this point.
// Check for bool inference
bA, okAb := a.bytesToBool()
bB, okBb := b.bytesToBool()
if okAb && okBb {
a.setBool(bA)
b.setBool(bB)
return nil
}
// Fallback to string
sA := a.bytesToString()
sB := b.bytesToString()
a.setString(sA)
b.setString(sB)
return nil
case okA && !okB:
// Here a has `a` is untyped, but `b` has a fixed
// type.
switch b.vType {
case typeString:
s := a.bytesToString()
a.setString(s)
case typeInt, typeFloat:
if iA, ok := a.bytesToInt(); ok {
a.setInt(iA)
} else if fA, ok := a.bytesToFloat(); ok {
a.setFloat(fA)
} else {
return fmt.Errorf("Could not convert %s to a number", string(a.value.([]byte)))
}
case typeBool:
if bA, ok := a.bytesToBool(); ok {
a.setBool(bA)
} else {
return fmt.Errorf("Could not convert %s to a boolean", string(a.value.([]byte)))
}
default:
return errCmpMismatchedTypes
}
return nil
case !okA && okB:
// swap arguments to avoid repeating code
return inferTypesForCmp(b, a)
default:
// Does not happen
return nil
}
}
// NullValue - returns underlying null value. It panics if value is not null type.
func (value *Value) NullValue() *struct{} {
if value.valueType == Null {
// Value arithmetic functions: we do not expose them outside the
// module. All arithmetic works only on numeric values with automatic
// promotion to the "larger" type that can represent the value. TODO:
// Add support for large number arithmetic.
// Supported arithmetic operators
const (
opPlus = "+"
opMinus = "-"
opDivide = "/"
opMultiply = "*"
opModulo = "%"
)
// For arithmetic operations, if both values are numeric then the
// operation shall succeed. If the types are unknown automatic type
// conversion to a number is attempted.
func (v *Value) arithOp(op string, a *Value) error {
err := inferTypeForArithOp(v)
if err != nil {
return err
}
err = inferTypeForArithOp(a)
if err != nil {
return err
}
if !v.isNumeric() || !a.isNumeric() {
return errInvalidDataType(errArithMismatchedTypes)
}
if !isValidArithOperator(op) {
return errInvalidDataType(errArithMismatchedTypes)
}
intV, ok1i := v.ToInt()
intA, ok2i := a.ToInt()
switch {
case ok1i && ok2i:
res, err := intArithOp(op, intV, intA)
v.setInt(res)
return err
default:
// Convert arguments to float
flV, _ := v.ToFloat()
flA, _ := a.ToFloat()
res, err := floatArithOp(op, flV, flA)
v.setFloat(res)
return err
}
}
func inferTypeForArithOp(a *Value) error {
if _, ok := a.ToBytes(); !ok {
return nil
}
panic(fmt.Sprintf("requested bool value but found %T type", value.value))
}
// BoolValue - returns underlying bool value. It panics if value is not Bool type.
func (value *Value) BoolValue() bool {
if value.valueType == Bool {
return value.value.(bool)
if i, ok := a.bytesToInt(); ok {
a.setInt(i)
return nil
}
panic(fmt.Sprintf("requested bool value but found %T type", value.value))
}
// IntValue - returns underlying int value. It panics if value is not Int type.
func (value *Value) IntValue() int64 {
if value.valueType == Int {
return value.value.(int64)
if f, ok := a.bytesToFloat(); ok {
a.setFloat(f)
return nil
}
panic(fmt.Sprintf("requested int value but found %T type", value.value))
err := fmt.Errorf("Could not convert %s to a number", string(a.value.([]byte)))
return errInvalidDataType(err)
}
// FloatValue - returns underlying int/float value as float64. It panics if value is not Int/Float type.
func (value *Value) FloatValue() float64 {
switch value.valueType {
case Int:
return float64(value.value.(int64))
case Float:
return value.value.(float64)
// All the bytesTo* functions defined below assume the value is a byte-slice.
// Converts untyped value into int. The bool return implies success -
// it returns false only if there is a conversion failure.
func (v *Value) bytesToInt() (int64, bool) {
bytes, _ := v.ToBytes()
i, err := strconv.ParseInt(string(bytes), 10, 64)
return i, err == nil
}
// Converts untyped value into float. The bool return implies success
// - it returns false only if there is a conversion failure.
func (v *Value) bytesToFloat() (float64, bool) {
bytes, _ := v.ToBytes()
i, err := strconv.ParseFloat(string(bytes), 64)
return i, err == nil
}
// Converts untyped value into bool. The second bool return implies
// success - it returns false in case of a conversion failure.
func (v *Value) bytesToBool() (val bool, ok bool) {
bytes, _ := v.ToBytes()
ok = true
switch strings.ToLower(string(bytes)) {
case "t", "true":
val = true
case "f", "false":
val = false
default:
ok = false
}
return val, ok
}
// bytesToString - never fails
func (v *Value) bytesToString() string {
bytes, _ := v.ToBytes()
return string(bytes)
}
// Calculates minimum or maximum of v and a and assigns the result to
// v - it works only on numeric arguments, where `v` is already
// assumed to be numeric. Attempts conversion to numeric type for `a`
// (first int, then float) only if the underlying values do not have a
// type.
func (v *Value) minmax(a *Value, isMax, isFirstRow bool) error {
err := inferTypeForArithOp(a)
if err != nil {
return err
}
panic(fmt.Sprintf("requested float value but found %T type", value.value))
}
// StringValue - returns underlying string value. It panics if value is not String type.
func (value *Value) StringValue() string {
if value.valueType == String {
return value.value.(string)
if !a.isNumeric() {
return errArithMismatchedTypes
}
panic(fmt.Sprintf("requested string value but found %T type", value.value))
}
// TimeValue - returns underlying time value. It panics if value is not Timestamp type.
func (value *Value) TimeValue() time.Time {
if value.valueType == Timestamp {
return value.value.(time.Time)
}
panic(fmt.Sprintf("requested time value but found %T type", value.value))
}
// ArrayValue - returns underlying value array. It panics if value is not Array type.
func (value *Value) ArrayValue() []*Value {
if value.valueType == Array {
return value.value.([]*Value)
}
panic(fmt.Sprintf("requested array value but found %T type", value.value))
}
func (value *Value) recordValue() Record {
if value.valueType == record {
return value.value.(Record)
}
panic(fmt.Sprintf("requested record value but found %T type", value.value))
}
// Type - returns value type.
func (value *Value) Type() Type {
return value.valueType
}
// Value - returns underneath value interface.
func (value *Value) Value() interface{} {
return value.value
}
// NewNull - creates new null value.
func NewNull() *Value {
return &Value{nil, Null}
}
// NewBool - creates new Bool value of b.
func NewBool(b bool) *Value {
return &Value{b, Bool}
}
// NewInt - creates new Int value of i.
func NewInt(i int64) *Value {
return &Value{i, Int}
}
// NewFloat - creates new Float value of f.
func NewFloat(f float64) *Value {
return &Value{f, Float}
}
// NewString - creates new Sring value of s.
func NewString(s string) *Value {
return &Value{s, String}
}
// NewTime - creates new Time value of t.
func NewTime(t time.Time) *Value {
return &Value{t, Timestamp}
}
// NewArray - creates new Array value of values.
func NewArray(values []*Value) *Value {
return &Value{values, Array}
}
func newRecordValue(r Record) *Value {
return &Value{r, record}
}
// NewValue - creates new Value from SQLVal v.
func NewValue(v *sqlparser.SQLVal) (*Value, error) {
switch v.Type {
case sqlparser.StrVal:
return NewString(string(v.Val)), nil
case sqlparser.IntVal:
i64, err := strconv.ParseInt(string(v.Val), 10, 64)
if err != nil {
return nil, err
// In case of first row, set v to a.
if isFirstRow {
intA, okI := a.ToInt()
if okI {
v.setInt(intA)
return nil
}
return NewInt(i64), nil
case sqlparser.FloatVal:
f64, err := strconv.ParseFloat(string(v.Val), 64)
if err != nil {
return nil, err
}
return NewFloat(f64), nil
case sqlparser.HexNum: // represented as 0xDD
i64, err := strconv.ParseInt(string(v.Val), 16, 64)
if err != nil {
return nil, err
}
return NewInt(i64), nil
case sqlparser.HexVal: // represented as X'0DD'
i64, err := strconv.ParseInt(string(v.Val), 16, 64)
if err != nil {
return nil, err
}
return NewInt(i64), nil
case sqlparser.BitVal: // represented as B'00'
i64, err := strconv.ParseInt(string(v.Val), 2, 64)
if err != nil {
return nil, err
}
return NewInt(i64), nil
case sqlparser.ValArg:
// FIXME: the format is unknown and not sure how to handle it.
floatA, _ := a.ToFloat()
v.setFloat(floatA)
return nil
}
return nil, fmt.Errorf("unknown SQL value %v; %v ", v, v.Type)
intV, ok1i := v.ToInt()
intA, ok2i := a.ToInt()
if ok1i && ok2i {
result := intV
if !isMax {
if intA < result {
result = intA
}
} else {
if intA > result {
result = intA
}
}
v.setInt(result)
return nil
}
floatV, _ := v.ToFloat()
floatA, _ := a.ToFloat()
var result float64
if !isMax {
result = math.Min(floatV, floatA)
} else {
result = math.Max(floatV, floatA)
}
v.setFloat(result)
return nil
}
// inferTypeAsString is used to convert untyped values to string - it
// is called when the caller requires a string context to proceed.
func inferTypeAsString(v *Value) {
b, ok := v.ToBytes()
if !ok {
return
}
v.setString(string(b))
}
func isValidComparisonOperator(op string) bool {
switch op {
case opLt:
case opLte:
case opGt:
case opGte:
case opEq:
case opIneq:
default:
return false
}
return true
}
func intCompare(op string, left, right int64) bool {
switch op {
case opLt:
return left < right
case opLte:
return left <= right
case opGt:
return left > right
case opGte:
return left >= right
case opEq:
return left == right
case opIneq:
return left != right
}
// This case does not happen
return false
}
func floatCompare(op string, left, right float64) bool {
switch op {
case opLt:
return left < right
case opLte:
return left <= right
case opGt:
return left > right
case opGte:
return left >= right
case opEq:
return left == right
case opIneq:
return left != right
}
// This case does not happen
return false
}
func stringCompare(op string, left, right string) bool {
switch op {
case opLt:
return left < right
case opLte:
return left <= right
case opGt:
return left > right
case opGte:
return left >= right
case opEq:
return left == right
case opIneq:
return left != right
}
// This case does not happen
return false
}
func boolCompare(op string, left, right bool) (bool, error) {
switch op {
case opEq:
return left == right, nil
case opIneq:
return left != right, nil
default:
return false, errCmpInvalidBoolOperator
}
}
func isValidArithOperator(op string) bool {
switch op {
case opPlus:
case opMinus:
case opDivide:
case opMultiply:
case opModulo:
default:
return false
}
return true
}
// Overflow errors are ignored.
func intArithOp(op string, left, right int64) (int64, error) {
switch op {
case opPlus:
return left + right, nil
case opMinus:
return left - right, nil
case opDivide:
if right == 0 {
return 0, errArithDivideByZero
}
return left / right, nil
case opMultiply:
return left * right, nil
case opModulo:
if right == 0 {
return 0, errArithDivideByZero
}
return left % right, nil
}
// This does not happen
return 0, nil
}
// Overflow errors are ignored.
func floatArithOp(op string, left, right float64) (float64, error) {
switch op {
case opPlus:
return left + right, nil
case opMinus:
return left - right, nil
case opDivide:
if right == 0 {
return 0, errArithDivideByZero
}
return left / right, nil
case opMultiply:
return left * right, nil
case opModulo:
if right == 0 {
return 0, errArithDivideByZero
}
return math.Mod(left, right), nil
}
// This does not happen
return 0, nil
}

19
vendor/github.com/alecthomas/participle/COPYING generated vendored Normal file
View File

@ -0,0 +1,19 @@
Copyright (C) 2017 Alec Thomas
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
of the Software, and to permit persons to whom the Software is furnished to do
so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

345
vendor/github.com/alecthomas/participle/README.md generated vendored Normal file
View File

@ -0,0 +1,345 @@
# A dead simple parser package for Go
[![Godoc](https://godoc.org/github.com/alecthomas/participle?status.svg)](http://godoc.org/github.com/alecthomas/participle) [![CircleCI](https://img.shields.io/circleci/project/github/alecthomas/participle.svg)](https://circleci.com/gh/alecthomas/participle)
[![Go Report Card](https://goreportcard.com/badge/github.com/alecthomas/participle)](https://goreportcard.com/report/github.com/alecthomas/participle) [![Gitter chat](https://badges.gitter.im/alecthomas.png)](https://gitter.im/alecthomas/Lobby)
<!-- TOC -->
1. [Introduction](#introduction)
2. [Limitations](#limitations)
3. [Tutorial](#tutorial)
4. [Overview](#overview)
5. [Annotation syntax](#annotation-syntax)
6. [Capturing](#capturing)
7. [Streaming](#streaming)
8. [Lexing](#lexing)
9. [Options](#options)
10. [Examples](#examples)
11. [Performance](#performance)
<!-- /TOC -->
<a id="markdown-introduction" name="introduction"></a>
## Introduction
The goal of this package is to provide a simple, idiomatic and elegant way of
defining parsers in Go.
Participle's method of defining grammars should be familiar to any Go
programmer who has used the `encoding/json` package: struct field tags define
what and how input is mapped to those same fields. This is not unusual for Go
encoders, but is unusual for a parser.
<a id="markdown-limitations" name="limitations"></a>
## Limitations
Participle parsers are recursive descent. Among other things, this means that they do not support left recursion.
There is an experimental lookahead option for using precomputed lookahead
tables for disambiguation. You can enable this with the parser option
`participle.UseLookahead()`.
Left recursion must be eliminated by restructuring your grammar.
<a id="markdown-tutorial" name="tutorial"></a>
## Tutorial
A [tutorial](TUTORIAL.md) is available, walking through the creation of an .ini parser.
<a id="markdown-overview" name="overview"></a>
## Overview
A grammar is an annotated Go structure used to both define the parser grammar,
and be the AST output by the parser. As an example, following is the final INI
parser from the tutorial.
```go
type INI struct {
Properties []*Property `{ @@ }`
Sections []*Section `{ @@ }`
}
type Section struct {
Identifier string `"[" @Ident "]"`
Properties []*Property `{ @@ }`
}
type Property struct {
Key string `@Ident "="`
Value *Value `@@`
}
type Value struct {
String *string ` @String`
Number *float64 `| @Float`
}
```
> **Note:** Participle also supports named struct tags (eg. <code>Hello string &#96;parser:"@Ident"&#96;</code>).
A parser is constructed from a grammar and a lexer:
```go
parser, err := participle.Build(&INI{})
```
Once constructed, the parser is applied to input to produce an AST:
```go
ast := &INI{}
err := parser.ParseString("size = 10", ast)
// ast == &INI{
// Properties: []*Property{
// {Key: "size", Value: &Value{Number: &10}},
// },
// }
```
<a id="markdown-annotation-syntax" name="annotation-syntax"></a>
## Annotation syntax
- `@<expr>` Capture expression into the field.
- `@@` Recursively capture using the fields own type.
- `<identifier>` Match named lexer token.
- `( ... )` Group.
- `"..."` Match the literal (note that the lexer must emit tokens matching this literal exactly).
- `"...":<identifier>` Match the literal, specifying the exact lexer token type to match.
- `<expr> <expr> ...` Match expressions.
- `<expr> | <expr>` Match one of the alternatives.
The following modifiers can be used after any expression:
- `*` Expression can match zero or more times.
- `+` Expression must match one or more times.
- `?` Expression can match zero or once.
- `!` Require a non-empty match (this is useful with a sequence of optional matches eg. `("a"? "b"? "c"?)!`).
Supported but deprecated:
- `{ ... }` Match 0 or more times (**DEPRECATED** - prefer `( ... )*`).
- `[ ... ]` Optional (**DEPRECATED** - prefer `( ... )?`).
Notes:
- Each struct is a single production, with each field applied in sequence.
- `@<expr>` is the mechanism for capturing matches into the field.
- if a struct field is not keyed with "parser", the entire struct tag
will be used as the grammar fragment. This allows the grammar syntax to remain
clear and simple to maintain.
<a id="markdown-capturing" name="capturing"></a>
## Capturing
Prefixing any expression in the grammar with `@` will capture matching values
for that expression into the corresponding field.
For example:
```go
// The grammar definition.
type Grammar struct {
Hello string `@Ident`
}
// The source text to parse.
source := "world"
// After parsing, the resulting AST.
result == &Grammar{
Hello: "world",
}
```
For slice and string fields, each instance of `@` will accumulate into the
field (including repeated patterns). Accumulation into other types is not
supported.
A successful capture match into a boolean field will set the field to true.
For integer and floating point types, a successful capture will be parsed
with `strconv.ParseInt()` and `strconv.ParseBool()` respectively.
Custom control of how values are captured into fields can be achieved by a
field type implementing the `Capture` interface (`Capture(values []string)
error`).
<a id="markdown-streaming" name="streaming"></a>
## Streaming
Participle supports streaming parsing. Simply pass a channel of your grammar into
`Parse*()`. The grammar will be repeatedly parsed and sent to the channel. Note that
the `Parse*()` call will not return until parsing completes, so it should generally be
started in a goroutine.
```go
type token struct {
Str string ` @Ident`
Num int `| @Int`
}
parser, err := participle.Build(&token{})
tokens := make(chan *token, 128)
err := parser.ParseString(`hello 10 11 12 world`, tokens)
for token := range tokens {
fmt.Printf("%#v\n", token)
}
```
<a id="markdown-lexing" name="lexing"></a>
## Lexing
Participle operates on tokens and thus relies on a lexer to convert character
streams to tokens.
Three lexers are provided, varying in speed and flexibility. The fastest lexer
is based on the [text/scanner](https://golang.org/pkg/text/scanner/) package
but only allows tokens provided by that package. Next fastest is the regexp
lexer (`lexer.Regexp()`). The slowest is currently the EBNF based lexer, but it has a large potential for optimisation through code generation.
To use your own Lexer you will need to implement two interfaces:
[Definition](https://godoc.org/github.com/alecthomas/participle/lexer#Definition)
and [Lexer](https://godoc.org/github.com/alecthomas/participle/lexer#Lexer).
<a id="markdown-options" name="options"></a>
## Options
The Parser's behaviour can be configured via [Options](https://godoc.org/github.com/alecthomas/participle#Option).
<a id="markdown-examples" name="examples"></a>
## Examples
There are several [examples](https://github.com/alecthomas/participle/tree/master/_examples) included:
Example | Description
--------|---------------
[BASIC](https://github.com/alecthomas/participle/tree/master/_examples/basic) | A lexer, parser and interpreter for a [rudimentary dialect](https://caml.inria.fr/pub/docs/oreilly-book/html/book-ora058.html) of BASIC.
[EBNF](https://github.com/alecthomas/participle/tree/master/_examples/ebnf) | Parser for the form of EBNF used by Go.
[Expr](https://github.com/alecthomas/participle/tree/master/_examples/expr) | A basic mathematical expression parser and evaluator.
[GraphQL](https://github.com/alecthomas/participle/tree/master/_examples/graphql) | Lexer+parser for GraphQL schemas
[HCL](https://github.com/alecthomas/participle/tree/master/_examples/hcl) | A parser for the [HashiCorp Configuration Language](https://github.com/hashicorp/hcl).
[INI](https://github.com/alecthomas/participle/tree/master/_examples/ini) | An INI file parser.
[Protobuf](https://github.com/alecthomas/participle/tree/master/_examples/protobuf) | A full [Protobuf](https://developers.google.com/protocol-buffers/) version 2 and 3 parser.
[SQL](https://github.com/alecthomas/participle/tree/master/_examples/sql) | A *very* rudimentary SQL SELECT parser.
[Thrift](https://github.com/alecthomas/participle/tree/master/_examples/thrift) | A full [Thrift](https://thrift.apache.org/docs/idl) parser.
[TOML](https://github.com/alecthomas/participle/blob/master/_examples/toml/main.go) | A [TOML](https://github.com/toml-lang/toml) parser.
Included below is a full GraphQL lexer and parser:
```go
package main
import (
"os"
"github.com/alecthomas/kong"
"github.com/alecthomas/repr"
"github.com/alecthomas/participle"
"github.com/alecthomas/participle/lexer"
"github.com/alecthomas/participle/lexer/ebnf"
)
type File struct {
Entries []*Entry `{ @@ }`
}
type Entry struct {
Type *Type ` @@`
Schema *Schema `| @@`
Enum *Enum `| @@`
Scalar string `| "scalar" @Ident`
}
type Enum struct {
Name string `"enum" @Ident`
Cases []string `"{" { @Ident } "}"`
}
type Schema struct {
Fields []*Field `"schema" "{" { @@ } "}"`
}
type Type struct {
Name string `"type" @Ident`
Implements string `[ "implements" @Ident ]`
Fields []*Field `"{" { @@ } "}"`
}
type Field struct {
Name string `@Ident`
Arguments []*Argument `[ "(" [ @@ { "," @@ } ] ")" ]`
Type *TypeRef `":" @@`
Annotation string `[ "@" @Ident ]`
}
type Argument struct {
Name string `@Ident`
Type *TypeRef `":" @@`
Default *Value `[ "=" @@ ]`
}
type TypeRef struct {
Array *TypeRef `( "[" @@ "]"`
Type string ` | @Ident )`
NonNullable bool `[ @"!" ]`
}
type Value struct {
Symbol string `@Ident`
}
var (
graphQLLexer = lexer.Must(ebnf.New(`
Comment = ("#" | "//") { "\u0000"…"\uffff"-"\n" } .
Ident = (alpha | "_") { "_" | alpha | digit } .
Number = ("." | digit) {"." | digit} .
Whitespace = " " | "\t" | "\n" | "\r" .
Punct = "!"…"/" | ":"…"@" | "["…`+"\"`\""+` | "{"…"~" .
alpha = "a"…"z" | "A"…"Z" .
digit = "0"…"9" .
`))
parser = participle.MustBuild(&File{},
participle.Lexer(graphQLLexer),
participle.Elide("Comment", "Whitespace"),
)
cli struct {
Files []string `arg:"" type:"existingfile" required:"" help:"GraphQL schema files to parse."`
}
)
func main() {
ctx := kong.Parse(&cli)
for _, file := range cli.Files {
ast := &File{}
r, err := os.Open(file)
ctx.FatalIfErrorf(err)
err = parser.Parse(r, ast)
r.Close()
repr.Println(ast)
ctx.FatalIfErrorf(err)
}
}
```
<a id="markdown-performance" name="performance"></a>
## Performance
One of the included examples is a complete Thrift parser
(shell-style comments are not supported). This gives
a convenient baseline for comparing to the PEG based
[pigeon](https://github.com/PuerkitoBio/pigeon), which is the parser used by
[go-thrift](https://github.com/samuel/go-thrift). Additionally, the pigeon
parser is utilising a generated parser, while the participle parser is built at
run time.
You can run the benchmarks yourself, but here's the output on my machine:
BenchmarkParticipleThrift-4 10000 221818 ns/op 48880 B/op 1240 allocs/op
BenchmarkGoThriftParser-4 2000 804709 ns/op 170301 B/op 3086 allocs/op
On a real life codebase of 47K lines of Thrift, Participle takes 200ms and go-
thrift takes 630ms, which aligns quite closely with the benchmarks.

255
vendor/github.com/alecthomas/participle/TUTORIAL.md generated vendored Normal file
View File

@ -0,0 +1,255 @@
# Participle parser tutorial
<!-- MarkdownTOC -->
1. [Introduction](#introduction)
1. [The complete grammar](#the-complete-grammar)
1. [Root of the .ini AST \(structure, fields\)](#root-of-the-ini-ast-structure-fields)
1. [.ini properties \(named tokens, capturing, literals\)](#ini-properties-named-tokens-capturing-literals)
1. [.ini property values \(alternates, recursive structs, sequences\)](#ini-property-values-alternates-recursive-structs-sequences)
1. [Complete, but limited, .ini grammar \(top-level properties only\)](#complete-but-limited-ini-grammar-top-level-properties-only)
1. [Extending our grammar to support sections](#extending-our-grammar-to-support-sections)
1. [\(Optional\) Source positional information](#optional-source-positional-information)
1. [Parsing using our grammar](#parsing-using-our-grammar)
<!-- /MarkdownTOC -->
## Introduction
Writing a parser in Participle typically involves starting from the "root" of
the AST, annotating fields with the grammar, then recursively expanding until
it is complete. The AST is expressed via Go data types and the grammar is
expressed through struct field tags, as a form of EBNF.
The parser we're going to create for this tutorial parses .ini files
like this:
```ini
age = 21
name = "Bob Smith"
[address]
city = "Beverly Hills"
postal_code = 90210
```
## The complete grammar
I think it's useful to see the complete grammar first, to see what we're
working towards. Read on below for details.
```go
type INI struct {
Properties []*Property `@@*`
Sections []*Section `@@*`
}
type Section struct {
Identifier string `"[" @Ident "]"`
Properties []*Property `@@*`
}
type Property struct {
Key string `@Ident "="`
Value *Value `@@`
}
type Value struct {
String *string ` @String`
Number *float64 `| @Float`
}
```
## Root of the .ini AST (structure, fields)
The first step is to create a root struct for our grammar. In the case of our
.ini parser, this struct will contain a sequence of properties:
```go
type INI struct {
Properties []*Property
}
type Property struct {
}
```
## .ini properties (named tokens, capturing, literals)
Each property in an .ini file has an identifier key:
```go
type Property struct {
Key string
}
```
The default lexer tokenises Go source code, and includes an `Ident` token type
that matches identifiers. To match this token we simply use the token type
name:
```go
type Property struct {
Key string `Ident`
}
```
This will *match* identifiers, but not *capture* them into the `Key` field. To
capture input tokens into AST fields, prefix any grammar node with `@`:
```go
type Property struct {
Key string `@Ident`
}
```
In .ini files, each key is separated from its value with a literal `=`. To
match a literal, enclose the literal in double quotes:
```go
type Property struct {
Key string `@Ident "="`
}
```
> Note: literals in the grammar must match tokens from the lexer *exactly*. In
> this example if the lexer does not output `=` as a distinct token the
> grammar will not match.
## .ini property values (alternates, recursive structs, sequences)
For the purposes of our example we are only going to support quoted string
and numeric property values. As each value can be *either* a string or a float
we'll need something akin to a sum type. Go's type system cannot express this
directly, so we'll use the common approach of making each element a pointer.
The selected "case" will *not* be nil.
```go
type Value struct {
String *string
Number *float64
}
```
> Note: Participle will hydrate pointers as necessary.
To express matching a set of alternatives we use the `|` operator:
```go
type Value struct {
String *string ` @String`
Number *float64 `| @Float`
}
```
> Note: the grammar can cross fields.
Next, we'll match values and capture them into the `Property`. To recursively
capture structs use `@@` (capture self):
```go
type Property struct {
Key string `@Ident "="`
Value *Value `@@`
}
```
Now that we can parse a `Property` we need to go back to the root of the
grammar. We want to parse 0 or more properties. To do this, we use `<expr>*`.
Participle will accumulate each match into the slice until matching fails,
then move to the next node in the grammar.
```go
type INI struct {
Properties []*Property `@@*`
}
```
> Note: tokens can also be accumulated into strings, appending each match.
## Complete, but limited, .ini grammar (top-level properties only)
We now have a functional, but limited, .ini parser!
```go
type INI struct {
Properties []*Property `@@*`
}
type Property struct {
Key string `@Ident "="`
Value *Value `@@`
}
type Value struct {
String *string ` @String`
Number *float64 `| @Float`
}
```
## Extending our grammar to support sections
Adding support for sections is simply a matter of utilising the constructs
we've just learnt. A section consists of a header identifier, and a sequence
of properties:
```go
type Section struct {
Identifier string `"[" @Ident "]"`
Properties []*Property `@@*`
}
```
Simple!
Now we just add a sequence of `Section`s to our root node:
```go
type INI struct {
Properties []*Property `@@*`
Sections []*Section `@@*`
}
```
And we're done!
## (Optional) Source positional information
If a grammar node includes a field with the name `Pos` and type `lexer.Position`, it will be automatically populated by positional information. eg.
```go
type Value struct {
Pos lexer.Position
String *string ` @String`
Number *float64 `| @Float`
}
```
This is useful for error reporting.
## Parsing using our grammar
To parse with this grammar we first construct the parser (we'll use the
default lexer for now):
```go
parser, err := participle.Build(&INI{})
```
Then create a root node and parse into it with `parser.Parse{,String,Bytes}()`:
```go
ini := &INI{}
err = parser.ParseString(`
age = 21
name = "Bob Smith"
[address]
city = "Beverly Hills"
postal_code = 90210
`, ini)
```
You can find the full example [here](_examples/ini/main.go), alongside
other examples including an SQL `SELECT` parser and a full
[Thrift](https://thrift.apache.org/) parser.

19
vendor/github.com/alecthomas/participle/api.go generated vendored Normal file
View File

@ -0,0 +1,19 @@
package participle
import (
"github.com/alecthomas/participle/lexer"
)
// Capture can be implemented by fields in order to transform captured tokens into field values.
type Capture interface {
Capture(values []string) error
}
// The Parseable interface can be implemented by any element in the grammar to provide custom parsing.
type Parseable interface {
// Parse into the receiver.
//
// Should return NextMatch if no tokens matched and parsing should continue.
// Nil should be returned if parsing was successful.
Parse(lex lexer.PeekingLexer) error
}

123
vendor/github.com/alecthomas/participle/context.go generated vendored Normal file
View File

@ -0,0 +1,123 @@
package participle
import (
"reflect"
"github.com/alecthomas/participle/lexer"
)
type contextFieldSet struct {
pos lexer.Position
strct reflect.Value
field structLexerField
fieldValue []reflect.Value
}
// Context for a single parse.
type parseContext struct {
*rewinder
lookahead int
caseInsensitive map[rune]bool
apply []*contextFieldSet
}
func newParseContext(lex lexer.Lexer, lookahead int, caseInsensitive map[rune]bool) (*parseContext, error) {
rew, err := newRewinder(lex)
if err != nil {
return nil, err
}
return &parseContext{
rewinder: rew,
caseInsensitive: caseInsensitive,
lookahead: lookahead,
}, nil
}
// Defer adds a function to be applied once a branch has been picked.
func (p *parseContext) Defer(pos lexer.Position, strct reflect.Value, field structLexerField, fieldValue []reflect.Value) {
p.apply = append(p.apply, &contextFieldSet{pos, strct, field, fieldValue})
}
// Apply deferred functions.
func (p *parseContext) Apply() error {
for _, apply := range p.apply {
if err := setField(apply.pos, apply.strct, apply.field, apply.fieldValue); err != nil {
return err
}
}
p.apply = nil
return nil
}
// Branch accepts the branch as the correct branch.
func (p *parseContext) Accept(branch *parseContext) {
p.apply = append(p.apply, branch.apply...)
p.rewinder = branch.rewinder
}
// Branch starts a new lookahead branch.
func (p *parseContext) Branch() *parseContext {
branch := &parseContext{}
*branch = *p
branch.apply = nil
branch.rewinder = p.rewinder.Lookahead()
return branch
}
// Stop returns true if parsing should terminate after the given "branch" failed to match.
func (p *parseContext) Stop(branch *parseContext) bool {
if branch.cursor > p.cursor+p.lookahead {
p.Accept(branch)
return true
}
return false
}
type rewinder struct {
cursor, limit int
tokens []lexer.Token
}
func newRewinder(lex lexer.Lexer) (*rewinder, error) {
r := &rewinder{}
for {
t, err := lex.Next()
if err != nil {
return nil, err
}
if t.EOF() {
break
}
r.tokens = append(r.tokens, t)
}
return r, nil
}
func (r *rewinder) Next() (lexer.Token, error) {
if r.cursor >= len(r.tokens) {
return lexer.EOFToken(lexer.Position{}), nil
}
r.cursor++
return r.tokens[r.cursor-1], nil
}
func (r *rewinder) Peek(n int) (lexer.Token, error) {
i := r.cursor + n
if i >= len(r.tokens) {
return lexer.EOFToken(lexer.Position{}), nil
}
return r.tokens[i], nil
}
// Lookahead returns a new rewinder usable for lookahead.
func (r *rewinder) Lookahead() *rewinder {
clone := &rewinder{}
*clone = *r
clone.limit = clone.cursor
return clone
}
// Keep this lookahead rewinder.
func (r *rewinder) Keep() {
r.limit = 0
}

73
vendor/github.com/alecthomas/participle/doc.go generated vendored Normal file
View File

@ -0,0 +1,73 @@
// Package participle constructs parsers from definitions in struct tags and parses directly into
// those structs. The approach is philosophically similar to how other marshallers work in Go,
// "unmarshalling" an instance of a grammar into a struct.
//
// The supported annotation syntax is:
//
// - `@<expr>` Capture expression into the field.
// - `@@` Recursively capture using the fields own type.
// - `<identifier>` Match named lexer token.
// - `( ... )` Group.
// - `"..."` Match the literal (note that the lexer must emit tokens matching this literal exactly).
// - `"...":<identifier>` Match the literal, specifying the exact lexer token type to match.
// - `<expr> <expr> ...` Match expressions.
// - `<expr> | <expr>` Match one of the alternatives.
//
// The following modifiers can be used after any expression:
//
// - `*` Expression can match zero or more times.
// - `+` Expression must match one or more times.
// - `?` Expression can match zero or once.
// - `!` Require a non-empty match (this is useful with a sequence of optional matches eg. `("a"? "b"? "c"?)!`).
//
// Supported but deprecated:
//
// - `{ ... }` Match 0 or more times (**DEPRECATED** - prefer `( ... )*`).
// - `[ ... ]` Optional (**DEPRECATED** - prefer `( ... )?`).
//
// Here's an example of an EBNF grammar.
//
// type Group struct {
// Expression *Expression `"(" @@ ")"`
// }
//
// type Option struct {
// Expression *Expression `"[" @@ "]"`
// }
//
// type Repetition struct {
// Expression *Expression `"{" @@ "}"`
// }
//
// type Literal struct {
// Start string `@String` // lexer.Lexer token "String"
// End string `("…" @String)?`
// }
//
// type Term struct {
// Name string ` @Ident`
// Literal *Literal `| @@`
// Group *Group `| @@`
// Option *Option `| @@`
// Repetition *Repetition `| @@`
// }
//
// type Sequence struct {
// Terms []*Term `@@+`
// }
//
// type Expression struct {
// Alternatives []*Sequence `@@ ("|" @@)*`
// }
//
// type Expressions []*Expression
//
// type Production struct {
// Name string `@Ident "="`
// Expressions Expressions `@@+ "."`
// }
//
// type EBNF struct {
// Productions []*Production `@@*`
// }
package participle

7
vendor/github.com/alecthomas/participle/go.mod generated vendored Normal file
View File

@ -0,0 +1,7 @@
module github.com/alecthomas/participle
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/stretchr/testify v1.2.2
)

6
vendor/github.com/alecthomas/participle/go.sum generated vendored Normal file
View File

@ -0,0 +1,6 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=

324
vendor/github.com/alecthomas/participle/grammar.go generated vendored Normal file
View File

@ -0,0 +1,324 @@
package participle
import (
"fmt"
"reflect"
"text/scanner"
"github.com/alecthomas/participle/lexer"
)
type generatorContext struct {
lexer.Definition
typeNodes map[reflect.Type]node
symbolsToIDs map[rune]string
}
func newGeneratorContext(lex lexer.Definition) *generatorContext {
return &generatorContext{
Definition: lex,
typeNodes: map[reflect.Type]node{},
symbolsToIDs: lexer.SymbolsByRune(lex),
}
}
// Takes a type and builds a tree of nodes out of it.
func (g *generatorContext) parseType(t reflect.Type) (_ node, returnedError error) {
rt := t
t = indirectType(t)
if n, ok := g.typeNodes[t]; ok {
return n, nil
}
if rt.Implements(parseableType) {
return &parseable{rt.Elem()}, nil
}
if reflect.PtrTo(rt).Implements(parseableType) {
return &parseable{rt}, nil
}
switch t.Kind() {
case reflect.Slice, reflect.Ptr:
t = indirectType(t.Elem())
if t.Kind() != reflect.Struct {
return nil, fmt.Errorf("expected a struct but got %T", t)
}
fallthrough
case reflect.Struct:
slexer, err := lexStruct(t)
if err != nil {
return nil, err
}
out := &strct{typ: t}
g.typeNodes[t] = out // Ensure we avoid infinite recursion.
if slexer.NumField() == 0 {
return nil, fmt.Errorf("can not parse into empty struct %s", t)
}
defer decorate(&returnedError, func() string { return slexer.Field().Name })
e, err := g.parseDisjunction(slexer)
if err != nil {
return nil, err
}
if e == nil {
return nil, fmt.Errorf("no grammar found in %s", t)
}
if token, _ := slexer.Peek(); !token.EOF() {
return nil, fmt.Errorf("unexpected input %q", token.Value)
}
out.expr = e
return out, nil
}
return nil, fmt.Errorf("%s should be a struct or should implement the Parseable interface", t)
}
func (g *generatorContext) parseDisjunction(slexer *structLexer) (node, error) {
out := &disjunction{}
for {
n, err := g.parseSequence(slexer)
if err != nil {
return nil, err
}
out.nodes = append(out.nodes, n)
if token, _ := slexer.Peek(); token.Type != '|' {
break
}
_, err = slexer.Next() // |
if err != nil {
return nil, err
}
}
if len(out.nodes) == 1 {
return out.nodes[0], nil
}
return out, nil
}
func (g *generatorContext) parseSequence(slexer *structLexer) (node, error) {
head := &sequence{}
cursor := head
loop:
for {
if token, err := slexer.Peek(); err != nil {
return nil, err
} else if token.Type == lexer.EOF {
break loop
}
term, err := g.parseTerm(slexer)
if err != nil {
return nil, err
}
if term == nil {
break loop
}
if cursor.node == nil {
cursor.head = true
cursor.node = term
} else {
cursor.next = &sequence{node: term}
cursor = cursor.next
}
}
if head.node == nil {
return nil, nil
}
if head.next == nil {
return head.node, nil
}
return head, nil
}
func (g *generatorContext) parseTermNoModifiers(slexer *structLexer) (node, error) {
t, err := slexer.Peek()
if err != nil {
return nil, err
}
var out node
switch t.Type {
case '@':
out, err = g.parseCapture(slexer)
case scanner.String, scanner.RawString, scanner.Char:
out, err = g.parseLiteral(slexer)
case '[':
return g.parseOptional(slexer)
case '{':
return g.parseRepetition(slexer)
case '(':
out, err = g.parseGroup(slexer)
case scanner.Ident:
out, err = g.parseReference(slexer)
case lexer.EOF:
_, _ = slexer.Next()
return nil, nil
default:
return nil, nil
}
return out, err
}
func (g *generatorContext) parseTerm(slexer *structLexer) (node, error) {
out, err := g.parseTermNoModifiers(slexer)
if err != nil {
return nil, err
}
return g.parseModifier(slexer, out)
}
// Parse modifiers: ?, *, + and/or !
func (g *generatorContext) parseModifier(slexer *structLexer, expr node) (node, error) {
out := &group{expr: expr}
t, err := slexer.Peek()
if err != nil {
return nil, err
}
switch t.Type {
case '!':
out.mode = groupMatchNonEmpty
case '+':
out.mode = groupMatchOneOrMore
case '*':
out.mode = groupMatchZeroOrMore
case '?':
out.mode = groupMatchZeroOrOne
default:
return expr, nil
}
_, _ = slexer.Next()
return out, nil
}
// @<expression> captures <expression> into the current field.
func (g *generatorContext) parseCapture(slexer *structLexer) (node, error) {
_, _ = slexer.Next()
token, err := slexer.Peek()
if err != nil {
return nil, err
}
field := slexer.Field()
if token.Type == '@' {
_, _ = slexer.Next()
n, err := g.parseType(field.Type)
if err != nil {
return nil, err
}
return &capture{field, n}, nil
}
if indirectType(field.Type).Kind() == reflect.Struct && !field.Type.Implements(captureType) {
return nil, fmt.Errorf("structs can only be parsed with @@ or by implementing the Capture interface")
}
n, err := g.parseTermNoModifiers(slexer)
if err != nil {
return nil, err
}
return &capture{field, n}, nil
}
// A reference in the form <identifier> refers to a named token from the lexer.
func (g *generatorContext) parseReference(slexer *structLexer) (node, error) { // nolint: interfacer
token, err := slexer.Next()
if err != nil {
return nil, err
}
if token.Type != scanner.Ident {
return nil, fmt.Errorf("expected identifier but got %q", token)
}
typ, ok := g.Symbols()[token.Value]
if !ok {
return nil, fmt.Errorf("unknown token type %q", token)
}
return &reference{typ: typ, identifier: token.Value}, nil
}
// [ <expression> ] optionally matches <expression>.
func (g *generatorContext) parseOptional(slexer *structLexer) (node, error) {
_, _ = slexer.Next() // [
disj, err := g.parseDisjunction(slexer)
if err != nil {
return nil, err
}
n := &group{expr: disj, mode: groupMatchZeroOrOne}
next, err := slexer.Next()
if err != nil {
return nil, err
}
if next.Type != ']' {
return nil, fmt.Errorf("expected ] but got %q", next)
}
return n, nil
}
// { <expression> } matches 0 or more repititions of <expression>
func (g *generatorContext) parseRepetition(slexer *structLexer) (node, error) {
_, _ = slexer.Next() // {
disj, err := g.parseDisjunction(slexer)
if err != nil {
return nil, err
}
n := &group{expr: disj, mode: groupMatchZeroOrMore}
next, err := slexer.Next()
if err != nil {
return nil, err
}
if next.Type != '}' {
return nil, fmt.Errorf("expected } but got %q", next)
}
return n, nil
}
// ( <expression> ) groups a sub-expression
func (g *generatorContext) parseGroup(slexer *structLexer) (node, error) {
_, _ = slexer.Next() // (
disj, err := g.parseDisjunction(slexer)
if err != nil {
return nil, err
}
next, err := slexer.Next() // )
if err != nil {
return nil, err
}
if next.Type != ')' {
return nil, fmt.Errorf("expected ) but got %q", next)
}
return &group{expr: disj}, nil
}
// A literal string.
//
// Note that for this to match, the tokeniser must be able to produce this string. For example,
// if the tokeniser only produces individual characters but the literal is "hello", or vice versa.
func (g *generatorContext) parseLiteral(lex *structLexer) (node, error) { // nolint: interfacer
token, err := lex.Next()
if err != nil {
return nil, err
}
if token.Type != scanner.String && token.Type != scanner.RawString && token.Type != scanner.Char {
return nil, fmt.Errorf("expected quoted string but got %q", token)
}
s := token.Value
t := rune(-1)
token, err = lex.Peek()
if err != nil {
return nil, err
}
if token.Value == ":" && (token.Type == scanner.Char || token.Type == ':') {
_, _ = lex.Next()
token, err = lex.Next()
if err != nil {
return nil, err
}
if token.Type != scanner.Ident {
return nil, fmt.Errorf("expected identifier for literal type constraint but got %q", token)
}
var ok bool
t, ok = g.Symbols()[token.Value]
if !ok {
return nil, fmt.Errorf("unknown token type %q in literal type constraint", token)
}
}
return &literal{s: s, t: t, tt: g.symbolsToIDs[t]}, nil
}
func indirectType(t reflect.Type) reflect.Type {
if t.Kind() == reflect.Ptr || t.Kind() == reflect.Slice {
return indirectType(t.Elem())
}
return t
}

19
vendor/github.com/alecthomas/participle/lexer/doc.go generated vendored Normal file
View File

@ -0,0 +1,19 @@
// Package lexer defines interfaces and implementations used by Participle to perform lexing.
//
// The primary interfaces are Definition and Lexer. There are three implementations of these
// interfaces:
//
// TextScannerLexer is based on text/scanner. This is the fastest, but least flexible, in that
// tokens are restricted to those supported by that package. It can scan about 5M tokens/second on a
// late 2013 15" MacBook Pro.
//
// The second lexer is constructed via the Regexp() function, mapping regexp capture groups
// to tokens. The complete input source is read into memory, so it is unsuitable for large inputs.
//
// The final lexer provided accepts a lexical grammar in EBNF. Each capitalised production is a
// lexical token supported by the resulting Lexer. This is very flexible, but a bit slower, scanning
// around 730K tokens/second on the same machine, though it is currently completely unoptimised.
// This could/should be converted to a table-based lexer.
//
// Lexer implementations must use Panic/Panicf to report errors.
package lexer

View File

@ -0,0 +1,26 @@
package lexer
import "fmt"
// Error represents an error while parsing.
type Error struct {
Message string
Pos Position
}
// Errorf creats a new Error at the given position.
func Errorf(pos Position, format string, args ...interface{}) *Error {
return &Error{
Message: fmt.Sprintf(format, args...),
Pos: pos,
}
}
// Error complies with the error interface and reports the position of an error.
func (e *Error) Error() string {
filename := e.Pos.Filename
if filename == "" {
filename = "<source>"
}
return fmt.Sprintf("%s:%d:%d: %s", filename, e.Pos.Line, e.Pos.Column, e.Message)
}

150
vendor/github.com/alecthomas/participle/lexer/lexer.go generated vendored Normal file
View File

@ -0,0 +1,150 @@
package lexer
import (
"fmt"
"io"
)
const (
// EOF represents an end of file.
EOF rune = -(iota + 1)
)
// EOFToken creates a new EOF token at the given position.
func EOFToken(pos Position) Token {
return Token{Type: EOF, Pos: pos}
}
// Definition provides the parser with metadata for a lexer.
type Definition interface {
// Lex an io.Reader.
Lex(io.Reader) (Lexer, error)
// Symbols returns a map of symbolic names to the corresponding pseudo-runes for those symbols.
// This is the same approach as used by text/scanner. For example, "EOF" might have the rune
// value of -1, "Ident" might be -2, and so on.
Symbols() map[string]rune
}
// A Lexer returns tokens from a source.
type Lexer interface {
// Next consumes and returns the next token.
Next() (Token, error)
}
// A PeekingLexer returns tokens from a source and allows peeking.
type PeekingLexer interface {
Lexer
// Peek at the next token.
Peek(n int) (Token, error)
}
// SymbolsByRune returns a map of lexer symbol names keyed by rune.
func SymbolsByRune(def Definition) map[rune]string {
out := map[rune]string{}
for s, r := range def.Symbols() {
out[r] = s
}
return out
}
// NameOfReader attempts to retrieve the filename of a reader.
func NameOfReader(r interface{}) string {
if nr, ok := r.(interface{ Name() string }); ok {
return nr.Name()
}
return ""
}
// Must takes the result of a Definition constructor call and returns the definition, but panics if
// it errors
//
// eg.
//
// lex = lexer.Must(lexer.Build(`Symbol = "symbol" .`))
func Must(def Definition, err error) Definition {
if err != nil {
panic(err)
}
return def
}
// ConsumeAll reads all tokens from a Lexer.
func ConsumeAll(lexer Lexer) ([]Token, error) {
tokens := []Token{}
for {
token, err := lexer.Next()
if err != nil {
return nil, err
}
tokens = append(tokens, token)
if token.Type == EOF {
return tokens, nil
}
}
}
// Position of a token.
type Position struct {
Filename string
Offset int
Line int
Column int
}
func (p Position) GoString() string {
return fmt.Sprintf("Position{Filename: %q, Offset: %d, Line: %d, Column: %d}",
p.Filename, p.Offset, p.Line, p.Column)
}
func (p Position) String() string {
filename := p.Filename
if filename == "" {
filename = "<source>"
}
return fmt.Sprintf("%s:%d:%d", filename, p.Line, p.Column)
}
// A Token returned by a Lexer.
type Token struct {
// Type of token. This is the value keyed by symbol as returned by Definition.Symbols().
Type rune
Value string
Pos Position
}
// RuneToken represents a rune as a Token.
func RuneToken(r rune) Token {
return Token{Type: r, Value: string(r)}
}
// EOF returns true if this Token is an EOF token.
func (t Token) EOF() bool {
return t.Type == EOF
}
func (t Token) String() string {
if t.EOF() {
return "<EOF>"
}
return t.Value
}
func (t Token) GoString() string {
return fmt.Sprintf("Token{%d, %q}", t.Type, t.Value)
}
// MakeSymbolTable builds a lookup table for checking token ID existence.
//
// For each symbolic name in "types", the returned map will contain the corresponding token ID as a key.
func MakeSymbolTable(def Definition, types ...string) (map[rune]bool, error) {
symbols := def.Symbols()
table := map[rune]bool{}
for _, symbol := range types {
rn, ok := symbols[symbol]
if !ok {
return nil, fmt.Errorf("lexer does not support symbol %q", symbol)
}
table[rn] = true
}
return table, nil
}

37
vendor/github.com/alecthomas/participle/lexer/peek.go generated vendored Normal file
View File

@ -0,0 +1,37 @@
package lexer
// Upgrade a Lexer to a PeekingLexer with arbitrary lookahead.
func Upgrade(lexer Lexer) PeekingLexer {
if peeking, ok := lexer.(PeekingLexer); ok {
return peeking
}
return &lookaheadLexer{Lexer: lexer}
}
type lookaheadLexer struct {
Lexer
peeked []Token
}
func (l *lookaheadLexer) Peek(n int) (Token, error) {
for len(l.peeked) <= n {
t, err := l.Lexer.Next()
if err != nil {
return Token{}, err
}
if t.EOF() {
return t, nil
}
l.peeked = append(l.peeked, t)
}
return l.peeked[n], nil
}
func (l *lookaheadLexer) Next() (Token, error) {
if len(l.peeked) > 0 {
t := l.peeked[0]
l.peeked = l.peeked[1:]
return t, nil
}
return l.Lexer.Next()
}

112
vendor/github.com/alecthomas/participle/lexer/regexp.go generated vendored Normal file
View File

@ -0,0 +1,112 @@
package lexer
import (
"bytes"
"io"
"io/ioutil"
"regexp"
"unicode/utf8"
)
var eolBytes = []byte("\n")
type regexpDefinition struct {
re *regexp.Regexp
symbols map[string]rune
}
// Regexp creates a lexer definition from a regular expression.
//
// Each named sub-expression in the regular expression matches a token. Anonymous sub-expressions
// will be matched and discarded.
//
// eg.
//
// def, err := Regexp(`(?P<Ident>[a-z]+)|(\s+)|(?P<Number>\d+)`)
func Regexp(pattern string) (Definition, error) {
re, err := regexp.Compile(pattern)
if err != nil {
return nil, err
}
symbols := map[string]rune{
"EOF": EOF,
}
for i, sym := range re.SubexpNames()[1:] {
if sym != "" {
symbols[sym] = EOF - 1 - rune(i)
}
}
return &regexpDefinition{re: re, symbols: symbols}, nil
}
func (d *regexpDefinition) Lex(r io.Reader) (Lexer, error) {
b, err := ioutil.ReadAll(r)
if err != nil {
return nil, err
}
return &regexpLexer{
pos: Position{
Filename: NameOfReader(r),
Line: 1,
Column: 1,
},
b: b,
re: d.re,
names: d.re.SubexpNames(),
}, nil
}
func (d *regexpDefinition) Symbols() map[string]rune {
return d.symbols
}
type regexpLexer struct {
pos Position
b []byte
re *regexp.Regexp
names []string
}
func (r *regexpLexer) Next() (Token, error) {
nextToken:
for len(r.b) != 0 {
matches := r.re.FindSubmatchIndex(r.b)
if matches == nil || matches[0] != 0 {
rn, _ := utf8.DecodeRune(r.b)
return Token{}, Errorf(r.pos, "invalid token %q", rn)
}
match := r.b[:matches[1]]
token := Token{
Pos: r.pos,
Value: string(match),
}
// Update lexer state.
r.pos.Offset += matches[1]
lines := bytes.Count(match, eolBytes)
r.pos.Line += lines
// Update column.
if lines == 0 {
r.pos.Column += utf8.RuneCount(match)
} else {
r.pos.Column = utf8.RuneCount(match[bytes.LastIndex(match, eolBytes):])
}
// Move slice along.
r.b = r.b[matches[1]:]
// Finally, assign token type. If it is not a named group, we continue to the next token.
for i := 2; i < len(matches); i += 2 {
if matches[i] != -1 {
if r.names[i/2] == "" {
continue nextToken
}
token.Type = EOF - rune(i/2)
break
}
}
return token, nil
}
return EOFToken(r.pos), nil
}

View File

@ -0,0 +1,125 @@
package lexer
import (
"bytes"
"fmt"
"io"
"strconv"
"strings"
"text/scanner"
"unicode/utf8"
)
// TextScannerLexer is a lexer that uses the text/scanner module.
var (
TextScannerLexer Definition = &defaultDefinition{}
// DefaultDefinition defines properties for the default lexer.
DefaultDefinition = TextScannerLexer
)
type defaultDefinition struct{}
func (d *defaultDefinition) Lex(r io.Reader) (Lexer, error) {
return Lex(r), nil
}
func (d *defaultDefinition) Symbols() map[string]rune {
return map[string]rune{
"EOF": scanner.EOF,
"Char": scanner.Char,
"Ident": scanner.Ident,
"Int": scanner.Int,
"Float": scanner.Float,
"String": scanner.String,
"RawString": scanner.RawString,
"Comment": scanner.Comment,
}
}
// textScannerLexer is a Lexer based on text/scanner.Scanner
type textScannerLexer struct {
scanner *scanner.Scanner
filename string
err error
}
// Lex an io.Reader with text/scanner.Scanner.
//
// This provides very fast lexing of source code compatible with Go tokens.
//
// Note that this differs from text/scanner.Scanner in that string tokens will be unquoted.
func Lex(r io.Reader) Lexer {
lexer := lexWithScanner(r, &scanner.Scanner{})
lexer.scanner.Error = func(s *scanner.Scanner, msg string) {
// This is to support single quoted strings. Hacky.
if msg != "illegal char literal" {
lexer.err = Errorf(Position(lexer.scanner.Pos()), msg)
}
}
return lexer
}
// LexWithScanner creates a Lexer from a user-provided scanner.Scanner.
//
// Useful if you need to customise the Scanner.
func LexWithScanner(r io.Reader, scan *scanner.Scanner) Lexer {
return lexWithScanner(r, scan)
}
func lexWithScanner(r io.Reader, scan *scanner.Scanner) *textScannerLexer {
lexer := &textScannerLexer{
filename: NameOfReader(r),
scanner: scan,
}
lexer.scanner.Init(r)
return lexer
}
// LexBytes returns a new default lexer over bytes.
func LexBytes(b []byte) Lexer {
return Lex(bytes.NewReader(b))
}
// LexString returns a new default lexer over a string.
func LexString(s string) Lexer {
return Lex(strings.NewReader(s))
}
func (t *textScannerLexer) Next() (Token, error) {
typ := t.scanner.Scan()
text := t.scanner.TokenText()
pos := Position(t.scanner.Position)
pos.Filename = t.filename
if t.err != nil {
return Token{}, t.err
}
return textScannerTransform(Token{
Type: typ,
Value: text,
Pos: pos,
})
}
func textScannerTransform(token Token) (Token, error) {
// Unquote strings.
switch token.Type {
case scanner.Char:
// FIXME(alec): This is pretty hacky...we convert a single quoted char into a double
// quoted string in order to support single quoted strings.
token.Value = fmt.Sprintf("\"%s\"", token.Value[1:len(token.Value)-1])
fallthrough
case scanner.String:
s, err := strconv.Unquote(token.Value)
if err != nil {
return Token{}, Errorf(token.Pos, "%s: %q", err.Error(), token.Value)
}
token.Value = s
if token.Type == scanner.Char && utf8.RuneCountInString(s) > 1 {
token.Type = scanner.String
}
case scanner.RawString:
token.Value = token.Value[1 : len(token.Value)-1]
}
return token, nil
}

118
vendor/github.com/alecthomas/participle/map.go generated vendored Normal file
View File

@ -0,0 +1,118 @@
package participle
import (
"errors"
"io"
"strconv"
"strings"
"github.com/alecthomas/participle/lexer"
)
type mapperByToken struct {
symbols []string
mapper Mapper
}
// DropToken can be returned by a Mapper to remove a token from the stream.
var DropToken = errors.New("drop token") // nolint: golint
// Mapper function for mutating tokens before being applied to the AST.
//
// If the Mapper func returns an error of DropToken, the token will be removed from the stream.
type Mapper func(token lexer.Token) (lexer.Token, error)
// Map is an Option that configures the Parser to apply a mapping function to each Token from the lexer.
//
// This can be useful to eg. upper-case all tokens of a certain type, or dequote strings.
//
// "symbols" specifies the token symbols that the Mapper will be applied to. If empty, all tokens will be mapped.
func Map(mapper Mapper, symbols ...string) Option {
return func(p *Parser) error {
p.mappers = append(p.mappers, mapperByToken{
mapper: mapper,
symbols: symbols,
})
return nil
}
}
// Unquote applies strconv.Unquote() to tokens of the given types.
//
// Tokens of type "String" will be unquoted if no other types are provided.
func Unquote(types ...string) Option {
if len(types) == 0 {
types = []string{"String"}
}
return Map(func(t lexer.Token) (lexer.Token, error) {
value, err := unquote(t.Value)
if err != nil {
return t, lexer.Errorf(t.Pos, "invalid quoted string %q: %s", t.Value, err.Error())
}
t.Value = value
return t, nil
}, types...)
}
func unquote(s string) (string, error) {
quote := s[0]
s = s[1 : len(s)-1]
out := ""
for s != "" {
value, _, tail, err := strconv.UnquoteChar(s, quote)
if err != nil {
return "", err
}
s = tail
out += string(value)
}
return out, nil
}
// Upper is an Option that upper-cases all tokens of the given type. Useful for case normalisation.
func Upper(types ...string) Option {
return Map(func(token lexer.Token) (lexer.Token, error) {
token.Value = strings.ToUpper(token.Value)
return token, nil
}, types...)
}
// Elide drops tokens of the specified types.
func Elide(types ...string) Option {
return Map(func(token lexer.Token) (lexer.Token, error) {
return lexer.Token{}, DropToken
}, types...)
}
// Apply a Mapping to all tokens coming out of a Lexer.
type mappingLexerDef struct {
lexer.Definition
mapper Mapper
}
func (m *mappingLexerDef) Lex(r io.Reader) (lexer.Lexer, error) {
lexer, err := m.Definition.Lex(r)
if err != nil {
return nil, err
}
return &mappingLexer{lexer, m.mapper}, nil
}
type mappingLexer struct {
lexer.Lexer
mapper Mapper
}
func (m *mappingLexer) Next() (lexer.Token, error) {
for {
t, err := m.Lexer.Next()
if err != nil {
return t, err
}
t, err = m.mapper(t)
if err == DropToken {
continue
}
return t, err
}
}

575
vendor/github.com/alecthomas/participle/nodes.go generated vendored Normal file
View File

@ -0,0 +1,575 @@
package participle
import (
"errors"
"fmt"
"reflect"
"strconv"
"strings"
"github.com/alecthomas/participle/lexer"
)
var (
// MaxIterations limits the number of elements capturable by {}.
MaxIterations = 1000000
positionType = reflect.TypeOf(lexer.Position{})
captureType = reflect.TypeOf((*Capture)(nil)).Elem()
parseableType = reflect.TypeOf((*Parseable)(nil)).Elem()
// NextMatch should be returned by Parseable.Parse() method implementations to indicate
// that the node did not match and that other matches should be attempted, if appropriate.
NextMatch = errors.New("no match") // nolint: golint
)
// A node in the grammar.
type node interface {
// Parse from scanner into value.
//
// Returned slice will be nil if the node does not match.
Parse(ctx *parseContext, parent reflect.Value) ([]reflect.Value, error)
// Return a decent string representation of the Node.
String() string
}
func decorate(err *error, name func() string) {
if *err == nil {
return
}
switch realError := (*err).(type) {
case *lexer.Error:
*err = &lexer.Error{Message: name() + ": " + realError.Message, Pos: realError.Pos}
default:
*err = fmt.Errorf("%s: %s", name(), realError)
}
}
// A node that proxies to an implementation that implements the Parseable interface.
type parseable struct {
t reflect.Type
}
func (p *parseable) String() string { return stringer(p) }
func (p *parseable) Parse(ctx *parseContext, parent reflect.Value) (out []reflect.Value, err error) {
rv := reflect.New(p.t)
v := rv.Interface().(Parseable)
err = v.Parse(ctx)
if err != nil {
if err == NextMatch {
return nil, nil
}
return nil, err
}
return []reflect.Value{rv.Elem()}, nil
}
type strct struct {
typ reflect.Type
expr node
}
func (s *strct) String() string { return stringer(s) }
func (s *strct) maybeInjectPos(pos lexer.Position, v reflect.Value) {
if f := v.FieldByName("Pos"); f.IsValid() && f.Type() == positionType {
f.Set(reflect.ValueOf(pos))
}
}
func (s *strct) Parse(ctx *parseContext, parent reflect.Value) (out []reflect.Value, err error) {
sv := reflect.New(s.typ).Elem()
t, err := ctx.Peek(0)
if err != nil {
return nil, err
}
s.maybeInjectPos(t.Pos, sv)
if out, err = s.expr.Parse(ctx, sv); err != nil {
_ = ctx.Apply()
return []reflect.Value{sv}, err
} else if out == nil {
return nil, nil
}
return []reflect.Value{sv}, ctx.Apply()
}
type groupMatchMode int
const (
groupMatchOnce groupMatchMode = iota
groupMatchZeroOrOne = iota
groupMatchZeroOrMore = iota
groupMatchOneOrMore = iota
groupMatchNonEmpty = iota
)
// ( <expr> ) - match once
// ( <expr> )* - match zero or more times
// ( <expr> )+ - match one or more times
// ( <expr> )? - match zero or once
// ( <expr> )! - must be a non-empty match
//
// The additional modifier "!" forces the content of the group to be non-empty if it does match.
type group struct {
expr node
mode groupMatchMode
}
func (g *group) String() string { return stringer(g) }
func (g *group) Parse(ctx *parseContext, parent reflect.Value) (out []reflect.Value, err error) {
// Configure min/max matches.
min := 1
max := 1
switch g.mode {
case groupMatchNonEmpty:
out, err = g.expr.Parse(ctx, parent)
if err != nil {
return out, err
}
if len(out) == 0 {
t, _ := ctx.Peek(0)
return out, lexer.Errorf(t.Pos, "sub-expression %s cannot be empty", g)
}
return out, nil
case groupMatchOnce:
return g.expr.Parse(ctx, parent)
case groupMatchZeroOrOne:
min = 0
case groupMatchZeroOrMore:
min = 0
max = MaxIterations
case groupMatchOneOrMore:
min = 1
max = MaxIterations
}
matches := 0
for ; matches < max; matches++ {
branch := ctx.Branch()
v, err := g.expr.Parse(branch, parent)
out = append(out, v...)
if err != nil {
// Optional part failed to match.
if ctx.Stop(branch) {
return out, err
}
break
} else {
ctx.Accept(branch)
}
if v == nil {
break
}
}
// fmt.Printf("%d < %d < %d: out == nil? %v\n", min, matches, max, out == nil)
t, _ := ctx.Peek(0)
if matches >= MaxIterations {
panic(lexer.Errorf(t.Pos, "too many iterations of %s (> %d)", g, MaxIterations))
}
if matches < min {
return out, lexer.Errorf(t.Pos, "sub-expression %s must match at least once", g)
}
// The idea here is that something like "a"? is a successful match and that parsing should proceed.
if min == 0 && out == nil {
out = []reflect.Value{}
}
return out, nil
}
// <expr> {"|" <expr>}
type disjunction struct {
nodes []node
}
func (d *disjunction) String() string { return stringer(d) }
func (d *disjunction) Parse(ctx *parseContext, parent reflect.Value) (out []reflect.Value, err error) {
var (
deepestError = 0
firstError error
firstValues []reflect.Value
)
for _, a := range d.nodes {
branch := ctx.Branch()
if value, err := a.Parse(branch, parent); err != nil {
// If this branch progressed too far and still didn't match, error out.
if ctx.Stop(branch) {
return value, err
}
// Show the closest error returned. The idea here is that the further the parser progresses
// without error, the more difficult it is to trace the error back to its root.
if err != nil && branch.cursor >= deepestError {
firstError = err
firstValues = value
deepestError = branch.cursor
}
} else if value != nil {
ctx.Accept(branch)
return value, nil
}
}
if firstError != nil {
return firstValues, firstError
}
return nil, nil
}
// <node> ...
type sequence struct {
head bool
node node
next *sequence
}
func (s *sequence) String() string { return stringer(s) }
func (s *sequence) Parse(ctx *parseContext, parent reflect.Value) (out []reflect.Value, err error) {
for n := s; n != nil; n = n.next {
child, err := n.node.Parse(ctx, parent)
out = append(out, child...)
if err != nil {
return out, err
}
if child == nil {
// Early exit if first value doesn't match, otherwise all values must match.
if n == s {
return nil, nil
}
token, err := ctx.Peek(0)
if err != nil {
return nil, err
}
return out, lexer.Errorf(token.Pos, "unexpected %q (expected %s)", token, n)
}
}
return out, nil
}
// @<expr>
type capture struct {
field structLexerField
node node
}
func (c *capture) String() string { return stringer(c) }
func (c *capture) Parse(ctx *parseContext, parent reflect.Value) (out []reflect.Value, err error) {
token, err := ctx.Peek(0)
if err != nil {
return nil, err
}
pos := token.Pos
v, err := c.node.Parse(ctx, parent)
if err != nil {
if v != nil {
ctx.Defer(pos, parent, c.field, v)
}
return []reflect.Value{parent}, err
}
if v == nil {
return nil, nil
}
ctx.Defer(pos, parent, c.field, v)
return []reflect.Value{parent}, nil
}
// <identifier> - named lexer token reference
type reference struct {
typ rune
identifier string // Used for informational purposes.
}
func (r *reference) String() string { return stringer(r) }
func (r *reference) Parse(ctx *parseContext, parent reflect.Value) (out []reflect.Value, err error) {
token, err := ctx.Peek(0)
if err != nil {
return nil, err
}
if token.Type != r.typ {
return nil, nil
}
_, _ = ctx.Next()
return []reflect.Value{reflect.ValueOf(token.Value)}, nil
}
// [ <expr> ] <sequence>
type optional struct {
node node
}
func (o *optional) String() string { return stringer(o) }
func (o *optional) Parse(ctx *parseContext, parent reflect.Value) (out []reflect.Value, err error) {
branch := ctx.Branch()
out, err = o.node.Parse(branch, parent)
if err != nil {
// Optional part failed to match.
if ctx.Stop(branch) {
return out, err
}
} else {
ctx.Accept(branch)
}
if out == nil {
out = []reflect.Value{}
}
return out, nil
}
// { <expr> } <sequence>
type repetition struct {
node node
}
func (r *repetition) String() string { return stringer(r) }
// Parse a repetition. Once a repetition is encountered it will always match, so grammars
// should ensure that branches are differentiated prior to the repetition.
func (r *repetition) Parse(ctx *parseContext, parent reflect.Value) (out []reflect.Value, err error) {
i := 0
for ; i < MaxIterations; i++ {
branch := ctx.Branch()
v, err := r.node.Parse(branch, parent)
out = append(out, v...)
if err != nil {
// Optional part failed to match.
if ctx.Stop(branch) {
return out, err
}
break
} else {
ctx.Accept(branch)
}
if v == nil {
break
}
}
if i >= MaxIterations {
t, _ := ctx.Peek(0)
panic(lexer.Errorf(t.Pos, "too many iterations of %s (> %d)", r, MaxIterations))
}
if out == nil {
out = []reflect.Value{}
}
return out, nil
}
// Match a token literal exactly "..."[:<type>].
type literal struct {
s string
t rune
tt string // Used for display purposes - symbolic name of t.
}
func (l *literal) String() string { return stringer(l) }
func (l *literal) Parse(ctx *parseContext, parent reflect.Value) (out []reflect.Value, err error) {
token, err := ctx.Peek(0)
if err != nil {
return nil, err
}
equal := false // nolint: ineffassign
if ctx.caseInsensitive[token.Type] {
equal = strings.EqualFold(token.Value, l.s)
} else {
equal = token.Value == l.s
}
if equal && (l.t == -1 || l.t == token.Type) {
next, err := ctx.Next()
if err != nil {
return nil, err
}
return []reflect.Value{reflect.ValueOf(next.Value)}, nil
}
return nil, nil
}
// Attempt to transform values to given type.
//
// This will dereference pointers, and attempt to parse strings into integer values, floats, etc.
func conform(t reflect.Type, values []reflect.Value) (out []reflect.Value, err error) {
for _, v := range values {
for t != v.Type() && t.Kind() == reflect.Ptr && v.Kind() != reflect.Ptr {
// This can occur during partial failure.
if !v.CanAddr() {
return
}
v = v.Addr()
}
// Already of the right kind, don't bother converting.
if v.Kind() == t.Kind() {
out = append(out, v)
continue
}
kind := t.Kind()
switch kind {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
n, err := strconv.ParseInt(v.String(), 0, sizeOfKind(kind))
if err != nil {
return nil, fmt.Errorf("invalid integer %q: %s", v.String(), err)
}
v = reflect.New(t).Elem()
v.SetInt(n)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
n, err := strconv.ParseUint(v.String(), 0, sizeOfKind(kind))
if err != nil {
return nil, fmt.Errorf("invalid integer %q: %s", v.String(), err)
}
v = reflect.New(t).Elem()
v.SetUint(n)
case reflect.Bool:
v = reflect.ValueOf(true)
case reflect.Float32, reflect.Float64:
n, err := strconv.ParseFloat(v.String(), sizeOfKind(kind))
if err != nil {
return nil, fmt.Errorf("invalid integer %q: %s", v.String(), err)
}
v = reflect.New(t).Elem()
v.SetFloat(n)
}
out = append(out, v)
}
return out, nil
}
func sizeOfKind(kind reflect.Kind) int {
switch kind {
case reflect.Int8, reflect.Uint8:
return 8
case reflect.Int16, reflect.Uint16:
return 16
case reflect.Int32, reflect.Uint32, reflect.Float32:
return 32
case reflect.Int64, reflect.Uint64, reflect.Float64:
return 64
case reflect.Int, reflect.Uint:
return strconv.IntSize
}
panic("unsupported kind " + kind.String())
}
// Set field.
//
// If field is a pointer the pointer will be set to the value. If field is a string, value will be
// appended. If field is a slice, value will be appended to slice.
//
// For all other types, an attempt will be made to convert the string to the corresponding
// type (int, float32, etc.).
func setField(pos lexer.Position, strct reflect.Value, field structLexerField, fieldValue []reflect.Value) (err error) { // nolint: gocyclo
defer decorate(&err, func() string { return pos.String() + ": " + strct.Type().String() + "." + field.Name })
f := strct.FieldByIndex(field.Index)
switch f.Kind() {
case reflect.Slice:
fieldValue, err = conform(f.Type().Elem(), fieldValue)
if err != nil {
return err
}
f.Set(reflect.Append(f, fieldValue...))
return nil
case reflect.Ptr:
if f.IsNil() {
fv := reflect.New(f.Type().Elem()).Elem()
f.Set(fv.Addr())
f = fv
} else {
f = f.Elem()
}
}
if f.Kind() == reflect.Struct {
if pf := f.FieldByName("Pos"); pf.IsValid() && pf.Type() == positionType {
pf.Set(reflect.ValueOf(pos))
}
}
if f.CanAddr() {
if d, ok := f.Addr().Interface().(Capture); ok {
ifv := []string{}
for _, v := range fieldValue {
ifv = append(ifv, v.Interface().(string))
}
err := d.Capture(ifv)
if err != nil {
return err
}
return nil
}
}
// Strings concatenate all captured tokens.
if f.Kind() == reflect.String {
fieldValue, err = conform(f.Type(), fieldValue)
if err != nil {
return err
}
for _, v := range fieldValue {
f.Set(reflect.ValueOf(f.String() + v.String()).Convert(f.Type()))
}
return nil
}
// Coalesce multiple tokens into one. This allows eg. ["-", "10"] to be captured as separate tokens but
// parsed as a single string "-10".
if len(fieldValue) > 1 {
out := []string{}
for _, v := range fieldValue {
out = append(out, v.String())
}
fieldValue = []reflect.Value{reflect.ValueOf(strings.Join(out, ""))}
}
fieldValue, err = conform(f.Type(), fieldValue)
if err != nil {
return err
}
fv := fieldValue[0]
switch f.Kind() {
// Numeric types will increment if the token can not be coerced.
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if fv.Type() != f.Type() {
f.SetInt(f.Int() + 1)
} else {
f.Set(fv)
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if fv.Type() != f.Type() {
f.SetUint(f.Uint() + 1)
} else {
f.Set(fv)
}
case reflect.Float32, reflect.Float64:
if fv.Type() != f.Type() {
f.SetFloat(f.Float() + 1)
} else {
f.Set(fv)
}
case reflect.Bool, reflect.Struct:
if fv.Type() != f.Type() {
return fmt.Errorf("value %q is not correct type %s", fv, f.Type())
}
f.Set(fv)
default:
return fmt.Errorf("unsupported field type %s for field %s", f.Type(), field.Name)
}
return nil
}
// Error is an error returned by the parser internally to differentiate from non-Participle errors.
type Error string
func (e Error) Error() string { return string(e) }

39
vendor/github.com/alecthomas/participle/options.go generated vendored Normal file
View File

@ -0,0 +1,39 @@
package participle
import (
"github.com/alecthomas/participle/lexer"
)
// An Option to modify the behaviour of the Parser.
type Option func(p *Parser) error
// Lexer is an Option that sets the lexer to use with the given grammar.
func Lexer(def lexer.Definition) Option {
return func(p *Parser) error {
p.lex = def
return nil
}
}
// UseLookahead allows branch lookahead up to "n" tokens.
//
// If parsing cannot be disambiguated before "n" tokens of lookahead, parsing will fail.
//
// Note that increasing lookahead has a minor performance impact, but also
// reduces the accuracy of error reporting.
func UseLookahead(n int) Option {
return func(p *Parser) error {
p.useLookahead = n
return nil
}
}
// CaseInsensitive allows the specified token types to be matched case-insensitively.
func CaseInsensitive(tokens ...string) Option {
return func(p *Parser) error {
for _, token := range tokens {
p.caseInsensitive[token] = true
}
return nil
}
}

229
vendor/github.com/alecthomas/participle/parser.go generated vendored Normal file
View File

@ -0,0 +1,229 @@
package participle
import (
"bytes"
"fmt"
"io"
"reflect"
"strings"
"github.com/alecthomas/participle/lexer"
)
// A Parser for a particular grammar and lexer.
type Parser struct {
root node
lex lexer.Definition
typ reflect.Type
useLookahead int
caseInsensitive map[string]bool
mappers []mapperByToken
}
// MustBuild calls Build(grammar, options...) and panics if an error occurs.
func MustBuild(grammar interface{}, options ...Option) *Parser {
parser, err := Build(grammar, options...)
if err != nil {
panic(err)
}
return parser
}
// Build constructs a parser for the given grammar.
//
// If "Lexer()" is not provided as an option, a default lexer based on text/scanner will be used. This scans typical Go-
// like tokens.
//
// See documentation for details
func Build(grammar interface{}, options ...Option) (parser *Parser, err error) {
// Configure Parser struct with defaults + options.
p := &Parser{
lex: lexer.TextScannerLexer,
caseInsensitive: map[string]bool{},
useLookahead: 1,
}
for _, option := range options {
if option == nil {
return nil, fmt.Errorf("nil Option passed, signature has changed; " +
"if you intended to provide a custom Lexer, try participle.Build(grammar, participle.Lexer(lexer))")
}
if err = option(p); err != nil {
return nil, err
}
}
if len(p.mappers) > 0 {
mappers := map[rune][]Mapper{}
symbols := p.lex.Symbols()
for _, mapper := range p.mappers {
if len(mapper.symbols) == 0 {
mappers[lexer.EOF] = append(mappers[lexer.EOF], mapper.mapper)
} else {
for _, symbol := range mapper.symbols {
if rn, ok := symbols[symbol]; !ok {
return nil, fmt.Errorf("mapper %#v uses unknown token %q", mapper, symbol)
} else { // nolint: golint
mappers[rn] = append(mappers[rn], mapper.mapper)
}
}
}
}
p.lex = &mappingLexerDef{p.lex, func(t lexer.Token) (lexer.Token, error) {
combined := make([]Mapper, 0, len(mappers[t.Type])+len(mappers[lexer.EOF]))
combined = append(combined, mappers[lexer.EOF]...)
combined = append(combined, mappers[t.Type]...)
var err error
for _, m := range combined {
t, err = m(t)
if err != nil {
return t, err
}
}
return t, nil
}}
}
context := newGeneratorContext(p.lex)
v := reflect.ValueOf(grammar)
if v.Kind() == reflect.Interface {
v = v.Elem()
}
p.typ = v.Type()
p.root, err = context.parseType(p.typ)
if err != nil {
return nil, err
}
return p, nil
}
// Lex uses the parser's lexer to tokenise input.
func (p *Parser) Lex(r io.Reader) ([]lexer.Token, error) {
lex, err := p.lex.Lex(r)
if err != nil {
return nil, err
}
return lexer.ConsumeAll(lex)
}
// Parse from r into grammar v which must be of the same type as the grammar passed to
// participle.Build().
func (p *Parser) Parse(r io.Reader, v interface{}) (err error) {
rv := reflect.ValueOf(v)
if rv.Kind() == reflect.Interface {
rv = rv.Elem()
}
var stream reflect.Value
if rv.Kind() == reflect.Chan {
stream = rv
rt := rv.Type().Elem()
rv = reflect.New(rt).Elem()
}
rt := rv.Type()
if rt != p.typ {
return fmt.Errorf("must parse into value of type %s not %T", p.typ, v)
}
baseLexer, err := p.lex.Lex(r)
if err != nil {
return err
}
lex := lexer.Upgrade(baseLexer)
caseInsensitive := map[rune]bool{}
for sym, rn := range p.lex.Symbols() {
if p.caseInsensitive[sym] {
caseInsensitive[rn] = true
}
}
ctx, err := newParseContext(lex, p.useLookahead, caseInsensitive)
if err != nil {
return err
}
// If the grammar implements Parseable, use it.
if parseable, ok := v.(Parseable); ok {
return p.rootParseable(ctx, parseable)
}
if rt.Kind() != reflect.Ptr || rt.Elem().Kind() != reflect.Struct {
return fmt.Errorf("target must be a pointer to a struct, not %s", rt)
}
if stream.IsValid() {
return p.parseStreaming(ctx, stream)
}
return p.parseOne(ctx, rv)
}
func (p *Parser) parseStreaming(ctx *parseContext, rv reflect.Value) error {
t := rv.Type().Elem().Elem()
for {
if token, _ := ctx.Peek(0); token.EOF() {
rv.Close()
return nil
}
v := reflect.New(t)
if err := p.parseInto(ctx, v); err != nil {
return err
}
rv.Send(v)
}
}
func (p *Parser) parseOne(ctx *parseContext, rv reflect.Value) error {
err := p.parseInto(ctx, rv)
if err != nil {
return err
}
token, err := ctx.Peek(0)
if err != nil {
return err
} else if !token.EOF() {
return lexer.Errorf(token.Pos, "unexpected trailing token %q", token)
}
return nil
}
func (p *Parser) parseInto(ctx *parseContext, rv reflect.Value) error {
if rv.IsNil() {
return fmt.Errorf("target must be a non-nil pointer to a struct, but is a nil %s", rv.Type())
}
pv, err := p.root.Parse(ctx, rv.Elem())
if len(pv) > 0 && pv[0].Type() == rv.Elem().Type() {
rv.Elem().Set(reflect.Indirect(pv[0]))
}
if err != nil {
return err
}
if pv == nil {
token, _ := ctx.Peek(0)
return lexer.Errorf(token.Pos, "invalid syntax")
}
return nil
}
func (p *Parser) rootParseable(lex lexer.PeekingLexer, parseable Parseable) error {
peek, err := lex.Peek(0)
if err != nil {
return err
}
err = parseable.Parse(lex)
if err == NextMatch {
return lexer.Errorf(peek.Pos, "invalid syntax")
}
if err == nil && !peek.EOF() {
return lexer.Errorf(peek.Pos, "unexpected token %q", peek)
}
return err
}
// ParseString is a convenience around Parse().
func (p *Parser) ParseString(s string, v interface{}) error {
return p.Parse(strings.NewReader(s), v)
}
// ParseBytes is a convenience around Parse().
func (p *Parser) ParseBytes(b []byte, v interface{}) error {
return p.Parse(bytes.NewReader(b), v)
}
// String representation of the grammar.
func (p *Parser) String() string {
return stringern(p.root, 128)
}

118
vendor/github.com/alecthomas/participle/stringer.go generated vendored Normal file
View File

@ -0,0 +1,118 @@
package participle
import (
"bytes"
"fmt"
"strings"
"github.com/alecthomas/participle/lexer"
)
type stringerVisitor struct {
bytes.Buffer
seen map[node]bool
}
func stringern(n node, depth int) string {
v := &stringerVisitor{seen: map[node]bool{}}
v.visit(n, depth, false)
return v.String()
}
func stringer(n node) string {
return stringern(n, 1)
}
func (s *stringerVisitor) visit(n node, depth int, disjunctions bool) {
if s.seen[n] || depth <= 0 {
fmt.Fprintf(s, "...")
return
}
s.seen[n] = true
switch n := n.(type) {
case *disjunction:
for i, c := range n.nodes {
if i > 0 {
fmt.Fprint(s, " | ")
}
s.visit(c, depth, disjunctions || len(n.nodes) > 1)
}
case *strct:
s.visit(n.expr, depth, disjunctions)
case *sequence:
c := n
for i := 0; c != nil && depth-i > 0; c, i = c.next, i+1 {
if c != n {
fmt.Fprint(s, " ")
}
s.visit(c.node, depth-i, disjunctions)
}
if c != nil {
fmt.Fprint(s, " ...")
}
case *parseable:
fmt.Fprintf(s, "<%s>", strings.ToLower(n.t.Name()))
case *capture:
if _, ok := n.node.(*parseable); ok {
fmt.Fprintf(s, "<%s>", strings.ToLower(n.field.Name))
} else {
if n.node == nil {
fmt.Fprintf(s, "<%s>", strings.ToLower(n.field.Name))
} else {
s.visit(n.node, depth, disjunctions)
}
}
case *reference:
fmt.Fprintf(s, "<%s>", strings.ToLower(n.identifier))
case *optional:
fmt.Fprint(s, "[ ")
s.visit(n.node, depth, disjunctions)
fmt.Fprint(s, " ]")
case *repetition:
fmt.Fprint(s, "{ ")
s.visit(n.node, depth, disjunctions)
fmt.Fprint(s, " }")
case *literal:
fmt.Fprintf(s, "%q", n.s)
if n.t != lexer.EOF && n.s == "" {
fmt.Fprintf(s, ":%s", n.tt)
}
case *group:
fmt.Fprint(s, "(")
if child, ok := n.expr.(*group); ok && child.mode == groupMatchOnce {
s.visit(child.expr, depth, disjunctions)
} else if child, ok := n.expr.(*capture); ok {
if grandchild, ok := child.node.(*group); ok && grandchild.mode == groupMatchOnce {
s.visit(grandchild.expr, depth, disjunctions)
} else {
s.visit(n.expr, depth, disjunctions)
}
} else {
s.visit(n.expr, depth, disjunctions)
}
fmt.Fprint(s, ")")
switch n.mode {
case groupMatchNonEmpty:
fmt.Fprintf(s, "!")
case groupMatchZeroOrOne:
fmt.Fprintf(s, "?")
case groupMatchZeroOrMore:
fmt.Fprintf(s, "*")
case groupMatchOneOrMore:
fmt.Fprintf(s, "+")
}
default:
panic("unsupported")
}
}

126
vendor/github.com/alecthomas/participle/struct.go generated vendored Normal file
View File

@ -0,0 +1,126 @@
package participle
import (
"fmt"
"reflect"
"github.com/alecthomas/participle/lexer"
)
// A structLexer lexes over the tags of struct fields while tracking the current field.
type structLexer struct {
s reflect.Type
field int
indexes [][]int
lexer lexer.PeekingLexer
}
func lexStruct(s reflect.Type) (*structLexer, error) {
indexes, err := collectFieldIndexes(s)
if err != nil {
return nil, err
}
slex := &structLexer{
s: s,
indexes: indexes,
}
if len(slex.indexes) > 0 {
tag := fieldLexerTag(slex.Field().StructField)
slex.lexer = lexer.Upgrade(lexer.LexString(tag))
}
return slex, nil
}
// NumField returns the number of fields in the struct associated with this structLexer.
func (s *structLexer) NumField() int {
return len(s.indexes)
}
type structLexerField struct {
reflect.StructField
Index []int
}
// Field returns the field associated with the current token.
func (s *structLexer) Field() structLexerField {
return s.GetField(s.field)
}
func (s *structLexer) GetField(field int) structLexerField {
if field >= len(s.indexes) {
field = len(s.indexes) - 1
}
return structLexerField{
StructField: s.s.FieldByIndex(s.indexes[field]),
Index: s.indexes[field],
}
}
func (s *structLexer) Peek() (lexer.Token, error) {
field := s.field
lex := s.lexer
for {
token, err := lex.Peek(0)
if err != nil {
return token, err
}
if !token.EOF() {
token.Pos.Line = field + 1
return token, nil
}
field++
if field >= s.NumField() {
return lexer.EOFToken(token.Pos), nil
}
tag := fieldLexerTag(s.GetField(field).StructField)
lex = lexer.Upgrade(lexer.LexString(tag))
}
}
func (s *structLexer) Next() (lexer.Token, error) {
token, err := s.lexer.Next()
if err != nil {
return token, err
}
if !token.EOF() {
token.Pos.Line = s.field + 1
return token, nil
}
if s.field+1 >= s.NumField() {
return lexer.EOFToken(token.Pos), nil
}
s.field++
tag := fieldLexerTag(s.Field().StructField)
s.lexer = lexer.Upgrade(lexer.LexString(tag))
return s.Next()
}
func fieldLexerTag(field reflect.StructField) string {
if tag, ok := field.Tag.Lookup("parser"); ok {
return tag
}
return string(field.Tag)
}
// Recursively collect flattened indices for top-level fields and embedded fields.
func collectFieldIndexes(s reflect.Type) (out [][]int, err error) {
if s.Kind() != reflect.Struct {
return nil, fmt.Errorf("expected a struct but got %q", s)
}
defer decorate(&err, s.String)
for i := 0; i < s.NumField(); i++ {
f := s.Field(i)
if f.Anonymous {
children, err := collectFieldIndexes(f.Type)
if err != nil {
return nil, err
}
for _, idx := range children {
out = append(out, append(f.Index, idx...))
}
} else if fieldLexerTag(f) != "" {
out = append(out, f.Index)
}
}
return
}

View File

@ -88,6 +88,7 @@ type File struct {
rowGroups []*parquet.RowGroup
rowGroupIndex int
nameList []string
columnNames set.StringSet
columns map[string]*column
rowIndex int64
@ -100,16 +101,23 @@ func Open(getReaderFunc GetReaderFunc, columnNames set.StringSet) (*File, error)
return nil, err
}
nameList := []string{}
schemaElements := fileMeta.GetSchema()
for _, element := range schemaElements {
nameList = append(nameList, element.Name)
}
return &File{
getReaderFunc: getReaderFunc,
rowGroups: fileMeta.GetRowGroups(),
schemaElements: fileMeta.GetSchema(),
schemaElements: schemaElements,
nameList: nameList,
columnNames: columnNames,
}, nil
}
// Read - reads single record.
func (file *File) Read() (record map[string]Value, err error) {
func (file *File) Read() (record *Record, err error) {
if file.rowGroupIndex >= len(file.rowGroups) {
return nil, io.EOF
}
@ -134,10 +142,10 @@ func (file *File) Read() (record map[string]Value, err error) {
return file.Read()
}
record = make(map[string]Value)
record = newRecord(file.nameList)
for name := range file.columns {
value, valueType := file.columns[name].read()
record[name] = Value{value, valueType}
record.set(name, Value{value, valueType})
}
file.rowIndex++

70
vendor/github.com/minio/parquet-go/record.go generated vendored Normal file
View File

@ -0,0 +1,70 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package parquet
import (
"fmt"
"strings"
)
// Record - ordered parquet record.
type Record struct {
nameList []string
nameValueMap map[string]Value
}
// String - returns string representation of this record.
func (r *Record) String() string {
values := []string{}
r.Range(func(name string, value Value) bool {
values = append(values, fmt.Sprintf("%v:%v", name, value))
return true
})
return "map[" + strings.Join(values, " ") + "]"
}
func (r *Record) set(name string, value Value) {
r.nameValueMap[name] = value
}
// Get - returns Value of name.
func (r *Record) Get(name string) (Value, bool) {
value, ok := r.nameValueMap[name]
return value, ok
}
// Range - calls f sequentially for each name and value present in the record. If f returns false, range stops the iteration.
func (r *Record) Range(f func(name string, value Value) bool) {
for _, name := range r.nameList {
value, ok := r.nameValueMap[name]
if !ok {
continue
}
if !f(name, value) {
break
}
}
}
func newRecord(nameList []string) *Record {
return &Record{
nameList: nameList,
nameValueMap: make(map[string]Value),
}
}

View File

@ -1,9 +0,0 @@
This project is originally a fork of [https://github.com/youtube/vitess](https://github.com/youtube/vitess)
Copyright Google Inc
# Contributors
Wenbin Xiao 2015
Started this project and maintained it.
Andrew Brampton 2017
Merged in multiple upstream fixes/changes.

View File

@ -1,201 +0,0 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@ -1,22 +0,0 @@
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
MAKEFLAGS = -s
sql.go: sql.y
goyacc -o sql.go sql.y
gofmt -w sql.go
clean:
rm -f y.output sql.go

View File

@ -1,150 +0,0 @@
# sqlparser [![Build Status](https://img.shields.io/travis/xwb1989/sqlparser.svg)](https://travis-ci.org/xwb1989/sqlparser) [![Coverage](https://img.shields.io/coveralls/xwb1989/sqlparser.svg)](https://coveralls.io/github/xwb1989/sqlparser) [![Report card](https://goreportcard.com/badge/github.com/xwb1989/sqlparser)](https://goreportcard.com/report/github.com/xwb1989/sqlparser) [![GoDoc](https://godoc.org/github.com/xwb1989/sqlparser?status.svg)](https://godoc.org/github.com/xwb1989/sqlparser)
Go package for parsing MySQL SQL queries.
## Notice
The backbone of this repo is extracted from [vitessio/vitess](https://github.com/vitessio/vitess).
Inside vitessio/vitess there is a very nicely written sql parser. However as it's not a self-contained application, I created this one.
It applies the same LICENSE as vitessio/vitess.
## Usage
```go
import (
"github.com/xwb1989/sqlparser"
)
```
Then use:
```go
sql := "SELECT * FROM table WHERE a = 'abc'"
stmt, err := sqlparser.Parse(sql)
if err != nil {
// Do something with the err
}
// Otherwise do something with stmt
switch stmt := stmt.(type) {
case *sqlparser.Select:
_ = stmt
case *sqlparser.Insert:
}
```
Alternative to read many queries from a io.Reader:
```go
r := strings.NewReader("INSERT INTO table1 VALUES (1, 'a'); INSERT INTO table2 VALUES (3, 4);")
tokens := sqlparser.NewTokenizer(r)
for {
stmt, err := sqlparser.ParseNext(tokens)
if err == io.EOF {
break
}
// Do something with stmt or err.
}
```
See [parse_test.go](https://github.com/xwb1989/sqlparser/blob/master/parse_test.go) for more examples, or read the [godoc](https://godoc.org/github.com/xwb1989/sqlparser).
## Porting Instructions
You only need the below if you plan to try and keep this library up to date with [vitessio/vitess](https://github.com/vitessio/vitess).
### Keeping up to date
```bash
shopt -s nullglob
VITESS=${GOPATH?}/src/vitess.io/vitess/go/
XWB1989=${GOPATH?}/src/github.com/xwb1989/sqlparser/
# Create patches for everything that changed
LASTIMPORT=1b7879cb91f1dfe1a2dfa06fea96e951e3a7aec5
for path in ${VITESS?}/{vt/sqlparser,sqltypes,bytes2,hack}; do
cd ${path}
git format-patch ${LASTIMPORT?} .
done;
# Apply patches to the dependencies
cd ${XWB1989?}
git am --directory dependency -p2 ${VITESS?}/{sqltypes,bytes2,hack}/*.patch
# Apply the main patches to the repo
cd ${XWB1989?}
git am -p4 ${VITESS?}/vt/sqlparser/*.patch
# If you encounter diff failures, manually fix them with
patch -p4 < .git/rebase-apply/patch
...
git add name_of_files
git am --continue
# Cleanup
rm ${VITESS?}/{sqltypes,bytes2,hack}/*.patch ${VITESS?}/*.patch
# and Finally update the LASTIMPORT in this README.
```
### Fresh install
TODO: Change these instructions to use git to copy the files, that'll make later patching easier.
```bash
VITESS=${GOPATH?}/src/vitess.io/vitess/go/
XWB1989=${GOPATH?}/src/github.com/xwb1989/sqlparser/
cd ${XWB1989?}
# Copy all the code
cp -pr ${VITESS?}/vt/sqlparser/ .
cp -pr ${VITESS?}/sqltypes dependency
cp -pr ${VITESS?}/bytes2 dependency
cp -pr ${VITESS?}/hack dependency
# Delete some code we haven't ported
rm dependency/sqltypes/arithmetic.go dependency/sqltypes/arithmetic_test.go dependency/sqltypes/event_token.go dependency/sqltypes/event_token_test.go dependency/sqltypes/proto3.go dependency/sqltypes/proto3_test.go dependency/sqltypes/query_response.go dependency/sqltypes/result.go dependency/sqltypes/result_test.go
# Some automated fixes
# Fix imports
sed -i '.bak' 's_vitess.io/vitess/go/vt/proto/query_github.com/xwb1989/sqlparser/dependency/querypb_g' *.go dependency/sqltypes/*.go
sed -i '.bak' 's_vitess.io/vitess/go/_github.com/xwb1989/sqlparser/dependency/_g' *.go dependency/sqltypes/*.go
# Copy the proto, but basically drop everything we don't want
cp -pr ${VITESS?}/vt/proto/query dependency/querypb
sed -i '.bak' 's_.*Descriptor.*__g' dependency/querypb/*.go
sed -i '.bak' 's_.*ProtoMessage.*__g' dependency/querypb/*.go
sed -i '.bak' 's/proto.CompactTextString(m)/"TODO"/g' dependency/querypb/*.go
sed -i '.bak' 's/proto.EnumName/EnumName/g' dependency/querypb/*.go
sed -i '.bak' 's/proto.Equal/reflect.DeepEqual/g' dependency/sqltypes/*.go
# Remove the error library
sed -i '.bak' 's/vterrors.Errorf([^,]*, /fmt.Errorf(/g' *.go dependency/sqltypes/*.go
sed -i '.bak' 's/vterrors.New([^,]*, /errors.New(/g' *.go dependency/sqltypes/*.go
```
### Testing
```bash
VITESS=${GOPATH?}/src/vitess.io/vitess/go/
XWB1989=${GOPATH?}/src/github.com/xwb1989/sqlparser/
cd ${XWB1989?}
# Test, fix and repeat
go test ./...
# Finally make some diffs (for later reference)
diff -u ${VITESS?}/sqltypes/ ${XWB1989?}/dependency/sqltypes/ > ${XWB1989?}/patches/sqltypes.patch
diff -u ${VITESS?}/bytes2/ ${XWB1989?}/dependency/bytes2/ > ${XWB1989?}/patches/bytes2.patch
diff -u ${VITESS?}/vt/proto/query/ ${XWB1989?}/dependency/querypb/ > ${XWB1989?}/patches/querypb.patch
diff -u ${VITESS?}/vt/sqlparser/ ${XWB1989?}/ > ${XWB1989?}/patches/sqlparser.patch
```

View File

@ -1,343 +0,0 @@
/*
Copyright 2017 Google Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package sqlparser
// analyzer.go contains utility analysis functions.
import (
"errors"
"fmt"
"strconv"
"strings"
"unicode"
"github.com/xwb1989/sqlparser/dependency/sqltypes"
)
// These constants are used to identify the SQL statement type.
const (
StmtSelect = iota
StmtStream
StmtInsert
StmtReplace
StmtUpdate
StmtDelete
StmtDDL
StmtBegin
StmtCommit
StmtRollback
StmtSet
StmtShow
StmtUse
StmtOther
StmtUnknown
StmtComment
)
// Preview analyzes the beginning of the query using a simpler and faster
// textual comparison to identify the statement type.
func Preview(sql string) int {
trimmed := StripLeadingComments(sql)
firstWord := trimmed
if end := strings.IndexFunc(trimmed, unicode.IsSpace); end != -1 {
firstWord = trimmed[:end]
}
firstWord = strings.TrimLeftFunc(firstWord, func(r rune) bool { return !unicode.IsLetter(r) })
// Comparison is done in order of priority.
loweredFirstWord := strings.ToLower(firstWord)
switch loweredFirstWord {
case "select":
return StmtSelect
case "stream":
return StmtStream
case "insert":
return StmtInsert
case "replace":
return StmtReplace
case "update":
return StmtUpdate
case "delete":
return StmtDelete
}
// For the following statements it is not sufficient to rely
// on loweredFirstWord. This is because they are not statements
// in the grammar and we are relying on Preview to parse them.
// For instance, we don't want: "BEGIN JUNK" to be parsed
// as StmtBegin.
trimmedNoComments, _ := SplitMarginComments(trimmed)
switch strings.ToLower(trimmedNoComments) {
case "begin", "start transaction":
return StmtBegin
case "commit":
return StmtCommit
case "rollback":
return StmtRollback
}
switch loweredFirstWord {
case "create", "alter", "rename", "drop", "truncate":
return StmtDDL
case "set":
return StmtSet
case "show":
return StmtShow
case "use":
return StmtUse
case "analyze", "describe", "desc", "explain", "repair", "optimize":
return StmtOther
}
if strings.Index(trimmed, "/*!") == 0 {
return StmtComment
}
return StmtUnknown
}
// StmtType returns the statement type as a string
func StmtType(stmtType int) string {
switch stmtType {
case StmtSelect:
return "SELECT"
case StmtStream:
return "STREAM"
case StmtInsert:
return "INSERT"
case StmtReplace:
return "REPLACE"
case StmtUpdate:
return "UPDATE"
case StmtDelete:
return "DELETE"
case StmtDDL:
return "DDL"
case StmtBegin:
return "BEGIN"
case StmtCommit:
return "COMMIT"
case StmtRollback:
return "ROLLBACK"
case StmtSet:
return "SET"
case StmtShow:
return "SHOW"
case StmtUse:
return "USE"
case StmtOther:
return "OTHER"
default:
return "UNKNOWN"
}
}
// IsDML returns true if the query is an INSERT, UPDATE or DELETE statement.
func IsDML(sql string) bool {
switch Preview(sql) {
case StmtInsert, StmtReplace, StmtUpdate, StmtDelete:
return true
}
return false
}
// GetTableName returns the table name from the SimpleTableExpr
// only if it's a simple expression. Otherwise, it returns "".
func GetTableName(node SimpleTableExpr) TableIdent {
if n, ok := node.(TableName); ok && n.Qualifier.IsEmpty() {
return n.Name
}
// sub-select or '.' expression
return NewTableIdent("")
}
// IsColName returns true if the Expr is a *ColName.
func IsColName(node Expr) bool {
_, ok := node.(*ColName)
return ok
}
// IsValue returns true if the Expr is a string, integral or value arg.
// NULL is not considered to be a value.
func IsValue(node Expr) bool {
switch v := node.(type) {
case *SQLVal:
switch v.Type {
case StrVal, HexVal, IntVal, ValArg:
return true
}
}
return false
}
// IsNull returns true if the Expr is SQL NULL
func IsNull(node Expr) bool {
switch node.(type) {
case *NullVal:
return true
}
return false
}
// IsSimpleTuple returns true if the Expr is a ValTuple that
// contains simple values or if it's a list arg.
func IsSimpleTuple(node Expr) bool {
switch vals := node.(type) {
case ValTuple:
for _, n := range vals {
if !IsValue(n) {
return false
}
}
return true
case ListArg:
return true
}
// It's a subquery
return false
}
// NewPlanValue builds a sqltypes.PlanValue from an Expr.
func NewPlanValue(node Expr) (sqltypes.PlanValue, error) {
switch node := node.(type) {
case *SQLVal:
switch node.Type {
case ValArg:
return sqltypes.PlanValue{Key: string(node.Val[1:])}, nil
case IntVal:
n, err := sqltypes.NewIntegral(string(node.Val))
if err != nil {
return sqltypes.PlanValue{}, fmt.Errorf("%v", err)
}
return sqltypes.PlanValue{Value: n}, nil
case StrVal:
return sqltypes.PlanValue{Value: sqltypes.MakeTrusted(sqltypes.VarBinary, node.Val)}, nil
case HexVal:
v, err := node.HexDecode()
if err != nil {
return sqltypes.PlanValue{}, fmt.Errorf("%v", err)
}
return sqltypes.PlanValue{Value: sqltypes.MakeTrusted(sqltypes.VarBinary, v)}, nil
}
case ListArg:
return sqltypes.PlanValue{ListKey: string(node[2:])}, nil
case ValTuple:
pv := sqltypes.PlanValue{
Values: make([]sqltypes.PlanValue, 0, len(node)),
}
for _, val := range node {
innerpv, err := NewPlanValue(val)
if err != nil {
return sqltypes.PlanValue{}, err
}
if innerpv.ListKey != "" || innerpv.Values != nil {
return sqltypes.PlanValue{}, errors.New("unsupported: nested lists")
}
pv.Values = append(pv.Values, innerpv)
}
return pv, nil
case *NullVal:
return sqltypes.PlanValue{}, nil
}
return sqltypes.PlanValue{}, fmt.Errorf("expression is too complex '%v'", String(node))
}
// StringIn is a convenience function that returns
// true if str matches any of the values.
func StringIn(str string, values ...string) bool {
for _, val := range values {
if str == val {
return true
}
}
return false
}
// SetKey is the extracted key from one SetExpr
type SetKey struct {
Key string
Scope string
}
// ExtractSetValues returns a map of key-value pairs
// if the query is a SET statement. Values can be bool, int64 or string.
// Since set variable names are case insensitive, all keys are returned
// as lower case.
func ExtractSetValues(sql string) (keyValues map[SetKey]interface{}, scope string, err error) {
stmt, err := Parse(sql)
if err != nil {
return nil, "", err
}
setStmt, ok := stmt.(*Set)
if !ok {
return nil, "", fmt.Errorf("ast did not yield *sqlparser.Set: %T", stmt)
}
result := make(map[SetKey]interface{})
for _, expr := range setStmt.Exprs {
scope := SessionStr
key := expr.Name.Lowered()
switch {
case strings.HasPrefix(key, "@@global."):
scope = GlobalStr
key = strings.TrimPrefix(key, "@@global.")
case strings.HasPrefix(key, "@@session."):
key = strings.TrimPrefix(key, "@@session.")
case strings.HasPrefix(key, "@@"):
key = strings.TrimPrefix(key, "@@")
}
if strings.HasPrefix(expr.Name.Lowered(), "@@") {
if setStmt.Scope != "" && scope != "" {
return nil, "", fmt.Errorf("unsupported in set: mixed using of variable scope")
}
_, out := NewStringTokenizer(key).Scan()
key = string(out)
}
setKey := SetKey{
Key: key,
Scope: scope,
}
switch expr := expr.Expr.(type) {
case *SQLVal:
switch expr.Type {
case StrVal:
result[setKey] = strings.ToLower(string(expr.Val))
case IntVal:
num, err := strconv.ParseInt(string(expr.Val), 0, 64)
if err != nil {
return nil, "", err
}
result[setKey] = num
default:
return nil, "", fmt.Errorf("invalid value type: %v", String(expr))
}
case BoolVal:
var val int64
if expr {
val = 1
}
result[setKey] = val
case *ColName:
result[setKey] = expr.Name.String()
case *NullVal:
result[setKey] = nil
case *Default:
result[setKey] = "default"
default:
return nil, "", fmt.Errorf("invalid syntax: %s", String(expr))
}
}
return result, strings.ToLower(setStmt.Scope), nil
}

File diff suppressed because it is too large Load Diff

View File

@ -1,293 +0,0 @@
/*
Copyright 2017 Google Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package sqlparser
import (
"strconv"
"strings"
"unicode"
)
const (
// DirectiveMultiShardAutocommit is the query comment directive to allow
// single round trip autocommit with a multi-shard statement.
DirectiveMultiShardAutocommit = "MULTI_SHARD_AUTOCOMMIT"
// DirectiveSkipQueryPlanCache skips query plan cache when set.
DirectiveSkipQueryPlanCache = "SKIP_QUERY_PLAN_CACHE"
// DirectiveQueryTimeout sets a query timeout in vtgate. Only supported for SELECTS.
DirectiveQueryTimeout = "QUERY_TIMEOUT_MS"
)
func isNonSpace(r rune) bool {
return !unicode.IsSpace(r)
}
// leadingCommentEnd returns the first index after all leading comments, or
// 0 if there are no leading comments.
func leadingCommentEnd(text string) (end int) {
hasComment := false
pos := 0
for pos < len(text) {
// Eat up any whitespace. Trailing whitespace will be considered part of
// the leading comments.
nextVisibleOffset := strings.IndexFunc(text[pos:], isNonSpace)
if nextVisibleOffset < 0 {
break
}
pos += nextVisibleOffset
remainingText := text[pos:]
// Found visible characters. Look for '/*' at the beginning
// and '*/' somewhere after that.
if len(remainingText) < 4 || remainingText[:2] != "/*" {
break
}
commentLength := 4 + strings.Index(remainingText[2:], "*/")
if commentLength < 4 {
// Missing end comment :/
break
}
hasComment = true
pos += commentLength
}
if hasComment {
return pos
}
return 0
}
// trailingCommentStart returns the first index of trailing comments.
// If there are no trailing comments, returns the length of the input string.
func trailingCommentStart(text string) (start int) {
hasComment := false
reducedLen := len(text)
for reducedLen > 0 {
// Eat up any whitespace. Leading whitespace will be considered part of
// the trailing comments.
nextReducedLen := strings.LastIndexFunc(text[:reducedLen], isNonSpace) + 1
if nextReducedLen == 0 {
break
}
reducedLen = nextReducedLen
if reducedLen < 4 || text[reducedLen-2:reducedLen] != "*/" {
break
}
// Find the beginning of the comment
startCommentPos := strings.LastIndex(text[:reducedLen-2], "/*")
if startCommentPos < 0 {
// Badly formatted sql :/
break
}
hasComment = true
reducedLen = startCommentPos
}
if hasComment {
return reducedLen
}
return len(text)
}
// MarginComments holds the leading and trailing comments that surround a query.
type MarginComments struct {
Leading string
Trailing string
}
// SplitMarginComments pulls out any leading or trailing comments from a raw sql query.
// This function also trims leading (if there's a comment) and trailing whitespace.
func SplitMarginComments(sql string) (query string, comments MarginComments) {
trailingStart := trailingCommentStart(sql)
leadingEnd := leadingCommentEnd(sql[:trailingStart])
comments = MarginComments{
Leading: strings.TrimLeftFunc(sql[:leadingEnd], unicode.IsSpace),
Trailing: strings.TrimRightFunc(sql[trailingStart:], unicode.IsSpace),
}
return strings.TrimFunc(sql[leadingEnd:trailingStart], unicode.IsSpace), comments
}
// StripLeadingComments trims the SQL string and removes any leading comments
func StripLeadingComments(sql string) string {
sql = strings.TrimFunc(sql, unicode.IsSpace)
for hasCommentPrefix(sql) {
switch sql[0] {
case '/':
// Multi line comment
index := strings.Index(sql, "*/")
if index <= 1 {
return sql
}
// don't strip /*! ... */ or /*!50700 ... */
if len(sql) > 2 && sql[2] == '!' {
return sql
}
sql = sql[index+2:]
case '-':
// Single line comment
index := strings.Index(sql, "\n")
if index == -1 {
return sql
}
sql = sql[index+1:]
}
sql = strings.TrimFunc(sql, unicode.IsSpace)
}
return sql
}
func hasCommentPrefix(sql string) bool {
return len(sql) > 1 && ((sql[0] == '/' && sql[1] == '*') || (sql[0] == '-' && sql[1] == '-'))
}
// ExtractMysqlComment extracts the version and SQL from a comment-only query
// such as /*!50708 sql here */
func ExtractMysqlComment(sql string) (version string, innerSQL string) {
sql = sql[3 : len(sql)-2]
digitCount := 0
endOfVersionIndex := strings.IndexFunc(sql, func(c rune) bool {
digitCount++
return !unicode.IsDigit(c) || digitCount == 6
})
version = sql[0:endOfVersionIndex]
innerSQL = strings.TrimFunc(sql[endOfVersionIndex:], unicode.IsSpace)
return version, innerSQL
}
const commentDirectivePreamble = "/*vt+"
// CommentDirectives is the parsed representation for execution directives
// conveyed in query comments
type CommentDirectives map[string]interface{}
// ExtractCommentDirectives parses the comment list for any execution directives
// of the form:
//
// /*vt+ OPTION_ONE=1 OPTION_TWO OPTION_THREE=abcd */
//
// It returns the map of the directive values or nil if there aren't any.
func ExtractCommentDirectives(comments Comments) CommentDirectives {
if comments == nil {
return nil
}
var vals map[string]interface{}
for _, comment := range comments {
commentStr := string(comment)
if commentStr[0:5] != commentDirectivePreamble {
continue
}
if vals == nil {
vals = make(map[string]interface{})
}
// Split on whitespace and ignore the first and last directive
// since they contain the comment start/end
directives := strings.Fields(commentStr)
for i := 1; i < len(directives)-1; i++ {
directive := directives[i]
sep := strings.IndexByte(directive, '=')
// No value is equivalent to a true boolean
if sep == -1 {
vals[directive] = true
continue
}
strVal := directive[sep+1:]
directive = directive[:sep]
intVal, err := strconv.Atoi(strVal)
if err == nil {
vals[directive] = intVal
continue
}
boolVal, err := strconv.ParseBool(strVal)
if err == nil {
vals[directive] = boolVal
continue
}
vals[directive] = strVal
}
}
return vals
}
// IsSet checks the directive map for the named directive and returns
// true if the directive is set and has a true/false or 0/1 value
func (d CommentDirectives) IsSet(key string) bool {
if d == nil {
return false
}
val, ok := d[key]
if !ok {
return false
}
boolVal, ok := val.(bool)
if ok {
return boolVal
}
intVal, ok := val.(int)
if ok {
return intVal == 1
}
return false
}
// SkipQueryPlanCacheDirective returns true if skip query plan cache directive is set to true in query.
func SkipQueryPlanCacheDirective(stmt Statement) bool {
switch stmt := stmt.(type) {
case *Select:
directives := ExtractCommentDirectives(stmt.Comments)
if directives.IsSet(DirectiveSkipQueryPlanCache) {
return true
}
case *Insert:
directives := ExtractCommentDirectives(stmt.Comments)
if directives.IsSet(DirectiveSkipQueryPlanCache) {
return true
}
case *Update:
directives := ExtractCommentDirectives(stmt.Comments)
if directives.IsSet(DirectiveSkipQueryPlanCache) {
return true
}
case *Delete:
directives := ExtractCommentDirectives(stmt.Comments)
if directives.IsSet(DirectiveSkipQueryPlanCache) {
return true
}
default:
return false
}
return false
}

View File

@ -1,99 +0,0 @@
/*
Copyright 2017 Google Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package sqlparser
import (
"bytes"
"github.com/xwb1989/sqlparser/dependency/sqltypes"
)
// This file contains types that are 'Encodable'.
// Encodable defines the interface for types that can
// be custom-encoded into SQL.
type Encodable interface {
EncodeSQL(buf *bytes.Buffer)
}
// InsertValues is a custom SQL encoder for the values of
// an insert statement.
type InsertValues [][]sqltypes.Value
// EncodeSQL performs the SQL encoding for InsertValues.
func (iv InsertValues) EncodeSQL(buf *bytes.Buffer) {
for i, rows := range iv {
if i != 0 {
buf.WriteString(", ")
}
buf.WriteByte('(')
for j, bv := range rows {
if j != 0 {
buf.WriteString(", ")
}
bv.EncodeSQL(buf)
}
buf.WriteByte(')')
}
}
// TupleEqualityList is for generating equality constraints
// for tables that have composite primary keys.
type TupleEqualityList struct {
Columns []ColIdent
Rows [][]sqltypes.Value
}
// EncodeSQL generates the where clause constraints for the tuple
// equality.
func (tpl *TupleEqualityList) EncodeSQL(buf *bytes.Buffer) {
if len(tpl.Columns) == 1 {
tpl.encodeAsIn(buf)
return
}
tpl.encodeAsEquality(buf)
}
func (tpl *TupleEqualityList) encodeAsIn(buf *bytes.Buffer) {
Append(buf, tpl.Columns[0])
buf.WriteString(" in (")
for i, r := range tpl.Rows {
if i != 0 {
buf.WriteString(", ")
}
r[0].EncodeSQL(buf)
}
buf.WriteByte(')')
}
func (tpl *TupleEqualityList) encodeAsEquality(buf *bytes.Buffer) {
for i, r := range tpl.Rows {
if i != 0 {
buf.WriteString(" or ")
}
buf.WriteString("(")
for j, c := range tpl.Columns {
if j != 0 {
buf.WriteString(" and ")
}
Append(buf, c)
buf.WriteString(" = ")
r[j].EncodeSQL(buf)
}
buf.WriteByte(')')
}
}

View File

@ -1,39 +0,0 @@
/*
Copyright 2017 Google Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreedto in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package sqlparser
// FormatImpossibleQuery creates an impossible query in a TrackedBuffer.
// An impossible query is a modified version of a query where all selects have where clauses that are
// impossible for mysql to resolve. This is used in the vtgate and vttablet:
//
// - In the vtgate it's used for joins: if the first query returns no result, then vtgate uses the impossible
// query just to fetch field info from vttablet
// - In the vttablet, it's just an optimization: the field info is fetched once form MySQL, cached and reused
// for subsequent queries
func FormatImpossibleQuery(buf *TrackedBuffer, node SQLNode) {
switch node := node.(type) {
case *Select:
buf.Myprintf("select %v from %v where 1 != 1", node.SelectExprs, node.From)
if node.GroupBy != nil {
node.GroupBy.Format(buf)
}
case *Union:
buf.Myprintf("%v %s %v", node.Left, node.Type, node.Right)
default:
node.Format(buf)
}
}

View File

@ -1,224 +0,0 @@
/*
Copyright 2017 Google Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package sqlparser
import (
"fmt"
"github.com/xwb1989/sqlparser/dependency/sqltypes"
"github.com/xwb1989/sqlparser/dependency/querypb"
)
// Normalize changes the statement to use bind values, and
// updates the bind vars to those values. The supplied prefix
// is used to generate the bind var names. The function ensures
// that there are no collisions with existing bind vars.
// Within Select constructs, bind vars are deduped. This allows
// us to identify vindex equality. Otherwise, every value is
// treated as distinct.
func Normalize(stmt Statement, bindVars map[string]*querypb.BindVariable, prefix string) {
nz := newNormalizer(stmt, bindVars, prefix)
_ = Walk(nz.WalkStatement, stmt)
}
type normalizer struct {
stmt Statement
bindVars map[string]*querypb.BindVariable
prefix string
reserved map[string]struct{}
counter int
vals map[string]string
}
func newNormalizer(stmt Statement, bindVars map[string]*querypb.BindVariable, prefix string) *normalizer {
return &normalizer{
stmt: stmt,
bindVars: bindVars,
prefix: prefix,
reserved: GetBindvars(stmt),
counter: 1,
vals: make(map[string]string),
}
}
// WalkStatement is the top level walk function.
// If it encounters a Select, it switches to a mode
// where variables are deduped.
func (nz *normalizer) WalkStatement(node SQLNode) (bool, error) {
switch node := node.(type) {
case *Select:
_ = Walk(nz.WalkSelect, node)
// Don't continue
return false, nil
case *SQLVal:
nz.convertSQLVal(node)
case *ComparisonExpr:
nz.convertComparison(node)
}
return true, nil
}
// WalkSelect normalizes the AST in Select mode.
func (nz *normalizer) WalkSelect(node SQLNode) (bool, error) {
switch node := node.(type) {
case *SQLVal:
nz.convertSQLValDedup(node)
case *ComparisonExpr:
nz.convertComparison(node)
}
return true, nil
}
func (nz *normalizer) convertSQLValDedup(node *SQLVal) {
// If value is too long, don't dedup.
// Such values are most likely not for vindexes.
// We save a lot of CPU because we avoid building
// the key for them.
if len(node.Val) > 256 {
nz.convertSQLVal(node)
return
}
// Make the bindvar
bval := nz.sqlToBindvar(node)
if bval == nil {
return
}
// Check if there's a bindvar for that value already.
var key string
if bval.Type == sqltypes.VarBinary {
// Prefixing strings with "'" ensures that a string
// and number that have the same representation don't
// collide.
key = "'" + string(node.Val)
} else {
key = string(node.Val)
}
bvname, ok := nz.vals[key]
if !ok {
// If there's no such bindvar, make a new one.
bvname = nz.newName()
nz.vals[key] = bvname
nz.bindVars[bvname] = bval
}
// Modify the AST node to a bindvar.
node.Type = ValArg
node.Val = append([]byte(":"), bvname...)
}
// convertSQLVal converts an SQLVal without the dedup.
func (nz *normalizer) convertSQLVal(node *SQLVal) {
bval := nz.sqlToBindvar(node)
if bval == nil {
return
}
bvname := nz.newName()
nz.bindVars[bvname] = bval
node.Type = ValArg
node.Val = append([]byte(":"), bvname...)
}
// convertComparison attempts to convert IN clauses to
// use the list bind var construct. If it fails, it returns
// with no change made. The walk function will then continue
// and iterate on converting each individual value into separate
// bind vars.
func (nz *normalizer) convertComparison(node *ComparisonExpr) {
if node.Operator != InStr && node.Operator != NotInStr {
return
}
tupleVals, ok := node.Right.(ValTuple)
if !ok {
return
}
// The RHS is a tuple of values.
// Make a list bindvar.
bvals := &querypb.BindVariable{
Type: querypb.Type_TUPLE,
}
for _, val := range tupleVals {
bval := nz.sqlToBindvar(val)
if bval == nil {
return
}
bvals.Values = append(bvals.Values, &querypb.Value{
Type: bval.Type,
Value: bval.Value,
})
}
bvname := nz.newName()
nz.bindVars[bvname] = bvals
// Modify RHS to be a list bindvar.
node.Right = ListArg(append([]byte("::"), bvname...))
}
func (nz *normalizer) sqlToBindvar(node SQLNode) *querypb.BindVariable {
if node, ok := node.(*SQLVal); ok {
var v sqltypes.Value
var err error
switch node.Type {
case StrVal:
v, err = sqltypes.NewValue(sqltypes.VarBinary, node.Val)
case IntVal:
v, err = sqltypes.NewValue(sqltypes.Int64, node.Val)
case FloatVal:
v, err = sqltypes.NewValue(sqltypes.Float64, node.Val)
default:
return nil
}
if err != nil {
return nil
}
return sqltypes.ValueBindVariable(v)
}
return nil
}
func (nz *normalizer) newName() string {
for {
newName := fmt.Sprintf("%s%d", nz.prefix, nz.counter)
if _, ok := nz.reserved[newName]; !ok {
nz.reserved[newName] = struct{}{}
return newName
}
nz.counter++
}
}
// GetBindvars returns a map of the bind vars referenced in the statement.
// TODO(sougou); This function gets called again from vtgate/planbuilder.
// Ideally, this should be done only once.
func GetBindvars(stmt Statement) map[string]struct{} {
bindvars := make(map[string]struct{})
_ = Walk(func(node SQLNode) (kontinue bool, err error) {
switch node := node.(type) {
case *SQLVal:
if node.Type == ValArg {
bindvars[string(node.Val[1:])] = struct{}{}
}
case ListArg:
bindvars[string(node[2:])] = struct{}{}
}
return true, nil
}, stmt)
return bindvars
}

View File

@ -1,119 +0,0 @@
/*
Copyright 2017 Google Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package sqlparser
import (
"bytes"
"fmt"
"github.com/xwb1989/sqlparser/dependency/querypb"
"github.com/xwb1989/sqlparser/dependency/sqltypes"
)
// ParsedQuery represents a parsed query where
// bind locations are precompued for fast substitutions.
type ParsedQuery struct {
Query string
bindLocations []bindLocation
}
type bindLocation struct {
offset, length int
}
// NewParsedQuery returns a ParsedQuery of the ast.
func NewParsedQuery(node SQLNode) *ParsedQuery {
buf := NewTrackedBuffer(nil)
buf.Myprintf("%v", node)
return buf.ParsedQuery()
}
// GenerateQuery generates a query by substituting the specified
// bindVariables. The extras parameter specifies special parameters
// that can perform custom encoding.
func (pq *ParsedQuery) GenerateQuery(bindVariables map[string]*querypb.BindVariable, extras map[string]Encodable) ([]byte, error) {
if len(pq.bindLocations) == 0 {
return []byte(pq.Query), nil
}
buf := bytes.NewBuffer(make([]byte, 0, len(pq.Query)))
current := 0
for _, loc := range pq.bindLocations {
buf.WriteString(pq.Query[current:loc.offset])
name := pq.Query[loc.offset : loc.offset+loc.length]
if encodable, ok := extras[name[1:]]; ok {
encodable.EncodeSQL(buf)
} else {
supplied, _, err := FetchBindVar(name, bindVariables)
if err != nil {
return nil, err
}
EncodeValue(buf, supplied)
}
current = loc.offset + loc.length
}
buf.WriteString(pq.Query[current:])
return buf.Bytes(), nil
}
// EncodeValue encodes one bind variable value into the query.
func EncodeValue(buf *bytes.Buffer, value *querypb.BindVariable) {
if value.Type != querypb.Type_TUPLE {
// Since we already check for TUPLE, we don't expect an error.
v, _ := sqltypes.BindVariableToValue(value)
v.EncodeSQL(buf)
return
}
// It's a TUPLE.
buf.WriteByte('(')
for i, bv := range value.Values {
if i != 0 {
buf.WriteString(", ")
}
sqltypes.ProtoToValue(bv).EncodeSQL(buf)
}
buf.WriteByte(')')
}
// FetchBindVar resolves the bind variable by fetching it from bindVariables.
func FetchBindVar(name string, bindVariables map[string]*querypb.BindVariable) (val *querypb.BindVariable, isList bool, err error) {
name = name[1:]
if name[0] == ':' {
name = name[1:]
isList = true
}
supplied, ok := bindVariables[name]
if !ok {
return nil, false, fmt.Errorf("missing bind var %s", name)
}
if isList {
if supplied.Type != querypb.Type_TUPLE {
return nil, false, fmt.Errorf("unexpected list arg type (%v) for key %s", supplied.Type, name)
}
if len(supplied.Values) == 0 {
return nil, false, fmt.Errorf("empty list supplied for %s", name)
}
return supplied, true, nil
}
if supplied.Type == querypb.Type_TUPLE {
return nil, false, fmt.Errorf("unexpected arg type (TUPLE) for non-list key %s", name)
}
return supplied, false, nil
}

View File

@ -1,19 +0,0 @@
package sqlparser
import querypb "github.com/xwb1989/sqlparser/dependency/querypb"
// RedactSQLQuery returns a sql string with the params stripped out for display
func RedactSQLQuery(sql string) (string, error) {
bv := map[string]*querypb.BindVariable{}
sqlStripped, comments := SplitMarginComments(sql)
stmt, err := Parse(sqlStripped)
if err != nil {
return "", err
}
prefix := "redacted"
Normalize(stmt, bv, prefix)
return comments.Leading + String(stmt) + comments.Trailing, nil
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,950 +0,0 @@
/*
Copyright 2017 Google Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package sqlparser
import (
"bytes"
"errors"
"fmt"
"io"
"github.com/xwb1989/sqlparser/dependency/bytes2"
"github.com/xwb1989/sqlparser/dependency/sqltypes"
)
const (
defaultBufSize = 4096
eofChar = 0x100
)
// Tokenizer is the struct used to generate SQL
// tokens for the parser.
type Tokenizer struct {
InStream io.Reader
AllowComments bool
ForceEOF bool
lastChar uint16
Position int
lastToken []byte
LastError error
posVarIndex int
ParseTree Statement
partialDDL *DDL
nesting int
multi bool
specialComment *Tokenizer
buf []byte
bufPos int
bufSize int
}
// NewStringTokenizer creates a new Tokenizer for the
// sql string.
func NewStringTokenizer(sql string) *Tokenizer {
buf := []byte(sql)
return &Tokenizer{
buf: buf,
bufSize: len(buf),
}
}
// NewTokenizer creates a new Tokenizer reading a sql
// string from the io.Reader.
func NewTokenizer(r io.Reader) *Tokenizer {
return &Tokenizer{
InStream: r,
buf: make([]byte, defaultBufSize),
}
}
// keywords is a map of mysql keywords that fall into two categories:
// 1) keywords considered reserved by MySQL
// 2) keywords for us to handle specially in sql.y
//
// Those marked as UNUSED are likely reserved keywords. We add them here so that
// when rewriting queries we can properly backtick quote them so they don't cause issues
//
// NOTE: If you add new keywords, add them also to the reserved_keywords or
// non_reserved_keywords grammar in sql.y -- this will allow the keyword to be used
// in identifiers. See the docs for each grammar to determine which one to put it into.
var keywords = map[string]int{
"accessible": UNUSED,
"add": ADD,
"against": AGAINST,
"all": ALL,
"alter": ALTER,
"analyze": ANALYZE,
"and": AND,
"as": AS,
"asc": ASC,
"asensitive": UNUSED,
"auto_increment": AUTO_INCREMENT,
"before": UNUSED,
"begin": BEGIN,
"between": BETWEEN,
"bigint": BIGINT,
"binary": BINARY,
"_binary": UNDERSCORE_BINARY,
"bit": BIT,
"blob": BLOB,
"bool": BOOL,
"boolean": BOOLEAN,
"both": UNUSED,
"by": BY,
"call": UNUSED,
"cascade": UNUSED,
"case": CASE,
"cast": CAST,
"change": UNUSED,
"char": CHAR,
"character": CHARACTER,
"charset": CHARSET,
"check": UNUSED,
"collate": COLLATE,
"column": COLUMN,
"comment": COMMENT_KEYWORD,
"committed": COMMITTED,
"commit": COMMIT,
"condition": UNUSED,
"constraint": CONSTRAINT,
"continue": UNUSED,
"convert": CONVERT,
"substr": SUBSTR,
"substring": SUBSTRING,
"create": CREATE,
"cross": CROSS,
"current_date": CURRENT_DATE,
"current_time": CURRENT_TIME,
"current_timestamp": CURRENT_TIMESTAMP,
"current_user": UNUSED,
"cursor": UNUSED,
"database": DATABASE,
"databases": DATABASES,
"day_hour": UNUSED,
"day_microsecond": UNUSED,
"day_minute": UNUSED,
"day_second": UNUSED,
"date": DATE,
"datetime": DATETIME,
"dec": UNUSED,
"decimal": DECIMAL,
"declare": UNUSED,
"default": DEFAULT,
"delayed": UNUSED,
"delete": DELETE,
"desc": DESC,
"describe": DESCRIBE,
"deterministic": UNUSED,
"distinct": DISTINCT,
"distinctrow": UNUSED,
"div": DIV,
"double": DOUBLE,
"drop": DROP,
"duplicate": DUPLICATE,
"each": UNUSED,
"else": ELSE,
"elseif": UNUSED,
"enclosed": UNUSED,
"end": END,
"enum": ENUM,
"escape": ESCAPE,
"escaped": UNUSED,
"exists": EXISTS,
"exit": UNUSED,
"explain": EXPLAIN,
"expansion": EXPANSION,
"extended": EXTENDED,
"false": FALSE,
"fetch": UNUSED,
"float": FLOAT_TYPE,
"float4": UNUSED,
"float8": UNUSED,
"for": FOR,
"force": FORCE,
"foreign": FOREIGN,
"from": FROM,
"full": FULL,
"fulltext": FULLTEXT,
"generated": UNUSED,
"geometry": GEOMETRY,
"geometrycollection": GEOMETRYCOLLECTION,
"get": UNUSED,
"global": GLOBAL,
"grant": UNUSED,
"group": GROUP,
"group_concat": GROUP_CONCAT,
"having": HAVING,
"high_priority": UNUSED,
"hour_microsecond": UNUSED,
"hour_minute": UNUSED,
"hour_second": UNUSED,
"if": IF,
"ignore": IGNORE,
"in": IN,
"index": INDEX,
"infile": UNUSED,
"inout": UNUSED,
"inner": INNER,
"insensitive": UNUSED,
"insert": INSERT,
"int": INT,
"int1": UNUSED,
"int2": UNUSED,
"int3": UNUSED,
"int4": UNUSED,
"int8": UNUSED,
"integer": INTEGER,
"interval": INTERVAL,
"into": INTO,
"io_after_gtids": UNUSED,
"is": IS,
"isolation": ISOLATION,
"iterate": UNUSED,
"join": JOIN,
"json": JSON,
"key": KEY,
"keys": KEYS,
"key_block_size": KEY_BLOCK_SIZE,
"kill": UNUSED,
"language": LANGUAGE,
"last_insert_id": LAST_INSERT_ID,
"leading": UNUSED,
"leave": UNUSED,
"left": LEFT,
"less": LESS,
"level": LEVEL,
"like": LIKE,
"limit": LIMIT,
"linear": UNUSED,
"lines": UNUSED,
"linestring": LINESTRING,
"load": UNUSED,
"localtime": LOCALTIME,
"localtimestamp": LOCALTIMESTAMP,
"lock": LOCK,
"long": UNUSED,
"longblob": LONGBLOB,
"longtext": LONGTEXT,
"loop": UNUSED,
"low_priority": UNUSED,
"master_bind": UNUSED,
"match": MATCH,
"maxvalue": MAXVALUE,
"mediumblob": MEDIUMBLOB,
"mediumint": MEDIUMINT,
"mediumtext": MEDIUMTEXT,
"middleint": UNUSED,
"minute_microsecond": UNUSED,
"minute_second": UNUSED,
"mod": MOD,
"mode": MODE,
"modifies": UNUSED,
"multilinestring": MULTILINESTRING,
"multipoint": MULTIPOINT,
"multipolygon": MULTIPOLYGON,
"names": NAMES,
"natural": NATURAL,
"nchar": NCHAR,
"next": NEXT,
"not": NOT,
"no_write_to_binlog": UNUSED,
"null": NULL,
"numeric": NUMERIC,
"offset": OFFSET,
"on": ON,
"only": ONLY,
"optimize": OPTIMIZE,
"optimizer_costs": UNUSED,
"option": UNUSED,
"optionally": UNUSED,
"or": OR,
"order": ORDER,
"out": UNUSED,
"outer": OUTER,
"outfile": UNUSED,
"partition": PARTITION,
"point": POINT,
"polygon": POLYGON,
"precision": UNUSED,
"primary": PRIMARY,
"processlist": PROCESSLIST,
"procedure": PROCEDURE,
"query": QUERY,
"range": UNUSED,
"read": READ,
"reads": UNUSED,
"read_write": UNUSED,
"real": REAL,
"references": UNUSED,
"regexp": REGEXP,
"release": UNUSED,
"rename": RENAME,
"reorganize": REORGANIZE,
"repair": REPAIR,
"repeat": UNUSED,
"repeatable": REPEATABLE,
"replace": REPLACE,
"require": UNUSED,
"resignal": UNUSED,
"restrict": UNUSED,
"return": UNUSED,
"revoke": UNUSED,
"right": RIGHT,
"rlike": REGEXP,
"rollback": ROLLBACK,
"schema": SCHEMA,
"schemas": UNUSED,
"second_microsecond": UNUSED,
"select": SELECT,
"sensitive": UNUSED,
"separator": SEPARATOR,
"serializable": SERIALIZABLE,
"session": SESSION,
"set": SET,
"share": SHARE,
"show": SHOW,
"signal": UNUSED,
"signed": SIGNED,
"smallint": SMALLINT,
"spatial": SPATIAL,
"specific": UNUSED,
"sql": UNUSED,
"sqlexception": UNUSED,
"sqlstate": UNUSED,
"sqlwarning": UNUSED,
"sql_big_result": UNUSED,
"sql_cache": SQL_CACHE,
"sql_calc_found_rows": UNUSED,
"sql_no_cache": SQL_NO_CACHE,
"sql_small_result": UNUSED,
"ssl": UNUSED,
"start": START,
"starting": UNUSED,
"status": STATUS,
"stored": UNUSED,
"straight_join": STRAIGHT_JOIN,
"stream": STREAM,
"table": TABLE,
"tables": TABLES,
"terminated": UNUSED,
"text": TEXT,
"than": THAN,
"then": THEN,
"time": TIME,
"timestamp": TIMESTAMP,
"tinyblob": TINYBLOB,
"tinyint": TINYINT,
"tinytext": TINYTEXT,
"to": TO,
"trailing": UNUSED,
"transaction": TRANSACTION,
"trigger": TRIGGER,
"true": TRUE,
"truncate": TRUNCATE,
"uncommitted": UNCOMMITTED,
"undo": UNUSED,
"union": UNION,
"unique": UNIQUE,
"unlock": UNUSED,
"unsigned": UNSIGNED,
"update": UPDATE,
"usage": UNUSED,
"use": USE,
"using": USING,
"utc_date": UTC_DATE,
"utc_time": UTC_TIME,
"utc_timestamp": UTC_TIMESTAMP,
"values": VALUES,
"variables": VARIABLES,
"varbinary": VARBINARY,
"varchar": VARCHAR,
"varcharacter": UNUSED,
"varying": UNUSED,
"virtual": UNUSED,
"vindex": VINDEX,
"vindexes": VINDEXES,
"view": VIEW,
"vitess_keyspaces": VITESS_KEYSPACES,
"vitess_shards": VITESS_SHARDS,
"vitess_tablets": VITESS_TABLETS,
"vschema_tables": VSCHEMA_TABLES,
"when": WHEN,
"where": WHERE,
"while": UNUSED,
"with": WITH,
"write": WRITE,
"xor": UNUSED,
"year": YEAR,
"year_month": UNUSED,
"zerofill": ZEROFILL,
}
// keywordStrings contains the reverse mapping of token to keyword strings
var keywordStrings = map[int]string{}
func init() {
for str, id := range keywords {
if id == UNUSED {
continue
}
keywordStrings[id] = str
}
}
// KeywordString returns the string corresponding to the given keyword
func KeywordString(id int) string {
str, ok := keywordStrings[id]
if !ok {
return ""
}
return str
}
// Lex returns the next token form the Tokenizer.
// This function is used by go yacc.
func (tkn *Tokenizer) Lex(lval *yySymType) int {
typ, val := tkn.Scan()
for typ == COMMENT {
if tkn.AllowComments {
break
}
typ, val = tkn.Scan()
}
lval.bytes = val
tkn.lastToken = val
return typ
}
// Error is called by go yacc if there's a parsing error.
func (tkn *Tokenizer) Error(err string) {
buf := &bytes2.Buffer{}
if tkn.lastToken != nil {
fmt.Fprintf(buf, "%s at position %v near '%s'", err, tkn.Position, tkn.lastToken)
} else {
fmt.Fprintf(buf, "%s at position %v", err, tkn.Position)
}
tkn.LastError = errors.New(buf.String())
// Try and re-sync to the next statement
if tkn.lastChar != ';' {
tkn.skipStatement()
}
}
// Scan scans the tokenizer for the next token and returns
// the token type and an optional value.
func (tkn *Tokenizer) Scan() (int, []byte) {
if tkn.specialComment != nil {
// Enter specialComment scan mode.
// for scanning such kind of comment: /*! MySQL-specific code */
specialComment := tkn.specialComment
tok, val := specialComment.Scan()
if tok != 0 {
// return the specialComment scan result as the result
return tok, val
}
// leave specialComment scan mode after all stream consumed.
tkn.specialComment = nil
}
if tkn.lastChar == 0 {
tkn.next()
}
if tkn.ForceEOF {
tkn.skipStatement()
return 0, nil
}
tkn.skipBlank()
switch ch := tkn.lastChar; {
case isLetter(ch):
tkn.next()
if ch == 'X' || ch == 'x' {
if tkn.lastChar == '\'' {
tkn.next()
return tkn.scanHex()
}
}
if ch == 'B' || ch == 'b' {
if tkn.lastChar == '\'' {
tkn.next()
return tkn.scanBitLiteral()
}
}
isDbSystemVariable := false
if ch == '@' && tkn.lastChar == '@' {
isDbSystemVariable = true
}
return tkn.scanIdentifier(byte(ch), isDbSystemVariable)
case isDigit(ch):
return tkn.scanNumber(false)
case ch == ':':
return tkn.scanBindVar()
case ch == ';' && tkn.multi:
return 0, nil
default:
tkn.next()
switch ch {
case eofChar:
return 0, nil
case '=', ',', ';', '(', ')', '+', '*', '%', '^', '~':
return int(ch), nil
case '&':
if tkn.lastChar == '&' {
tkn.next()
return AND, nil
}
return int(ch), nil
case '|':
if tkn.lastChar == '|' {
tkn.next()
return OR, nil
}
return int(ch), nil
case '?':
tkn.posVarIndex++
buf := new(bytes2.Buffer)
fmt.Fprintf(buf, ":v%d", tkn.posVarIndex)
return VALUE_ARG, buf.Bytes()
case '.':
if isDigit(tkn.lastChar) {
return tkn.scanNumber(true)
}
return int(ch), nil
case '/':
switch tkn.lastChar {
case '/':
tkn.next()
return tkn.scanCommentType1("//")
case '*':
tkn.next()
switch tkn.lastChar {
case '!':
return tkn.scanMySQLSpecificComment()
default:
return tkn.scanCommentType2()
}
default:
return int(ch), nil
}
case '#':
return tkn.scanCommentType1("#")
case '-':
switch tkn.lastChar {
case '-':
tkn.next()
return tkn.scanCommentType1("--")
case '>':
tkn.next()
if tkn.lastChar == '>' {
tkn.next()
return JSON_UNQUOTE_EXTRACT_OP, nil
}
return JSON_EXTRACT_OP, nil
}
return int(ch), nil
case '<':
switch tkn.lastChar {
case '>':
tkn.next()
return NE, nil
case '<':
tkn.next()
return SHIFT_LEFT, nil
case '=':
tkn.next()
switch tkn.lastChar {
case '>':
tkn.next()
return NULL_SAFE_EQUAL, nil
default:
return LE, nil
}
default:
return int(ch), nil
}
case '>':
switch tkn.lastChar {
case '=':
tkn.next()
return GE, nil
case '>':
tkn.next()
return SHIFT_RIGHT, nil
default:
return int(ch), nil
}
case '!':
if tkn.lastChar == '=' {
tkn.next()
return NE, nil
}
return int(ch), nil
case '\'', '"':
return tkn.scanString(ch, STRING)
case '`':
return tkn.scanLiteralIdentifier()
default:
return LEX_ERROR, []byte{byte(ch)}
}
}
}
// skipStatement scans until the EOF, or end of statement is encountered.
func (tkn *Tokenizer) skipStatement() {
ch := tkn.lastChar
for ch != ';' && ch != eofChar {
tkn.next()
ch = tkn.lastChar
}
}
func (tkn *Tokenizer) skipBlank() {
ch := tkn.lastChar
for ch == ' ' || ch == '\n' || ch == '\r' || ch == '\t' {
tkn.next()
ch = tkn.lastChar
}
}
func (tkn *Tokenizer) scanIdentifier(firstByte byte, isDbSystemVariable bool) (int, []byte) {
buffer := &bytes2.Buffer{}
buffer.WriteByte(firstByte)
for isLetter(tkn.lastChar) || isDigit(tkn.lastChar) || (isDbSystemVariable && isCarat(tkn.lastChar)) {
buffer.WriteByte(byte(tkn.lastChar))
tkn.next()
}
lowered := bytes.ToLower(buffer.Bytes())
loweredStr := string(lowered)
if keywordID, found := keywords[loweredStr]; found {
return keywordID, lowered
}
// dual must always be case-insensitive
if loweredStr == "dual" {
return ID, lowered
}
return ID, buffer.Bytes()
}
func (tkn *Tokenizer) scanHex() (int, []byte) {
buffer := &bytes2.Buffer{}
tkn.scanMantissa(16, buffer)
if tkn.lastChar != '\'' {
return LEX_ERROR, buffer.Bytes()
}
tkn.next()
if buffer.Len()%2 != 0 {
return LEX_ERROR, buffer.Bytes()
}
return HEX, buffer.Bytes()
}
func (tkn *Tokenizer) scanBitLiteral() (int, []byte) {
buffer := &bytes2.Buffer{}
tkn.scanMantissa(2, buffer)
if tkn.lastChar != '\'' {
return LEX_ERROR, buffer.Bytes()
}
tkn.next()
return BIT_LITERAL, buffer.Bytes()
}
func (tkn *Tokenizer) scanLiteralIdentifier() (int, []byte) {
buffer := &bytes2.Buffer{}
backTickSeen := false
for {
if backTickSeen {
if tkn.lastChar != '`' {
break
}
backTickSeen = false
buffer.WriteByte('`')
tkn.next()
continue
}
// The previous char was not a backtick.
switch tkn.lastChar {
case '`':
backTickSeen = true
case eofChar:
// Premature EOF.
return LEX_ERROR, buffer.Bytes()
default:
buffer.WriteByte(byte(tkn.lastChar))
}
tkn.next()
}
if buffer.Len() == 0 {
return LEX_ERROR, buffer.Bytes()
}
return ID, buffer.Bytes()
}
func (tkn *Tokenizer) scanBindVar() (int, []byte) {
buffer := &bytes2.Buffer{}
buffer.WriteByte(byte(tkn.lastChar))
token := VALUE_ARG
tkn.next()
if tkn.lastChar == ':' {
token = LIST_ARG
buffer.WriteByte(byte(tkn.lastChar))
tkn.next()
}
if !isLetter(tkn.lastChar) {
return LEX_ERROR, buffer.Bytes()
}
for isLetter(tkn.lastChar) || isDigit(tkn.lastChar) || tkn.lastChar == '.' {
buffer.WriteByte(byte(tkn.lastChar))
tkn.next()
}
return token, buffer.Bytes()
}
func (tkn *Tokenizer) scanMantissa(base int, buffer *bytes2.Buffer) {
for digitVal(tkn.lastChar) < base {
tkn.consumeNext(buffer)
}
}
func (tkn *Tokenizer) scanNumber(seenDecimalPoint bool) (int, []byte) {
token := INTEGRAL
buffer := &bytes2.Buffer{}
if seenDecimalPoint {
token = FLOAT
buffer.WriteByte('.')
tkn.scanMantissa(10, buffer)
goto exponent
}
// 0x construct.
if tkn.lastChar == '0' {
tkn.consumeNext(buffer)
if tkn.lastChar == 'x' || tkn.lastChar == 'X' {
token = HEXNUM
tkn.consumeNext(buffer)
tkn.scanMantissa(16, buffer)
goto exit
}
}
tkn.scanMantissa(10, buffer)
if tkn.lastChar == '.' {
token = FLOAT
tkn.consumeNext(buffer)
tkn.scanMantissa(10, buffer)
}
exponent:
if tkn.lastChar == 'e' || tkn.lastChar == 'E' {
token = FLOAT
tkn.consumeNext(buffer)
if tkn.lastChar == '+' || tkn.lastChar == '-' {
tkn.consumeNext(buffer)
}
tkn.scanMantissa(10, buffer)
}
exit:
// A letter cannot immediately follow a number.
if isLetter(tkn.lastChar) {
return LEX_ERROR, buffer.Bytes()
}
return token, buffer.Bytes()
}
func (tkn *Tokenizer) scanString(delim uint16, typ int) (int, []byte) {
var buffer bytes2.Buffer
for {
ch := tkn.lastChar
if ch == eofChar {
// Unterminated string.
return LEX_ERROR, buffer.Bytes()
}
if ch != delim && ch != '\\' {
buffer.WriteByte(byte(ch))
// Scan ahead to the next interesting character.
start := tkn.bufPos
for ; tkn.bufPos < tkn.bufSize; tkn.bufPos++ {
ch = uint16(tkn.buf[tkn.bufPos])
if ch == delim || ch == '\\' {
break
}
}
buffer.Write(tkn.buf[start:tkn.bufPos])
tkn.Position += (tkn.bufPos - start)
if tkn.bufPos >= tkn.bufSize {
// Reached the end of the buffer without finding a delim or
// escape character.
tkn.next()
continue
}
tkn.bufPos++
tkn.Position++
}
tkn.next() // Read one past the delim or escape character.
if ch == '\\' {
if tkn.lastChar == eofChar {
// String terminates mid escape character.
return LEX_ERROR, buffer.Bytes()
}
if decodedChar := sqltypes.SQLDecodeMap[byte(tkn.lastChar)]; decodedChar == sqltypes.DontEscape {
ch = tkn.lastChar
} else {
ch = uint16(decodedChar)
}
} else if ch == delim && tkn.lastChar != delim {
// Correctly terminated string, which is not a double delim.
break
}
buffer.WriteByte(byte(ch))
tkn.next()
}
return typ, buffer.Bytes()
}
func (tkn *Tokenizer) scanCommentType1(prefix string) (int, []byte) {
buffer := &bytes2.Buffer{}
buffer.WriteString(prefix)
for tkn.lastChar != eofChar {
if tkn.lastChar == '\n' {
tkn.consumeNext(buffer)
break
}
tkn.consumeNext(buffer)
}
return COMMENT, buffer.Bytes()
}
func (tkn *Tokenizer) scanCommentType2() (int, []byte) {
buffer := &bytes2.Buffer{}
buffer.WriteString("/*")
for {
if tkn.lastChar == '*' {
tkn.consumeNext(buffer)
if tkn.lastChar == '/' {
tkn.consumeNext(buffer)
break
}
continue
}
if tkn.lastChar == eofChar {
return LEX_ERROR, buffer.Bytes()
}
tkn.consumeNext(buffer)
}
return COMMENT, buffer.Bytes()
}
func (tkn *Tokenizer) scanMySQLSpecificComment() (int, []byte) {
buffer := &bytes2.Buffer{}
buffer.WriteString("/*!")
tkn.next()
for {
if tkn.lastChar == '*' {
tkn.consumeNext(buffer)
if tkn.lastChar == '/' {
tkn.consumeNext(buffer)
break
}
continue
}
if tkn.lastChar == eofChar {
return LEX_ERROR, buffer.Bytes()
}
tkn.consumeNext(buffer)
}
_, sql := ExtractMysqlComment(buffer.String())
tkn.specialComment = NewStringTokenizer(sql)
return tkn.Scan()
}
func (tkn *Tokenizer) consumeNext(buffer *bytes2.Buffer) {
if tkn.lastChar == eofChar {
// This should never happen.
panic("unexpected EOF")
}
buffer.WriteByte(byte(tkn.lastChar))
tkn.next()
}
func (tkn *Tokenizer) next() {
if tkn.bufPos >= tkn.bufSize && tkn.InStream != nil {
// Try and refill the buffer
var err error
tkn.bufPos = 0
if tkn.bufSize, err = tkn.InStream.Read(tkn.buf); err != io.EOF && err != nil {
tkn.LastError = err
}
}
if tkn.bufPos >= tkn.bufSize {
if tkn.lastChar != eofChar {
tkn.Position++
tkn.lastChar = eofChar
}
} else {
tkn.Position++
tkn.lastChar = uint16(tkn.buf[tkn.bufPos])
tkn.bufPos++
}
}
// reset clears any internal state.
func (tkn *Tokenizer) reset() {
tkn.ParseTree = nil
tkn.partialDDL = nil
tkn.specialComment = nil
tkn.posVarIndex = 0
tkn.nesting = 0
tkn.ForceEOF = false
}
func isLetter(ch uint16) bool {
return 'a' <= ch && ch <= 'z' || 'A' <= ch && ch <= 'Z' || ch == '_' || ch == '@'
}
func isCarat(ch uint16) bool {
return ch == '.' || ch == '\'' || ch == '"' || ch == '`'
}
func digitVal(ch uint16) int {
switch {
case '0' <= ch && ch <= '9':
return int(ch) - '0'
case 'a' <= ch && ch <= 'f':
return int(ch) - 'a' + 10
case 'A' <= ch && ch <= 'F':
return int(ch) - 'A' + 10
}
return 16 // larger than any legal digit val
}
func isDigit(ch uint16) bool {
return '0' <= ch && ch <= '9'
}

View File

@ -1,140 +0,0 @@
/*
Copyright 2017 Google Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package sqlparser
import (
"bytes"
"fmt"
)
// NodeFormatter defines the signature of a custom node formatter
// function that can be given to TrackedBuffer for code generation.
type NodeFormatter func(buf *TrackedBuffer, node SQLNode)
// TrackedBuffer is used to rebuild a query from the ast.
// bindLocations keeps track of locations in the buffer that
// use bind variables for efficient future substitutions.
// nodeFormatter is the formatting function the buffer will
// use to format a node. By default(nil), it's FormatNode.
// But you can supply a different formatting function if you
// want to generate a query that's different from the default.
type TrackedBuffer struct {
*bytes.Buffer
bindLocations []bindLocation
nodeFormatter NodeFormatter
}
// NewTrackedBuffer creates a new TrackedBuffer.
func NewTrackedBuffer(nodeFormatter NodeFormatter) *TrackedBuffer {
return &TrackedBuffer{
Buffer: new(bytes.Buffer),
nodeFormatter: nodeFormatter,
}
}
// WriteNode function, initiates the writing of a single SQLNode tree by passing
// through to Myprintf with a default format string
func (buf *TrackedBuffer) WriteNode(node SQLNode) *TrackedBuffer {
buf.Myprintf("%v", node)
return buf
}
// Myprintf mimics fmt.Fprintf(buf, ...), but limited to Node(%v),
// Node.Value(%s) and string(%s). It also allows a %a for a value argument, in
// which case it adds tracking info for future substitutions.
//
// The name must be something other than the usual Printf() to avoid "go vet"
// warnings due to our custom format specifiers.
func (buf *TrackedBuffer) Myprintf(format string, values ...interface{}) {
end := len(format)
fieldnum := 0
for i := 0; i < end; {
lasti := i
for i < end && format[i] != '%' {
i++
}
if i > lasti {
buf.WriteString(format[lasti:i])
}
if i >= end {
break
}
i++ // '%'
switch format[i] {
case 'c':
switch v := values[fieldnum].(type) {
case byte:
buf.WriteByte(v)
case rune:
buf.WriteRune(v)
default:
panic(fmt.Sprintf("unexpected TrackedBuffer type %T", v))
}
case 's':
switch v := values[fieldnum].(type) {
case []byte:
buf.Write(v)
case string:
buf.WriteString(v)
default:
panic(fmt.Sprintf("unexpected TrackedBuffer type %T", v))
}
case 'v':
node := values[fieldnum].(SQLNode)
if buf.nodeFormatter == nil {
node.Format(buf)
} else {
buf.nodeFormatter(buf, node)
}
case 'a':
buf.WriteArg(values[fieldnum].(string))
default:
panic("unexpected")
}
fieldnum++
i++
}
}
// WriteArg writes a value argument into the buffer along with
// tracking information for future substitutions. arg must contain
// the ":" or "::" prefix.
func (buf *TrackedBuffer) WriteArg(arg string) {
buf.bindLocations = append(buf.bindLocations, bindLocation{
offset: buf.Len(),
length: len(arg),
})
buf.WriteString(arg)
}
// ParsedQuery returns a ParsedQuery that contains bind
// locations for easy substitution.
func (buf *TrackedBuffer) ParsedQuery() *ParsedQuery {
return &ParsedQuery{Query: buf.String(), bindLocations: buf.bindLocations}
}
// HasBindVars returns true if the parsed query uses bind vars.
func (buf *TrackedBuffer) HasBindVars() bool {
return len(buf.bindLocations) != 0
}
// BuildParsedQuery builds a ParsedQuery from the input.
func BuildParsedQuery(in string, vars ...interface{}) *ParsedQuery {
buf := NewTrackedBuffer(nil)
buf.Myprintf(in, vars...)
return buf.ParsedQuery()
}

24
vendor/vendor.json vendored
View File

@ -97,6 +97,18 @@
"revision": "f6be1abbb5abd0517522f850dd785990d373da7e",
"revisionTime": "2017-09-13T22:19:17Z"
},
{
"checksumSHA1": "Xmp7mYQyG/1fyIahOyTyN9yZamY=",
"path": "github.com/alecthomas/participle",
"revision": "bf8340a459bd383e5eb7d44a9a1b3af23b6cf8cd",
"revisionTime": "2019-01-03T08:53:15Z"
},
{
"checksumSHA1": "0R8Lqt4DtU8+7Eq1mL7Hd+cjDOI=",
"path": "github.com/alecthomas/participle/lexer",
"revision": "bf8340a459bd383e5eb7d44a9a1b3af23b6cf8cd",
"revisionTime": "2019-01-03T08:53:15Z"
},
{
"checksumSHA1": "tX0Bq1gzqskL98nnB1X2rDqxH18=",
"path": "github.com/aliyun/aliyun-oss-go-sdk/oss",
@ -644,10 +656,10 @@
"revisionTime": "2019-01-20T10:05:29Z"
},
{
"checksumSHA1": "pxgHNx36gpRdhSqtaE5fqp7lrAA=",
"checksumSHA1": "ik77jlf0oMQTlSndP85DlIVOnOY=",
"path": "github.com/minio/parquet-go",
"revision": "1014bfb4d0e323e3fbf6683e3519a98b0721f5cc",
"revisionTime": "2019-01-14T09:43:57Z"
"revision": "7a17a919eeed02c393f3117a9ed1ac6df0da9aa5",
"revisionTime": "2019-01-18T04:40:39Z"
},
{
"checksumSHA1": "N4WRPw4p3AN958RH/O53kUsJacQ=",
@ -888,12 +900,6 @@
"revision": "ceec8f93295a060cdb565ec25e4ccf17941dbd55",
"revisionTime": "2016-11-14T21:01:44Z"
},
{
"checksumSHA1": "6ksZHYhLc3yOzTbcWKb3bDENhD4=",
"path": "github.com/xwb1989/sqlparser",
"revision": "120387863bf27d04bc07db8015110a6e96d0146c",
"revisionTime": "2018-06-06T15:21:19Z"
},
{
"checksumSHA1": "L/Q8Ylbo+wnj5whDFfMxxwyxmdo=",
"path": "github.com/xwb1989/sqlparser/dependency/bytes2",