mirror of https://github.com/minio/minio.git
Add new SQL parser to support S3 Select syntax (#7102)
- New parser written from scratch, allows easier and complete parsing of the full S3 Select SQL syntax. Parser definition is directly provided by the AST defined for the SQL grammar. - Bring support to parse and interpret SQL involving JSON path expressions; evaluation of JSON path expressions will be subsequently added. - Bring automatic type inference and conversion for untyped values (e.g. CSV data).
This commit is contained in:
parent
0a28c28a8c
commit
2786055df4
|
@ -185,6 +185,12 @@ func (api objectAPIHandlers) SelectObjectContentHandler(w http.ResponseWriter, r
|
|||
return getObjectNInfo(ctx, bucket, object, rs, r.Header, readLock, ObjectOptions{})
|
||||
}
|
||||
|
||||
objInfo, err := getObjectInfo(ctx, bucket, object, opts)
|
||||
if err != nil {
|
||||
writeErrorResponse(w, toAPIErrorCode(ctx, err), r.URL, guessIsBrowserReq(r))
|
||||
return
|
||||
}
|
||||
|
||||
if err = s3Select.Open(getObject); err != nil {
|
||||
if serr, ok := err.(s3select.SelectError); ok {
|
||||
w.WriteHeader(serr.HTTPStatusCode())
|
||||
|
@ -198,12 +204,6 @@ func (api objectAPIHandlers) SelectObjectContentHandler(w http.ResponseWriter, r
|
|||
s3Select.Evaluate(w)
|
||||
s3Select.Close()
|
||||
|
||||
objInfo, err := getObjectInfo(ctx, bucket, object, opts)
|
||||
if err != nil {
|
||||
logger.LogIf(ctx, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Get host and port from Request.RemoteAddr.
|
||||
host, port, err := net.SplitHostPort(handlers.GetSourceIP(r))
|
||||
if err != nil {
|
||||
|
|
|
@ -3,6 +3,15 @@ Traditional retrieval of objects is always as whole entities, i.e GetObject for
|
|||
|
||||
> This implementation is compatible with AWS S3 Select API
|
||||
|
||||
### Implemention status:
|
||||
|
||||
- Full S3 SQL syntax is supported
|
||||
- All aggregation, conditional, type-conversion and strings SQL functions are supported
|
||||
- JSONPath expressions are not yet evaluated
|
||||
- Large numbers (more than 64-bit) are not yet supported
|
||||
- Date related functions are not yet supported (EXTRACT, DATE_DIFF, etc)
|
||||
- S3's reserved keywords list is not yet respected
|
||||
|
||||
## 1. Prerequisites
|
||||
- Install Minio Server from [here](http://docs.minio.io/docs/minio-quickstart-guide).
|
||||
- Familiarity with AWS S3 API
|
||||
|
|
|
@ -32,7 +32,11 @@ type Record struct {
|
|||
nameIndexMap map[string]int64
|
||||
}
|
||||
|
||||
// Get - gets the value for a column name.
|
||||
// Get - gets the value for a column name. CSV fields do not have any
|
||||
// defined type (other than the default string). So this function
|
||||
// always returns fields using sql.FromBytes so that the type
|
||||
// specified/implied by the query can be used, or can be automatically
|
||||
// converted based on the query.
|
||||
func (r *Record) Get(name string) (*sql.Value, error) {
|
||||
index, found := r.nameIndexMap[name]
|
||||
if !found {
|
||||
|
@ -40,11 +44,12 @@ func (r *Record) Get(name string) (*sql.Value, error) {
|
|||
}
|
||||
|
||||
if index >= int64(len(r.csvRecord)) {
|
||||
// No value found for column 'name', hence return empty string for compatibility.
|
||||
return sql.NewString(""), nil
|
||||
// No value found for column 'name', hence return null
|
||||
// value
|
||||
return sql.FromNull(), nil
|
||||
}
|
||||
|
||||
return sql.NewString(r.csvRecord[index]), nil
|
||||
return sql.FromBytes([]byte(r.csvRecord[index])), nil
|
||||
}
|
||||
|
||||
// Set - sets the value for a column name.
|
||||
|
|
|
@ -37,15 +37,15 @@ func (r *Record) Get(name string) (*sql.Value, error) {
|
|||
result := gjson.GetBytes(r.data, name)
|
||||
switch result.Type {
|
||||
case gjson.Null:
|
||||
return sql.NewNull(), nil
|
||||
return sql.FromNull(), nil
|
||||
case gjson.False:
|
||||
return sql.NewBool(false), nil
|
||||
return sql.FromBool(false), nil
|
||||
case gjson.Number:
|
||||
return sql.NewFloat(result.Float()), nil
|
||||
return sql.FromFloat(result.Float()), nil
|
||||
case gjson.String:
|
||||
return sql.NewString(result.String()), nil
|
||||
return sql.FromString(result.String()), nil
|
||||
case gjson.True:
|
||||
return sql.NewBool(true), nil
|
||||
return sql.FromBool(true), nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unsupported gjson value %v; %v", result, result.Type)
|
||||
|
@ -54,19 +54,20 @@ func (r *Record) Get(name string) (*sql.Value, error) {
|
|||
// Set - sets the value for a column name.
|
||||
func (r *Record) Set(name string, value *sql.Value) (err error) {
|
||||
var v interface{}
|
||||
switch value.Type() {
|
||||
case sql.Null:
|
||||
v = value.NullValue()
|
||||
case sql.Bool:
|
||||
v = value.BoolValue()
|
||||
case sql.Int:
|
||||
v = value.IntValue()
|
||||
case sql.Float:
|
||||
v = value.FloatValue()
|
||||
case sql.String:
|
||||
v = value.StringValue()
|
||||
default:
|
||||
return fmt.Errorf("unsupported sql value %v and type %v", value, value.Type())
|
||||
if b, ok := value.ToBool(); ok {
|
||||
v = b
|
||||
} else if f, ok := value.ToFloat(); ok {
|
||||
v = f
|
||||
} else if i, ok := value.ToInt(); ok {
|
||||
v = i
|
||||
} else if s, ok := value.ToString(); ok {
|
||||
v = s
|
||||
} else if value.IsNull() {
|
||||
v = nil
|
||||
} else if b, ok := value.ToBytes(); ok {
|
||||
v = string(b)
|
||||
} else {
|
||||
return fmt.Errorf("unsupported sql value %v and type %v", value, value.GetTypeString())
|
||||
}
|
||||
|
||||
name = strings.Replace(name, "*", "__ALL__", -1)
|
||||
|
|
|
@ -32,7 +32,7 @@ type Reader struct {
|
|||
}
|
||||
|
||||
// Read - reads single record.
|
||||
func (r *Reader) Read() (sql.Record, error) {
|
||||
func (r *Reader) Read() (rec sql.Record, rerr error) {
|
||||
parquetRecord, err := r.file.Read()
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
|
@ -43,39 +43,41 @@ func (r *Reader) Read() (sql.Record, error) {
|
|||
}
|
||||
|
||||
record := json.NewRecord()
|
||||
for name, v := range parquetRecord {
|
||||
f := func(name string, v parquetgo.Value) bool {
|
||||
if v.Value == nil {
|
||||
if err = record.Set(name, sql.NewNull()); err != nil {
|
||||
return nil, errParquetParsingError(err)
|
||||
if err := record.Set(name, sql.FromNull()); err != nil {
|
||||
rerr = errParquetParsingError(err)
|
||||
}
|
||||
|
||||
continue
|
||||
return rerr == nil
|
||||
}
|
||||
|
||||
var value *sql.Value
|
||||
switch v.Type {
|
||||
case parquetgen.Type_BOOLEAN:
|
||||
value = sql.NewBool(v.Value.(bool))
|
||||
value = sql.FromBool(v.Value.(bool))
|
||||
case parquetgen.Type_INT32:
|
||||
value = sql.NewInt(int64(v.Value.(int32)))
|
||||
value = sql.FromInt(int64(v.Value.(int32)))
|
||||
case parquetgen.Type_INT64:
|
||||
value = sql.NewInt(v.Value.(int64))
|
||||
value = sql.FromInt(int64(v.Value.(int64)))
|
||||
case parquetgen.Type_FLOAT:
|
||||
value = sql.NewFloat(float64(v.Value.(float32)))
|
||||
value = sql.FromFloat(float64(v.Value.(float32)))
|
||||
case parquetgen.Type_DOUBLE:
|
||||
value = sql.NewFloat(v.Value.(float64))
|
||||
value = sql.FromFloat(v.Value.(float64))
|
||||
case parquetgen.Type_INT96, parquetgen.Type_BYTE_ARRAY, parquetgen.Type_FIXED_LEN_BYTE_ARRAY:
|
||||
value = sql.NewString(string(v.Value.([]byte)))
|
||||
value = sql.FromString(string(v.Value.([]byte)))
|
||||
default:
|
||||
return nil, errParquetParsingError(nil)
|
||||
rerr = errParquetParsingError(nil)
|
||||
return false
|
||||
}
|
||||
|
||||
if err = record.Set(name, value); err != nil {
|
||||
return nil, errParquetParsingError(err)
|
||||
rerr = errParquetParsingError(err)
|
||||
}
|
||||
return rerr == nil
|
||||
}
|
||||
|
||||
return record, nil
|
||||
parquetRecord.Range(f)
|
||||
return record, rerr
|
||||
}
|
||||
|
||||
// Close - closes underlaying readers.
|
||||
|
|
|
@ -105,7 +105,7 @@ func (input *InputSerialization) UnmarshalXML(d *xml.Decoder, start xml.StartEle
|
|||
found++
|
||||
}
|
||||
if !parsedInput.ParquetArgs.IsEmpty() {
|
||||
if parsedInput.CompressionType != noneType {
|
||||
if parsedInput.CompressionType != "" && parsedInput.CompressionType != noneType {
|
||||
return errInvalidRequestParameter(fmt.Errorf("CompressionType must be NONE for Parquet format"))
|
||||
}
|
||||
|
||||
|
@ -178,7 +178,7 @@ type S3Select struct {
|
|||
Output OutputSerialization `xml:"OutputSerialization"`
|
||||
Progress RequestProgress `xml:"RequestProgress"`
|
||||
|
||||
statement *sql.Select
|
||||
statement *sql.SelectStatement
|
||||
progressReader *progressReader
|
||||
recordReader recordReader
|
||||
}
|
||||
|
@ -209,12 +209,12 @@ func (s3Select *S3Select) UnmarshalXML(d *xml.Decoder, start xml.StartElement) e
|
|||
return errMissingRequiredParameter(fmt.Errorf("OutputSerialization must be provided"))
|
||||
}
|
||||
|
||||
statement, err := sql.NewSelect(parsedS3Select.Expression)
|
||||
statement, err := sql.ParseSelectStatement(parsedS3Select.Expression)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
parsedS3Select.statement = statement
|
||||
parsedS3Select.statement = &statement
|
||||
|
||||
*s3Select = S3Select(parsedS3Select)
|
||||
return nil
|
||||
|
@ -334,6 +334,14 @@ func (s3Select *S3Select) Evaluate(w http.ResponseWriter) {
|
|||
}
|
||||
|
||||
for {
|
||||
if s3Select.statement.LimitReached() {
|
||||
if err = writer.SendStats(s3Select.getProgress()); err != nil {
|
||||
// FIXME: log this error.
|
||||
err = nil
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if inputRecord, err = s3Select.recordReader.Read(); err != nil {
|
||||
if err != io.EOF {
|
||||
break
|
||||
|
@ -358,19 +366,25 @@ func (s3Select *S3Select) Evaluate(w http.ResponseWriter) {
|
|||
break
|
||||
}
|
||||
|
||||
outputRecord = s3Select.outputRecord()
|
||||
if outputRecord, err = s3Select.statement.Eval(inputRecord, outputRecord); err != nil {
|
||||
break
|
||||
}
|
||||
if s3Select.statement.IsAggregated() {
|
||||
if err = s3Select.statement.AggregateRow(inputRecord); err != nil {
|
||||
break
|
||||
}
|
||||
} else {
|
||||
outputRecord = s3Select.outputRecord()
|
||||
if outputRecord, err = s3Select.statement.Eval(inputRecord, outputRecord); err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
if !s3Select.statement.IsAggregated() {
|
||||
if !sendRecord() {
|
||||
break
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
fmt.Printf("SQL Err: %#v\n", err)
|
||||
if serr := writer.SendError("InternalError", err.Error()); serr != nil {
|
||||
// FIXME: log errors.
|
||||
}
|
||||
|
|
|
@ -0,0 +1,318 @@
|
|||
/*
|
||||
* Minio Cloud Storage, (C) 2019 Minio, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Aggregation Function name constants
|
||||
const (
|
||||
aggFnAvg FuncName = "AVG"
|
||||
aggFnCount FuncName = "COUNT"
|
||||
aggFnMax FuncName = "MAX"
|
||||
aggFnMin FuncName = "MIN"
|
||||
aggFnSum FuncName = "SUM"
|
||||
)
|
||||
|
||||
var (
|
||||
errNonNumericArg = func(fnStr FuncName) error {
|
||||
return fmt.Errorf("%s() requires a numeric argument", fnStr)
|
||||
}
|
||||
errInvalidAggregation = errors.New("Invalid aggregation seen")
|
||||
)
|
||||
|
||||
type aggVal struct {
|
||||
runningSum *Value
|
||||
runningCount int64
|
||||
runningMax, runningMin *Value
|
||||
|
||||
// Stores if at least one record has been seen
|
||||
seen bool
|
||||
}
|
||||
|
||||
func newAggVal(fn FuncName) *aggVal {
|
||||
switch fn {
|
||||
case aggFnAvg, aggFnSum:
|
||||
return &aggVal{runningSum: FromInt(0)}
|
||||
case aggFnMin:
|
||||
return &aggVal{runningMin: FromInt(0)}
|
||||
case aggFnMax:
|
||||
return &aggVal{runningMax: FromInt(0)}
|
||||
default:
|
||||
return &aggVal{}
|
||||
}
|
||||
}
|
||||
|
||||
// evalAggregationNode - performs partial computation using the
|
||||
// current row and stores the result.
|
||||
//
|
||||
// On success, it returns (nil, nil).
|
||||
func (e *FuncExpr) evalAggregationNode(r Record) error {
|
||||
// It is assumed that this function is called only when
|
||||
// `e` is an aggregation function.
|
||||
|
||||
var val *Value
|
||||
var err error
|
||||
funcName := e.getFunctionName()
|
||||
if aggFnCount == funcName {
|
||||
if e.Count.StarArg {
|
||||
// Handle COUNT(*)
|
||||
e.aggregate.runningCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
val, err = e.Count.ExprArg.evalNode(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// Evaluate the (only) argument
|
||||
val, err = e.SFunc.ArgsList[0].evalNode(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if val.IsNull() {
|
||||
// E.g. the column or field does not exist in the
|
||||
// record - in all such cases the aggregation is not
|
||||
// updated.
|
||||
return nil
|
||||
}
|
||||
|
||||
argVal := val
|
||||
if funcName != aggFnCount {
|
||||
// All aggregation functions, except COUNT require a
|
||||
// numeric argument.
|
||||
|
||||
// Here, we diverge from Amazon S3 behavior by
|
||||
// inferring untyped values are numbers.
|
||||
if i, ok := argVal.bytesToInt(); ok {
|
||||
argVal.setInt(i)
|
||||
} else if f, ok := argVal.bytesToFloat(); ok {
|
||||
argVal.setFloat(f)
|
||||
} else {
|
||||
return errNonNumericArg(funcName)
|
||||
}
|
||||
}
|
||||
|
||||
// Mark that we have seen one non-null value.
|
||||
isFirstRow := false
|
||||
if !e.aggregate.seen {
|
||||
e.aggregate.seen = true
|
||||
isFirstRow = true
|
||||
}
|
||||
|
||||
switch funcName {
|
||||
case aggFnCount:
|
||||
// For all non-null values, the count is incremented.
|
||||
e.aggregate.runningCount++
|
||||
|
||||
case aggFnAvg:
|
||||
e.aggregate.runningCount++
|
||||
err = e.aggregate.runningSum.arithOp(opPlus, argVal)
|
||||
|
||||
case aggFnMin:
|
||||
err = e.aggregate.runningMin.minmax(argVal, false, isFirstRow)
|
||||
|
||||
case aggFnMax:
|
||||
err = e.aggregate.runningMax.minmax(argVal, true, isFirstRow)
|
||||
|
||||
case aggFnSum:
|
||||
err = e.aggregate.runningSum.arithOp(opPlus, argVal)
|
||||
|
||||
default:
|
||||
err = errInvalidAggregation
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (e *AliasedExpression) aggregateRow(r Record) error {
|
||||
return e.Expression.aggregateRow(r)
|
||||
}
|
||||
|
||||
func (e *Expression) aggregateRow(r Record) error {
|
||||
for _, ex := range e.And {
|
||||
err := ex.aggregateRow(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *AndCondition) aggregateRow(r Record) error {
|
||||
for _, ex := range e.Condition {
|
||||
err := ex.aggregateRow(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Condition) aggregateRow(r Record) error {
|
||||
if e.Operand != nil {
|
||||
return e.Operand.aggregateRow(r)
|
||||
}
|
||||
return e.Not.aggregateRow(r)
|
||||
}
|
||||
|
||||
func (e *ConditionOperand) aggregateRow(r Record) error {
|
||||
err := e.Operand.aggregateRow(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if e.ConditionRHS == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch {
|
||||
case e.ConditionRHS.Compare != nil:
|
||||
return e.ConditionRHS.Compare.Operand.aggregateRow(r)
|
||||
case e.ConditionRHS.Between != nil:
|
||||
err = e.ConditionRHS.Between.Start.aggregateRow(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return e.ConditionRHS.Between.End.aggregateRow(r)
|
||||
case e.ConditionRHS.In != nil:
|
||||
for _, elt := range e.ConditionRHS.In.Expressions {
|
||||
err = elt.aggregateRow(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
case e.ConditionRHS.Like != nil:
|
||||
err = e.ConditionRHS.Like.Pattern.aggregateRow(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return e.ConditionRHS.Like.EscapeChar.aggregateRow(r)
|
||||
default:
|
||||
return errInvalidASTNode
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Operand) aggregateRow(r Record) error {
|
||||
err := e.Left.aggregateRow(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, rt := range e.Right {
|
||||
err = rt.Right.aggregateRow(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *MultOp) aggregateRow(r Record) error {
|
||||
err := e.Left.aggregateRow(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, rt := range e.Right {
|
||||
err = rt.Right.aggregateRow(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *UnaryTerm) aggregateRow(r Record) error {
|
||||
if e.Negated != nil {
|
||||
return e.Negated.Term.aggregateRow(r)
|
||||
}
|
||||
return e.Primary.aggregateRow(r)
|
||||
}
|
||||
|
||||
func (e *PrimaryTerm) aggregateRow(r Record) error {
|
||||
switch {
|
||||
case e.SubExpression != nil:
|
||||
return e.SubExpression.aggregateRow(r)
|
||||
case e.FuncCall != nil:
|
||||
return e.FuncCall.aggregateRow(r)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *FuncExpr) aggregateRow(r Record) error {
|
||||
switch e.getFunctionName() {
|
||||
case aggFnAvg, aggFnSum, aggFnMax, aggFnMin, aggFnCount:
|
||||
return e.evalAggregationNode(r)
|
||||
default:
|
||||
// TODO: traverse arguments and call aggregateRow on
|
||||
// them if they could be an ancestor of an
|
||||
// aggregation.
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getAggregate() implementation for each AST node follows. This is
|
||||
// called after calling aggregateRow() on each input row, to calculate
|
||||
// the final aggregate result.
|
||||
|
||||
func (e *Expression) getAggregate() (*Value, error) {
|
||||
return e.evalNode(nil)
|
||||
}
|
||||
|
||||
func (e *FuncExpr) getAggregate() (*Value, error) {
|
||||
switch e.getFunctionName() {
|
||||
case aggFnCount:
|
||||
return FromFloat(float64(e.aggregate.runningCount)), nil
|
||||
|
||||
case aggFnAvg:
|
||||
if e.aggregate.runningCount == 0 {
|
||||
// No rows were seen by AVG.
|
||||
return FromNull(), nil
|
||||
}
|
||||
err := e.aggregate.runningSum.arithOp(opDivide, FromInt(e.aggregate.runningCount))
|
||||
return e.aggregate.runningSum, err
|
||||
|
||||
case aggFnMin:
|
||||
if !e.aggregate.seen {
|
||||
// No rows were seen by MIN
|
||||
return FromNull(), nil
|
||||
}
|
||||
return e.aggregate.runningMin, nil
|
||||
|
||||
case aggFnMax:
|
||||
if !e.aggregate.seen {
|
||||
// No rows were seen by MAX
|
||||
return FromNull(), nil
|
||||
}
|
||||
return e.aggregate.runningMax, nil
|
||||
|
||||
case aggFnSum:
|
||||
// TODO: check if returning 0 when no rows were seen
|
||||
// by SUM is expected behavior.
|
||||
return e.aggregate.runningSum, nil
|
||||
|
||||
default:
|
||||
// TODO:
|
||||
}
|
||||
|
||||
return nil, errInvalidAggregation
|
||||
}
|
|
@ -0,0 +1,290 @@
|
|||
/*
|
||||
* Minio Cloud Storage, (C) 2019 Minio, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Query analysis - The query is analyzed to determine if it involves
|
||||
// aggregation.
|
||||
//
|
||||
// Aggregation functions - An expression that involves aggregation of
|
||||
// rows in some manner. Requires all input rows to be processed,
|
||||
// before a result is returned.
|
||||
//
|
||||
// Row function - An expression that depends on a value in the
|
||||
// row. They have an output for each input row.
|
||||
//
|
||||
// Some types of a queries are not valid. For example, an aggregation
|
||||
// function combined with a row function is meaningless ("AVG(s.Age) +
|
||||
// s.Salary"). Analysis determines if such a scenario exists so an
|
||||
// error can be returned.
|
||||
|
||||
var (
|
||||
// Fatal error for query processing.
|
||||
errNestedAggregation = errors.New("Cannot nest aggregations")
|
||||
errFunctionNotImplemented = errors.New("Function is not yet implemented")
|
||||
errUnexpectedInvalidNode = errors.New("Unexpected node value")
|
||||
errInvalidKeypath = errors.New("A provided keypath is invalid")
|
||||
)
|
||||
|
||||
// qProp contains analysis info about an SQL term.
|
||||
type qProp struct {
|
||||
isAggregation, isRowFunc bool
|
||||
|
||||
err error
|
||||
}
|
||||
|
||||
// `combine` combines a pair of `qProp`s, so that errors are
|
||||
// propagated correctly, and checks that an aggregation is not being
|
||||
// combined with a row-function term.
|
||||
func (p *qProp) combine(q qProp) {
|
||||
switch {
|
||||
case p.err != nil:
|
||||
// Do nothing
|
||||
case q.err != nil:
|
||||
p.err = q.err
|
||||
default:
|
||||
p.isAggregation = p.isAggregation || q.isAggregation
|
||||
p.isRowFunc = p.isRowFunc || q.isRowFunc
|
||||
if p.isAggregation && p.isRowFunc {
|
||||
p.err = errNestedAggregation
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (e *SelectExpression) analyze(s *Select) (result qProp) {
|
||||
if e.All {
|
||||
return qProp{isRowFunc: true}
|
||||
}
|
||||
|
||||
for _, ex := range e.Expressions {
|
||||
result.combine(ex.analyze(s))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (e *AliasedExpression) analyze(s *Select) qProp {
|
||||
return e.Expression.analyze(s)
|
||||
}
|
||||
|
||||
func (e *Expression) analyze(s *Select) (result qProp) {
|
||||
for _, ac := range e.And {
|
||||
result.combine(ac.analyze(s))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (e *AndCondition) analyze(s *Select) (result qProp) {
|
||||
for _, ac := range e.Condition {
|
||||
result.combine(ac.analyze(s))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (e *Condition) analyze(s *Select) (result qProp) {
|
||||
if e.Operand != nil {
|
||||
result = e.Operand.analyze(s)
|
||||
} else {
|
||||
result = e.Not.analyze(s)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (e *ConditionOperand) analyze(s *Select) (result qProp) {
|
||||
if e.ConditionRHS == nil {
|
||||
result = e.Operand.analyze(s)
|
||||
} else {
|
||||
result.combine(e.Operand.analyze(s))
|
||||
result.combine(e.ConditionRHS.analyze(s))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (e *ConditionRHS) analyze(s *Select) (result qProp) {
|
||||
switch {
|
||||
case e.Compare != nil:
|
||||
result = e.Compare.Operand.analyze(s)
|
||||
case e.Between != nil:
|
||||
result.combine(e.Between.Start.analyze(s))
|
||||
result.combine(e.Between.End.analyze(s))
|
||||
case e.In != nil:
|
||||
for _, elt := range e.In.Expressions {
|
||||
result.combine(elt.analyze(s))
|
||||
}
|
||||
case e.Like != nil:
|
||||
result.combine(e.Like.Pattern.analyze(s))
|
||||
if e.Like.EscapeChar != nil {
|
||||
result.combine(e.Like.EscapeChar.analyze(s))
|
||||
}
|
||||
default:
|
||||
result = qProp{err: errUnexpectedInvalidNode}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (e *Operand) analyze(s *Select) (result qProp) {
|
||||
result.combine(e.Left.analyze(s))
|
||||
for _, r := range e.Right {
|
||||
result.combine(r.Right.analyze(s))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (e *MultOp) analyze(s *Select) (result qProp) {
|
||||
result.combine(e.Left.analyze(s))
|
||||
for _, r := range e.Right {
|
||||
result.combine(r.Right.analyze(s))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (e *UnaryTerm) analyze(s *Select) (result qProp) {
|
||||
if e.Negated != nil {
|
||||
result = e.Negated.Term.analyze(s)
|
||||
} else {
|
||||
result = e.Primary.analyze(s)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (e *PrimaryTerm) analyze(s *Select) (result qProp) {
|
||||
switch {
|
||||
case e.Value != nil:
|
||||
result = qProp{}
|
||||
|
||||
case e.JPathExpr != nil:
|
||||
// Check if the path expression is valid
|
||||
if len(e.JPathExpr.PathExpr) > 0 {
|
||||
if e.JPathExpr.BaseKey.String() != s.From.As {
|
||||
result = qProp{err: errInvalidKeypath}
|
||||
return
|
||||
}
|
||||
}
|
||||
result = qProp{isRowFunc: true}
|
||||
|
||||
case e.SubExpression != nil:
|
||||
result = e.SubExpression.analyze(s)
|
||||
|
||||
case e.FuncCall != nil:
|
||||
result = e.FuncCall.analyze(s)
|
||||
|
||||
default:
|
||||
result = qProp{err: errUnexpectedInvalidNode}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (e *FuncExpr) analyze(s *Select) (result qProp) {
|
||||
funcName := e.getFunctionName()
|
||||
|
||||
switch funcName {
|
||||
case sqlFnCast:
|
||||
return e.Cast.Expr.analyze(s)
|
||||
|
||||
case sqlFnExtract:
|
||||
return e.Extract.From.analyze(s)
|
||||
|
||||
// Handle aggregation function calls
|
||||
case aggFnAvg, aggFnMax, aggFnMin, aggFnSum, aggFnCount:
|
||||
// Initialize accumulator
|
||||
e.aggregate = newAggVal(funcName)
|
||||
|
||||
var exprA qProp
|
||||
if funcName == aggFnCount {
|
||||
if e.Count.StarArg {
|
||||
return qProp{isAggregation: true}
|
||||
}
|
||||
|
||||
exprA = e.Count.ExprArg.analyze(s)
|
||||
} else {
|
||||
if len(e.SFunc.ArgsList) != 1 {
|
||||
return qProp{err: fmt.Errorf("%s takes exactly one argument", funcName)}
|
||||
}
|
||||
exprA = e.SFunc.ArgsList[0].analyze(s)
|
||||
}
|
||||
|
||||
if exprA.err != nil {
|
||||
return exprA
|
||||
}
|
||||
if exprA.isAggregation {
|
||||
return qProp{err: errNestedAggregation}
|
||||
}
|
||||
return qProp{isAggregation: true}
|
||||
|
||||
case sqlFnCoalesce:
|
||||
if len(e.SFunc.ArgsList) == 0 {
|
||||
return qProp{err: fmt.Errorf("%s needs at least one argument", string(funcName))}
|
||||
}
|
||||
for _, arg := range e.SFunc.ArgsList {
|
||||
result.combine(arg.analyze(s))
|
||||
}
|
||||
return result
|
||||
|
||||
case sqlFnNullIf:
|
||||
if len(e.SFunc.ArgsList) != 2 {
|
||||
return qProp{err: fmt.Errorf("%s needs exactly 2 arguments", string(funcName))}
|
||||
}
|
||||
for _, arg := range e.SFunc.ArgsList {
|
||||
result.combine(arg.analyze(s))
|
||||
}
|
||||
return result
|
||||
|
||||
case sqlFnCharLength, sqlFnCharacterLength:
|
||||
if len(e.SFunc.ArgsList) != 1 {
|
||||
return qProp{err: fmt.Errorf("%s needs exactly 2 arguments", string(funcName))}
|
||||
}
|
||||
for _, arg := range e.SFunc.ArgsList {
|
||||
result.combine(arg.analyze(s))
|
||||
}
|
||||
return result
|
||||
|
||||
case sqlFnLower, sqlFnUpper:
|
||||
if len(e.SFunc.ArgsList) != 1 {
|
||||
return qProp{err: fmt.Errorf("%s needs exactly 2 arguments", string(funcName))}
|
||||
}
|
||||
for _, arg := range e.SFunc.ArgsList {
|
||||
result.combine(arg.analyze(s))
|
||||
}
|
||||
return result
|
||||
|
||||
case sqlFnSubstring:
|
||||
errVal := fmt.Errorf("Invalid argument(s) to %s", string(funcName))
|
||||
result.combine(e.Substring.Expr.analyze(s))
|
||||
switch {
|
||||
case e.Substring.From != nil:
|
||||
result.combine(e.Substring.From.analyze(s))
|
||||
if e.Substring.For != nil {
|
||||
result.combine(e.Substring.Expr.analyze(s))
|
||||
}
|
||||
case e.Substring.Arg2 != nil:
|
||||
result.combine(e.Substring.Arg2.analyze(s))
|
||||
if e.Substring.Arg3 != nil {
|
||||
result.combine(e.Substring.Arg3.analyze(s))
|
||||
}
|
||||
default:
|
||||
result.err = errVal
|
||||
}
|
||||
return result
|
||||
|
||||
}
|
||||
|
||||
// TODO: implement other functions
|
||||
return qProp{err: errFunctionNotImplemented}
|
||||
}
|
|
@ -1,175 +0,0 @@
|
|||
/*
|
||||
* Minio Cloud Storage, (C) 2019 Minio, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package sql
|
||||
|
||||
import "fmt"
|
||||
|
||||
// ArithOperator - arithmetic operator.
|
||||
type ArithOperator string
|
||||
|
||||
const (
|
||||
// Add operator '+'.
|
||||
Add ArithOperator = "+"
|
||||
|
||||
// Subtract operator '-'.
|
||||
Subtract ArithOperator = "-"
|
||||
|
||||
// Multiply operator '*'.
|
||||
Multiply ArithOperator = "*"
|
||||
|
||||
// Divide operator '/'.
|
||||
Divide ArithOperator = "/"
|
||||
|
||||
// Modulo operator '%'.
|
||||
Modulo ArithOperator = "%"
|
||||
)
|
||||
|
||||
// arithExpr - arithmetic function.
|
||||
type arithExpr struct {
|
||||
left Expr
|
||||
right Expr
|
||||
operator ArithOperator
|
||||
funcType Type
|
||||
}
|
||||
|
||||
// String - returns string representation of this function.
|
||||
func (f *arithExpr) String() string {
|
||||
return fmt.Sprintf("(%v %v %v)", f.left, f.operator, f.right)
|
||||
}
|
||||
|
||||
func (f *arithExpr) compute(lv, rv *Value) (*Value, error) {
|
||||
leftValueType := lv.Type()
|
||||
rightValueType := rv.Type()
|
||||
if !leftValueType.isNumber() {
|
||||
err := fmt.Errorf("%v: left side expression evaluated to %v; not to number", f, leftValueType)
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
if !rightValueType.isNumber() {
|
||||
err := fmt.Errorf("%v: right side expression evaluated to %v; not to number", f, rightValueType)
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
leftValue := lv.FloatValue()
|
||||
rightValue := rv.FloatValue()
|
||||
|
||||
var result float64
|
||||
switch f.operator {
|
||||
case Add:
|
||||
result = leftValue + rightValue
|
||||
case Subtract:
|
||||
result = leftValue - rightValue
|
||||
case Multiply:
|
||||
result = leftValue * rightValue
|
||||
case Divide:
|
||||
result = leftValue / rightValue
|
||||
case Modulo:
|
||||
result = float64(int64(leftValue) % int64(rightValue))
|
||||
}
|
||||
|
||||
if leftValueType == Float || rightValueType == Float {
|
||||
return NewFloat(result), nil
|
||||
}
|
||||
|
||||
return NewInt(int64(result)), nil
|
||||
}
|
||||
|
||||
// Call - evaluates this function for given arg values and returns result as Value.
|
||||
func (f *arithExpr) Eval(record Record) (*Value, error) {
|
||||
leftValue, err := f.left.Eval(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rightValue, err := f.right.Eval(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if f.funcType == aggregateFunction {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return f.compute(leftValue, rightValue)
|
||||
}
|
||||
|
||||
// AggregateValue - returns aggregated value.
|
||||
func (f *arithExpr) AggregateValue() (*Value, error) {
|
||||
if f.funcType != aggregateFunction {
|
||||
err := fmt.Errorf("%v is not aggreate expression", f)
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
lv, err := f.left.AggregateValue()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rv, err := f.right.AggregateValue()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return f.compute(lv, rv)
|
||||
}
|
||||
|
||||
// Type - returns arithmeticFunction or aggregateFunction type.
|
||||
func (f *arithExpr) Type() Type {
|
||||
return f.funcType
|
||||
}
|
||||
|
||||
// ReturnType - returns Float as return type.
|
||||
func (f *arithExpr) ReturnType() Type {
|
||||
return Float
|
||||
}
|
||||
|
||||
// newArithExpr - creates new arithmetic function.
|
||||
func newArithExpr(operator ArithOperator, left, right Expr) (*arithExpr, error) {
|
||||
if !left.ReturnType().isNumberKind() {
|
||||
err := fmt.Errorf("operator %v: left side expression %v evaluate to %v, not number", operator, left, left.ReturnType())
|
||||
return nil, errInvalidDataType(err)
|
||||
}
|
||||
|
||||
if !right.ReturnType().isNumberKind() {
|
||||
err := fmt.Errorf("operator %v: right side expression %v evaluate to %v; not number", operator, right, right.ReturnType())
|
||||
return nil, errInvalidDataType(err)
|
||||
}
|
||||
|
||||
funcType := arithmeticFunction
|
||||
if left.Type() == aggregateFunction || right.Type() == aggregateFunction {
|
||||
funcType = aggregateFunction
|
||||
switch left.Type() {
|
||||
case Int, Float, aggregateFunction:
|
||||
default:
|
||||
err := fmt.Errorf("operator %v: left side expression %v return type %v is incompatible for aggregate evaluation", operator, left, left.Type())
|
||||
return nil, errUnsupportedSQLOperation(err)
|
||||
}
|
||||
|
||||
switch right.Type() {
|
||||
case Int, Float, aggregateFunction:
|
||||
default:
|
||||
err := fmt.Errorf("operator %v: right side expression %v return type %v is incompatible for aggregate evaluation", operator, right, right.Type())
|
||||
return nil, errUnsupportedSQLOperation(err)
|
||||
}
|
||||
}
|
||||
|
||||
return &arithExpr{
|
||||
left: left,
|
||||
right: right,
|
||||
operator: operator,
|
||||
funcType: funcType,
|
||||
}, nil
|
||||
}
|
|
@ -1,636 +0,0 @@
|
|||
/*
|
||||
* Minio Cloud Storage, (C) 2019 Minio, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ComparisonOperator - comparison operator.
|
||||
type ComparisonOperator string
|
||||
|
||||
const (
|
||||
// Equal operator '='.
|
||||
Equal ComparisonOperator = "="
|
||||
|
||||
// NotEqual operator '!=' or '<>'.
|
||||
NotEqual ComparisonOperator = "!="
|
||||
|
||||
// LessThan operator '<'.
|
||||
LessThan ComparisonOperator = "<"
|
||||
|
||||
// GreaterThan operator '>'.
|
||||
GreaterThan ComparisonOperator = ">"
|
||||
|
||||
// LessThanEqual operator '<='.
|
||||
LessThanEqual ComparisonOperator = "<="
|
||||
|
||||
// GreaterThanEqual operator '>='.
|
||||
GreaterThanEqual ComparisonOperator = ">="
|
||||
|
||||
// Between operator 'BETWEEN'
|
||||
Between ComparisonOperator = "between"
|
||||
|
||||
// In operator 'IN'
|
||||
In ComparisonOperator = "in"
|
||||
|
||||
// Like operator 'LIKE'
|
||||
Like ComparisonOperator = "like"
|
||||
|
||||
// NotBetween operator 'NOT BETWEEN'
|
||||
NotBetween ComparisonOperator = "not between"
|
||||
|
||||
// NotIn operator 'NOT IN'
|
||||
NotIn ComparisonOperator = "not in"
|
||||
|
||||
// NotLike operator 'NOT LIKE'
|
||||
NotLike ComparisonOperator = "not like"
|
||||
|
||||
// IsNull operator 'IS NULL'
|
||||
IsNull ComparisonOperator = "is null"
|
||||
|
||||
// IsNotNull operator 'IS NOT NULL'
|
||||
IsNotNull ComparisonOperator = "is not null"
|
||||
)
|
||||
|
||||
// String - returns string representation of this operator.
|
||||
func (operator ComparisonOperator) String() string {
|
||||
return strings.ToUpper((string(operator)))
|
||||
}
|
||||
|
||||
func equal(leftValue, rightValue *Value) (bool, error) {
|
||||
switch {
|
||||
case leftValue.Type() == Null && rightValue.Type() == Null:
|
||||
return true, nil
|
||||
case leftValue.Type() == Bool && rightValue.Type() == Bool:
|
||||
return leftValue.BoolValue() == rightValue.BoolValue(), nil
|
||||
case (leftValue.Type() == Int || leftValue.Type() == Float) &&
|
||||
(rightValue.Type() == Int || rightValue.Type() == Float):
|
||||
return leftValue.FloatValue() == rightValue.FloatValue(), nil
|
||||
case leftValue.Type() == String && rightValue.Type() == String:
|
||||
return leftValue.StringValue() == rightValue.StringValue(), nil
|
||||
case leftValue.Type() == Timestamp && rightValue.Type() == Timestamp:
|
||||
return leftValue.TimeValue() == rightValue.TimeValue(), nil
|
||||
}
|
||||
|
||||
return false, fmt.Errorf("left value type %v and right value type %v are incompatible for equality check", leftValue.Type(), rightValue.Type())
|
||||
}
|
||||
|
||||
// comparisonExpr - comparison function.
|
||||
type comparisonExpr struct {
|
||||
left Expr
|
||||
right Expr
|
||||
to Expr
|
||||
operator ComparisonOperator
|
||||
funcType Type
|
||||
}
|
||||
|
||||
// String - returns string representation of this function.
|
||||
func (f *comparisonExpr) String() string {
|
||||
switch f.operator {
|
||||
case Equal, NotEqual, LessThan, GreaterThan, LessThanEqual, GreaterThanEqual, In, Like, NotIn, NotLike:
|
||||
return fmt.Sprintf("(%v %v %v)", f.left, f.operator, f.right)
|
||||
case Between, NotBetween:
|
||||
return fmt.Sprintf("(%v %v %v AND %v)", f.left, f.operator, f.right, f.to)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("(%v %v %v %v)", f.left, f.right, f.to, f.operator)
|
||||
}
|
||||
|
||||
func (f *comparisonExpr) equal(leftValue, rightValue *Value) (*Value, error) {
|
||||
result, err := equal(leftValue, rightValue)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("%v: %v", f, err)
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
return NewBool(result), nil
|
||||
}
|
||||
|
||||
func (f *comparisonExpr) notEqual(leftValue, rightValue *Value) (*Value, error) {
|
||||
result, err := equal(leftValue, rightValue)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("%v: %v", f, err)
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
return NewBool(!result), nil
|
||||
}
|
||||
|
||||
func (f *comparisonExpr) lessThan(leftValue, rightValue *Value) (*Value, error) {
|
||||
if !leftValue.Type().isNumber() {
|
||||
err := fmt.Errorf("%v: left side expression evaluated to %v; not to number", f, leftValue.Type())
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
if !rightValue.Type().isNumber() {
|
||||
err := fmt.Errorf("%v: right side expression evaluated to %v; not to number", f, rightValue.Type())
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
return NewBool(leftValue.FloatValue() < rightValue.FloatValue()), nil
|
||||
}
|
||||
|
||||
func (f *comparisonExpr) greaterThan(leftValue, rightValue *Value) (*Value, error) {
|
||||
if !leftValue.Type().isNumber() {
|
||||
err := fmt.Errorf("%v: left side expression evaluated to %v; not to number", f, leftValue.Type())
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
if !rightValue.Type().isNumber() {
|
||||
err := fmt.Errorf("%v: right side expression evaluated to %v; not to number", f, rightValue.Type())
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
return NewBool(leftValue.FloatValue() > rightValue.FloatValue()), nil
|
||||
}
|
||||
|
||||
func (f *comparisonExpr) lessThanEqual(leftValue, rightValue *Value) (*Value, error) {
|
||||
if !leftValue.Type().isNumber() {
|
||||
err := fmt.Errorf("%v: left side expression evaluated to %v; not to number", f, leftValue.Type())
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
if !rightValue.Type().isNumber() {
|
||||
err := fmt.Errorf("%v: right side expression evaluated to %v; not to number", f, rightValue.Type())
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
return NewBool(leftValue.FloatValue() <= rightValue.FloatValue()), nil
|
||||
}
|
||||
|
||||
func (f *comparisonExpr) greaterThanEqual(leftValue, rightValue *Value) (*Value, error) {
|
||||
if !leftValue.Type().isNumber() {
|
||||
err := fmt.Errorf("%v: left side expression evaluated to %v; not to number", f, leftValue.Type())
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
if !rightValue.Type().isNumber() {
|
||||
err := fmt.Errorf("%v: right side expression evaluated to %v; not to number", f, rightValue.Type())
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
return NewBool(leftValue.FloatValue() >= rightValue.FloatValue()), nil
|
||||
}
|
||||
|
||||
func (f *comparisonExpr) computeBetween(leftValue, fromValue, toValue *Value) (bool, error) {
|
||||
if !leftValue.Type().isNumber() {
|
||||
err := fmt.Errorf("%v: left side expression evaluated to %v; not to number", f, leftValue.Type())
|
||||
return false, errExternalEvalException(err)
|
||||
}
|
||||
if !fromValue.Type().isNumber() {
|
||||
err := fmt.Errorf("%v: from side expression evaluated to %v; not to number", f, fromValue.Type())
|
||||
return false, errExternalEvalException(err)
|
||||
}
|
||||
if !toValue.Type().isNumber() {
|
||||
err := fmt.Errorf("%v: to side expression evaluated to %v; not to number", f, toValue.Type())
|
||||
return false, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
return leftValue.FloatValue() >= fromValue.FloatValue() &&
|
||||
leftValue.FloatValue() <= toValue.FloatValue(), nil
|
||||
}
|
||||
|
||||
func (f *comparisonExpr) between(leftValue, fromValue, toValue *Value) (*Value, error) {
|
||||
result, err := f.computeBetween(leftValue, fromValue, toValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewBool(result), nil
|
||||
}
|
||||
|
||||
func (f *comparisonExpr) notBetween(leftValue, fromValue, toValue *Value) (*Value, error) {
|
||||
result, err := f.computeBetween(leftValue, fromValue, toValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewBool(!result), nil
|
||||
}
|
||||
|
||||
func (f *comparisonExpr) computeIn(leftValue, rightValue *Value) (found bool, err error) {
|
||||
if rightValue.Type() != Array {
|
||||
err := fmt.Errorf("%v: right side expression evaluated to %v; not to Array", f, rightValue.Type())
|
||||
return false, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
values := rightValue.ArrayValue()
|
||||
|
||||
for i := range values {
|
||||
found, err = equal(leftValue, values[i])
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if found {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (f *comparisonExpr) in(leftValue, rightValue *Value) (*Value, error) {
|
||||
result, err := f.computeIn(leftValue, rightValue)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("%v: %v", f, err)
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
return NewBool(result), nil
|
||||
}
|
||||
|
||||
func (f *comparisonExpr) notIn(leftValue, rightValue *Value) (*Value, error) {
|
||||
result, err := f.computeIn(leftValue, rightValue)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("%v: %v", f, err)
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
return NewBool(!result), nil
|
||||
}
|
||||
|
||||
func (f *comparisonExpr) computeLike(leftValue, rightValue *Value) (matched bool, err error) {
|
||||
if leftValue.Type() != String {
|
||||
err := fmt.Errorf("%v: left side expression evaluated to %v; not to string", f, leftValue.Type())
|
||||
return false, errExternalEvalException(err)
|
||||
}
|
||||
if rightValue.Type() != String {
|
||||
err := fmt.Errorf("%v: right side expression evaluated to %v; not to string", f, rightValue.Type())
|
||||
return false, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
matched, err = regexp.MatchString(rightValue.StringValue(), leftValue.StringValue())
|
||||
if err != nil {
|
||||
err = fmt.Errorf("%v: %v", f, err)
|
||||
return false, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
return matched, nil
|
||||
}
|
||||
|
||||
func (f *comparisonExpr) like(leftValue, rightValue *Value) (*Value, error) {
|
||||
result, err := f.computeLike(leftValue, rightValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewBool(result), nil
|
||||
}
|
||||
|
||||
func (f *comparisonExpr) notLike(leftValue, rightValue *Value) (*Value, error) {
|
||||
result, err := f.computeLike(leftValue, rightValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewBool(!result), nil
|
||||
}
|
||||
|
||||
func (f *comparisonExpr) compute(leftValue, rightValue, toValue *Value) (*Value, error) {
|
||||
switch f.operator {
|
||||
case Equal:
|
||||
return f.equal(leftValue, rightValue)
|
||||
case NotEqual:
|
||||
return f.notEqual(leftValue, rightValue)
|
||||
case LessThan:
|
||||
return f.lessThan(leftValue, rightValue)
|
||||
case GreaterThan:
|
||||
return f.greaterThan(leftValue, rightValue)
|
||||
case LessThanEqual:
|
||||
return f.lessThanEqual(leftValue, rightValue)
|
||||
case GreaterThanEqual:
|
||||
return f.greaterThanEqual(leftValue, rightValue)
|
||||
case Between:
|
||||
return f.between(leftValue, rightValue, toValue)
|
||||
case In:
|
||||
return f.in(leftValue, rightValue)
|
||||
case Like:
|
||||
return f.like(leftValue, rightValue)
|
||||
case NotBetween:
|
||||
return f.notBetween(leftValue, rightValue, toValue)
|
||||
case NotIn:
|
||||
return f.notIn(leftValue, rightValue)
|
||||
case NotLike:
|
||||
return f.notLike(leftValue, rightValue)
|
||||
}
|
||||
|
||||
panic(fmt.Errorf("unexpected expression %v", f))
|
||||
}
|
||||
|
||||
// Call - evaluates this function for given arg values and returns result as Value.
|
||||
func (f *comparisonExpr) Eval(record Record) (*Value, error) {
|
||||
leftValue, err := f.left.Eval(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rightValue, err := f.right.Eval(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var toValue *Value
|
||||
if f.to != nil {
|
||||
toValue, err = f.to.Eval(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if f.funcType == aggregateFunction {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return f.compute(leftValue, rightValue, toValue)
|
||||
}
|
||||
|
||||
// AggregateValue - returns aggregated value.
|
||||
func (f *comparisonExpr) AggregateValue() (*Value, error) {
|
||||
if f.funcType != aggregateFunction {
|
||||
err := fmt.Errorf("%v is not aggreate expression", f)
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
leftValue, err := f.left.AggregateValue()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rightValue, err := f.right.AggregateValue()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var toValue *Value
|
||||
if f.to != nil {
|
||||
toValue, err = f.to.AggregateValue()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return f.compute(leftValue, rightValue, toValue)
|
||||
}
|
||||
|
||||
// Type - returns comparisonFunction or aggregateFunction type.
|
||||
func (f *comparisonExpr) Type() Type {
|
||||
return f.funcType
|
||||
}
|
||||
|
||||
// ReturnType - returns Bool as return type.
|
||||
func (f *comparisonExpr) ReturnType() Type {
|
||||
return Bool
|
||||
}
|
||||
|
||||
// newComparisonExpr - creates new comparison function.
|
||||
func newComparisonExpr(operator ComparisonOperator, funcs ...Expr) (*comparisonExpr, error) {
|
||||
funcType := comparisonFunction
|
||||
switch operator {
|
||||
case Equal, NotEqual:
|
||||
if len(funcs) != 2 {
|
||||
panic(fmt.Sprintf("exactly two arguments are expected, but found %v", len(funcs)))
|
||||
}
|
||||
|
||||
left := funcs[0]
|
||||
if !left.ReturnType().isBaseKind() {
|
||||
err := fmt.Errorf("operator %v: left side expression %v evaluate to %v is incompatible for equality check", operator, left, left.ReturnType())
|
||||
return nil, errInvalidDataType(err)
|
||||
}
|
||||
|
||||
right := funcs[1]
|
||||
if !right.ReturnType().isBaseKind() {
|
||||
err := fmt.Errorf("operator %v: right side expression %v evaluate to %v is incompatible for equality check", operator, right, right.ReturnType())
|
||||
return nil, errInvalidDataType(err)
|
||||
}
|
||||
|
||||
if left.Type() == aggregateFunction || right.Type() == aggregateFunction {
|
||||
funcType = aggregateFunction
|
||||
switch left.Type() {
|
||||
case column, Array, function, arithmeticFunction, comparisonFunction, logicalFunction, record:
|
||||
err := fmt.Errorf("operator %v: left side expression %v return type %v is incompatible for equality check", operator, left, left.Type())
|
||||
return nil, errUnsupportedSQLOperation(err)
|
||||
}
|
||||
switch right.Type() {
|
||||
case column, Array, function, arithmeticFunction, comparisonFunction, logicalFunction, record:
|
||||
err := fmt.Errorf("operator %v: right side expression %v return type %v is incompatible for equality check", operator, right, right.Type())
|
||||
return nil, errUnsupportedSQLOperation(err)
|
||||
}
|
||||
}
|
||||
|
||||
return &comparisonExpr{
|
||||
left: left,
|
||||
right: right,
|
||||
operator: operator,
|
||||
funcType: funcType,
|
||||
}, nil
|
||||
|
||||
case LessThan, GreaterThan, LessThanEqual, GreaterThanEqual:
|
||||
if len(funcs) != 2 {
|
||||
panic(fmt.Sprintf("exactly two arguments are expected, but found %v", len(funcs)))
|
||||
}
|
||||
|
||||
left := funcs[0]
|
||||
if !left.ReturnType().isNumberKind() {
|
||||
err := fmt.Errorf("operator %v: left side expression %v evaluate to %v, not number", operator, left, left.ReturnType())
|
||||
return nil, errInvalidDataType(err)
|
||||
}
|
||||
|
||||
right := funcs[1]
|
||||
if !right.ReturnType().isNumberKind() {
|
||||
err := fmt.Errorf("operator %v: right side expression %v evaluate to %v; not number", operator, right, right.ReturnType())
|
||||
return nil, errInvalidDataType(err)
|
||||
}
|
||||
|
||||
if left.Type() == aggregateFunction || right.Type() == aggregateFunction {
|
||||
funcType = aggregateFunction
|
||||
switch left.Type() {
|
||||
case Int, Float, aggregateFunction:
|
||||
default:
|
||||
err := fmt.Errorf("operator %v: left side expression %v return type %v is incompatible for aggregate evaluation", operator, left, left.Type())
|
||||
return nil, errUnsupportedSQLOperation(err)
|
||||
}
|
||||
|
||||
switch right.Type() {
|
||||
case Int, Float, aggregateFunction:
|
||||
default:
|
||||
err := fmt.Errorf("operator %v: right side expression %v return type %v is incompatible for aggregate evaluation", operator, right, right.Type())
|
||||
return nil, errUnsupportedSQLOperation(err)
|
||||
}
|
||||
}
|
||||
|
||||
return &comparisonExpr{
|
||||
left: left,
|
||||
right: right,
|
||||
operator: operator,
|
||||
funcType: funcType,
|
||||
}, nil
|
||||
|
||||
case In, NotIn:
|
||||
if len(funcs) != 2 {
|
||||
panic(fmt.Sprintf("exactly two arguments are expected, but found %v", len(funcs)))
|
||||
}
|
||||
|
||||
left := funcs[0]
|
||||
if !left.ReturnType().isBaseKind() {
|
||||
err := fmt.Errorf("operator %v: left side expression %v evaluate to %v is incompatible for equality check", operator, left, left.ReturnType())
|
||||
return nil, errInvalidDataType(err)
|
||||
}
|
||||
|
||||
right := funcs[1]
|
||||
if right.ReturnType() != Array {
|
||||
err := fmt.Errorf("operator %v: right side expression %v evaluate to %v is incompatible for equality check", operator, right, right.ReturnType())
|
||||
return nil, errInvalidDataType(err)
|
||||
}
|
||||
|
||||
if left.Type() == aggregateFunction || right.Type() == aggregateFunction {
|
||||
funcType = aggregateFunction
|
||||
switch left.Type() {
|
||||
case column, Array, function, arithmeticFunction, comparisonFunction, logicalFunction, record:
|
||||
err := fmt.Errorf("operator %v: left side expression %v return type %v is incompatible for aggregate evaluation", operator, left, left.Type())
|
||||
return nil, errUnsupportedSQLOperation(err)
|
||||
}
|
||||
switch right.Type() {
|
||||
case Array, aggregateFunction:
|
||||
default:
|
||||
err := fmt.Errorf("operator %v: right side expression %v return type %v is incompatible for aggregate evaluation", operator, right, right.Type())
|
||||
return nil, errUnsupportedSQLOperation(err)
|
||||
}
|
||||
}
|
||||
|
||||
return &comparisonExpr{
|
||||
left: left,
|
||||
right: right,
|
||||
operator: operator,
|
||||
funcType: funcType,
|
||||
}, nil
|
||||
|
||||
case Like, NotLike:
|
||||
if len(funcs) != 2 {
|
||||
panic(fmt.Sprintf("exactly two arguments are expected, but found %v", len(funcs)))
|
||||
}
|
||||
|
||||
left := funcs[0]
|
||||
if !left.ReturnType().isStringKind() {
|
||||
err := fmt.Errorf("operator %v: left side expression %v evaluate to %v, not string", operator, left, left.ReturnType())
|
||||
return nil, errLikeInvalidInputs(err)
|
||||
}
|
||||
|
||||
right := funcs[1]
|
||||
if !right.ReturnType().isStringKind() {
|
||||
err := fmt.Errorf("operator %v: right side expression %v evaluate to %v, not string", operator, right, right.ReturnType())
|
||||
return nil, errLikeInvalidInputs(err)
|
||||
}
|
||||
|
||||
if left.Type() == aggregateFunction || right.Type() == aggregateFunction {
|
||||
funcType = aggregateFunction
|
||||
switch left.Type() {
|
||||
case String, aggregateFunction:
|
||||
default:
|
||||
err := fmt.Errorf("operator %v: left side expression %v return type %v is incompatible for aggregate evaluation", operator, left, left.Type())
|
||||
return nil, errUnsupportedSQLOperation(err)
|
||||
}
|
||||
switch right.Type() {
|
||||
case String, aggregateFunction:
|
||||
default:
|
||||
err := fmt.Errorf("operator %v: right side expression %v return type %v is incompatible for aggregate evaluation", operator, right, right.Type())
|
||||
return nil, errUnsupportedSQLOperation(err)
|
||||
}
|
||||
}
|
||||
|
||||
return &comparisonExpr{
|
||||
left: left,
|
||||
right: right,
|
||||
operator: operator,
|
||||
funcType: funcType,
|
||||
}, nil
|
||||
case Between, NotBetween:
|
||||
if len(funcs) != 3 {
|
||||
panic(fmt.Sprintf("too many values in funcs %v", funcs))
|
||||
}
|
||||
|
||||
left := funcs[0]
|
||||
if !left.ReturnType().isNumberKind() {
|
||||
err := fmt.Errorf("operator %v: left side expression %v evaluate to %v, not number", operator, left, left.ReturnType())
|
||||
return nil, errInvalidDataType(err)
|
||||
}
|
||||
|
||||
from := funcs[1]
|
||||
if !from.ReturnType().isNumberKind() {
|
||||
err := fmt.Errorf("operator %v: from expression %v evaluate to %v, not number", operator, from, from.ReturnType())
|
||||
return nil, errInvalidDataType(err)
|
||||
}
|
||||
|
||||
to := funcs[2]
|
||||
if !to.ReturnType().isNumberKind() {
|
||||
err := fmt.Errorf("operator %v: to expression %v evaluate to %v, not number", operator, to, to.ReturnType())
|
||||
return nil, errInvalidDataType(err)
|
||||
}
|
||||
|
||||
if left.Type() == aggregateFunction || from.Type() == aggregateFunction || to.Type() == aggregateFunction {
|
||||
funcType = aggregateFunction
|
||||
switch left.Type() {
|
||||
case Int, Float, aggregateFunction:
|
||||
default:
|
||||
err := fmt.Errorf("operator %v: left side expression %v return type %v is incompatible for aggregate evaluation", operator, left, left.Type())
|
||||
return nil, errUnsupportedSQLOperation(err)
|
||||
}
|
||||
switch from.Type() {
|
||||
case Int, Float, aggregateFunction:
|
||||
default:
|
||||
err := fmt.Errorf("operator %v: from expression %v return type %v is incompatible for aggregate evaluation", operator, from, from.Type())
|
||||
return nil, errUnsupportedSQLOperation(err)
|
||||
}
|
||||
switch to.Type() {
|
||||
case Int, Float, aggregateFunction:
|
||||
default:
|
||||
err := fmt.Errorf("operator %v: to expression %v return type %v is incompatible for aggregate evaluation", operator, to, to.Type())
|
||||
return nil, errUnsupportedSQLOperation(err)
|
||||
}
|
||||
}
|
||||
|
||||
return &comparisonExpr{
|
||||
left: left,
|
||||
right: from,
|
||||
to: to,
|
||||
operator: operator,
|
||||
funcType: funcType,
|
||||
}, nil
|
||||
case IsNull, IsNotNull:
|
||||
if len(funcs) != 1 {
|
||||
panic(fmt.Sprintf("too many values in funcs %v", funcs))
|
||||
}
|
||||
|
||||
if funcs[0].Type() == aggregateFunction {
|
||||
funcType = aggregateFunction
|
||||
}
|
||||
|
||||
if operator == IsNull {
|
||||
operator = Equal
|
||||
} else {
|
||||
operator = NotEqual
|
||||
}
|
||||
|
||||
return &comparisonExpr{
|
||||
left: funcs[0],
|
||||
right: newValueExpr(NewNull()),
|
||||
operator: operator,
|
||||
funcType: funcType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, errParseUnknownOperator(fmt.Errorf("unknown operator %v", operator))
|
||||
}
|
|
@ -43,42 +43,6 @@ func (err *s3Error) Error() string {
|
|||
return err.message
|
||||
}
|
||||
|
||||
func errUnsupportedSQLStructure(err error) *s3Error {
|
||||
return &s3Error{
|
||||
code: "UnsupportedSqlStructure",
|
||||
message: "Encountered an unsupported SQL structure. Check the SQL Reference.",
|
||||
statusCode: 400,
|
||||
cause: err,
|
||||
}
|
||||
}
|
||||
|
||||
func errParseUnsupportedSelect(err error) *s3Error {
|
||||
return &s3Error{
|
||||
code: "ParseUnsupportedSelect",
|
||||
message: "The SQL expression contains an unsupported use of SELECT.",
|
||||
statusCode: 400,
|
||||
cause: err,
|
||||
}
|
||||
}
|
||||
|
||||
func errParseAsteriskIsNotAloneInSelectList(err error) *s3Error {
|
||||
return &s3Error{
|
||||
code: "ParseAsteriskIsNotAloneInSelectList",
|
||||
message: "Other expressions are not allowed in the SELECT list when '*' is used without dot notation in the SQL expression.",
|
||||
statusCode: 400,
|
||||
cause: err,
|
||||
}
|
||||
}
|
||||
|
||||
func errParseInvalidContextForWildcardInSelectList(err error) *s3Error {
|
||||
return &s3Error{
|
||||
code: "ParseInvalidContextForWildcardInSelectList",
|
||||
message: "Invalid use of * in SELECT list in the SQL expression.",
|
||||
statusCode: 400,
|
||||
cause: err,
|
||||
}
|
||||
}
|
||||
|
||||
func errInvalidDataType(err error) *s3Error {
|
||||
return &s3Error{
|
||||
code: "InvalidDataType",
|
||||
|
@ -88,55 +52,10 @@ func errInvalidDataType(err error) *s3Error {
|
|||
}
|
||||
}
|
||||
|
||||
func errUnsupportedFunction(err error) *s3Error {
|
||||
return &s3Error{
|
||||
code: "UnsupportedFunction",
|
||||
message: "Encountered an unsupported SQL function.",
|
||||
statusCode: 400,
|
||||
cause: err,
|
||||
}
|
||||
}
|
||||
|
||||
func errParseNonUnaryAgregateFunctionCall(err error) *s3Error {
|
||||
return &s3Error{
|
||||
code: "ParseNonUnaryAgregateFunctionCall",
|
||||
message: "Only one argument is supported for aggregate functions in the SQL expression.",
|
||||
statusCode: 400,
|
||||
cause: err,
|
||||
}
|
||||
}
|
||||
|
||||
func errIncorrectSQLFunctionArgumentType(err error) *s3Error {
|
||||
return &s3Error{
|
||||
code: "IncorrectSqlFunctionArgumentType",
|
||||
message: "Incorrect type of arguments in function call in the SQL expression.",
|
||||
statusCode: 400,
|
||||
cause: err,
|
||||
}
|
||||
}
|
||||
|
||||
func errEvaluatorInvalidArguments(err error) *s3Error {
|
||||
return &s3Error{
|
||||
code: "EvaluatorInvalidArguments",
|
||||
message: "Incorrect number of arguments in the function call in the SQL expression.",
|
||||
statusCode: 400,
|
||||
cause: err,
|
||||
}
|
||||
}
|
||||
|
||||
func errUnsupportedSQLOperation(err error) *s3Error {
|
||||
return &s3Error{
|
||||
code: "UnsupportedSqlOperation",
|
||||
message: "Encountered an unsupported SQL operation.",
|
||||
statusCode: 400,
|
||||
cause: err,
|
||||
}
|
||||
}
|
||||
|
||||
func errParseUnknownOperator(err error) *s3Error {
|
||||
return &s3Error{
|
||||
code: "ParseUnknownOperator",
|
||||
message: "The SQL expression contains an invalid operator.",
|
||||
message: "Incorrect type of arguments in function call.",
|
||||
statusCode: 400,
|
||||
cause: err,
|
||||
}
|
||||
|
@ -151,64 +70,28 @@ func errLikeInvalidInputs(err error) *s3Error {
|
|||
}
|
||||
}
|
||||
|
||||
func errExternalEvalException(err error) *s3Error {
|
||||
func errQueryParseFailure(err error) *s3Error {
|
||||
return &s3Error{
|
||||
code: "ExternalEvalException",
|
||||
message: "The query cannot be evaluated. Check the file and try again.",
|
||||
code: "ParseSelectFailure",
|
||||
message: err.Error(),
|
||||
statusCode: 400,
|
||||
cause: err,
|
||||
}
|
||||
}
|
||||
|
||||
func errValueParseFailure(err error) *s3Error {
|
||||
func errQueryAnalysisFailure(err error) *s3Error {
|
||||
return &s3Error{
|
||||
code: "ValueParseFailure",
|
||||
message: "Time stamp parse failure in the SQL expression.",
|
||||
code: "InvalidQuery",
|
||||
message: err.Error(),
|
||||
statusCode: 400,
|
||||
cause: err,
|
||||
}
|
||||
}
|
||||
|
||||
func errEvaluatorBindingDoesNotExist(err error) *s3Error {
|
||||
func errBadTableName(err error) *s3Error {
|
||||
return &s3Error{
|
||||
code: "EvaluatorBindingDoesNotExist",
|
||||
message: "A column name or a path provided does not exist in the SQL expression.",
|
||||
statusCode: 400,
|
||||
cause: err,
|
||||
}
|
||||
}
|
||||
|
||||
func errInternalError(err error) *s3Error {
|
||||
return &s3Error{
|
||||
code: "InternalError",
|
||||
message: "Encountered an internal error.",
|
||||
statusCode: 500,
|
||||
cause: err,
|
||||
}
|
||||
}
|
||||
|
||||
func errParseInvalidTypeParam(err error) *s3Error {
|
||||
return &s3Error{
|
||||
code: "ParseInvalidTypeParam",
|
||||
message: "The SQL expression contains an invalid parameter value.",
|
||||
statusCode: 400,
|
||||
cause: err,
|
||||
}
|
||||
}
|
||||
|
||||
func errParseUnsupportedSyntax(err error) *s3Error {
|
||||
return &s3Error{
|
||||
code: "ParseUnsupportedSyntax",
|
||||
message: "The SQL expression contains unsupported syntax.",
|
||||
statusCode: 400,
|
||||
cause: err,
|
||||
}
|
||||
}
|
||||
|
||||
func errInvalidKeyPath(err error) *s3Error {
|
||||
return &s3Error{
|
||||
code: "InvalidKeyPath",
|
||||
message: "Key path in the SQL expression is invalid.",
|
||||
code: "BadTableName",
|
||||
message: "The table name is not supported",
|
||||
statusCode: 400,
|
||||
cause: err,
|
||||
}
|
||||
|
|
|
@ -0,0 +1,361 @@
|
|||
/*
|
||||
* Minio Cloud Storage, (C) 2019 Minio, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
errInvalidASTNode = errors.New("invalid AST Node")
|
||||
errExpectedBool = errors.New("expected bool")
|
||||
errLikeNonStrArg = errors.New("LIKE clause requires string arguments")
|
||||
errLikeInvalidEscape = errors.New("LIKE clause has invalid ESCAPE character")
|
||||
errNotImplemented = errors.New("not implemented")
|
||||
)
|
||||
|
||||
// AST Node Evaluation functions
|
||||
//
|
||||
// During evaluation, the query is known to be valid, as analysis is
|
||||
// complete. The only errors possible are due to value type
|
||||
// mismatches, etc.
|
||||
//
|
||||
// If an aggregation node is present as a descendant (when
|
||||
// e.prop.isAggregation is true), we call evalNode on all child nodes,
|
||||
// check for errors, but do not perform any combining of the results
|
||||
// of child nodes. The final result row is returned after all rows are
|
||||
// processed, and the `getAggregate` function is called.
|
||||
|
||||
func (e *AliasedExpression) evalNode(r Record) (*Value, error) {
|
||||
return e.Expression.evalNode(r)
|
||||
}
|
||||
|
||||
func (e *Expression) evalNode(r Record) (*Value, error) {
|
||||
if len(e.And) == 1 {
|
||||
// In this case, result is not required to be boolean
|
||||
// type.
|
||||
return e.And[0].evalNode(r)
|
||||
}
|
||||
|
||||
// Compute OR of conditions
|
||||
result := false
|
||||
for _, ex := range e.And {
|
||||
res, err := ex.evalNode(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b, ok := res.ToBool()
|
||||
if !ok {
|
||||
return nil, errExpectedBool
|
||||
}
|
||||
result = result || b
|
||||
}
|
||||
return FromBool(result), nil
|
||||
}
|
||||
|
||||
func (e *AndCondition) evalNode(r Record) (*Value, error) {
|
||||
if len(e.Condition) == 1 {
|
||||
// In this case, result does not have to be boolean
|
||||
return e.Condition[0].evalNode(r)
|
||||
}
|
||||
|
||||
// Compute AND of conditions
|
||||
result := true
|
||||
for _, ex := range e.Condition {
|
||||
res, err := ex.evalNode(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b, ok := res.ToBool()
|
||||
if !ok {
|
||||
return nil, errExpectedBool
|
||||
}
|
||||
result = result && b
|
||||
}
|
||||
return FromBool(result), nil
|
||||
}
|
||||
|
||||
func (e *Condition) evalNode(r Record) (*Value, error) {
|
||||
if e.Operand != nil {
|
||||
// In this case, result does not have to be boolean
|
||||
return e.Operand.evalNode(r)
|
||||
}
|
||||
|
||||
// Compute NOT of condition
|
||||
res, err := e.Not.evalNode(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b, ok := res.ToBool()
|
||||
if !ok {
|
||||
return nil, errExpectedBool
|
||||
}
|
||||
return FromBool(!b), nil
|
||||
}
|
||||
|
||||
func (e *ConditionOperand) evalNode(r Record) (*Value, error) {
|
||||
opVal, opErr := e.Operand.evalNode(r)
|
||||
if opErr != nil || e.ConditionRHS == nil {
|
||||
return opVal, opErr
|
||||
}
|
||||
|
||||
// Need to evaluate the ConditionRHS
|
||||
switch {
|
||||
case e.ConditionRHS.Compare != nil:
|
||||
cmpRight, cmpRErr := e.ConditionRHS.Compare.Operand.evalNode(r)
|
||||
if cmpRErr != nil {
|
||||
return nil, cmpRErr
|
||||
}
|
||||
|
||||
b, err := opVal.compareOp(e.ConditionRHS.Compare.Operator, cmpRight)
|
||||
return FromBool(b), err
|
||||
|
||||
case e.ConditionRHS.Between != nil:
|
||||
return e.ConditionRHS.Between.evalBetweenNode(r, opVal)
|
||||
|
||||
case e.ConditionRHS.Like != nil:
|
||||
return e.ConditionRHS.Like.evalLikeNode(r, opVal)
|
||||
|
||||
case e.ConditionRHS.In != nil:
|
||||
return e.ConditionRHS.In.evalInNode(r, opVal)
|
||||
|
||||
default:
|
||||
return nil, errInvalidASTNode
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Between) evalBetweenNode(r Record, arg *Value) (*Value, error) {
|
||||
stVal, stErr := e.Start.evalNode(r)
|
||||
if stErr != nil {
|
||||
return nil, stErr
|
||||
}
|
||||
|
||||
endVal, endErr := e.End.evalNode(r)
|
||||
if endErr != nil {
|
||||
return nil, endErr
|
||||
}
|
||||
|
||||
part1, err1 := stVal.compareOp(opLte, arg)
|
||||
if err1 != nil {
|
||||
return nil, err1
|
||||
}
|
||||
|
||||
part2, err2 := arg.compareOp(opLte, endVal)
|
||||
if err2 != nil {
|
||||
return nil, err2
|
||||
}
|
||||
|
||||
result := part1 && part2
|
||||
if e.Not {
|
||||
result = !result
|
||||
}
|
||||
|
||||
return FromBool(result), nil
|
||||
}
|
||||
|
||||
func (e *Like) evalLikeNode(r Record, arg *Value) (*Value, error) {
|
||||
inferTypeAsString(arg)
|
||||
|
||||
s, ok := arg.ToString()
|
||||
if !ok {
|
||||
err := errLikeNonStrArg
|
||||
return nil, errLikeInvalidInputs(err)
|
||||
}
|
||||
|
||||
pattern, err1 := e.Pattern.evalNode(r)
|
||||
if err1 != nil {
|
||||
return nil, err1
|
||||
}
|
||||
|
||||
// Infer pattern as string (in case it is untyped)
|
||||
inferTypeAsString(pattern)
|
||||
|
||||
patternStr, ok := pattern.ToString()
|
||||
if !ok {
|
||||
err := errLikeNonStrArg
|
||||
return nil, errLikeInvalidInputs(err)
|
||||
}
|
||||
|
||||
escape := runeZero
|
||||
if e.EscapeChar != nil {
|
||||
escapeVal, err2 := e.EscapeChar.evalNode(r)
|
||||
if err2 != nil {
|
||||
return nil, err2
|
||||
}
|
||||
|
||||
inferTypeAsString(escapeVal)
|
||||
|
||||
escapeStr, ok := escapeVal.ToString()
|
||||
if !ok {
|
||||
err := errLikeNonStrArg
|
||||
return nil, errLikeInvalidInputs(err)
|
||||
}
|
||||
|
||||
if len([]rune(escapeStr)) > 1 {
|
||||
err := errLikeInvalidEscape
|
||||
return nil, errLikeInvalidInputs(err)
|
||||
}
|
||||
}
|
||||
|
||||
matchResult, err := evalSQLLike(s, patternStr, escape)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if e.Not {
|
||||
matchResult = !matchResult
|
||||
}
|
||||
|
||||
return FromBool(matchResult), nil
|
||||
}
|
||||
|
||||
func (e *In) evalInNode(r Record, arg *Value) (*Value, error) {
|
||||
result := false
|
||||
for _, elt := range e.Expressions {
|
||||
eltVal, err := elt.evalNode(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// FIXME: type inference?
|
||||
|
||||
// Types must match.
|
||||
if arg.vType != eltVal.vType {
|
||||
// match failed.
|
||||
continue
|
||||
}
|
||||
|
||||
if arg.value == eltVal.value {
|
||||
result = true
|
||||
break
|
||||
}
|
||||
}
|
||||
return FromBool(result), nil
|
||||
}
|
||||
|
||||
func (e *Operand) evalNode(r Record) (*Value, error) {
|
||||
lval, lerr := e.Left.evalNode(r)
|
||||
if lerr != nil || len(e.Right) == 0 {
|
||||
return lval, lerr
|
||||
}
|
||||
|
||||
// Process remaining child nodes - result must be
|
||||
// numeric. This AST node is for terms separated by + or -
|
||||
// symbols.
|
||||
for _, rightTerm := range e.Right {
|
||||
op := rightTerm.Op
|
||||
rval, rerr := rightTerm.Right.evalNode(r)
|
||||
if rerr != nil {
|
||||
return nil, rerr
|
||||
}
|
||||
err := lval.arithOp(op, rval)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return lval, nil
|
||||
}
|
||||
|
||||
func (e *MultOp) evalNode(r Record) (*Value, error) {
|
||||
lval, lerr := e.Left.evalNode(r)
|
||||
if lerr != nil || len(e.Right) == 0 {
|
||||
return lval, lerr
|
||||
}
|
||||
|
||||
// Process other child nodes - result must be numeric. This
|
||||
// AST node is for terms separated by *, / or % symbols.
|
||||
for _, rightTerm := range e.Right {
|
||||
op := rightTerm.Op
|
||||
rval, rerr := rightTerm.Right.evalNode(r)
|
||||
if rerr != nil {
|
||||
return nil, rerr
|
||||
}
|
||||
|
||||
err := lval.arithOp(op, rval)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return lval, nil
|
||||
}
|
||||
|
||||
func (e *UnaryTerm) evalNode(r Record) (*Value, error) {
|
||||
if e.Negated == nil {
|
||||
return e.Primary.evalNode(r)
|
||||
}
|
||||
|
||||
v, err := e.Negated.Term.evalNode(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
inferTypeForArithOp(v)
|
||||
if ival, ok := v.ToInt(); ok {
|
||||
return FromInt(-ival), nil
|
||||
} else if fval, ok := v.ToFloat(); ok {
|
||||
return FromFloat(-fval), nil
|
||||
}
|
||||
return nil, errArithMismatchedTypes
|
||||
}
|
||||
|
||||
func (e *JSONPath) evalNode(r Record) (*Value, error) {
|
||||
// Strip the table name from the keypath.
|
||||
keypath := e.String()
|
||||
ps := strings.SplitN(keypath, ".", 2)
|
||||
if len(ps) == 2 {
|
||||
keypath = ps[1]
|
||||
}
|
||||
return r.Get(keypath)
|
||||
}
|
||||
|
||||
func (e *PrimaryTerm) evalNode(r Record) (res *Value, err error) {
|
||||
switch {
|
||||
case e.Value != nil:
|
||||
return e.Value.evalNode(r)
|
||||
case e.JPathExpr != nil:
|
||||
return e.JPathExpr.evalNode(r)
|
||||
case e.SubExpression != nil:
|
||||
return e.SubExpression.evalNode(r)
|
||||
case e.FuncCall != nil:
|
||||
return e.FuncCall.evalNode(r)
|
||||
}
|
||||
return nil, errInvalidASTNode
|
||||
}
|
||||
|
||||
func (e *FuncExpr) evalNode(r Record) (res *Value, err error) {
|
||||
switch e.getFunctionName() {
|
||||
case aggFnCount, aggFnAvg, aggFnMax, aggFnMin, aggFnSum:
|
||||
return e.getAggregate()
|
||||
default:
|
||||
return e.evalSQLFnNode(r)
|
||||
}
|
||||
}
|
||||
|
||||
// evalNode on a literal value is independent of the node being an
|
||||
// aggregation or a row function - it always returns a value.
|
||||
func (e *LitValue) evalNode(_ Record) (res *Value, err error) {
|
||||
switch {
|
||||
case e.Number != nil:
|
||||
return floatToValue(*e.Number), nil
|
||||
case e.String != nil:
|
||||
return FromString(string(*e.String)), nil
|
||||
case e.Boolean != nil:
|
||||
return FromBool(bool(*e.Boolean)), nil
|
||||
}
|
||||
return FromNull(), nil
|
||||
}
|
|
@ -1,160 +0,0 @@
|
|||
/*
|
||||
* Minio Cloud Storage, (C) 2019 Minio, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Expr - a SQL expression type.
|
||||
type Expr interface {
|
||||
AggregateValue() (*Value, error)
|
||||
Eval(record Record) (*Value, error)
|
||||
ReturnType() Type
|
||||
Type() Type
|
||||
}
|
||||
|
||||
// aliasExpr - aliases expression by alias.
|
||||
type aliasExpr struct {
|
||||
alias string
|
||||
expr Expr
|
||||
}
|
||||
|
||||
// String - returns string representation of this expression.
|
||||
func (expr *aliasExpr) String() string {
|
||||
return fmt.Sprintf("(%v AS %v)", expr.expr, expr.alias)
|
||||
}
|
||||
|
||||
// Eval - evaluates underlaying expression for given record and returns evaluated result.
|
||||
func (expr *aliasExpr) Eval(record Record) (*Value, error) {
|
||||
return expr.expr.Eval(record)
|
||||
}
|
||||
|
||||
// AggregateValue - returns aggregated value from underlaying expression.
|
||||
func (expr *aliasExpr) AggregateValue() (*Value, error) {
|
||||
return expr.expr.AggregateValue()
|
||||
}
|
||||
|
||||
// Type - returns underlaying expression type.
|
||||
func (expr *aliasExpr) Type() Type {
|
||||
return expr.expr.Type()
|
||||
}
|
||||
|
||||
// ReturnType - returns underlaying expression's return type.
|
||||
func (expr *aliasExpr) ReturnType() Type {
|
||||
return expr.expr.ReturnType()
|
||||
}
|
||||
|
||||
// newAliasExpr - creates new alias expression.
|
||||
func newAliasExpr(alias string, expr Expr) *aliasExpr {
|
||||
return &aliasExpr{alias, expr}
|
||||
}
|
||||
|
||||
// starExpr - asterisk (*) expression.
|
||||
type starExpr struct {
|
||||
}
|
||||
|
||||
// String - returns string representation of this expression.
|
||||
func (expr *starExpr) String() string {
|
||||
return "*"
|
||||
}
|
||||
|
||||
// Eval - returns given args as map value.
|
||||
func (expr *starExpr) Eval(record Record) (*Value, error) {
|
||||
return newRecordValue(record), nil
|
||||
}
|
||||
|
||||
// AggregateValue - returns nil value.
|
||||
func (expr *starExpr) AggregateValue() (*Value, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Type - returns record type.
|
||||
func (expr *starExpr) Type() Type {
|
||||
return record
|
||||
}
|
||||
|
||||
// ReturnType - returns record as return type.
|
||||
func (expr *starExpr) ReturnType() Type {
|
||||
return record
|
||||
}
|
||||
|
||||
// newStarExpr - returns new asterisk (*) expression.
|
||||
func newStarExpr() *starExpr {
|
||||
return &starExpr{}
|
||||
}
|
||||
|
||||
type valueExpr struct {
|
||||
value *Value
|
||||
}
|
||||
|
||||
func (expr *valueExpr) String() string {
|
||||
return expr.value.String()
|
||||
}
|
||||
|
||||
func (expr *valueExpr) Eval(record Record) (*Value, error) {
|
||||
return expr.value, nil
|
||||
}
|
||||
|
||||
func (expr *valueExpr) AggregateValue() (*Value, error) {
|
||||
return expr.value, nil
|
||||
}
|
||||
|
||||
func (expr *valueExpr) Type() Type {
|
||||
return expr.value.Type()
|
||||
}
|
||||
|
||||
func (expr *valueExpr) ReturnType() Type {
|
||||
return expr.value.Type()
|
||||
}
|
||||
|
||||
func newValueExpr(value *Value) *valueExpr {
|
||||
return &valueExpr{value: value}
|
||||
}
|
||||
|
||||
type columnExpr struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func (expr *columnExpr) String() string {
|
||||
return expr.name
|
||||
}
|
||||
|
||||
func (expr *columnExpr) Eval(record Record) (*Value, error) {
|
||||
value, err := record.Get(expr.name)
|
||||
if err != nil {
|
||||
return nil, errEvaluatorBindingDoesNotExist(err)
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func (expr *columnExpr) AggregateValue() (*Value, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (expr *columnExpr) Type() Type {
|
||||
return column
|
||||
}
|
||||
|
||||
func (expr *columnExpr) ReturnType() Type {
|
||||
return column
|
||||
}
|
||||
|
||||
func newColumnExpr(columnName string) *columnExpr {
|
||||
return &columnExpr{name: columnName}
|
||||
}
|
|
@ -0,0 +1,433 @@
|
|||
/*
|
||||
* Minio Cloud Storage, (C) 2019 Minio, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// FuncName - SQL function name.
|
||||
type FuncName string
|
||||
|
||||
// SQL Function name constants
|
||||
const (
|
||||
// Conditionals
|
||||
sqlFnCoalesce FuncName = "COALESCE"
|
||||
sqlFnNullIf FuncName = "NULLIF"
|
||||
|
||||
// Conversion
|
||||
sqlFnCast FuncName = "CAST"
|
||||
|
||||
// Date and time
|
||||
sqlFnDateAdd FuncName = "DATE_ADD"
|
||||
sqlFnDateDiff FuncName = "DATE_DIFF"
|
||||
sqlFnExtract FuncName = "EXTRACT"
|
||||
sqlFnToString FuncName = "TO_STRING"
|
||||
sqlFnToTimestamp FuncName = "TO_TIMESTAMP"
|
||||
sqlFnUTCNow FuncName = "UTCNOW"
|
||||
|
||||
// String
|
||||
sqlFnCharLength FuncName = "CHAR_LENGTH"
|
||||
sqlFnCharacterLength FuncName = "CHARACTER_LENGTH"
|
||||
sqlFnLower FuncName = "LOWER"
|
||||
sqlFnSubstring FuncName = "SUBSTRING"
|
||||
sqlFnTrim FuncName = "TRIM"
|
||||
sqlFnUpper FuncName = "UPPER"
|
||||
)
|
||||
|
||||
// Allowed cast types
|
||||
const (
|
||||
castBool = "BOOL"
|
||||
castInt = "INT"
|
||||
castInteger = "INTEGER"
|
||||
castString = "STRING"
|
||||
castFloat = "FLOAT"
|
||||
castDecimal = "DECIMAL"
|
||||
castNumeric = "NUMERIC"
|
||||
castTimestamp = "TIMESTAMP"
|
||||
)
|
||||
|
||||
var (
|
||||
errUnimplementedCast = errors.New("This cast not yet implemented")
|
||||
errNonStringTrimArg = errors.New("TRIM() received a non-string argument")
|
||||
)
|
||||
|
||||
func (e *FuncExpr) getFunctionName() FuncName {
|
||||
switch {
|
||||
case e.SFunc != nil:
|
||||
return FuncName(strings.ToUpper(e.SFunc.FunctionName))
|
||||
case e.Count != nil:
|
||||
return FuncName(aggFnCount)
|
||||
case e.Cast != nil:
|
||||
return sqlFnCast
|
||||
case e.Substring != nil:
|
||||
return sqlFnSubstring
|
||||
case e.Extract != nil:
|
||||
return sqlFnExtract
|
||||
case e.Trim != nil:
|
||||
return sqlFnTrim
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// evalSQLFnNode assumes that the FuncExpr is not an aggregation
|
||||
// function.
|
||||
func (e *FuncExpr) evalSQLFnNode(r Record) (res *Value, err error) {
|
||||
// Handle functions that have phrase arguments
|
||||
switch e.getFunctionName() {
|
||||
case sqlFnCast:
|
||||
expr := e.Cast.Expr
|
||||
res, err = expr.castTo(r, strings.ToUpper(e.Cast.CastType))
|
||||
return
|
||||
|
||||
case sqlFnSubstring:
|
||||
return handleSQLSubstring(r, e.Substring)
|
||||
|
||||
case sqlFnExtract:
|
||||
return nil, errNotImplemented
|
||||
|
||||
case sqlFnTrim:
|
||||
return handleSQLTrim(r, e.Trim)
|
||||
}
|
||||
|
||||
// For all simple argument functions, we evaluate the arguments here
|
||||
argVals := make([]*Value, len(e.SFunc.ArgsList))
|
||||
for i, arg := range e.SFunc.ArgsList {
|
||||
argVals[i], err = arg.evalNode(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
switch e.getFunctionName() {
|
||||
case sqlFnCoalesce:
|
||||
return coalesce(r, argVals)
|
||||
|
||||
case sqlFnNullIf:
|
||||
return nullif(r, argVals[0], argVals[1])
|
||||
|
||||
case sqlFnCharLength, sqlFnCharacterLength:
|
||||
return charlen(r, argVals[0])
|
||||
|
||||
case sqlFnLower:
|
||||
return lowerCase(r, argVals[0])
|
||||
|
||||
case sqlFnUpper:
|
||||
return upperCase(r, argVals[0])
|
||||
|
||||
case sqlFnDateAdd, sqlFnDateDiff, sqlFnToString, sqlFnToTimestamp, sqlFnUTCNow:
|
||||
// TODO: implement
|
||||
fallthrough
|
||||
|
||||
default:
|
||||
return nil, errInvalidASTNode
|
||||
}
|
||||
}
|
||||
|
||||
func coalesce(r Record, args []*Value) (res *Value, err error) {
|
||||
for _, arg := range args {
|
||||
if arg.IsNull() {
|
||||
continue
|
||||
}
|
||||
return arg, nil
|
||||
}
|
||||
return FromNull(), nil
|
||||
}
|
||||
|
||||
func nullif(r Record, v1, v2 *Value) (res *Value, err error) {
|
||||
// Handle Null cases
|
||||
if v1.IsNull() || v2.IsNull() {
|
||||
return v1, nil
|
||||
}
|
||||
|
||||
err = inferTypesForCmp(v1, v2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
atleastOneNumeric := v1.isNumeric() || v2.isNumeric()
|
||||
bothNumeric := v1.isNumeric() && v2.isNumeric()
|
||||
if atleastOneNumeric || !bothNumeric {
|
||||
return v1, nil
|
||||
}
|
||||
|
||||
if v1.vType != v2.vType {
|
||||
return v1, nil
|
||||
}
|
||||
|
||||
cmpResult, cmpErr := v1.compareOp(opEq, v2)
|
||||
if cmpErr != nil {
|
||||
return nil, cmpErr
|
||||
}
|
||||
|
||||
if cmpResult {
|
||||
return FromNull(), nil
|
||||
}
|
||||
|
||||
return v1, nil
|
||||
}
|
||||
|
||||
func charlen(r Record, v *Value) (*Value, error) {
|
||||
inferTypeAsString(v)
|
||||
s, ok := v.ToString()
|
||||
if !ok {
|
||||
err := fmt.Errorf("%s/%s expects a string argument", sqlFnCharLength, sqlFnCharacterLength)
|
||||
return nil, errIncorrectSQLFunctionArgumentType(err)
|
||||
}
|
||||
return FromInt(int64(len(s))), nil
|
||||
}
|
||||
|
||||
func lowerCase(r Record, v *Value) (*Value, error) {
|
||||
inferTypeAsString(v)
|
||||
s, ok := v.ToString()
|
||||
if !ok {
|
||||
err := fmt.Errorf("%s expects a string argument", sqlFnLower)
|
||||
return nil, errIncorrectSQLFunctionArgumentType(err)
|
||||
}
|
||||
return FromString(strings.ToLower(s)), nil
|
||||
}
|
||||
|
||||
func upperCase(r Record, v *Value) (*Value, error) {
|
||||
inferTypeAsString(v)
|
||||
s, ok := v.ToString()
|
||||
if !ok {
|
||||
err := fmt.Errorf("%s expects a string argument", sqlFnUpper)
|
||||
return nil, errIncorrectSQLFunctionArgumentType(err)
|
||||
}
|
||||
return FromString(strings.ToUpper(s)), nil
|
||||
}
|
||||
|
||||
func handleSQLSubstring(r Record, e *SubstringFunc) (val *Value, err error) {
|
||||
// Both forms `SUBSTRING('abc' FROM 2 FOR 1)` and
|
||||
// SUBSTRING('abc', 2, 1) are supported.
|
||||
|
||||
// Evaluate the string argument
|
||||
v1, err := e.Expr.evalNode(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inferTypeAsString(v1)
|
||||
s, ok := v1.ToString()
|
||||
if !ok {
|
||||
err := fmt.Errorf("Incorrect argument type passed to %s", sqlFnSubstring)
|
||||
return nil, errIncorrectSQLFunctionArgumentType(err)
|
||||
}
|
||||
|
||||
// Assemble other arguments
|
||||
arg2, arg3 := e.From, e.For
|
||||
// Check if the second form of substring is being used
|
||||
if e.From == nil {
|
||||
arg2, arg3 = e.Arg2, e.Arg3
|
||||
}
|
||||
|
||||
// Evaluate the FROM argument
|
||||
v2, err := arg2.evalNode(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inferTypeForArithOp(v2)
|
||||
startIdx, ok := v2.ToInt()
|
||||
if !ok {
|
||||
err := fmt.Errorf("Incorrect type for start index argument in %s", sqlFnSubstring)
|
||||
return nil, errIncorrectSQLFunctionArgumentType(err)
|
||||
}
|
||||
|
||||
length := -1
|
||||
// Evaluate the optional FOR argument
|
||||
if arg3 != nil {
|
||||
v3, err := arg3.evalNode(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inferTypeForArithOp(v3)
|
||||
lenInt, ok := v3.ToInt()
|
||||
if !ok {
|
||||
err := fmt.Errorf("Incorrect type for length argument in %s", sqlFnSubstring)
|
||||
return nil, errIncorrectSQLFunctionArgumentType(err)
|
||||
}
|
||||
length = int(lenInt)
|
||||
if length < 0 {
|
||||
err := fmt.Errorf("Negative length argument in %s", sqlFnSubstring)
|
||||
return nil, errIncorrectSQLFunctionArgumentType(err)
|
||||
}
|
||||
}
|
||||
|
||||
res, err := evalSQLSubstring(s, int(startIdx), length)
|
||||
return FromString(res), err
|
||||
}
|
||||
|
||||
func handleSQLTrim(r Record, e *TrimFunc) (res *Value, err error) {
|
||||
charsV, cerr := e.TrimChars.evalNode(r)
|
||||
if cerr != nil {
|
||||
return nil, cerr
|
||||
}
|
||||
inferTypeAsString(charsV)
|
||||
chars, ok := charsV.ToString()
|
||||
if !ok {
|
||||
return nil, errNonStringTrimArg
|
||||
}
|
||||
|
||||
fromV, ferr := e.TrimFrom.evalNode(r)
|
||||
if ferr != nil {
|
||||
return nil, ferr
|
||||
}
|
||||
from, ok := fromV.ToString()
|
||||
if !ok {
|
||||
return nil, errNonStringTrimArg
|
||||
}
|
||||
|
||||
result, terr := evalSQLTrim(e.TrimWhere, chars, from)
|
||||
if terr != nil {
|
||||
return nil, terr
|
||||
}
|
||||
return FromString(result), nil
|
||||
}
|
||||
|
||||
func errUnsupportedCast(fromType, toType string) error {
|
||||
return fmt.Errorf("Cannot cast from %v to %v", fromType, toType)
|
||||
}
|
||||
|
||||
func errCastFailure(msg string) error {
|
||||
return fmt.Errorf("Error casting: %s", msg)
|
||||
}
|
||||
|
||||
func (e *Expression) castTo(r Record, castType string) (res *Value, err error) {
|
||||
v, err := e.evalNode(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fmt.Println("Cast to ", castType)
|
||||
|
||||
switch castType {
|
||||
case castInt, castInteger:
|
||||
i, err := intCast(v)
|
||||
return FromInt(i), err
|
||||
|
||||
case castFloat:
|
||||
f, err := floatCast(v)
|
||||
return FromFloat(f), err
|
||||
|
||||
case castString:
|
||||
s, err := stringCast(v)
|
||||
return FromString(s), err
|
||||
|
||||
case castBool, castDecimal, castNumeric, castTimestamp:
|
||||
fallthrough
|
||||
|
||||
default:
|
||||
return nil, errUnimplementedCast
|
||||
}
|
||||
}
|
||||
|
||||
func intCast(v *Value) (int64, error) {
|
||||
// This conversion truncates floating point numbers to
|
||||
// integer.
|
||||
strToInt := func(s string) (int64, bool) {
|
||||
i, errI := strconv.ParseInt(s, 10, 64)
|
||||
if errI == nil {
|
||||
return i, true
|
||||
}
|
||||
f, errF := strconv.ParseFloat(s, 64)
|
||||
if errF == nil {
|
||||
return int64(f), true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
switch v.vType {
|
||||
case typeFloat:
|
||||
// Truncate fractional part
|
||||
return int64(v.value.(float64)), nil
|
||||
case typeInt:
|
||||
return v.value.(int64), nil
|
||||
case typeString:
|
||||
// Parse as number, truncate floating point if
|
||||
// needed.
|
||||
s, _ := v.ToString()
|
||||
res, ok := strToInt(s)
|
||||
if !ok {
|
||||
return 0, errCastFailure("could not parse as int")
|
||||
}
|
||||
return res, nil
|
||||
case typeBytes:
|
||||
// Parse as number, truncate floating point if
|
||||
// needed.
|
||||
b, _ := v.ToBytes()
|
||||
s := string(b)
|
||||
res, ok := strToInt(s)
|
||||
if !ok {
|
||||
return 0, errCastFailure("could not parse as int")
|
||||
}
|
||||
return res, nil
|
||||
|
||||
default:
|
||||
return 0, errUnsupportedCast(v.GetTypeString(), castInt)
|
||||
}
|
||||
}
|
||||
|
||||
func floatCast(v *Value) (float64, error) {
|
||||
switch v.vType {
|
||||
case typeFloat:
|
||||
return v.value.(float64), nil
|
||||
case typeInt:
|
||||
return float64(v.value.(int64)), nil
|
||||
case typeString:
|
||||
f, err := strconv.ParseFloat(v.value.(string), 64)
|
||||
if err != nil {
|
||||
return 0, errCastFailure("could not parse as float")
|
||||
}
|
||||
return f, nil
|
||||
case typeBytes:
|
||||
b, _ := v.ToBytes()
|
||||
f, err := strconv.ParseFloat(string(b), 64)
|
||||
if err != nil {
|
||||
return 0, errCastFailure("could not parse as float")
|
||||
}
|
||||
return f, nil
|
||||
default:
|
||||
return 0, errUnsupportedCast(v.GetTypeString(), castFloat)
|
||||
}
|
||||
}
|
||||
|
||||
func stringCast(v *Value) (string, error) {
|
||||
switch v.vType {
|
||||
case typeFloat:
|
||||
f, _ := v.ToFloat()
|
||||
return fmt.Sprintf("%v", f), nil
|
||||
case typeInt:
|
||||
i, _ := v.ToInt()
|
||||
return fmt.Sprintf("%v", i), nil
|
||||
case typeString:
|
||||
s, _ := v.ToString()
|
||||
return s, nil
|
||||
case typeBytes:
|
||||
b, _ := v.ToBytes()
|
||||
return string(b), nil
|
||||
case typeBool:
|
||||
b, _ := v.ToBool()
|
||||
return fmt.Sprintf("%v", b), nil
|
||||
case typeNull:
|
||||
// FIXME: verify this case is correct
|
||||
return fmt.Sprintf("NULL"), nil
|
||||
}
|
||||
// This does not happen
|
||||
return "", nil
|
||||
}
|
|
@ -1,550 +0,0 @@
|
|||
/*
|
||||
* Minio Cloud Storage, (C) 2019 Minio, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// FuncName - SQL function name.
|
||||
type FuncName string
|
||||
|
||||
const (
|
||||
// Avg - aggregate SQL function AVG().
|
||||
Avg FuncName = "AVG"
|
||||
|
||||
// Count - aggregate SQL function COUNT().
|
||||
Count FuncName = "COUNT"
|
||||
|
||||
// Max - aggregate SQL function MAX().
|
||||
Max FuncName = "MAX"
|
||||
|
||||
// Min - aggregate SQL function MIN().
|
||||
Min FuncName = "MIN"
|
||||
|
||||
// Sum - aggregate SQL function SUM().
|
||||
Sum FuncName = "SUM"
|
||||
|
||||
// Coalesce - conditional SQL function COALESCE().
|
||||
Coalesce FuncName = "COALESCE"
|
||||
|
||||
// NullIf - conditional SQL function NULLIF().
|
||||
NullIf FuncName = "NULLIF"
|
||||
|
||||
// ToTimestamp - conversion SQL function TO_TIMESTAMP().
|
||||
ToTimestamp FuncName = "TO_TIMESTAMP"
|
||||
|
||||
// UTCNow - date SQL function UTCNOW().
|
||||
UTCNow FuncName = "UTCNOW"
|
||||
|
||||
// CharLength - string SQL function CHAR_LENGTH().
|
||||
CharLength FuncName = "CHAR_LENGTH"
|
||||
|
||||
// CharacterLength - string SQL function CHARACTER_LENGTH() same as CHAR_LENGTH().
|
||||
CharacterLength FuncName = "CHARACTER_LENGTH"
|
||||
|
||||
// Lower - string SQL function LOWER().
|
||||
Lower FuncName = "LOWER"
|
||||
|
||||
// Substring - string SQL function SUBSTRING().
|
||||
Substring FuncName = "SUBSTRING"
|
||||
|
||||
// Trim - string SQL function TRIM().
|
||||
Trim FuncName = "TRIM"
|
||||
|
||||
// Upper - string SQL function UPPER().
|
||||
Upper FuncName = "UPPER"
|
||||
|
||||
// DateAdd FuncName = "DATE_ADD"
|
||||
// DateDiff FuncName = "DATE_DIFF"
|
||||
// Extract FuncName = "EXTRACT"
|
||||
// ToString FuncName = "TO_STRING"
|
||||
// Cast FuncName = "CAST" // CAST('2007-04-05T14:30Z' AS TIMESTAMP)
|
||||
)
|
||||
|
||||
func isAggregateFuncName(s string) bool {
|
||||
switch FuncName(s) {
|
||||
case Avg, Count, Max, Min, Sum:
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func callForNumber(f Expr, record Record) (*Value, error) {
|
||||
value, err := f.Eval(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !value.Type().isNumber() {
|
||||
err := fmt.Errorf("%v evaluated to %v; not to number", f, value.Type())
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func callForInt(f Expr, record Record) (*Value, error) {
|
||||
value, err := f.Eval(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if value.Type() != Int {
|
||||
err := fmt.Errorf("%v evaluated to %v; not to int", f, value.Type())
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func callForString(f Expr, record Record) (*Value, error) {
|
||||
value, err := f.Eval(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if value.Type() != String {
|
||||
err := fmt.Errorf("%v evaluated to %v; not to string", f, value.Type())
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// funcExpr - SQL function.
|
||||
type funcExpr struct {
|
||||
args []Expr
|
||||
name FuncName
|
||||
|
||||
sumValue float64
|
||||
countValue int64
|
||||
maxValue float64
|
||||
minValue float64
|
||||
}
|
||||
|
||||
// String - returns string representation of this function.
|
||||
func (f *funcExpr) String() string {
|
||||
var argStrings []string
|
||||
for _, arg := range f.args {
|
||||
argStrings = append(argStrings, fmt.Sprintf("%v", arg))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%v(%v)", f.name, strings.Join(argStrings, ","))
|
||||
}
|
||||
|
||||
func (f *funcExpr) sum(record Record) (*Value, error) {
|
||||
value, err := callForNumber(f.args[0], record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f.sumValue += value.FloatValue()
|
||||
f.countValue++
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *funcExpr) count(record Record) (*Value, error) {
|
||||
value, err := f.args[0].Eval(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if value.valueType != Null {
|
||||
f.countValue++
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *funcExpr) max(record Record) (*Value, error) {
|
||||
value, err := callForNumber(f.args[0], record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
v := value.FloatValue()
|
||||
if v > f.maxValue {
|
||||
f.maxValue = v
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *funcExpr) min(record Record) (*Value, error) {
|
||||
value, err := callForNumber(f.args[0], record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
v := value.FloatValue()
|
||||
if v < f.minValue {
|
||||
f.minValue = v
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *funcExpr) charLength(record Record) (*Value, error) {
|
||||
value, err := callForString(f.args[0], record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewInt(int64(len(value.StringValue()))), nil
|
||||
}
|
||||
|
||||
func (f *funcExpr) trim(record Record) (*Value, error) {
|
||||
value, err := callForString(f.args[0], record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewString(strings.TrimSpace(value.StringValue())), nil
|
||||
}
|
||||
|
||||
func (f *funcExpr) lower(record Record) (*Value, error) {
|
||||
value, err := callForString(f.args[0], record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewString(strings.ToLower(value.StringValue())), nil
|
||||
}
|
||||
|
||||
func (f *funcExpr) upper(record Record) (*Value, error) {
|
||||
value, err := callForString(f.args[0], record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewString(strings.ToUpper(value.StringValue())), nil
|
||||
}
|
||||
|
||||
func (f *funcExpr) substring(record Record) (*Value, error) {
|
||||
stringValue, err := callForString(f.args[0], record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
offsetValue, err := callForInt(f.args[1], record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var lengthValue *Value
|
||||
if len(f.args) == 3 {
|
||||
lengthValue, err = callForInt(f.args[2], record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
value := stringValue.StringValue()
|
||||
offset := int(offsetValue.FloatValue())
|
||||
if offset < 0 || offset > len(value) {
|
||||
offset = 0
|
||||
}
|
||||
length := len(value)
|
||||
if lengthValue != nil {
|
||||
length = int(lengthValue.FloatValue())
|
||||
if length < 0 || length > len(value) {
|
||||
length = len(value)
|
||||
}
|
||||
}
|
||||
|
||||
return NewString(value[offset:length]), nil
|
||||
}
|
||||
|
||||
func (f *funcExpr) coalesce(record Record) (*Value, error) {
|
||||
values := make([]*Value, len(f.args))
|
||||
var err error
|
||||
for i := range f.args {
|
||||
values[i], err = f.args[i].Eval(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
for i := range values {
|
||||
if values[i].Type() != Null {
|
||||
return values[i], nil
|
||||
}
|
||||
}
|
||||
|
||||
return values[0], nil
|
||||
}
|
||||
|
||||
func (f *funcExpr) nullIf(record Record) (*Value, error) {
|
||||
value1, err := f.args[0].Eval(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
value2, err := f.args[1].Eval(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result, err := equal(value1, value2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result {
|
||||
return NewNull(), nil
|
||||
}
|
||||
|
||||
return value1, nil
|
||||
}
|
||||
|
||||
func (f *funcExpr) toTimeStamp(record Record) (*Value, error) {
|
||||
value, err := callForString(f.args[0], record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t, err := time.Parse(time.RFC3339, value.StringValue())
|
||||
if err != nil {
|
||||
err := fmt.Errorf("%v: value '%v': %v", f, value, err)
|
||||
return nil, errValueParseFailure(err)
|
||||
}
|
||||
|
||||
return NewTime(t), nil
|
||||
}
|
||||
|
||||
func (f *funcExpr) utcNow(record Record) (*Value, error) {
|
||||
return NewTime(time.Now().UTC()), nil
|
||||
}
|
||||
|
||||
// Call - evaluates this function for given arg values and returns result as Value.
|
||||
func (f *funcExpr) Eval(record Record) (*Value, error) {
|
||||
switch f.name {
|
||||
case Avg, Sum:
|
||||
return f.sum(record)
|
||||
case Count:
|
||||
return f.count(record)
|
||||
case Max:
|
||||
return f.max(record)
|
||||
case Min:
|
||||
return f.min(record)
|
||||
case Coalesce:
|
||||
return f.coalesce(record)
|
||||
case NullIf:
|
||||
return f.nullIf(record)
|
||||
case ToTimestamp:
|
||||
return f.toTimeStamp(record)
|
||||
case UTCNow:
|
||||
return f.utcNow(record)
|
||||
case Substring:
|
||||
return f.substring(record)
|
||||
case CharLength, CharacterLength:
|
||||
return f.charLength(record)
|
||||
case Trim:
|
||||
return f.trim(record)
|
||||
case Lower:
|
||||
return f.lower(record)
|
||||
case Upper:
|
||||
return f.upper(record)
|
||||
}
|
||||
|
||||
panic(fmt.Sprintf("unsupported aggregate function %v", f.name))
|
||||
}
|
||||
|
||||
// AggregateValue - returns aggregated value.
|
||||
func (f *funcExpr) AggregateValue() (*Value, error) {
|
||||
switch f.name {
|
||||
case Avg:
|
||||
return NewFloat(f.sumValue / float64(f.countValue)), nil
|
||||
case Count:
|
||||
return NewInt(f.countValue), nil
|
||||
case Max:
|
||||
return NewFloat(f.maxValue), nil
|
||||
case Min:
|
||||
return NewFloat(f.minValue), nil
|
||||
case Sum:
|
||||
return NewFloat(f.sumValue), nil
|
||||
}
|
||||
|
||||
err := fmt.Errorf("%v is not aggreate function", f)
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
// Type - returns Function or aggregateFunction type.
|
||||
func (f *funcExpr) Type() Type {
|
||||
switch f.name {
|
||||
case Avg, Count, Max, Min, Sum:
|
||||
return aggregateFunction
|
||||
}
|
||||
|
||||
return function
|
||||
}
|
||||
|
||||
// ReturnType - returns respective primitive type depending on SQL function.
|
||||
func (f *funcExpr) ReturnType() Type {
|
||||
switch f.name {
|
||||
case Avg, Max, Min, Sum:
|
||||
return Float
|
||||
case Count:
|
||||
return Int
|
||||
case CharLength, CharacterLength, Trim, Lower, Upper, Substring:
|
||||
return String
|
||||
case ToTimestamp, UTCNow:
|
||||
return Timestamp
|
||||
case Coalesce, NullIf:
|
||||
return column
|
||||
}
|
||||
|
||||
return function
|
||||
}
|
||||
|
||||
// newFuncExpr - creates new SQL function.
|
||||
func newFuncExpr(funcName FuncName, funcs ...Expr) (*funcExpr, error) {
|
||||
switch funcName {
|
||||
case Avg, Max, Min, Sum:
|
||||
if len(funcs) != 1 {
|
||||
err := fmt.Errorf("%v(): exactly one argument expected; got %v", funcName, len(funcs))
|
||||
return nil, errParseNonUnaryAgregateFunctionCall(err)
|
||||
}
|
||||
|
||||
if !funcs[0].ReturnType().isNumberKind() {
|
||||
err := fmt.Errorf("%v(): argument %v evaluate to %v, not number", funcName, funcs[0], funcs[0].ReturnType())
|
||||
return nil, errIncorrectSQLFunctionArgumentType(err)
|
||||
}
|
||||
|
||||
return &funcExpr{
|
||||
args: funcs,
|
||||
name: funcName,
|
||||
}, nil
|
||||
|
||||
case Count:
|
||||
if len(funcs) != 1 {
|
||||
err := fmt.Errorf("%v(): exactly one argument expected; got %v", funcName, len(funcs))
|
||||
return nil, errParseNonUnaryAgregateFunctionCall(err)
|
||||
}
|
||||
|
||||
switch funcs[0].ReturnType() {
|
||||
case Null, Bool, Int, Float, String, Timestamp, column, record:
|
||||
default:
|
||||
err := fmt.Errorf("%v(): argument %v evaluate to %v is incompatible", funcName, funcs[0], funcs[0].ReturnType())
|
||||
return nil, errIncorrectSQLFunctionArgumentType(err)
|
||||
}
|
||||
|
||||
return &funcExpr{
|
||||
args: funcs,
|
||||
name: funcName,
|
||||
}, nil
|
||||
|
||||
case CharLength, CharacterLength, Trim, Lower, Upper, ToTimestamp:
|
||||
if len(funcs) != 1 {
|
||||
err := fmt.Errorf("%v(): exactly one argument expected; got %v", funcName, len(funcs))
|
||||
return nil, errEvaluatorInvalidArguments(err)
|
||||
}
|
||||
|
||||
if !funcs[0].ReturnType().isStringKind() {
|
||||
err := fmt.Errorf("%v(): argument %v evaluate to %v, not string", funcName, funcs[0], funcs[0].ReturnType())
|
||||
return nil, errIncorrectSQLFunctionArgumentType(err)
|
||||
}
|
||||
|
||||
return &funcExpr{
|
||||
args: funcs,
|
||||
name: funcName,
|
||||
}, nil
|
||||
|
||||
case Coalesce:
|
||||
if len(funcs) < 1 {
|
||||
err := fmt.Errorf("%v(): one or more argument expected; got %v", funcName, len(funcs))
|
||||
return nil, errEvaluatorInvalidArguments(err)
|
||||
}
|
||||
|
||||
for i := range funcs {
|
||||
if !funcs[i].ReturnType().isBaseKind() {
|
||||
err := fmt.Errorf("%v(): argument-%v %v evaluate to %v is incompatible", funcName, i+1, funcs[i], funcs[i].ReturnType())
|
||||
return nil, errIncorrectSQLFunctionArgumentType(err)
|
||||
}
|
||||
}
|
||||
|
||||
return &funcExpr{
|
||||
args: funcs,
|
||||
name: funcName,
|
||||
}, nil
|
||||
|
||||
case NullIf:
|
||||
if len(funcs) != 2 {
|
||||
err := fmt.Errorf("%v(): exactly two arguments expected; got %v", funcName, len(funcs))
|
||||
return nil, errEvaluatorInvalidArguments(err)
|
||||
}
|
||||
|
||||
if !funcs[0].ReturnType().isBaseKind() {
|
||||
err := fmt.Errorf("%v(): argument-1 %v evaluate to %v is incompatible", funcName, funcs[0], funcs[0].ReturnType())
|
||||
return nil, errIncorrectSQLFunctionArgumentType(err)
|
||||
}
|
||||
|
||||
if !funcs[1].ReturnType().isBaseKind() {
|
||||
err := fmt.Errorf("%v(): argument-2 %v evaluate to %v is incompatible", funcName, funcs[1], funcs[1].ReturnType())
|
||||
return nil, errIncorrectSQLFunctionArgumentType(err)
|
||||
}
|
||||
|
||||
return &funcExpr{
|
||||
args: funcs,
|
||||
name: funcName,
|
||||
}, nil
|
||||
|
||||
case UTCNow:
|
||||
if len(funcs) != 0 {
|
||||
err := fmt.Errorf("%v(): no argument expected; got %v", funcName, len(funcs))
|
||||
return nil, errEvaluatorInvalidArguments(err)
|
||||
}
|
||||
|
||||
return &funcExpr{
|
||||
args: funcs,
|
||||
name: funcName,
|
||||
}, nil
|
||||
|
||||
case Substring:
|
||||
if len(funcs) < 2 || len(funcs) > 3 {
|
||||
err := fmt.Errorf("%v(): exactly two or three arguments expected; got %v", funcName, len(funcs))
|
||||
return nil, errEvaluatorInvalidArguments(err)
|
||||
}
|
||||
|
||||
if !funcs[0].ReturnType().isStringKind() {
|
||||
err := fmt.Errorf("%v(): argument-1 %v evaluate to %v, not string", funcName, funcs[0], funcs[0].ReturnType())
|
||||
return nil, errIncorrectSQLFunctionArgumentType(err)
|
||||
}
|
||||
|
||||
if !funcs[1].ReturnType().isIntKind() {
|
||||
err := fmt.Errorf("%v(): argument-2 %v evaluate to %v, not int", funcName, funcs[1], funcs[1].ReturnType())
|
||||
return nil, errIncorrectSQLFunctionArgumentType(err)
|
||||
}
|
||||
|
||||
if len(funcs) > 2 {
|
||||
if !funcs[2].ReturnType().isIntKind() {
|
||||
err := fmt.Errorf("%v(): argument-3 %v evaluate to %v, not int", funcName, funcs[2], funcs[2].ReturnType())
|
||||
return nil, errIncorrectSQLFunctionArgumentType(err)
|
||||
}
|
||||
}
|
||||
|
||||
return &funcExpr{
|
||||
args: funcs,
|
||||
name: funcName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, errUnsupportedFunction(fmt.Errorf("unknown function name %v", funcName))
|
||||
}
|
|
@ -1,336 +0,0 @@
|
|||
/*
|
||||
* Minio Cloud Storage, (C) 2019 Minio, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package sql
|
||||
|
||||
import "fmt"
|
||||
|
||||
// andExpr - logical AND function.
|
||||
type andExpr struct {
|
||||
left Expr
|
||||
right Expr
|
||||
funcType Type
|
||||
}
|
||||
|
||||
// String - returns string representation of this function.
|
||||
func (f *andExpr) String() string {
|
||||
return fmt.Sprintf("(%v AND %v)", f.left, f.right)
|
||||
}
|
||||
|
||||
// Call - evaluates this function for given arg values and returns result as Value.
|
||||
func (f *andExpr) Eval(record Record) (*Value, error) {
|
||||
leftValue, err := f.left.Eval(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if f.funcType == aggregateFunction {
|
||||
_, err = f.right.Eval(record)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if leftValue.Type() != Bool {
|
||||
err := fmt.Errorf("%v: left side expression evaluated to %v; not to bool", f, leftValue.Type())
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
if !leftValue.BoolValue() {
|
||||
return leftValue, nil
|
||||
}
|
||||
|
||||
rightValue, err := f.right.Eval(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rightValue.Type() != Bool {
|
||||
err := fmt.Errorf("%v: right side expression evaluated to %v; not to bool", f, rightValue.Type())
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
return rightValue, nil
|
||||
}
|
||||
|
||||
// AggregateValue - returns aggregated value.
|
||||
func (f *andExpr) AggregateValue() (*Value, error) {
|
||||
if f.funcType != aggregateFunction {
|
||||
err := fmt.Errorf("%v is not aggreate expression", f)
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
leftValue, err := f.left.AggregateValue()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if leftValue.Type() != Bool {
|
||||
err := fmt.Errorf("%v: left side expression evaluated to %v; not to bool", f, leftValue.Type())
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
if !leftValue.BoolValue() {
|
||||
return leftValue, nil
|
||||
}
|
||||
|
||||
rightValue, err := f.right.AggregateValue()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rightValue.Type() != Bool {
|
||||
err := fmt.Errorf("%v: right side expression evaluated to %v; not to bool", f, rightValue.Type())
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
return rightValue, nil
|
||||
}
|
||||
|
||||
// Type - returns logicalFunction or aggregateFunction type.
|
||||
func (f *andExpr) Type() Type {
|
||||
return f.funcType
|
||||
}
|
||||
|
||||
// ReturnType - returns Bool as return type.
|
||||
func (f *andExpr) ReturnType() Type {
|
||||
return Bool
|
||||
}
|
||||
|
||||
// newAndExpr - creates new AND logical function.
|
||||
func newAndExpr(left, right Expr) (*andExpr, error) {
|
||||
if !left.ReturnType().isBoolKind() {
|
||||
err := fmt.Errorf("operator AND: left side expression %v evaluate to %v, not bool", left, left.ReturnType())
|
||||
return nil, errInvalidDataType(err)
|
||||
}
|
||||
|
||||
if !right.ReturnType().isBoolKind() {
|
||||
err := fmt.Errorf("operator AND: right side expression %v evaluate to %v; not bool", right, right.ReturnType())
|
||||
return nil, errInvalidDataType(err)
|
||||
}
|
||||
|
||||
funcType := logicalFunction
|
||||
if left.Type() == aggregateFunction || right.Type() == aggregateFunction {
|
||||
funcType = aggregateFunction
|
||||
if left.Type() == column {
|
||||
err := fmt.Errorf("operator AND: left side expression %v return type %v is incompatible for aggregate evaluation", left, left.Type())
|
||||
return nil, errUnsupportedSQLOperation(err)
|
||||
}
|
||||
|
||||
if right.Type() == column {
|
||||
err := fmt.Errorf("operator AND: right side expression %v return type %v is incompatible for aggregate evaluation", right, right.Type())
|
||||
return nil, errUnsupportedSQLOperation(err)
|
||||
}
|
||||
}
|
||||
|
||||
return &andExpr{
|
||||
left: left,
|
||||
right: right,
|
||||
funcType: funcType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// orExpr - logical OR function.
|
||||
type orExpr struct {
|
||||
left Expr
|
||||
right Expr
|
||||
funcType Type
|
||||
}
|
||||
|
||||
// String - returns string representation of this function.
|
||||
func (f *orExpr) String() string {
|
||||
return fmt.Sprintf("(%v OR %v)", f.left, f.right)
|
||||
}
|
||||
|
||||
// Call - evaluates this function for given arg values and returns result as Value.
|
||||
func (f *orExpr) Eval(record Record) (*Value, error) {
|
||||
leftValue, err := f.left.Eval(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if f.funcType == aggregateFunction {
|
||||
_, err = f.right.Eval(record)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if leftValue.Type() != Bool {
|
||||
err := fmt.Errorf("%v: left side expression evaluated to %v; not to bool", f, leftValue.Type())
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
if leftValue.BoolValue() {
|
||||
return leftValue, nil
|
||||
}
|
||||
|
||||
rightValue, err := f.right.Eval(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rightValue.Type() != Bool {
|
||||
err := fmt.Errorf("%v: right side expression evaluated to %v; not to bool", f, rightValue.Type())
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
return rightValue, nil
|
||||
}
|
||||
|
||||
// AggregateValue - returns aggregated value.
|
||||
func (f *orExpr) AggregateValue() (*Value, error) {
|
||||
if f.funcType != aggregateFunction {
|
||||
err := fmt.Errorf("%v is not aggreate expression", f)
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
leftValue, err := f.left.AggregateValue()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if leftValue.Type() != Bool {
|
||||
err := fmt.Errorf("%v: left side expression evaluated to %v; not to bool", f, leftValue.Type())
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
if leftValue.BoolValue() {
|
||||
return leftValue, nil
|
||||
}
|
||||
|
||||
rightValue, err := f.right.AggregateValue()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rightValue.Type() != Bool {
|
||||
err := fmt.Errorf("%v: right side expression evaluated to %v; not to bool", f, rightValue.Type())
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
return rightValue, nil
|
||||
}
|
||||
|
||||
// Type - returns logicalFunction or aggregateFunction type.
|
||||
func (f *orExpr) Type() Type {
|
||||
return f.funcType
|
||||
}
|
||||
|
||||
// ReturnType - returns Bool as return type.
|
||||
func (f *orExpr) ReturnType() Type {
|
||||
return Bool
|
||||
}
|
||||
|
||||
// newOrExpr - creates new OR logical function.
|
||||
func newOrExpr(left, right Expr) (*orExpr, error) {
|
||||
if !left.ReturnType().isBoolKind() {
|
||||
err := fmt.Errorf("operator OR: left side expression %v evaluate to %v, not bool", left, left.ReturnType())
|
||||
return nil, errInvalidDataType(err)
|
||||
}
|
||||
|
||||
if !right.ReturnType().isBoolKind() {
|
||||
err := fmt.Errorf("operator OR: right side expression %v evaluate to %v; not bool", right, right.ReturnType())
|
||||
return nil, errInvalidDataType(err)
|
||||
}
|
||||
|
||||
funcType := logicalFunction
|
||||
if left.Type() == aggregateFunction || right.Type() == aggregateFunction {
|
||||
funcType = aggregateFunction
|
||||
if left.Type() == column {
|
||||
err := fmt.Errorf("operator OR: left side expression %v return type %v is incompatible for aggregate evaluation", left, left.Type())
|
||||
return nil, errUnsupportedSQLOperation(err)
|
||||
}
|
||||
|
||||
if right.Type() == column {
|
||||
err := fmt.Errorf("operator OR: right side expression %v return type %v is incompatible for aggregate evaluation", right, right.Type())
|
||||
return nil, errUnsupportedSQLOperation(err)
|
||||
}
|
||||
}
|
||||
|
||||
return &orExpr{
|
||||
left: left,
|
||||
right: right,
|
||||
funcType: funcType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// notExpr - logical NOT function.
|
||||
type notExpr struct {
|
||||
right Expr
|
||||
funcType Type
|
||||
}
|
||||
|
||||
// String - returns string representation of this function.
|
||||
func (f *notExpr) String() string {
|
||||
return fmt.Sprintf("(%v)", f.right)
|
||||
}
|
||||
|
||||
// Call - evaluates this function for given arg values and returns result as Value.
|
||||
func (f *notExpr) Eval(record Record) (*Value, error) {
|
||||
rightValue, err := f.right.Eval(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if f.funcType == aggregateFunction {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if rightValue.Type() != Bool {
|
||||
err := fmt.Errorf("%v: right side expression evaluated to %v; not to bool", f, rightValue.Type())
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
return NewBool(!rightValue.BoolValue()), nil
|
||||
}
|
||||
|
||||
// AggregateValue - returns aggregated value.
|
||||
func (f *notExpr) AggregateValue() (*Value, error) {
|
||||
if f.funcType != aggregateFunction {
|
||||
err := fmt.Errorf("%v is not aggreate expression", f)
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
rightValue, err := f.right.AggregateValue()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rightValue.Type() != Bool {
|
||||
err := fmt.Errorf("%v: right side expression evaluated to %v; not to bool", f, rightValue.Type())
|
||||
return nil, errExternalEvalException(err)
|
||||
}
|
||||
|
||||
return NewBool(!rightValue.BoolValue()), nil
|
||||
}
|
||||
|
||||
// Type - returns logicalFunction or aggregateFunction type.
|
||||
func (f *notExpr) Type() Type {
|
||||
return f.funcType
|
||||
}
|
||||
|
||||
// ReturnType - returns Bool as return type.
|
||||
func (f *notExpr) ReturnType() Type {
|
||||
return Bool
|
||||
}
|
||||
|
||||
// newNotExpr - creates new NOT logical function.
|
||||
func newNotExpr(right Expr) (*notExpr, error) {
|
||||
if !right.ReturnType().isBoolKind() {
|
||||
err := fmt.Errorf("operator NOT: right side expression %v evaluate to %v; not bool", right, right.ReturnType())
|
||||
return nil, errInvalidDataType(err)
|
||||
}
|
||||
|
||||
funcType := logicalFunction
|
||||
if right.Type() == aggregateFunction {
|
||||
funcType = aggregateFunction
|
||||
}
|
||||
|
||||
return ¬Expr{
|
||||
right: right,
|
||||
funcType: funcType,
|
||||
}, nil
|
||||
}
|
|
@ -0,0 +1,329 @@
|
|||
/*
|
||||
* Minio Cloud Storage, (C) 2019 Minio, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/alecthomas/participle"
|
||||
"github.com/alecthomas/participle/lexer"
|
||||
)
|
||||
|
||||
// Types with custom Capture interface for parsing
|
||||
|
||||
// Boolean is a type for a parsed Boolean literal
|
||||
type Boolean bool
|
||||
|
||||
// Capture interface used by participle
|
||||
func (b *Boolean) Capture(values []string) error {
|
||||
*b = strings.ToLower(values[0]) == "true"
|
||||
return nil
|
||||
}
|
||||
|
||||
// LiteralString is a type for parsed SQL string literals
|
||||
type LiteralString string
|
||||
|
||||
// Capture interface used by participle
|
||||
func (ls *LiteralString) Capture(values []string) error {
|
||||
// Remove enclosing single quote
|
||||
n := len(values[0])
|
||||
r := values[0][1 : n-1]
|
||||
// Translate doubled quotes
|
||||
*ls = LiteralString(strings.Replace(r, "''", "'", -1))
|
||||
return nil
|
||||
}
|
||||
|
||||
// ObjectKey is a type for parsed strings occurring in key paths
|
||||
type ObjectKey struct {
|
||||
Lit *LiteralString `parser:" \"[\" @LitString \"]\""`
|
||||
ID *Identifier `parser:"| \".\" @@"`
|
||||
}
|
||||
|
||||
// QuotedIdentifier is a type for parsed strings that are double
|
||||
// quoted.
|
||||
type QuotedIdentifier string
|
||||
|
||||
// Capture inferface used by participle
|
||||
func (qi *QuotedIdentifier) Capture(values []string) error {
|
||||
// Remove enclosing quotes
|
||||
n := len(values[0])
|
||||
r := values[0][1 : n-1]
|
||||
|
||||
// Translate doubled quotes
|
||||
*qi = QuotedIdentifier(strings.Replace(r, `""`, `"`, -1))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Types representing AST of SQL statement. Only SELECT is supported.
|
||||
|
||||
// Select is the top level AST node type
|
||||
type Select struct {
|
||||
Expression *SelectExpression `parser:"\"SELECT\" @@"`
|
||||
From *TableExpression `parser:"\"FROM\" @@"`
|
||||
Where *Expression `parser:"[ \"WHERE\" @@ ]"`
|
||||
Limit *LitValue `parser:"[ \"LIMIT\" @@ ]"`
|
||||
}
|
||||
|
||||
// SelectExpression represents the items requested in the select
|
||||
// statement
|
||||
type SelectExpression struct {
|
||||
All bool `parser:" @\"*\""`
|
||||
Expressions []*AliasedExpression `parser:"| @@ { \",\" @@ }"`
|
||||
|
||||
prop qProp
|
||||
}
|
||||
|
||||
// TableExpression represents the FROM clause
|
||||
type TableExpression struct {
|
||||
Table *JSONPath `parser:"@@"`
|
||||
As string `parser:"( \"AS\"? @Ident )?"`
|
||||
}
|
||||
|
||||
// JSONPathElement represents a keypath component
|
||||
type JSONPathElement struct {
|
||||
Key *ObjectKey `parser:" @@"` // ['name'] and .name forms
|
||||
Index *uint64 `parser:"| \"[\" @Number \"]\""` // [3] form
|
||||
ObjectWildcard bool `parser:"| @\".*\""` // .* form
|
||||
ArrayWildcard bool `parser:"| @\"[*]\""` // [*] form
|
||||
}
|
||||
|
||||
// JSONPath represents a keypath
|
||||
type JSONPath struct {
|
||||
BaseKey *Identifier `parser:" @@"`
|
||||
PathExpr []*JSONPathElement `parser:"(@@)*"`
|
||||
}
|
||||
|
||||
// AliasedExpression is an expression that can be optionally named
|
||||
type AliasedExpression struct {
|
||||
Expression *Expression `parser:"@@"`
|
||||
As string `parser:"[ \"AS\" @Ident ]"`
|
||||
}
|
||||
|
||||
// Grammar for Expression
|
||||
//
|
||||
// Expression → AndCondition ("OR" AndCondition)*
|
||||
// AndCondition → Condition ("AND" Condition)*
|
||||
// Condition → "NOT" Condition | ConditionExpression
|
||||
// ConditionExpression → ValueExpression ("=" | "<>" | "<=" | ">=" | "<" | ">") ValueExpression
|
||||
// | ValueExpression "LIKE" ValueExpression ("ESCAPE" LitString)?
|
||||
// | ValueExpression ("NOT"? "BETWEEN" ValueExpression "AND" ValueExpression)
|
||||
// | ValueExpression "IN" "(" Expression ("," Expression)* ")"
|
||||
// | ValueExpression
|
||||
// ValueExpression → Operand
|
||||
//
|
||||
// Operand grammar follows below
|
||||
|
||||
// Expression represents a logical disjunction of clauses
|
||||
type Expression struct {
|
||||
And []*AndCondition `parser:"@@ ( \"OR\" @@ )*"`
|
||||
}
|
||||
|
||||
// AndCondition represents logical conjunction of clauses
|
||||
type AndCondition struct {
|
||||
Condition []*Condition `parser:"@@ ( \"AND\" @@ )*"`
|
||||
}
|
||||
|
||||
// Condition represents a negation or a condition operand
|
||||
type Condition struct {
|
||||
Operand *ConditionOperand `parser:" @@"`
|
||||
Not *Condition `parser:"| \"NOT\" @@"`
|
||||
}
|
||||
|
||||
// ConditionOperand is a operand followed by an an optional operation
|
||||
// expression
|
||||
type ConditionOperand struct {
|
||||
Operand *Operand `parser:"@@"`
|
||||
ConditionRHS *ConditionRHS `parser:"@@?"`
|
||||
}
|
||||
|
||||
// ConditionRHS represents the right-hand-side of Compare, Between, In
|
||||
// or Like expressions.
|
||||
type ConditionRHS struct {
|
||||
Compare *Compare `parser:" @@"`
|
||||
Between *Between `parser:"| @@"`
|
||||
In *In `parser:"| \"IN\" \"(\" @@ \")\""`
|
||||
Like *Like `parser:"| @@"`
|
||||
}
|
||||
|
||||
// Compare represents the RHS of a comparison expression
|
||||
type Compare struct {
|
||||
Operator string `parser:"@( \"<>\" | \"<=\" | \">=\" | \"=\" | \"<\" | \">\" | \"!=\" )"`
|
||||
Operand *Operand `parser:" @@"`
|
||||
}
|
||||
|
||||
// Like represents the RHS of a LIKE expression
|
||||
type Like struct {
|
||||
Not bool `parser:" @\"NOT\"? "`
|
||||
Pattern *Operand `parser:" \"LIKE\" @@ "`
|
||||
EscapeChar *Operand `parser:" (\"ESCAPE\" @@)? "`
|
||||
}
|
||||
|
||||
// Between represents the RHS of a BETWEEN expression
|
||||
type Between struct {
|
||||
Not bool `parser:" @\"NOT\"? "`
|
||||
Start *Operand `parser:" \"BETWEEN\" @@ "`
|
||||
End *Operand `parser:" \"AND\" @@ "`
|
||||
}
|
||||
|
||||
// In represents the RHS of an IN expression
|
||||
type In struct {
|
||||
Expressions []*Expression `parser:"@@ ( \",\" @@ )*"`
|
||||
}
|
||||
|
||||
// Grammar for Operand:
|
||||
//
|
||||
// operand → multOp ( ("-" | "+") multOp )*
|
||||
// multOp → unary ( ("/" | "*" | "%") unary )*
|
||||
// unary → "-" unary | primary
|
||||
// primary → Value | Variable | "(" expression ")"
|
||||
//
|
||||
|
||||
// An Operand is a single term followed by an optional sequence of
|
||||
// terms separated by +/-
|
||||
type Operand struct {
|
||||
Left *MultOp `parser:"@@"`
|
||||
Right []*OpFactor `parser:"(@@)*"`
|
||||
}
|
||||
|
||||
// OpFactor represents the right-side of a +/- operation.
|
||||
type OpFactor struct {
|
||||
Op string `parser:"@(\"+\" | \"-\")"`
|
||||
Right *MultOp `parser:"@@"`
|
||||
}
|
||||
|
||||
// MultOp represents a single term followed by an optional sequence of
|
||||
// terms separated by *, / or % operators.
|
||||
type MultOp struct {
|
||||
Left *UnaryTerm `parser:"@@"`
|
||||
Right []*OpUnaryTerm `parser:"(@@)*"`
|
||||
}
|
||||
|
||||
// OpUnaryTerm represents the right side of *, / or % binary operations.
|
||||
type OpUnaryTerm struct {
|
||||
Op string `parser:"@(\"*\" | \"/\" | \"%\")"`
|
||||
Right *UnaryTerm `parser:"@@"`
|
||||
}
|
||||
|
||||
// UnaryTerm represents a single negated term or a primary term
|
||||
type UnaryTerm struct {
|
||||
Negated *NegatedTerm `parser:" @@"`
|
||||
Primary *PrimaryTerm `parser:"| @@"`
|
||||
}
|
||||
|
||||
// NegatedTerm has a leading minus sign.
|
||||
type NegatedTerm struct {
|
||||
Term *PrimaryTerm `parser:"\"-\" @@"`
|
||||
}
|
||||
|
||||
// PrimaryTerm represents a Value, Path expression, a Sub-expression
|
||||
// or a function call.
|
||||
type PrimaryTerm struct {
|
||||
Value *LitValue `parser:" @@"`
|
||||
JPathExpr *JSONPath `parser:"| @@"`
|
||||
SubExpression *Expression `parser:"| \"(\" @@ \")\""`
|
||||
// Include function expressions here.
|
||||
FuncCall *FuncExpr `parser:"| @@"`
|
||||
}
|
||||
|
||||
// FuncExpr represents a function call
|
||||
type FuncExpr struct {
|
||||
SFunc *SimpleArgFunc `parser:" @@"`
|
||||
Count *CountFunc `parser:"| @@"`
|
||||
Cast *CastFunc `parser:"| @@"`
|
||||
Substring *SubstringFunc `parser:"| @@"`
|
||||
Extract *ExtractFunc `parser:"| @@"`
|
||||
Trim *TrimFunc `parser:"| @@"`
|
||||
|
||||
// Used during evaluation for aggregation funcs
|
||||
aggregate *aggVal
|
||||
}
|
||||
|
||||
// SimpleArgFunc represents functions with simple expression
|
||||
// arguments.
|
||||
type SimpleArgFunc struct {
|
||||
FunctionName string `parser:" @(\"AVG\" | \"MAX\" | \"MIN\" | \"SUM\" | \"COALESCE\" | \"NULLIF\" | \"DATE_ADD\" | \"DATE_DIFF\" | \"TO_STRING\" | \"TO_TIMESTAMP\" | \"UTCNOW\" | \"CHAR_LENGTH\" | \"CHARACTER_LENGTH\" | \"LOWER\" | \"UPPER\") "`
|
||||
|
||||
ArgsList []*Expression `parser:"\"(\" (@@ (\",\" @@)*)?\")\""`
|
||||
}
|
||||
|
||||
// CountFunc represents the COUNT sql function
|
||||
type CountFunc struct {
|
||||
StarArg bool `parser:" \"COUNT\" \"(\" ( @\"*\"?"`
|
||||
ExprArg *Expression `parser:" @@? )! \")\""`
|
||||
}
|
||||
|
||||
// CastFunc represents CAST sql function
|
||||
type CastFunc struct {
|
||||
Expr *Expression `parser:" \"CAST\" \"(\" @@ "`
|
||||
CastType string `parser:" \"AS\" @(\"BOOL\" | \"INT\" | \"INTEGER\" | \"STRING\" | \"FLOAT\" | \"DECIMAL\" | \"NUMERIC\" | \"TIMESTAMP\") \")\" "`
|
||||
}
|
||||
|
||||
// SubstringFunc represents SUBSTRING sql function
|
||||
type SubstringFunc struct {
|
||||
Expr *PrimaryTerm `parser:" \"SUBSTRING\" \"(\" @@ "`
|
||||
From *Operand `parser:" ( \"FROM\" @@ "`
|
||||
For *Operand `parser:" (\"FOR\" @@)? \")\" "`
|
||||
Arg2 *Operand `parser:" | \",\" @@ "`
|
||||
Arg3 *Operand `parser:" (\",\" @@)? \")\" )"`
|
||||
}
|
||||
|
||||
// ExtractFunc represents EXTRACT sql function
|
||||
type ExtractFunc struct {
|
||||
Timeword string `parser:" \"EXTRACT\" \"(\" @( \"YEAR\":Timeword | \"MONTH\":Timeword | \"DAY\":Timeword | \"HOUR\":Timeword | \"MINUTE\":Timeword | \"SECOND\":Timeword | \"TIMEZONE_HOUR\":Timeword | \"TIMEZONE_MINUTE\":Timeword ) "`
|
||||
From *PrimaryTerm `parser:" \"FROM\" @@ \")\" "`
|
||||
}
|
||||
|
||||
// TrimFunc represents TRIM sql function
|
||||
type TrimFunc struct {
|
||||
TrimWhere *string `parser:" \"TRIM\" \"(\" ( @( \"LEADING\" | \"TRAILING\" | \"BOTH\" ) "`
|
||||
TrimChars *PrimaryTerm `parser:" @@? "`
|
||||
TrimFrom *PrimaryTerm `parser:" \"FROM\" )? @@ \")\" "`
|
||||
}
|
||||
|
||||
// LitValue represents a literal value parsed from the sql
|
||||
type LitValue struct {
|
||||
Number *float64 `parser:"( @Number"`
|
||||
String *LiteralString `parser:" | @LitString"`
|
||||
Boolean *Boolean `parser:" | @(\"TRUE\" | \"FALSE\")"`
|
||||
Null bool `parser:" | @\"NULL\")"`
|
||||
}
|
||||
|
||||
// Identifier represents a parsed identifier
|
||||
type Identifier struct {
|
||||
Unquoted *string `parser:" @Ident"`
|
||||
Quoted *QuotedIdentifier `parser:"| @QuotIdent"`
|
||||
}
|
||||
|
||||
var (
|
||||
sqlLexer = lexer.Must(lexer.Regexp(`(\s+)` +
|
||||
`|(?P<Timeword>(?i)\b(?:YEAR|MONTH|DAY|HOUR|MINUTE|SECOND|TIMEZONE_HOUR|TIMEZONE_MINUTE)\b)` +
|
||||
`|(?P<Keyword>(?i)\b(?:SELECT|FROM|TOP|DISTINCT|ALL|WHERE|GROUP|BY|HAVING|UNION|MINUS|EXCEPT|INTERSECT|ORDER|LIMIT|OFFSET|TRUE|FALSE|NULL|IS|NOT|ANY|SOME|BETWEEN|AND|OR|LIKE|ESCAPE|AS|IN|BOOL|INT|INTEGER|STRING|FLOAT|DECIMAL|NUMERIC|TIMESTAMP|AVG|COUNT|MAX|MIN|SUM|COALESCE|NULLIF|CAST|DATE_ADD|DATE_DIFF|EXTRACT|TO_STRING|TO_TIMESTAMP|UTCNOW|CHAR_LENGTH|CHARACTER_LENGTH|LOWER|SUBSTRING|TRIM|UPPER|LEADING|TRAILING|BOTH|FOR)\b)` +
|
||||
`|(?P<Ident>[a-zA-Z_][a-zA-Z0-9_]*)` +
|
||||
`|(?P<QuotIdent>"([^"]*("")?)*")` +
|
||||
`|(?P<Number>\d*\.?\d+([eE][-+]?\d+)?)` +
|
||||
`|(?P<LitString>'([^']*('')?)*')` +
|
||||
`|(?P<Operators><>|!=|<=|>=|\.\*|\[\*\]|[-+*/%,.()=<>\[\]])`,
|
||||
))
|
||||
|
||||
// SQLParser is used to parse SQL statements
|
||||
SQLParser = participle.MustBuild(
|
||||
&Select{},
|
||||
participle.Lexer(sqlLexer),
|
||||
participle.CaseInsensitive("Keyword"),
|
||||
participle.CaseInsensitive("Timeword"),
|
||||
)
|
||||
)
|
|
@ -0,0 +1,383 @@
|
|||
/*
|
||||
* Minio Cloud Storage, (C) 2019 Minio, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/alecthomas/participle"
|
||||
"github.com/alecthomas/participle/lexer"
|
||||
)
|
||||
|
||||
func TestJSONPathElement(t *testing.T) {
|
||||
p := participle.MustBuild(
|
||||
&JSONPathElement{},
|
||||
participle.Lexer(sqlLexer),
|
||||
participle.CaseInsensitive("Keyword"),
|
||||
)
|
||||
|
||||
j := JSONPathElement{}
|
||||
cases := []string{
|
||||
// Key
|
||||
"['name']", ".name", `."name"`,
|
||||
|
||||
// Index
|
||||
"[2]", "[0]", "[100]",
|
||||
|
||||
// Object wilcard
|
||||
".*",
|
||||
|
||||
// array wildcard
|
||||
"[*]",
|
||||
}
|
||||
for i, tc := range cases {
|
||||
err := p.ParseString(tc, &j)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: %v", i, err)
|
||||
}
|
||||
// repr.Println(j, repr.Indent(" "), repr.OmitEmpty(true))
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONPath(t *testing.T) {
|
||||
p := participle.MustBuild(
|
||||
&JSONPath{},
|
||||
participle.Lexer(sqlLexer),
|
||||
participle.CaseInsensitive("Keyword"),
|
||||
)
|
||||
|
||||
j := JSONPath{}
|
||||
cases := []string{
|
||||
"S3Object",
|
||||
"S3Object.id",
|
||||
"S3Object.book.title",
|
||||
"S3Object.id[1]",
|
||||
"S3Object.id['abc']",
|
||||
"S3Object.id['ab']",
|
||||
"S3Object.words.*.id",
|
||||
"S3Object.words.name[*].val",
|
||||
"S3Object.words.name[*].val[*]",
|
||||
"S3Object.words.name[*].val.*",
|
||||
}
|
||||
for i, tc := range cases {
|
||||
err := p.ParseString(tc, &j)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: %v", i, err)
|
||||
}
|
||||
// repr.Println(j, repr.Indent(" "), repr.OmitEmpty(true))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestIdentifierParsing(t *testing.T) {
|
||||
p := participle.MustBuild(
|
||||
&Identifier{},
|
||||
participle.Lexer(sqlLexer),
|
||||
participle.CaseInsensitive("Keyword"),
|
||||
)
|
||||
|
||||
id := Identifier{}
|
||||
validCases := []string{
|
||||
"a",
|
||||
"_a",
|
||||
"abc_a",
|
||||
"a2",
|
||||
`"abc"`,
|
||||
`"abc\a""ac"`,
|
||||
}
|
||||
for i, tc := range validCases {
|
||||
err := p.ParseString(tc, &id)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: %v", i, err)
|
||||
}
|
||||
// repr.Println(id, repr.Indent(" "), repr.OmitEmpty(true))
|
||||
}
|
||||
|
||||
invalidCases := []string{
|
||||
"+a",
|
||||
"-a",
|
||||
"1a",
|
||||
`"ab`,
|
||||
`abc"`,
|
||||
`aa""a`,
|
||||
`"a"a"`,
|
||||
}
|
||||
for i, tc := range invalidCases {
|
||||
err := p.ParseString(tc, &id)
|
||||
if err == nil {
|
||||
t.Fatalf("%d: %v", i, err)
|
||||
}
|
||||
// fmt.Println(tc, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLiteralStringParsing(t *testing.T) {
|
||||
var k ObjectKey
|
||||
p := participle.MustBuild(
|
||||
&ObjectKey{},
|
||||
participle.Lexer(sqlLexer),
|
||||
participle.CaseInsensitive("Keyword"),
|
||||
)
|
||||
|
||||
validCases := []string{
|
||||
"['abc']",
|
||||
"['ab''c']",
|
||||
"['a''b''c']",
|
||||
"['abc-x_1##@(*&(#*))/\\']",
|
||||
}
|
||||
for i, tc := range validCases {
|
||||
err := p.ParseString(tc, &k)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: %v", i, err)
|
||||
}
|
||||
if string(*k.Lit) == "" {
|
||||
t.Fatalf("Incorrect parse %#v", k)
|
||||
}
|
||||
// repr.Println(k, repr.Indent(" "), repr.OmitEmpty(true))
|
||||
}
|
||||
|
||||
invalidCases := []string{
|
||||
"['abc'']",
|
||||
"['-abc'sc']",
|
||||
"[abc']",
|
||||
"['ac]",
|
||||
}
|
||||
for i, tc := range invalidCases {
|
||||
err := p.ParseString(tc, &k)
|
||||
if err == nil {
|
||||
t.Fatalf("%d: %v", i, err)
|
||||
}
|
||||
// fmt.Println(tc, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFunctionParsing(t *testing.T) {
|
||||
var fex FuncExpr
|
||||
p := participle.MustBuild(
|
||||
&FuncExpr{},
|
||||
participle.Lexer(sqlLexer),
|
||||
participle.CaseInsensitive("Keyword"),
|
||||
participle.CaseInsensitive("Timeword"),
|
||||
)
|
||||
|
||||
validCases := []string{
|
||||
"count(*)",
|
||||
"sum(2 + s.id)",
|
||||
"sum(t)",
|
||||
"avg(s.id[1])",
|
||||
"coalesce(s.id[1], 2, 2 + 3)",
|
||||
|
||||
"cast(s as string)",
|
||||
"cast(s AS INT)",
|
||||
"cast(s as DECIMAL)",
|
||||
"extract(YEAR from '2018-01-09')",
|
||||
"extract(month from '2018-01-09')",
|
||||
|
||||
"extract(hour from '2018-01-09')",
|
||||
"extract(day from '2018-01-09')",
|
||||
"substring('abcd' from 2 for 2)",
|
||||
"substring('abcd' from 2)",
|
||||
"substring('abcd' , 2 , 2)",
|
||||
|
||||
"substring('abcd' , 22 )",
|
||||
"trim(' aab ')",
|
||||
"trim(leading from ' aab ')",
|
||||
"trim(trailing from ' aab ')",
|
||||
"trim(both from ' aab ')",
|
||||
|
||||
"trim(both '12' from ' aab ')",
|
||||
"trim(leading '12' from ' aab ')",
|
||||
"trim(trailing '12' from ' aab ')",
|
||||
"count(23)",
|
||||
}
|
||||
for i, tc := range validCases {
|
||||
err := p.ParseString(tc, &fex)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: %v", i, err)
|
||||
}
|
||||
// repr.Println(fex, repr.Indent(" "), repr.OmitEmpty(true))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlLexer(t *testing.T) {
|
||||
// s := bytes.NewBuffer([]byte("s.['name'].*.[*].abc.[\"abc\"]"))
|
||||
s := bytes.NewBuffer([]byte("S3Object.words.*.id"))
|
||||
// s := bytes.NewBuffer([]byte("COUNT(Id)"))
|
||||
lex, err := sqlLexer.Lex(s)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tokens, err := lexer.ConsumeAll(lex)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// for i, t := range tokens {
|
||||
// fmt.Printf("%d: %#v\n", i, t)
|
||||
// }
|
||||
if len(tokens) != 7 {
|
||||
t.Fatalf("Expected 7 got %d", len(tokens))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectWhere(t *testing.T) {
|
||||
p := participle.MustBuild(
|
||||
&Select{},
|
||||
participle.Lexer(sqlLexer),
|
||||
participle.CaseInsensitive("Keyword"),
|
||||
)
|
||||
|
||||
s := Select{}
|
||||
cases := []string{
|
||||
"select * from s3object",
|
||||
"select a, b from s3object s",
|
||||
"select a, b from s3object as s",
|
||||
"select a, b from s3object as s where a = 1",
|
||||
"select a, b from s3object s where a = 1",
|
||||
"select a, b from s3object where a = 1",
|
||||
}
|
||||
for i, tc := range cases {
|
||||
err := p.ParseString(tc, &s)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: %v", i, err)
|
||||
}
|
||||
|
||||
// repr.Println(s, repr.Indent(" "), repr.OmitEmpty(true))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLikeClause(t *testing.T) {
|
||||
p := participle.MustBuild(
|
||||
&Select{},
|
||||
participle.Lexer(sqlLexer),
|
||||
participle.CaseInsensitive("Keyword"),
|
||||
)
|
||||
|
||||
s := Select{}
|
||||
cases := []string{
|
||||
`select * from s3object where Name like 'abcd'`,
|
||||
`select Name like 'abc' from s3object`,
|
||||
`select * from s3object where Name not like 'abc'`,
|
||||
`select * from s3object where Name like 'abc' escape 't'`,
|
||||
`select * from s3object where Name like 'a\%' escape '?'`,
|
||||
`select * from s3object where Name not like 'abc\' escape '?'`,
|
||||
`select * from s3object where Name like 'a\%' escape LOWER('?')`,
|
||||
`select * from s3object where Name not like LOWER('Bc\') escape '?'`,
|
||||
}
|
||||
for i, tc := range cases {
|
||||
err := p.ParseString(tc, &s)
|
||||
if err != nil {
|
||||
t.Errorf("%d: %v", i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBetweenClause(t *testing.T) {
|
||||
p := participle.MustBuild(
|
||||
&Select{},
|
||||
participle.Lexer(sqlLexer),
|
||||
participle.CaseInsensitive("Keyword"),
|
||||
)
|
||||
|
||||
s := Select{}
|
||||
cases := []string{
|
||||
`select * from s3object where Id between 1 and 2`,
|
||||
`select * from s3object where Id between 1 and 2 and name = 'Ab'`,
|
||||
`select * from s3object where Id not between 1 and 2`,
|
||||
`select * from s3object where Id not between 1 and 2 and name = 'Bc'`,
|
||||
}
|
||||
for i, tc := range cases {
|
||||
err := p.ParseString(tc, &s)
|
||||
if err != nil {
|
||||
t.Errorf("%d: %v", i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromClauseJSONPath(t *testing.T) {
|
||||
p := participle.MustBuild(
|
||||
&Select{},
|
||||
participle.Lexer(sqlLexer),
|
||||
participle.CaseInsensitive("Keyword"),
|
||||
)
|
||||
|
||||
s := Select{}
|
||||
cases := []string{
|
||||
"select * from s3object",
|
||||
"select * from s3object[*].name",
|
||||
"select * from s3object[*].books[*]",
|
||||
"select * from s3object[*].books[*].name",
|
||||
"select * from s3object where name > 2",
|
||||
"select * from s3object[*].name where name > 2",
|
||||
"select * from s3object[*].books[*] where name > 2",
|
||||
"select * from s3object[*].books[*].name where name > 2",
|
||||
"select * from s3object[*].books[*] s",
|
||||
"select * from s3object[*].books[*].name as s",
|
||||
"select * from s3object s where name > 2",
|
||||
"select * from s3object[*].name as s where name > 2",
|
||||
}
|
||||
for i, tc := range cases {
|
||||
err := p.ParseString(tc, &s)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: %v", i, err)
|
||||
}
|
||||
|
||||
// repr.Println(s, repr.Indent(" "), repr.OmitEmpty(true))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestSelectParsing(t *testing.T) {
|
||||
p := participle.MustBuild(
|
||||
&Select{},
|
||||
participle.Lexer(sqlLexer),
|
||||
participle.CaseInsensitive("Keyword"),
|
||||
)
|
||||
|
||||
s := Select{}
|
||||
cases := []string{
|
||||
"select * from s3object where name > 2 or value > 1 or word > 2",
|
||||
"select s.word.id + 2 from s3object s",
|
||||
"select 1-2-3 from s3object s limit 1",
|
||||
}
|
||||
for i, tc := range cases {
|
||||
err := p.ParseString(tc, &s)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: %v", i, err)
|
||||
}
|
||||
|
||||
// repr.Println(s, repr.Indent(" "), repr.OmitEmpty(true))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlLexerArithOps(t *testing.T) {
|
||||
s := bytes.NewBuffer([]byte("year from select month hour distinct"))
|
||||
lex, err := sqlLexer.Lex(s)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tokens, err := lexer.ConsumeAll(lex)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(tokens) != 7 {
|
||||
t.Errorf("Expected 7 got %d", len(tokens))
|
||||
}
|
||||
// for i, t := range tokens {
|
||||
// fmt.Printf("%d: %#v\n", i, t)
|
||||
// }
|
||||
}
|
|
@ -1,529 +0,0 @@
|
|||
/*
|
||||
* Minio Cloud Storage, (C) 2019 Minio, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/xwb1989/sqlparser"
|
||||
)
|
||||
|
||||
func getColumnName(colName *sqlparser.ColName) string {
|
||||
columnName := colName.Qualifier.Name.String()
|
||||
if qualifier := colName.Qualifier.Qualifier.String(); qualifier != "" {
|
||||
columnName = qualifier + "." + columnName
|
||||
}
|
||||
|
||||
if columnName == "" {
|
||||
columnName = colName.Name.String()
|
||||
} else {
|
||||
columnName = columnName + "." + colName.Name.String()
|
||||
}
|
||||
|
||||
return columnName
|
||||
}
|
||||
|
||||
func newLiteralExpr(parserExpr sqlparser.Expr, tableAlias string) (Expr, error) {
|
||||
switch parserExpr.(type) {
|
||||
case *sqlparser.NullVal:
|
||||
return newValueExpr(NewNull()), nil
|
||||
case sqlparser.BoolVal:
|
||||
return newValueExpr(NewBool((bool(parserExpr.(sqlparser.BoolVal))))), nil
|
||||
case *sqlparser.SQLVal:
|
||||
sqlValue := parserExpr.(*sqlparser.SQLVal)
|
||||
value, err := NewValue(sqlValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newValueExpr(value), nil
|
||||
case *sqlparser.ColName:
|
||||
columnName := getColumnName(parserExpr.(*sqlparser.ColName))
|
||||
if tableAlias != "" {
|
||||
if !strings.HasPrefix(columnName, tableAlias+".") {
|
||||
err := fmt.Errorf("column name %v does not start with table alias %v", columnName, tableAlias)
|
||||
return nil, errInvalidKeyPath(err)
|
||||
}
|
||||
columnName = strings.TrimPrefix(columnName, tableAlias+".")
|
||||
}
|
||||
|
||||
return newColumnExpr(columnName), nil
|
||||
case sqlparser.ValTuple:
|
||||
var valueType Type
|
||||
var values []*Value
|
||||
for i, valExpr := range parserExpr.(sqlparser.ValTuple) {
|
||||
sqlVal, ok := valExpr.(*sqlparser.SQLVal)
|
||||
if !ok {
|
||||
return nil, errParseInvalidTypeParam(fmt.Errorf("value %v in Tuple should be primitive value", i+1))
|
||||
}
|
||||
|
||||
val, err := NewValue(sqlVal)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if i == 0 {
|
||||
valueType = val.Type()
|
||||
} else if valueType != val.Type() {
|
||||
return nil, errParseInvalidTypeParam(fmt.Errorf("mixed value type is not allowed in Tuple"))
|
||||
}
|
||||
|
||||
values = append(values, val)
|
||||
}
|
||||
|
||||
return newValueExpr(NewArray(values)), nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func isExprToComparisonExpr(parserExpr *sqlparser.IsExpr, tableAlias string, isSelectExpr bool) (Expr, error) {
|
||||
leftExpr, err := newExpr(parserExpr.Expr, tableAlias, isSelectExpr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f, err := newComparisonExpr(ComparisonOperator(parserExpr.Operator), leftExpr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !leftExpr.Type().isBase() {
|
||||
return f, nil
|
||||
}
|
||||
|
||||
value, err := f.Eval(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newValueExpr(value), nil
|
||||
}
|
||||
|
||||
func rangeCondToComparisonFunc(parserExpr *sqlparser.RangeCond, tableAlias string, isSelectExpr bool) (Expr, error) {
|
||||
leftExpr, err := newExpr(parserExpr.Left, tableAlias, isSelectExpr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fromExpr, err := newExpr(parserExpr.From, tableAlias, isSelectExpr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
toExpr, err := newExpr(parserExpr.To, tableAlias, isSelectExpr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f, err := newComparisonExpr(ComparisonOperator(parserExpr.Operator), leftExpr, fromExpr, toExpr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !leftExpr.Type().isBase() || !fromExpr.Type().isBase() || !toExpr.Type().isBase() {
|
||||
return f, nil
|
||||
}
|
||||
|
||||
value, err := f.Eval(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newValueExpr(value), nil
|
||||
}
|
||||
|
||||
func toComparisonExpr(parserExpr *sqlparser.ComparisonExpr, tableAlias string, isSelectExpr bool) (Expr, error) {
|
||||
leftExpr, err := newExpr(parserExpr.Left, tableAlias, isSelectExpr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rightExpr, err := newExpr(parserExpr.Right, tableAlias, isSelectExpr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f, err := newComparisonExpr(ComparisonOperator(parserExpr.Operator), leftExpr, rightExpr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !leftExpr.Type().isBase() || !rightExpr.Type().isBase() {
|
||||
return f, nil
|
||||
}
|
||||
|
||||
value, err := f.Eval(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newValueExpr(value), nil
|
||||
}
|
||||
|
||||
func toArithExpr(parserExpr *sqlparser.BinaryExpr, tableAlias string, isSelectExpr bool) (Expr, error) {
|
||||
leftExpr, err := newExpr(parserExpr.Left, tableAlias, isSelectExpr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rightExpr, err := newExpr(parserExpr.Right, tableAlias, isSelectExpr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f, err := newArithExpr(ArithOperator(parserExpr.Operator), leftExpr, rightExpr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !leftExpr.Type().isBase() || !rightExpr.Type().isBase() {
|
||||
return f, nil
|
||||
}
|
||||
|
||||
value, err := f.Eval(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newValueExpr(value), nil
|
||||
}
|
||||
|
||||
func toFuncExpr(parserExpr *sqlparser.FuncExpr, tableAlias string, isSelectExpr bool) (Expr, error) {
|
||||
funcName := strings.ToUpper(parserExpr.Name.String())
|
||||
if !isSelectExpr && isAggregateFuncName(funcName) {
|
||||
return nil, errUnsupportedSQLOperation(fmt.Errorf("%v() must be used in select expression", funcName))
|
||||
}
|
||||
funcs, aggregatedExprFound, err := newSelectExprs(parserExpr.Exprs, tableAlias)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if aggregatedExprFound {
|
||||
return nil, errIncorrectSQLFunctionArgumentType(fmt.Errorf("%v(): aggregated expression must not be used as argument", funcName))
|
||||
}
|
||||
|
||||
return newFuncExpr(FuncName(funcName), funcs...)
|
||||
}
|
||||
|
||||
func toAndExpr(parserExpr *sqlparser.AndExpr, tableAlias string, isSelectExpr bool) (Expr, error) {
|
||||
leftExpr, err := newExpr(parserExpr.Left, tableAlias, isSelectExpr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rightExpr, err := newExpr(parserExpr.Right, tableAlias, isSelectExpr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f, err := newAndExpr(leftExpr, rightExpr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if leftExpr.Type() != Bool || rightExpr.Type() != Bool {
|
||||
return f, nil
|
||||
}
|
||||
|
||||
value, err := f.Eval(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newValueExpr(value), nil
|
||||
}
|
||||
|
||||
func toOrExpr(parserExpr *sqlparser.OrExpr, tableAlias string, isSelectExpr bool) (Expr, error) {
|
||||
leftExpr, err := newExpr(parserExpr.Left, tableAlias, isSelectExpr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rightExpr, err := newExpr(parserExpr.Right, tableAlias, isSelectExpr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f, err := newOrExpr(leftExpr, rightExpr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if leftExpr.Type() != Bool || rightExpr.Type() != Bool {
|
||||
return f, nil
|
||||
}
|
||||
|
||||
value, err := f.Eval(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newValueExpr(value), nil
|
||||
}
|
||||
|
||||
func toNotExpr(parserExpr *sqlparser.NotExpr, tableAlias string, isSelectExpr bool) (Expr, error) {
|
||||
rightExpr, err := newExpr(parserExpr.Expr, tableAlias, isSelectExpr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f, err := newNotExpr(rightExpr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if rightExpr.Type() != Bool {
|
||||
return f, nil
|
||||
}
|
||||
|
||||
value, err := f.Eval(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newValueExpr(value), nil
|
||||
}
|
||||
|
||||
func newExpr(parserExpr sqlparser.Expr, tableAlias string, isSelectExpr bool) (Expr, error) {
|
||||
f, err := newLiteralExpr(parserExpr, tableAlias)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if f != nil {
|
||||
return f, nil
|
||||
}
|
||||
|
||||
switch parserExpr.(type) {
|
||||
case *sqlparser.ParenExpr:
|
||||
return newExpr(parserExpr.(*sqlparser.ParenExpr).Expr, tableAlias, isSelectExpr)
|
||||
case *sqlparser.IsExpr:
|
||||
return isExprToComparisonExpr(parserExpr.(*sqlparser.IsExpr), tableAlias, isSelectExpr)
|
||||
case *sqlparser.RangeCond:
|
||||
return rangeCondToComparisonFunc(parserExpr.(*sqlparser.RangeCond), tableAlias, isSelectExpr)
|
||||
case *sqlparser.ComparisonExpr:
|
||||
return toComparisonExpr(parserExpr.(*sqlparser.ComparisonExpr), tableAlias, isSelectExpr)
|
||||
case *sqlparser.BinaryExpr:
|
||||
return toArithExpr(parserExpr.(*sqlparser.BinaryExpr), tableAlias, isSelectExpr)
|
||||
case *sqlparser.FuncExpr:
|
||||
return toFuncExpr(parserExpr.(*sqlparser.FuncExpr), tableAlias, isSelectExpr)
|
||||
case *sqlparser.AndExpr:
|
||||
return toAndExpr(parserExpr.(*sqlparser.AndExpr), tableAlias, isSelectExpr)
|
||||
case *sqlparser.OrExpr:
|
||||
return toOrExpr(parserExpr.(*sqlparser.OrExpr), tableAlias, isSelectExpr)
|
||||
case *sqlparser.NotExpr:
|
||||
return toNotExpr(parserExpr.(*sqlparser.NotExpr), tableAlias, isSelectExpr)
|
||||
}
|
||||
|
||||
return nil, errParseUnsupportedSyntax(fmt.Errorf("unknown expression type %T; %v", parserExpr, parserExpr))
|
||||
}
|
||||
|
||||
func newSelectExprs(parserSelectExprs []sqlparser.SelectExpr, tableAlias string) ([]Expr, bool, error) {
|
||||
var funcs []Expr
|
||||
starExprFound := false
|
||||
aggregatedExprFound := false
|
||||
|
||||
for _, selectExpr := range parserSelectExprs {
|
||||
switch selectExpr.(type) {
|
||||
case *sqlparser.AliasedExpr:
|
||||
if starExprFound {
|
||||
return nil, false, errParseAsteriskIsNotAloneInSelectList(nil)
|
||||
}
|
||||
|
||||
aliasedExpr := selectExpr.(*sqlparser.AliasedExpr)
|
||||
f, err := newExpr(aliasedExpr.Expr, tableAlias, true)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
if f.Type() == aggregateFunction {
|
||||
if !aggregatedExprFound {
|
||||
aggregatedExprFound = true
|
||||
if len(funcs) > 0 {
|
||||
return nil, false, errParseUnsupportedSyntax(fmt.Errorf("expression must not mixed with aggregated expression"))
|
||||
}
|
||||
}
|
||||
} else if aggregatedExprFound {
|
||||
return nil, false, errParseUnsupportedSyntax(fmt.Errorf("expression must not mixed with aggregated expression"))
|
||||
}
|
||||
|
||||
alias := aliasedExpr.As.String()
|
||||
if alias != "" {
|
||||
f = newAliasExpr(alias, f)
|
||||
}
|
||||
|
||||
funcs = append(funcs, f)
|
||||
case *sqlparser.StarExpr:
|
||||
if starExprFound {
|
||||
err := fmt.Errorf("only single star expression allowed")
|
||||
return nil, false, errParseInvalidContextForWildcardInSelectList(err)
|
||||
}
|
||||
starExprFound = true
|
||||
funcs = append(funcs, newStarExpr())
|
||||
default:
|
||||
return nil, false, errParseUnsupportedSyntax(fmt.Errorf("unknown select expression %v", selectExpr))
|
||||
}
|
||||
}
|
||||
|
||||
return funcs, aggregatedExprFound, nil
|
||||
}
|
||||
|
||||
// Select - SQL Select statement.
|
||||
type Select struct {
|
||||
tableName string
|
||||
tableAlias string
|
||||
selectExprs []Expr
|
||||
aggregatedExprFound bool
|
||||
whereExpr Expr
|
||||
}
|
||||
|
||||
// TableAlias - returns table alias name.
|
||||
func (statement *Select) TableAlias() string {
|
||||
return statement.tableAlias
|
||||
}
|
||||
|
||||
// IsSelectAll - returns whether '*' is used in select expression or not.
|
||||
func (statement *Select) IsSelectAll() bool {
|
||||
if len(statement.selectExprs) == 1 {
|
||||
_, ok := statement.selectExprs[0].(*starExpr)
|
||||
return ok
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// IsAggregated - returns whether aggregated functions are used in select expression or not.
|
||||
func (statement *Select) IsAggregated() bool {
|
||||
return statement.aggregatedExprFound
|
||||
}
|
||||
|
||||
// AggregateResult - returns aggregate result as record.
|
||||
func (statement *Select) AggregateResult(output Record) error {
|
||||
if !statement.aggregatedExprFound {
|
||||
return nil
|
||||
}
|
||||
|
||||
for i, expr := range statement.selectExprs {
|
||||
value, err := expr.AggregateValue()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if value == nil {
|
||||
return errInternalError(fmt.Errorf("%v returns <nil> for AggregateValue()", expr))
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("_%v", i+1)
|
||||
if _, ok := expr.(*aliasExpr); ok {
|
||||
name = expr.(*aliasExpr).alias
|
||||
}
|
||||
|
||||
if err = output.Set(name, value); err != nil {
|
||||
return errInternalError(fmt.Errorf("error occurred to store value %v for %v; %v", value, name, err))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Eval - evaluates this Select expressions for given record.
|
||||
func (statement *Select) Eval(input, output Record) (Record, error) {
|
||||
if statement.whereExpr != nil {
|
||||
value, err := statement.whereExpr.Eval(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if value == nil || value.valueType != Bool {
|
||||
err = fmt.Errorf("WHERE expression %v returns invalid bool value %v", statement.whereExpr, value)
|
||||
return nil, errInternalError(err)
|
||||
}
|
||||
|
||||
if !value.BoolValue() {
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Call selectExprs
|
||||
for i, expr := range statement.selectExprs {
|
||||
value, err := expr.Eval(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if statement.aggregatedExprFound {
|
||||
continue
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("_%v", i+1)
|
||||
switch expr.(type) {
|
||||
case *starExpr:
|
||||
return value.recordValue(), nil
|
||||
case *aliasExpr:
|
||||
name = expr.(*aliasExpr).alias
|
||||
case *columnExpr:
|
||||
name = expr.(*columnExpr).name
|
||||
}
|
||||
|
||||
if err = output.Set(name, value); err != nil {
|
||||
return nil, errInternalError(fmt.Errorf("error occurred to store value %v for %v; %v", value, name, err))
|
||||
}
|
||||
}
|
||||
|
||||
return output, nil
|
||||
}
|
||||
|
||||
// NewSelect - creates new Select by parsing sql.
|
||||
func NewSelect(sql string) (*Select, error) {
|
||||
stmt, err := sqlparser.Parse(sql)
|
||||
if err != nil {
|
||||
return nil, errUnsupportedSQLStructure(err)
|
||||
}
|
||||
|
||||
selectStmt, ok := stmt.(*sqlparser.Select)
|
||||
if !ok {
|
||||
return nil, errParseUnsupportedSelect(fmt.Errorf("unsupported SQL statement %v", sql))
|
||||
}
|
||||
|
||||
var tableName, tableAlias string
|
||||
for _, fromExpr := range selectStmt.From {
|
||||
tableExpr := fromExpr.(*sqlparser.AliasedTableExpr)
|
||||
tableName = tableExpr.Expr.(sqlparser.TableName).Name.String()
|
||||
tableAlias = tableExpr.As.String()
|
||||
}
|
||||
|
||||
selectExprs, aggregatedExprFound, err := newSelectExprs(selectStmt.SelectExprs, tableAlias)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var whereExpr Expr
|
||||
if selectStmt.Where != nil {
|
||||
whereExpr, err = newExpr(selectStmt.Where.Expr, tableAlias, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &Select{
|
||||
tableName: tableName,
|
||||
tableAlias: tableAlias,
|
||||
selectExprs: selectExprs,
|
||||
aggregatedExprFound: aggregatedExprFound,
|
||||
whereExpr: whereExpr,
|
||||
}, nil
|
||||
}
|
|
@ -0,0 +1,202 @@
|
|||
/*
|
||||
* Minio Cloud Storage, (C) 2019 Minio, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
errBadLimitSpecified = errors.New("Limit value must be a positive integer")
|
||||
)
|
||||
|
||||
// SelectStatement is the top level parsed and analyzed structure
|
||||
type SelectStatement struct {
|
||||
selectAST *Select
|
||||
|
||||
// Analysis result of the statement
|
||||
selectQProp qProp
|
||||
|
||||
// Result of parsing the limit clause if one is present
|
||||
// (otherwise -1)
|
||||
limitValue int64
|
||||
|
||||
// Count of rows that have been output.
|
||||
outputCount int64
|
||||
}
|
||||
|
||||
// ParseSelectStatement - parses a select query from the given string
|
||||
// and analyzes it.
|
||||
func ParseSelectStatement(s string) (stmt SelectStatement, err error) {
|
||||
var selectAST Select
|
||||
err = SQLParser.ParseString(s, &selectAST)
|
||||
if err != nil {
|
||||
err = errQueryParseFailure(err)
|
||||
return
|
||||
}
|
||||
stmt.selectAST = &selectAST
|
||||
|
||||
// Check the parsed limit value
|
||||
stmt.limitValue, err = parseLimit(selectAST.Limit)
|
||||
if err != nil {
|
||||
err = errQueryAnalysisFailure(err)
|
||||
return
|
||||
}
|
||||
|
||||
// Analyze where clause
|
||||
if selectAST.Where != nil {
|
||||
whereQProp := selectAST.Where.analyze(&selectAST)
|
||||
if whereQProp.err != nil {
|
||||
err = errQueryAnalysisFailure(fmt.Errorf("Where clause error: %v", whereQProp.err))
|
||||
return
|
||||
}
|
||||
|
||||
if whereQProp.isAggregation {
|
||||
err = errQueryAnalysisFailure(errors.New("WHERE clause cannot have an aggregation"))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Validate table name
|
||||
tableString := strings.ToLower(selectAST.From.Table.String())
|
||||
if !strings.HasPrefix(tableString, "s3object.") && tableString != "s3object" {
|
||||
err = errBadTableName(errors.New("Table name must be s3object"))
|
||||
return
|
||||
}
|
||||
|
||||
// Analyze main select expression
|
||||
stmt.selectQProp = selectAST.Expression.analyze(&selectAST)
|
||||
err = stmt.selectQProp.err
|
||||
if err != nil {
|
||||
fmt.Println("Got Analysis err:", err)
|
||||
err = errQueryAnalysisFailure(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func parseLimit(v *LitValue) (int64, error) {
|
||||
switch {
|
||||
case v == nil:
|
||||
return -1, nil
|
||||
case v.Number == nil:
|
||||
return -1, errBadLimitSpecified
|
||||
default:
|
||||
r := int64(*v.Number)
|
||||
if r < 0 {
|
||||
return -1, errBadLimitSpecified
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
}
|
||||
|
||||
// IsAggregated returns if the statement involves SQL aggregation
|
||||
func (e *SelectStatement) IsAggregated() bool {
|
||||
return e.selectQProp.isAggregation
|
||||
}
|
||||
|
||||
// AggregateResult - returns the aggregated result after all input
|
||||
// records have been processed. Applies only to aggregation queries.
|
||||
func (e *SelectStatement) AggregateResult(output Record) error {
|
||||
for i, expr := range e.selectAST.Expression.Expressions {
|
||||
v, err := expr.evalNode(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
output.Set(fmt.Sprintf("_%d", i+1), v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AggregateRow - aggregates the input record. Applies only to
|
||||
// aggregation queries.
|
||||
func (e *SelectStatement) AggregateRow(input Record) error {
|
||||
for _, expr := range e.selectAST.Expression.Expressions {
|
||||
err := expr.aggregateRow(input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Eval - evaluates the Select statement for the given record. It
|
||||
// applies only to non-aggregation queries.
|
||||
func (e *SelectStatement) Eval(input, output Record) (Record, error) {
|
||||
if whereExpr := e.selectAST.Where; whereExpr != nil {
|
||||
value, err := whereExpr.evalNode(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b, ok := value.ToBool()
|
||||
if !ok {
|
||||
err = fmt.Errorf("WHERE expression did not return bool")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !b {
|
||||
// Where clause is not satisfied by the row
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
if e.selectAST.Expression.All {
|
||||
// Return the input record for `SELECT * FROM
|
||||
// .. WHERE ..`
|
||||
|
||||
// Update count of records output.
|
||||
if e.limitValue > -1 {
|
||||
e.outputCount++
|
||||
}
|
||||
|
||||
return input, nil
|
||||
}
|
||||
|
||||
for i, expr := range e.selectAST.Expression.Expressions {
|
||||
v, err := expr.evalNode(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Pick output column names
|
||||
if expr.As != "" {
|
||||
output.Set(expr.As, v)
|
||||
} else if comp, ok := getLastKeypathComponent(expr.Expression); ok {
|
||||
output.Set(comp, v)
|
||||
} else {
|
||||
output.Set(fmt.Sprintf("_%d", i+1), v)
|
||||
}
|
||||
}
|
||||
|
||||
// Update count of records output.
|
||||
if e.limitValue > -1 {
|
||||
e.outputCount++
|
||||
}
|
||||
|
||||
return output, nil
|
||||
}
|
||||
|
||||
// LimitReached - returns true if the number of records output has
|
||||
// reached the value of the `LIMIT` clause.
|
||||
func (e *SelectStatement) LimitReached() bool {
|
||||
if e.limitValue == -1 {
|
||||
return false
|
||||
}
|
||||
return e.outputCount >= e.limitValue
|
||||
}
|
|
@ -0,0 +1,188 @@
|
|||
/*
|
||||
* Minio Cloud Storage, (C) 2019 Minio, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
errMalformedEscapeSequence = errors.New("Malformed escape sequence in LIKE clause")
|
||||
errInvalidTrimArg = errors.New("Trim argument is invalid - this should not happen")
|
||||
errInvalidSubstringIndexLen = errors.New("Substring start index or length falls outside the string")
|
||||
)
|
||||
|
||||
const (
|
||||
percent rune = '%'
|
||||
underscore rune = '_'
|
||||
runeZero rune = 0
|
||||
)
|
||||
|
||||
func evalSQLLike(text, pattern string, escape rune) (match bool, err error) {
|
||||
s := []rune{}
|
||||
prev := runeZero
|
||||
hasLeadingPercent := false
|
||||
patLen := len([]rune(pattern))
|
||||
for i, r := range pattern {
|
||||
if i > 0 && prev == escape {
|
||||
switch r {
|
||||
case percent, escape, underscore:
|
||||
s = append(s, r)
|
||||
prev = r
|
||||
if r == escape {
|
||||
prev = runeZero
|
||||
}
|
||||
default:
|
||||
return false, errMalformedEscapeSequence
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
prev = r
|
||||
|
||||
var ok bool
|
||||
switch r {
|
||||
case percent:
|
||||
if len(s) == 0 {
|
||||
hasLeadingPercent = true
|
||||
continue
|
||||
}
|
||||
|
||||
text, ok = matcher(text, string(s), hasLeadingPercent)
|
||||
if !ok {
|
||||
return false, nil
|
||||
}
|
||||
hasLeadingPercent = true
|
||||
s = []rune{}
|
||||
|
||||
if i == patLen-1 {
|
||||
// Last pattern character is a %, so
|
||||
// we are done.
|
||||
return true, nil
|
||||
}
|
||||
|
||||
case underscore:
|
||||
if len(s) == 0 {
|
||||
text, ok = dropRune(text)
|
||||
if !ok {
|
||||
return false, nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
text, ok = matcher(text, string(s), hasLeadingPercent)
|
||||
if !ok {
|
||||
return false, nil
|
||||
}
|
||||
hasLeadingPercent = false
|
||||
|
||||
text, ok = dropRune(text)
|
||||
if !ok {
|
||||
return false, nil
|
||||
}
|
||||
s = []rune{}
|
||||
|
||||
case escape:
|
||||
if i == patLen-1 {
|
||||
return false, errMalformedEscapeSequence
|
||||
}
|
||||
// Otherwise do nothing.
|
||||
|
||||
default:
|
||||
s = append(s, r)
|
||||
}
|
||||
|
||||
}
|
||||
if hasLeadingPercent {
|
||||
return strings.HasSuffix(text, string(s)), nil
|
||||
}
|
||||
return string(s) == text, nil
|
||||
}
|
||||
|
||||
// matcher - Finds `pat` in `text`, and returns the part remainder of
|
||||
// `text`, after the match. If leadingPercent is false, `pat` must be
|
||||
// the prefix of `text`, otherwise it must be a substring.
|
||||
func matcher(text, pat string, leadingPercent bool) (res string, match bool) {
|
||||
if !leadingPercent {
|
||||
res = strings.TrimPrefix(text, pat)
|
||||
if len(text) == len(res) {
|
||||
return "", false
|
||||
}
|
||||
} else {
|
||||
parts := strings.SplitN(text, pat, 2)
|
||||
if len(parts) == 1 {
|
||||
return "", false
|
||||
}
|
||||
res = parts[1]
|
||||
}
|
||||
return res, true
|
||||
}
|
||||
|
||||
func dropRune(text string) (res string, ok bool) {
|
||||
r := []rune(text)
|
||||
if len(r) == 0 {
|
||||
return "", false
|
||||
}
|
||||
return string(r[1:]), true
|
||||
}
|
||||
|
||||
func evalSQLSubstring(s string, startIdx, length int) (res string, err error) {
|
||||
if startIdx <= 0 || startIdx > len(s) {
|
||||
return "", errInvalidSubstringIndexLen
|
||||
}
|
||||
// StartIdx is 1-based in the input
|
||||
startIdx--
|
||||
|
||||
rs := []rune(s)
|
||||
endIdx := len(rs)
|
||||
if length != -1 {
|
||||
if length < 0 || startIdx+length > len(s) {
|
||||
return "", errInvalidSubstringIndexLen
|
||||
}
|
||||
endIdx = startIdx + length
|
||||
}
|
||||
|
||||
return string(rs[startIdx:endIdx]), nil
|
||||
}
|
||||
|
||||
const (
|
||||
trimLeading = "LEADING"
|
||||
trimTrailing = "TRAILING"
|
||||
trimBoth = "BOTH"
|
||||
)
|
||||
|
||||
func evalSQLTrim(where *string, trimChars, text string) (result string, err error) {
|
||||
cutSet := " "
|
||||
if trimChars != "" {
|
||||
cutSet = trimChars
|
||||
}
|
||||
|
||||
trimFunc := strings.Trim
|
||||
switch {
|
||||
case where == nil:
|
||||
case *where == trimBoth:
|
||||
case *where == trimLeading:
|
||||
trimFunc = strings.TrimLeft
|
||||
case *where == trimTrailing:
|
||||
trimFunc = strings.TrimRight
|
||||
default:
|
||||
return "", errInvalidTrimArg
|
||||
}
|
||||
|
||||
return trimFunc(text, cutSet), nil
|
||||
}
|
|
@ -0,0 +1,107 @@
|
|||
/*
|
||||
* Minio Cloud Storage, (C) 2019 Minio, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEvalSQLLike(t *testing.T) {
|
||||
dropCases := []struct {
|
||||
input, resultExpected string
|
||||
matchExpected bool
|
||||
}{
|
||||
{"", "", false},
|
||||
{"a", "", true},
|
||||
{"ab", "b", true},
|
||||
{"தமிழ்", "மிழ்", true},
|
||||
}
|
||||
|
||||
for i, tc := range dropCases {
|
||||
res, ok := dropRune(tc.input)
|
||||
if res != tc.resultExpected || ok != tc.matchExpected {
|
||||
t.Errorf("DropRune Case %d failed", i)
|
||||
}
|
||||
}
|
||||
|
||||
matcherCases := []struct {
|
||||
iText, iPat string
|
||||
iHasLeadingPercent bool
|
||||
resultExpected string
|
||||
matchExpected bool
|
||||
}{
|
||||
{"abcd", "bcd", false, "", false},
|
||||
{"abcd", "bcd", true, "", true},
|
||||
{"abcd", "abcd", false, "", true},
|
||||
{"abcd", "abcd", true, "", true},
|
||||
{"abcd", "ab", false, "cd", true},
|
||||
{"abcd", "ab", true, "cd", true},
|
||||
{"abcd", "bc", false, "", false},
|
||||
{"abcd", "bc", true, "d", true},
|
||||
}
|
||||
|
||||
for i, tc := range matcherCases {
|
||||
res, ok := matcher(tc.iText, tc.iPat, tc.iHasLeadingPercent)
|
||||
if res != tc.resultExpected || ok != tc.matchExpected {
|
||||
t.Errorf("Matcher Case %d failed", i)
|
||||
}
|
||||
}
|
||||
|
||||
evalCases := []struct {
|
||||
iText, iPat string
|
||||
iEsc rune
|
||||
matchExpected bool
|
||||
errExpected error
|
||||
}{
|
||||
{"abcd", "abc", runeZero, false, nil},
|
||||
{"abcd", "abcd", runeZero, true, nil},
|
||||
{"abcd", "abc_", runeZero, true, nil},
|
||||
{"abcd", "_bdd", runeZero, false, nil},
|
||||
{"abcd", "_b_d", runeZero, true, nil},
|
||||
|
||||
{"abcd", "____", runeZero, true, nil},
|
||||
{"abcd", "____%", runeZero, true, nil},
|
||||
{"abcd", "%____", runeZero, true, nil},
|
||||
{"abcd", "%__", runeZero, true, nil},
|
||||
{"abcd", "____", runeZero, true, nil},
|
||||
|
||||
{"", "_", runeZero, false, nil},
|
||||
{"", "%", runeZero, true, nil},
|
||||
{"abcd", "%%%%%", runeZero, true, nil},
|
||||
{"abcd", "_____", runeZero, false, nil},
|
||||
{"abcd", "%%%%%", runeZero, true, nil},
|
||||
|
||||
{"a%%d", `a\%\%d`, '\\', true, nil},
|
||||
{"a%%d", `a\%d`, '\\', false, nil},
|
||||
{`a%%\d`, `a\%\%\\d`, '\\', true, nil},
|
||||
{`a%%\`, `a\%\%\\`, '\\', true, nil},
|
||||
{`a%__%\`, `a\%\_\_\%\\`, '\\', true, nil},
|
||||
|
||||
{`a%__%\`, `a\%\_\_\%_`, '\\', true, nil},
|
||||
{`a%__%\`, `a\%\_\__`, '\\', false, nil},
|
||||
{`a%__%\`, `a\%\_\_%`, '\\', true, nil},
|
||||
{`a%__%\`, `a?%?_?_?%\`, '?', true, nil},
|
||||
}
|
||||
|
||||
for i, tc := range evalCases {
|
||||
// fmt.Println("Case:", i)
|
||||
res, err := evalSQLLike(tc.iText, tc.iPat, tc.iEsc)
|
||||
if res != tc.matchExpected || err != tc.errExpected {
|
||||
t.Errorf("Eval Case %d failed: %v %v", i, res, err)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,118 +0,0 @@
|
|||
/*
|
||||
* Minio Cloud Storage, (C) 2019 Minio, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package sql
|
||||
|
||||
// Type - value type.
|
||||
type Type string
|
||||
|
||||
const (
|
||||
// Null - represents NULL value type.
|
||||
Null Type = "null"
|
||||
|
||||
// Bool - represents boolean value type.
|
||||
Bool Type = "bool"
|
||||
|
||||
// Int - represents integer value type.
|
||||
Int Type = "int"
|
||||
|
||||
// Float - represents floating point value type.
|
||||
Float Type = "float"
|
||||
|
||||
// String - represents string value type.
|
||||
String Type = "string"
|
||||
|
||||
// Timestamp - represents time value type.
|
||||
Timestamp Type = "timestamp"
|
||||
|
||||
// Array - represents array of values where each value type is one of above.
|
||||
Array Type = "array"
|
||||
|
||||
column Type = "column"
|
||||
record Type = "record"
|
||||
function Type = "function"
|
||||
aggregateFunction Type = "aggregatefunction"
|
||||
arithmeticFunction Type = "arithmeticfunction"
|
||||
comparisonFunction Type = "comparisonfunction"
|
||||
logicalFunction Type = "logicalfunction"
|
||||
|
||||
// Integer Type = "integer" // Same as Int
|
||||
// Decimal Type = "decimal" // Same as Float
|
||||
// Numeric Type = "numeric" // Same as Float
|
||||
)
|
||||
|
||||
func (t Type) isBase() bool {
|
||||
switch t {
|
||||
case Null, Bool, Int, Float, String, Timestamp:
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (t Type) isBaseKind() bool {
|
||||
switch t {
|
||||
case Null, Bool, Int, Float, String, Timestamp, column:
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (t Type) isNumber() bool {
|
||||
switch t {
|
||||
case Int, Float:
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (t Type) isNumberKind() bool {
|
||||
switch t {
|
||||
case Int, Float, column:
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (t Type) isIntKind() bool {
|
||||
switch t {
|
||||
case Int, column:
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (t Type) isBoolKind() bool {
|
||||
switch t {
|
||||
case Bool, column:
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (t Type) isStringKind() bool {
|
||||
switch t {
|
||||
case String, column:
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
|
@ -0,0 +1,87 @@
|
|||
/*
|
||||
* Minio Cloud Storage, (C) 2019 Minio, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// String functions
|
||||
|
||||
// String - returns the JSONPath representation
|
||||
func (e *JSONPath) String() string {
|
||||
parts := make([]string, len(e.PathExpr)+1)
|
||||
parts[0] = e.BaseKey.String()
|
||||
for i, pe := range e.PathExpr {
|
||||
parts[i+1] = pe.String()
|
||||
}
|
||||
return strings.Join(parts, "")
|
||||
}
|
||||
|
||||
func (e *JSONPathElement) String() string {
|
||||
switch {
|
||||
case e.Key != nil:
|
||||
return e.Key.String()
|
||||
case e.Index != nil:
|
||||
return fmt.Sprintf("[%d]", *e.Index)
|
||||
case e.ObjectWildcard:
|
||||
return ".*"
|
||||
case e.ArrayWildcard:
|
||||
return "[*]"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Removes double quotes in quoted identifiers
|
||||
func (i *Identifier) String() string {
|
||||
if i.Unquoted != nil {
|
||||
return *i.Unquoted
|
||||
}
|
||||
return string(*i.Quoted)
|
||||
}
|
||||
|
||||
func (o *ObjectKey) String() string {
|
||||
if o.Lit != nil {
|
||||
return fmt.Sprintf("['%s']", string(*o.Lit))
|
||||
}
|
||||
return fmt.Sprintf(".%s", o.ID.String())
|
||||
}
|
||||
|
||||
// getLastKeypathComponent checks if the given expression is a path
|
||||
// expression, and if so extracts the last dot separated component of
|
||||
// the path. Otherwise it returns false.
|
||||
func getLastKeypathComponent(e *Expression) (string, bool) {
|
||||
if len(e.And) > 1 ||
|
||||
len(e.And[0].Condition) > 1 ||
|
||||
e.And[0].Condition[0].Not != nil ||
|
||||
e.And[0].Condition[0].Operand.ConditionRHS != nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
operand := e.And[0].Condition[0].Operand.Operand
|
||||
if operand.Right != nil ||
|
||||
operand.Left.Right != nil ||
|
||||
operand.Left.Left.Negated != nil ||
|
||||
operand.Left.Left.Primary.JPathExpr == nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
keypath := operand.Left.Left.Primary.JPathExpr.String()
|
||||
ps := strings.Split(keypath, ".")
|
||||
return ps[len(ps)-1], true
|
||||
}
|
|
@ -17,220 +17,710 @@
|
|||
package sql
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/xwb1989/sqlparser"
|
||||
)
|
||||
|
||||
// Value - represents any primitive value of bool, int, float, string and time.
|
||||
var (
|
||||
errArithMismatchedTypes = errors.New("cannot perform arithmetic on mismatched types")
|
||||
errArithInvalidOperator = errors.New("invalid arithmetic operator")
|
||||
errArithDivideByZero = errors.New("cannot divide by 0")
|
||||
|
||||
errCmpMismatchedTypes = errors.New("cannot compare values of different types")
|
||||
errCmpInvalidBoolOperator = errors.New("invalid comparison operator for boolean arguments")
|
||||
)
|
||||
|
||||
// vType represents the concrete type of a `Value`
|
||||
type vType int
|
||||
|
||||
// Valid values for Type
|
||||
const (
|
||||
typeNull vType = iota + 1
|
||||
typeBool
|
||||
typeString
|
||||
|
||||
// 64-bit signed integer
|
||||
typeInt
|
||||
|
||||
// 64-bit floating point
|
||||
typeFloat
|
||||
|
||||
// This type refers to untyped values, e.g. as read from CSV
|
||||
typeBytes
|
||||
)
|
||||
|
||||
// Value represents a value of restricted type reduced from an
|
||||
// expression represented by an ASTNode. Only one of the fields is
|
||||
// non-nil.
|
||||
//
|
||||
// In cases where we are fetching data from a data source (like csv),
|
||||
// the type may not be determined yet. In these cases, a byte-slice is
|
||||
// used.
|
||||
type Value struct {
|
||||
value interface{}
|
||||
valueType Type
|
||||
value interface{}
|
||||
vType vType
|
||||
}
|
||||
|
||||
// String - represents value as string.
|
||||
func (value *Value) String() string {
|
||||
if value.value == nil {
|
||||
if value.valueType == Null {
|
||||
return "NULL"
|
||||
}
|
||||
|
||||
return "<nil>"
|
||||
// GetTypeString returns a string representation for vType
|
||||
func (v *Value) GetTypeString() string {
|
||||
switch v.vType {
|
||||
case typeNull:
|
||||
return "NULL"
|
||||
case typeBool:
|
||||
return "BOOL"
|
||||
case typeString:
|
||||
return "STRING"
|
||||
case typeInt:
|
||||
return "INT"
|
||||
case typeFloat:
|
||||
return "FLOAT"
|
||||
case typeBytes:
|
||||
return "BYTES"
|
||||
}
|
||||
|
||||
switch value.valueType {
|
||||
case String:
|
||||
return fmt.Sprintf("'%v'", value.value)
|
||||
case Array:
|
||||
var valueStrings []string
|
||||
for _, v := range value.value.([]*Value) {
|
||||
valueStrings = append(valueStrings, fmt.Sprintf("%v", v))
|
||||
}
|
||||
return fmt.Sprintf("(%v)", strings.Join(valueStrings, ","))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%v", value.value)
|
||||
return "--"
|
||||
}
|
||||
|
||||
// CSVString - encodes to CSV string.
|
||||
func (value *Value) CSVString() string {
|
||||
if value.valueType == Null {
|
||||
// Repr returns a string representation of value.
|
||||
func (v *Value) Repr() string {
|
||||
switch v.vType {
|
||||
case typeNull:
|
||||
return ":NULL"
|
||||
case typeBool, typeInt, typeFloat:
|
||||
return fmt.Sprintf("%v:%s", v.value, v.GetTypeString())
|
||||
case typeString:
|
||||
return fmt.Sprintf("\"%s\":%s", v.value.(string), v.GetTypeString())
|
||||
case typeBytes:
|
||||
return fmt.Sprintf("\"%s\":BYTES", string(v.value.([]byte)))
|
||||
default:
|
||||
return fmt.Sprintf("%v:INVALID", v.value)
|
||||
}
|
||||
}
|
||||
|
||||
// FromFloat creates a Value from a number
|
||||
func FromFloat(f float64) *Value {
|
||||
return &Value{value: f, vType: typeFloat}
|
||||
}
|
||||
|
||||
// FromInt creates a Value from an int
|
||||
func FromInt(f int64) *Value {
|
||||
return &Value{value: f, vType: typeInt}
|
||||
}
|
||||
|
||||
// FromString creates a Value from a string
|
||||
func FromString(str string) *Value {
|
||||
return &Value{value: str, vType: typeString}
|
||||
}
|
||||
|
||||
// FromBool creates a Value from a bool
|
||||
func FromBool(b bool) *Value {
|
||||
return &Value{value: b, vType: typeBool}
|
||||
}
|
||||
|
||||
// FromNull creates a Value with Null value
|
||||
func FromNull() *Value {
|
||||
return &Value{vType: typeNull}
|
||||
}
|
||||
|
||||
// FromBytes creates a Value from a []byte
|
||||
func FromBytes(b []byte) *Value {
|
||||
return &Value{value: b, vType: typeBytes}
|
||||
}
|
||||
|
||||
// ToFloat works for int and float values
|
||||
func (v *Value) ToFloat() (val float64, ok bool) {
|
||||
switch v.vType {
|
||||
case typeFloat:
|
||||
val, ok = v.value.(float64)
|
||||
case typeInt:
|
||||
var i int64
|
||||
i, ok = v.value.(int64)
|
||||
val = float64(i)
|
||||
default:
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// ToInt converts value to int.
|
||||
func (v *Value) ToInt() (val int64, ok bool) {
|
||||
switch v.vType {
|
||||
case typeInt:
|
||||
val, ok = v.value.(int64)
|
||||
default:
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// ToString converts value to string.
|
||||
func (v *Value) ToString() (val string, ok bool) {
|
||||
switch v.vType {
|
||||
case typeString:
|
||||
val, ok = v.value.(string)
|
||||
default:
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// ToBool returns the bool value; second return value refers to if the bool
|
||||
// conversion succeeded.
|
||||
func (v *Value) ToBool() (val bool, ok bool) {
|
||||
switch v.vType {
|
||||
case typeBool:
|
||||
return v.value.(bool), true
|
||||
}
|
||||
return false, false
|
||||
}
|
||||
|
||||
// ToBytes converts Value to byte-slice.
|
||||
func (v *Value) ToBytes() ([]byte, bool) {
|
||||
switch v.vType {
|
||||
case typeBytes:
|
||||
return v.value.([]byte), true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// IsNull - checks if value is missing.
|
||||
func (v *Value) IsNull() bool {
|
||||
return v.vType == typeNull
|
||||
}
|
||||
|
||||
func (v *Value) isNumeric() bool {
|
||||
return v.vType == typeInt || v.vType == typeFloat
|
||||
}
|
||||
|
||||
// setters used internally to mutate values
|
||||
|
||||
func (v *Value) setInt(i int64) {
|
||||
v.vType = typeInt
|
||||
v.value = i
|
||||
}
|
||||
|
||||
func (v *Value) setFloat(f float64) {
|
||||
v.vType = typeFloat
|
||||
v.value = f
|
||||
}
|
||||
|
||||
func (v *Value) setString(s string) {
|
||||
v.vType = typeString
|
||||
v.value = s
|
||||
}
|
||||
|
||||
func (v *Value) setBool(b bool) {
|
||||
v.vType = typeBool
|
||||
v.value = b
|
||||
}
|
||||
|
||||
// CSVString - convert to string for CSV serialization
|
||||
func (v *Value) CSVString() string {
|
||||
switch v.vType {
|
||||
case typeNull:
|
||||
return ""
|
||||
case typeBool:
|
||||
return fmt.Sprintf("%v", v.value.(bool))
|
||||
case typeString:
|
||||
return fmt.Sprintf("%s", v.value.(string))
|
||||
case typeInt:
|
||||
return fmt.Sprintf("%v", v.value.(int64))
|
||||
case typeFloat:
|
||||
return fmt.Sprintf("%v", v.value.(float64))
|
||||
case typeBytes:
|
||||
return fmt.Sprintf("%v", string(v.value.([]byte)))
|
||||
default:
|
||||
return "CSV serialization not implemented for this type"
|
||||
}
|
||||
}
|
||||
|
||||
// floatToValue converts a float into int representation if needed.
|
||||
func floatToValue(f float64) *Value {
|
||||
intPart, fracPart := math.Modf(f)
|
||||
if fracPart == 0 {
|
||||
return FromInt(int64(intPart))
|
||||
}
|
||||
return FromFloat(f)
|
||||
}
|
||||
|
||||
// Value comparison functions: we do not expose them outside the
|
||||
// module. Logical operators "<", ">", ">=", "<=" work on strings and
|
||||
// numbers. Equality operators "=", "!=" work on strings,
|
||||
// numbers and booleans.
|
||||
|
||||
// Supported comparison operators
|
||||
const (
|
||||
opLt = "<"
|
||||
opLte = "<="
|
||||
opGt = ">"
|
||||
opGte = ">="
|
||||
opEq = "="
|
||||
opIneq = "!="
|
||||
)
|
||||
|
||||
// When numeric types are compared, type promotions could happen. If
|
||||
// values do not have types (e.g. when reading from CSV), for
|
||||
// comparison operations, automatic type conversion happens by trying
|
||||
// to check if the value is a number (first an integer, then a float),
|
||||
// and falling back to string.
|
||||
func (v *Value) compareOp(op string, a *Value) (res bool, err error) {
|
||||
if !isValidComparisonOperator(op) {
|
||||
return false, errArithInvalidOperator
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%v", value.value)
|
||||
// Check if type conversion/inference is needed - it is needed
|
||||
// if the Value is a byte-slice.
|
||||
err = inferTypesForCmp(v, a)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
isNumeric := v.isNumeric() && a.isNumeric()
|
||||
if isNumeric {
|
||||
intV, ok1i := v.ToInt()
|
||||
intA, ok2i := a.ToInt()
|
||||
if ok1i && ok2i {
|
||||
return intCompare(op, intV, intA), nil
|
||||
}
|
||||
|
||||
// If both values are numeric, then at least one is
|
||||
// float since we got here, so we convert.
|
||||
flV, _ := v.ToFloat()
|
||||
flA, _ := a.ToFloat()
|
||||
return floatCompare(op, flV, flA), nil
|
||||
}
|
||||
|
||||
strV, ok1s := v.ToString()
|
||||
strA, ok2s := a.ToString()
|
||||
if ok1s && ok2s {
|
||||
return stringCompare(op, strV, strA), nil
|
||||
}
|
||||
|
||||
boolV, ok1b := v.ToBool()
|
||||
boolA, ok2b := v.ToBool()
|
||||
if ok1b && ok2b {
|
||||
return boolCompare(op, boolV, boolA)
|
||||
}
|
||||
|
||||
return false, errCmpMismatchedTypes
|
||||
}
|
||||
|
||||
// MarshalJSON - encodes to JSON data.
|
||||
func (value *Value) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(value.value)
|
||||
func inferTypesForCmp(a *Value, b *Value) error {
|
||||
_, okA := a.ToBytes()
|
||||
_, okB := b.ToBytes()
|
||||
switch {
|
||||
case !okA && !okB:
|
||||
// Both Values already have types
|
||||
return nil
|
||||
|
||||
case okA && okB:
|
||||
// Both Values are untyped so try the types in order:
|
||||
// int, float, bool, string
|
||||
|
||||
// Check for numeric inference
|
||||
iA, okAi := a.bytesToInt()
|
||||
iB, okBi := b.bytesToInt()
|
||||
if okAi && okBi {
|
||||
a.setInt(iA)
|
||||
b.setInt(iB)
|
||||
return nil
|
||||
}
|
||||
|
||||
fA, okAf := a.bytesToFloat()
|
||||
fB, okBf := b.bytesToFloat()
|
||||
if okAf && okBf {
|
||||
a.setFloat(fA)
|
||||
b.setFloat(fB)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if they int and float combination.
|
||||
if okAi && okBf {
|
||||
a.setInt(iA)
|
||||
b.setFloat(fA)
|
||||
return nil
|
||||
}
|
||||
if okBi && okAf {
|
||||
a.setFloat(fA)
|
||||
b.setInt(iB)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Not numeric types at this point.
|
||||
|
||||
// Check for bool inference
|
||||
bA, okAb := a.bytesToBool()
|
||||
bB, okBb := b.bytesToBool()
|
||||
if okAb && okBb {
|
||||
a.setBool(bA)
|
||||
b.setBool(bB)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Fallback to string
|
||||
sA := a.bytesToString()
|
||||
sB := b.bytesToString()
|
||||
a.setString(sA)
|
||||
b.setString(sB)
|
||||
return nil
|
||||
|
||||
case okA && !okB:
|
||||
// Here a has `a` is untyped, but `b` has a fixed
|
||||
// type.
|
||||
switch b.vType {
|
||||
case typeString:
|
||||
s := a.bytesToString()
|
||||
a.setString(s)
|
||||
|
||||
case typeInt, typeFloat:
|
||||
if iA, ok := a.bytesToInt(); ok {
|
||||
a.setInt(iA)
|
||||
} else if fA, ok := a.bytesToFloat(); ok {
|
||||
a.setFloat(fA)
|
||||
} else {
|
||||
return fmt.Errorf("Could not convert %s to a number", string(a.value.([]byte)))
|
||||
}
|
||||
|
||||
case typeBool:
|
||||
if bA, ok := a.bytesToBool(); ok {
|
||||
a.setBool(bA)
|
||||
} else {
|
||||
return fmt.Errorf("Could not convert %s to a boolean", string(a.value.([]byte)))
|
||||
}
|
||||
|
||||
default:
|
||||
return errCmpMismatchedTypes
|
||||
}
|
||||
return nil
|
||||
|
||||
case !okA && okB:
|
||||
// swap arguments to avoid repeating code
|
||||
return inferTypesForCmp(b, a)
|
||||
|
||||
default:
|
||||
// Does not happen
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// NullValue - returns underlying null value. It panics if value is not null type.
|
||||
func (value *Value) NullValue() *struct{} {
|
||||
if value.valueType == Null {
|
||||
// Value arithmetic functions: we do not expose them outside the
|
||||
// module. All arithmetic works only on numeric values with automatic
|
||||
// promotion to the "larger" type that can represent the value. TODO:
|
||||
// Add support for large number arithmetic.
|
||||
|
||||
// Supported arithmetic operators
|
||||
const (
|
||||
opPlus = "+"
|
||||
opMinus = "-"
|
||||
opDivide = "/"
|
||||
opMultiply = "*"
|
||||
opModulo = "%"
|
||||
)
|
||||
|
||||
// For arithmetic operations, if both values are numeric then the
|
||||
// operation shall succeed. If the types are unknown automatic type
|
||||
// conversion to a number is attempted.
|
||||
func (v *Value) arithOp(op string, a *Value) error {
|
||||
err := inferTypeForArithOp(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = inferTypeForArithOp(a)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !v.isNumeric() || !a.isNumeric() {
|
||||
return errInvalidDataType(errArithMismatchedTypes)
|
||||
}
|
||||
|
||||
if !isValidArithOperator(op) {
|
||||
return errInvalidDataType(errArithMismatchedTypes)
|
||||
}
|
||||
|
||||
intV, ok1i := v.ToInt()
|
||||
intA, ok2i := a.ToInt()
|
||||
switch {
|
||||
case ok1i && ok2i:
|
||||
res, err := intArithOp(op, intV, intA)
|
||||
v.setInt(res)
|
||||
return err
|
||||
|
||||
default:
|
||||
// Convert arguments to float
|
||||
flV, _ := v.ToFloat()
|
||||
flA, _ := a.ToFloat()
|
||||
res, err := floatArithOp(op, flV, flA)
|
||||
v.setFloat(res)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func inferTypeForArithOp(a *Value) error {
|
||||
if _, ok := a.ToBytes(); !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
panic(fmt.Sprintf("requested bool value but found %T type", value.value))
|
||||
}
|
||||
|
||||
// BoolValue - returns underlying bool value. It panics if value is not Bool type.
|
||||
func (value *Value) BoolValue() bool {
|
||||
if value.valueType == Bool {
|
||||
return value.value.(bool)
|
||||
if i, ok := a.bytesToInt(); ok {
|
||||
a.setInt(i)
|
||||
return nil
|
||||
}
|
||||
|
||||
panic(fmt.Sprintf("requested bool value but found %T type", value.value))
|
||||
}
|
||||
|
||||
// IntValue - returns underlying int value. It panics if value is not Int type.
|
||||
func (value *Value) IntValue() int64 {
|
||||
if value.valueType == Int {
|
||||
return value.value.(int64)
|
||||
if f, ok := a.bytesToFloat(); ok {
|
||||
a.setFloat(f)
|
||||
return nil
|
||||
}
|
||||
|
||||
panic(fmt.Sprintf("requested int value but found %T type", value.value))
|
||||
err := fmt.Errorf("Could not convert %s to a number", string(a.value.([]byte)))
|
||||
return errInvalidDataType(err)
|
||||
}
|
||||
|
||||
// FloatValue - returns underlying int/float value as float64. It panics if value is not Int/Float type.
|
||||
func (value *Value) FloatValue() float64 {
|
||||
switch value.valueType {
|
||||
case Int:
|
||||
return float64(value.value.(int64))
|
||||
case Float:
|
||||
return value.value.(float64)
|
||||
// All the bytesTo* functions defined below assume the value is a byte-slice.
|
||||
|
||||
// Converts untyped value into int. The bool return implies success -
|
||||
// it returns false only if there is a conversion failure.
|
||||
func (v *Value) bytesToInt() (int64, bool) {
|
||||
bytes, _ := v.ToBytes()
|
||||
i, err := strconv.ParseInt(string(bytes), 10, 64)
|
||||
return i, err == nil
|
||||
}
|
||||
|
||||
// Converts untyped value into float. The bool return implies success
|
||||
// - it returns false only if there is a conversion failure.
|
||||
func (v *Value) bytesToFloat() (float64, bool) {
|
||||
bytes, _ := v.ToBytes()
|
||||
i, err := strconv.ParseFloat(string(bytes), 64)
|
||||
return i, err == nil
|
||||
}
|
||||
|
||||
// Converts untyped value into bool. The second bool return implies
|
||||
// success - it returns false in case of a conversion failure.
|
||||
func (v *Value) bytesToBool() (val bool, ok bool) {
|
||||
bytes, _ := v.ToBytes()
|
||||
ok = true
|
||||
switch strings.ToLower(string(bytes)) {
|
||||
case "t", "true":
|
||||
val = true
|
||||
case "f", "false":
|
||||
val = false
|
||||
default:
|
||||
ok = false
|
||||
}
|
||||
return val, ok
|
||||
}
|
||||
|
||||
// bytesToString - never fails
|
||||
func (v *Value) bytesToString() string {
|
||||
bytes, _ := v.ToBytes()
|
||||
return string(bytes)
|
||||
}
|
||||
|
||||
// Calculates minimum or maximum of v and a and assigns the result to
|
||||
// v - it works only on numeric arguments, where `v` is already
|
||||
// assumed to be numeric. Attempts conversion to numeric type for `a`
|
||||
// (first int, then float) only if the underlying values do not have a
|
||||
// type.
|
||||
func (v *Value) minmax(a *Value, isMax, isFirstRow bool) error {
|
||||
err := inferTypeForArithOp(a)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
panic(fmt.Sprintf("requested float value but found %T type", value.value))
|
||||
}
|
||||
|
||||
// StringValue - returns underlying string value. It panics if value is not String type.
|
||||
func (value *Value) StringValue() string {
|
||||
if value.valueType == String {
|
||||
return value.value.(string)
|
||||
if !a.isNumeric() {
|
||||
return errArithMismatchedTypes
|
||||
}
|
||||
|
||||
panic(fmt.Sprintf("requested string value but found %T type", value.value))
|
||||
}
|
||||
|
||||
// TimeValue - returns underlying time value. It panics if value is not Timestamp type.
|
||||
func (value *Value) TimeValue() time.Time {
|
||||
if value.valueType == Timestamp {
|
||||
return value.value.(time.Time)
|
||||
}
|
||||
|
||||
panic(fmt.Sprintf("requested time value but found %T type", value.value))
|
||||
}
|
||||
|
||||
// ArrayValue - returns underlying value array. It panics if value is not Array type.
|
||||
func (value *Value) ArrayValue() []*Value {
|
||||
if value.valueType == Array {
|
||||
return value.value.([]*Value)
|
||||
}
|
||||
|
||||
panic(fmt.Sprintf("requested array value but found %T type", value.value))
|
||||
}
|
||||
|
||||
func (value *Value) recordValue() Record {
|
||||
if value.valueType == record {
|
||||
return value.value.(Record)
|
||||
}
|
||||
|
||||
panic(fmt.Sprintf("requested record value but found %T type", value.value))
|
||||
}
|
||||
|
||||
// Type - returns value type.
|
||||
func (value *Value) Type() Type {
|
||||
return value.valueType
|
||||
}
|
||||
|
||||
// Value - returns underneath value interface.
|
||||
func (value *Value) Value() interface{} {
|
||||
return value.value
|
||||
}
|
||||
|
||||
// NewNull - creates new null value.
|
||||
func NewNull() *Value {
|
||||
return &Value{nil, Null}
|
||||
}
|
||||
|
||||
// NewBool - creates new Bool value of b.
|
||||
func NewBool(b bool) *Value {
|
||||
return &Value{b, Bool}
|
||||
}
|
||||
|
||||
// NewInt - creates new Int value of i.
|
||||
func NewInt(i int64) *Value {
|
||||
return &Value{i, Int}
|
||||
}
|
||||
|
||||
// NewFloat - creates new Float value of f.
|
||||
func NewFloat(f float64) *Value {
|
||||
return &Value{f, Float}
|
||||
}
|
||||
|
||||
// NewString - creates new Sring value of s.
|
||||
func NewString(s string) *Value {
|
||||
return &Value{s, String}
|
||||
}
|
||||
|
||||
// NewTime - creates new Time value of t.
|
||||
func NewTime(t time.Time) *Value {
|
||||
return &Value{t, Timestamp}
|
||||
}
|
||||
|
||||
// NewArray - creates new Array value of values.
|
||||
func NewArray(values []*Value) *Value {
|
||||
return &Value{values, Array}
|
||||
}
|
||||
|
||||
func newRecordValue(r Record) *Value {
|
||||
return &Value{r, record}
|
||||
}
|
||||
|
||||
// NewValue - creates new Value from SQLVal v.
|
||||
func NewValue(v *sqlparser.SQLVal) (*Value, error) {
|
||||
switch v.Type {
|
||||
case sqlparser.StrVal:
|
||||
return NewString(string(v.Val)), nil
|
||||
case sqlparser.IntVal:
|
||||
i64, err := strconv.ParseInt(string(v.Val), 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// In case of first row, set v to a.
|
||||
if isFirstRow {
|
||||
intA, okI := a.ToInt()
|
||||
if okI {
|
||||
v.setInt(intA)
|
||||
return nil
|
||||
}
|
||||
return NewInt(i64), nil
|
||||
case sqlparser.FloatVal:
|
||||
f64, err := strconv.ParseFloat(string(v.Val), 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewFloat(f64), nil
|
||||
case sqlparser.HexNum: // represented as 0xDD
|
||||
i64, err := strconv.ParseInt(string(v.Val), 16, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewInt(i64), nil
|
||||
case sqlparser.HexVal: // represented as X'0DD'
|
||||
i64, err := strconv.ParseInt(string(v.Val), 16, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewInt(i64), nil
|
||||
case sqlparser.BitVal: // represented as B'00'
|
||||
i64, err := strconv.ParseInt(string(v.Val), 2, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewInt(i64), nil
|
||||
case sqlparser.ValArg:
|
||||
// FIXME: the format is unknown and not sure how to handle it.
|
||||
floatA, _ := a.ToFloat()
|
||||
v.setFloat(floatA)
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unknown SQL value %v; %v ", v, v.Type)
|
||||
intV, ok1i := v.ToInt()
|
||||
intA, ok2i := a.ToInt()
|
||||
if ok1i && ok2i {
|
||||
result := intV
|
||||
if !isMax {
|
||||
if intA < result {
|
||||
result = intA
|
||||
}
|
||||
} else {
|
||||
if intA > result {
|
||||
result = intA
|
||||
}
|
||||
}
|
||||
v.setInt(result)
|
||||
return nil
|
||||
}
|
||||
|
||||
floatV, _ := v.ToFloat()
|
||||
floatA, _ := a.ToFloat()
|
||||
var result float64
|
||||
if !isMax {
|
||||
result = math.Min(floatV, floatA)
|
||||
} else {
|
||||
result = math.Max(floatV, floatA)
|
||||
}
|
||||
v.setFloat(result)
|
||||
return nil
|
||||
}
|
||||
|
||||
// inferTypeAsString is used to convert untyped values to string - it
|
||||
// is called when the caller requires a string context to proceed.
|
||||
func inferTypeAsString(v *Value) {
|
||||
b, ok := v.ToBytes()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
v.setString(string(b))
|
||||
}
|
||||
|
||||
func isValidComparisonOperator(op string) bool {
|
||||
switch op {
|
||||
case opLt:
|
||||
case opLte:
|
||||
case opGt:
|
||||
case opGte:
|
||||
case opEq:
|
||||
case opIneq:
|
||||
default:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func intCompare(op string, left, right int64) bool {
|
||||
switch op {
|
||||
case opLt:
|
||||
return left < right
|
||||
case opLte:
|
||||
return left <= right
|
||||
case opGt:
|
||||
return left > right
|
||||
case opGte:
|
||||
return left >= right
|
||||
case opEq:
|
||||
return left == right
|
||||
case opIneq:
|
||||
return left != right
|
||||
}
|
||||
// This case does not happen
|
||||
return false
|
||||
}
|
||||
|
||||
func floatCompare(op string, left, right float64) bool {
|
||||
switch op {
|
||||
case opLt:
|
||||
return left < right
|
||||
case opLte:
|
||||
return left <= right
|
||||
case opGt:
|
||||
return left > right
|
||||
case opGte:
|
||||
return left >= right
|
||||
case opEq:
|
||||
return left == right
|
||||
case opIneq:
|
||||
return left != right
|
||||
}
|
||||
// This case does not happen
|
||||
return false
|
||||
}
|
||||
|
||||
func stringCompare(op string, left, right string) bool {
|
||||
switch op {
|
||||
case opLt:
|
||||
return left < right
|
||||
case opLte:
|
||||
return left <= right
|
||||
case opGt:
|
||||
return left > right
|
||||
case opGte:
|
||||
return left >= right
|
||||
case opEq:
|
||||
return left == right
|
||||
case opIneq:
|
||||
return left != right
|
||||
}
|
||||
// This case does not happen
|
||||
return false
|
||||
}
|
||||
|
||||
func boolCompare(op string, left, right bool) (bool, error) {
|
||||
switch op {
|
||||
case opEq:
|
||||
return left == right, nil
|
||||
case opIneq:
|
||||
return left != right, nil
|
||||
default:
|
||||
return false, errCmpInvalidBoolOperator
|
||||
}
|
||||
}
|
||||
|
||||
func isValidArithOperator(op string) bool {
|
||||
switch op {
|
||||
case opPlus:
|
||||
case opMinus:
|
||||
case opDivide:
|
||||
case opMultiply:
|
||||
case opModulo:
|
||||
default:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Overflow errors are ignored.
|
||||
func intArithOp(op string, left, right int64) (int64, error) {
|
||||
switch op {
|
||||
case opPlus:
|
||||
return left + right, nil
|
||||
case opMinus:
|
||||
return left - right, nil
|
||||
case opDivide:
|
||||
if right == 0 {
|
||||
return 0, errArithDivideByZero
|
||||
}
|
||||
return left / right, nil
|
||||
case opMultiply:
|
||||
return left * right, nil
|
||||
case opModulo:
|
||||
if right == 0 {
|
||||
return 0, errArithDivideByZero
|
||||
}
|
||||
return left % right, nil
|
||||
}
|
||||
// This does not happen
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Overflow errors are ignored.
|
||||
func floatArithOp(op string, left, right float64) (float64, error) {
|
||||
switch op {
|
||||
case opPlus:
|
||||
return left + right, nil
|
||||
case opMinus:
|
||||
return left - right, nil
|
||||
case opDivide:
|
||||
if right == 0 {
|
||||
return 0, errArithDivideByZero
|
||||
}
|
||||
return left / right, nil
|
||||
case opMultiply:
|
||||
return left * right, nil
|
||||
case opModulo:
|
||||
if right == 0 {
|
||||
return 0, errArithDivideByZero
|
||||
}
|
||||
return math.Mod(left, right), nil
|
||||
}
|
||||
// This does not happen
|
||||
return 0, nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
Copyright (C) 2017 Alec Thomas
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
this software and associated documentation files (the "Software"), to deal in
|
||||
the Software without restriction, including without limitation the rights to
|
||||
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
|
||||
of the Software, and to permit persons to whom the Software is furnished to do
|
||||
so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
|
@ -0,0 +1,345 @@
|
|||
# A dead simple parser package for Go
|
||||
|
||||
[![Godoc](https://godoc.org/github.com/alecthomas/participle?status.svg)](http://godoc.org/github.com/alecthomas/participle) [![CircleCI](https://img.shields.io/circleci/project/github/alecthomas/participle.svg)](https://circleci.com/gh/alecthomas/participle)
|
||||
[![Go Report Card](https://goreportcard.com/badge/github.com/alecthomas/participle)](https://goreportcard.com/report/github.com/alecthomas/participle) [![Gitter chat](https://badges.gitter.im/alecthomas.png)](https://gitter.im/alecthomas/Lobby)
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
1. [Introduction](#introduction)
|
||||
2. [Limitations](#limitations)
|
||||
3. [Tutorial](#tutorial)
|
||||
4. [Overview](#overview)
|
||||
5. [Annotation syntax](#annotation-syntax)
|
||||
6. [Capturing](#capturing)
|
||||
7. [Streaming](#streaming)
|
||||
8. [Lexing](#lexing)
|
||||
9. [Options](#options)
|
||||
10. [Examples](#examples)
|
||||
11. [Performance](#performance)
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
<a id="markdown-introduction" name="introduction"></a>
|
||||
## Introduction
|
||||
|
||||
The goal of this package is to provide a simple, idiomatic and elegant way of
|
||||
defining parsers in Go.
|
||||
|
||||
Participle's method of defining grammars should be familiar to any Go
|
||||
programmer who has used the `encoding/json` package: struct field tags define
|
||||
what and how input is mapped to those same fields. This is not unusual for Go
|
||||
encoders, but is unusual for a parser.
|
||||
|
||||
<a id="markdown-limitations" name="limitations"></a>
|
||||
## Limitations
|
||||
|
||||
Participle parsers are recursive descent. Among other things, this means that they do not support left recursion.
|
||||
|
||||
There is an experimental lookahead option for using precomputed lookahead
|
||||
tables for disambiguation. You can enable this with the parser option
|
||||
`participle.UseLookahead()`.
|
||||
|
||||
Left recursion must be eliminated by restructuring your grammar.
|
||||
|
||||
<a id="markdown-tutorial" name="tutorial"></a>
|
||||
## Tutorial
|
||||
|
||||
A [tutorial](TUTORIAL.md) is available, walking through the creation of an .ini parser.
|
||||
|
||||
<a id="markdown-overview" name="overview"></a>
|
||||
## Overview
|
||||
|
||||
A grammar is an annotated Go structure used to both define the parser grammar,
|
||||
and be the AST output by the parser. As an example, following is the final INI
|
||||
parser from the tutorial.
|
||||
|
||||
```go
|
||||
type INI struct {
|
||||
Properties []*Property `{ @@ }`
|
||||
Sections []*Section `{ @@ }`
|
||||
}
|
||||
|
||||
type Section struct {
|
||||
Identifier string `"[" @Ident "]"`
|
||||
Properties []*Property `{ @@ }`
|
||||
}
|
||||
|
||||
type Property struct {
|
||||
Key string `@Ident "="`
|
||||
Value *Value `@@`
|
||||
}
|
||||
|
||||
type Value struct {
|
||||
String *string ` @String`
|
||||
Number *float64 `| @Float`
|
||||
}
|
||||
```
|
||||
|
||||
> **Note:** Participle also supports named struct tags (eg. <code>Hello string `parser:"@Ident"`</code>).
|
||||
|
||||
A parser is constructed from a grammar and a lexer:
|
||||
|
||||
```go
|
||||
parser, err := participle.Build(&INI{})
|
||||
```
|
||||
|
||||
Once constructed, the parser is applied to input to produce an AST:
|
||||
|
||||
```go
|
||||
ast := &INI{}
|
||||
err := parser.ParseString("size = 10", ast)
|
||||
// ast == &INI{
|
||||
// Properties: []*Property{
|
||||
// {Key: "size", Value: &Value{Number: &10}},
|
||||
// },
|
||||
// }
|
||||
```
|
||||
|
||||
<a id="markdown-annotation-syntax" name="annotation-syntax"></a>
|
||||
## Annotation syntax
|
||||
|
||||
- `@<expr>` Capture expression into the field.
|
||||
- `@@` Recursively capture using the fields own type.
|
||||
- `<identifier>` Match named lexer token.
|
||||
- `( ... )` Group.
|
||||
- `"..."` Match the literal (note that the lexer must emit tokens matching this literal exactly).
|
||||
- `"...":<identifier>` Match the literal, specifying the exact lexer token type to match.
|
||||
- `<expr> <expr> ...` Match expressions.
|
||||
- `<expr> | <expr>` Match one of the alternatives.
|
||||
|
||||
The following modifiers can be used after any expression:
|
||||
|
||||
- `*` Expression can match zero or more times.
|
||||
- `+` Expression must match one or more times.
|
||||
- `?` Expression can match zero or once.
|
||||
- `!` Require a non-empty match (this is useful with a sequence of optional matches eg. `("a"? "b"? "c"?)!`).
|
||||
|
||||
Supported but deprecated:
|
||||
- `{ ... }` Match 0 or more times (**DEPRECATED** - prefer `( ... )*`).
|
||||
- `[ ... ]` Optional (**DEPRECATED** - prefer `( ... )?`).
|
||||
|
||||
Notes:
|
||||
|
||||
- Each struct is a single production, with each field applied in sequence.
|
||||
- `@<expr>` is the mechanism for capturing matches into the field.
|
||||
- if a struct field is not keyed with "parser", the entire struct tag
|
||||
will be used as the grammar fragment. This allows the grammar syntax to remain
|
||||
clear and simple to maintain.
|
||||
|
||||
<a id="markdown-capturing" name="capturing"></a>
|
||||
## Capturing
|
||||
|
||||
Prefixing any expression in the grammar with `@` will capture matching values
|
||||
for that expression into the corresponding field.
|
||||
|
||||
For example:
|
||||
|
||||
```go
|
||||
// The grammar definition.
|
||||
type Grammar struct {
|
||||
Hello string `@Ident`
|
||||
}
|
||||
|
||||
// The source text to parse.
|
||||
source := "world"
|
||||
|
||||
// After parsing, the resulting AST.
|
||||
result == &Grammar{
|
||||
Hello: "world",
|
||||
}
|
||||
```
|
||||
|
||||
For slice and string fields, each instance of `@` will accumulate into the
|
||||
field (including repeated patterns). Accumulation into other types is not
|
||||
supported.
|
||||
|
||||
A successful capture match into a boolean field will set the field to true.
|
||||
|
||||
For integer and floating point types, a successful capture will be parsed
|
||||
with `strconv.ParseInt()` and `strconv.ParseBool()` respectively.
|
||||
|
||||
Custom control of how values are captured into fields can be achieved by a
|
||||
field type implementing the `Capture` interface (`Capture(values []string)
|
||||
error`).
|
||||
|
||||
<a id="markdown-streaming" name="streaming"></a>
|
||||
## Streaming
|
||||
|
||||
Participle supports streaming parsing. Simply pass a channel of your grammar into
|
||||
`Parse*()`. The grammar will be repeatedly parsed and sent to the channel. Note that
|
||||
the `Parse*()` call will not return until parsing completes, so it should generally be
|
||||
started in a goroutine.
|
||||
|
||||
```go
|
||||
type token struct {
|
||||
Str string ` @Ident`
|
||||
Num int `| @Int`
|
||||
}
|
||||
|
||||
parser, err := participle.Build(&token{})
|
||||
|
||||
tokens := make(chan *token, 128)
|
||||
err := parser.ParseString(`hello 10 11 12 world`, tokens)
|
||||
for token := range tokens {
|
||||
fmt.Printf("%#v\n", token)
|
||||
}
|
||||
```
|
||||
|
||||
<a id="markdown-lexing" name="lexing"></a>
|
||||
## Lexing
|
||||
|
||||
Participle operates on tokens and thus relies on a lexer to convert character
|
||||
streams to tokens.
|
||||
|
||||
Three lexers are provided, varying in speed and flexibility. The fastest lexer
|
||||
is based on the [text/scanner](https://golang.org/pkg/text/scanner/) package
|
||||
but only allows tokens provided by that package. Next fastest is the regexp
|
||||
lexer (`lexer.Regexp()`). The slowest is currently the EBNF based lexer, but it has a large potential for optimisation through code generation.
|
||||
|
||||
To use your own Lexer you will need to implement two interfaces:
|
||||
[Definition](https://godoc.org/github.com/alecthomas/participle/lexer#Definition)
|
||||
and [Lexer](https://godoc.org/github.com/alecthomas/participle/lexer#Lexer).
|
||||
|
||||
<a id="markdown-options" name="options"></a>
|
||||
## Options
|
||||
|
||||
The Parser's behaviour can be configured via [Options](https://godoc.org/github.com/alecthomas/participle#Option).
|
||||
|
||||
<a id="markdown-examples" name="examples"></a>
|
||||
## Examples
|
||||
|
||||
There are several [examples](https://github.com/alecthomas/participle/tree/master/_examples) included:
|
||||
|
||||
Example | Description
|
||||
--------|---------------
|
||||
[BASIC](https://github.com/alecthomas/participle/tree/master/_examples/basic) | A lexer, parser and interpreter for a [rudimentary dialect](https://caml.inria.fr/pub/docs/oreilly-book/html/book-ora058.html) of BASIC.
|
||||
[EBNF](https://github.com/alecthomas/participle/tree/master/_examples/ebnf) | Parser for the form of EBNF used by Go.
|
||||
[Expr](https://github.com/alecthomas/participle/tree/master/_examples/expr) | A basic mathematical expression parser and evaluator.
|
||||
[GraphQL](https://github.com/alecthomas/participle/tree/master/_examples/graphql) | Lexer+parser for GraphQL schemas
|
||||
[HCL](https://github.com/alecthomas/participle/tree/master/_examples/hcl) | A parser for the [HashiCorp Configuration Language](https://github.com/hashicorp/hcl).
|
||||
[INI](https://github.com/alecthomas/participle/tree/master/_examples/ini) | An INI file parser.
|
||||
[Protobuf](https://github.com/alecthomas/participle/tree/master/_examples/protobuf) | A full [Protobuf](https://developers.google.com/protocol-buffers/) version 2 and 3 parser.
|
||||
[SQL](https://github.com/alecthomas/participle/tree/master/_examples/sql) | A *very* rudimentary SQL SELECT parser.
|
||||
[Thrift](https://github.com/alecthomas/participle/tree/master/_examples/thrift) | A full [Thrift](https://thrift.apache.org/docs/idl) parser.
|
||||
[TOML](https://github.com/alecthomas/participle/blob/master/_examples/toml/main.go) | A [TOML](https://github.com/toml-lang/toml) parser.
|
||||
|
||||
Included below is a full GraphQL lexer and parser:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/alecthomas/kong"
|
||||
"github.com/alecthomas/repr"
|
||||
|
||||
"github.com/alecthomas/participle"
|
||||
"github.com/alecthomas/participle/lexer"
|
||||
"github.com/alecthomas/participle/lexer/ebnf"
|
||||
)
|
||||
|
||||
type File struct {
|
||||
Entries []*Entry `{ @@ }`
|
||||
}
|
||||
|
||||
type Entry struct {
|
||||
Type *Type ` @@`
|
||||
Schema *Schema `| @@`
|
||||
Enum *Enum `| @@`
|
||||
Scalar string `| "scalar" @Ident`
|
||||
}
|
||||
|
||||
type Enum struct {
|
||||
Name string `"enum" @Ident`
|
||||
Cases []string `"{" { @Ident } "}"`
|
||||
}
|
||||
|
||||
type Schema struct {
|
||||
Fields []*Field `"schema" "{" { @@ } "}"`
|
||||
}
|
||||
|
||||
type Type struct {
|
||||
Name string `"type" @Ident`
|
||||
Implements string `[ "implements" @Ident ]`
|
||||
Fields []*Field `"{" { @@ } "}"`
|
||||
}
|
||||
|
||||
type Field struct {
|
||||
Name string `@Ident`
|
||||
Arguments []*Argument `[ "(" [ @@ { "," @@ } ] ")" ]`
|
||||
Type *TypeRef `":" @@`
|
||||
Annotation string `[ "@" @Ident ]`
|
||||
}
|
||||
|
||||
type Argument struct {
|
||||
Name string `@Ident`
|
||||
Type *TypeRef `":" @@`
|
||||
Default *Value `[ "=" @@ ]`
|
||||
}
|
||||
|
||||
type TypeRef struct {
|
||||
Array *TypeRef `( "[" @@ "]"`
|
||||
Type string ` | @Ident )`
|
||||
NonNullable bool `[ @"!" ]`
|
||||
}
|
||||
|
||||
type Value struct {
|
||||
Symbol string `@Ident`
|
||||
}
|
||||
|
||||
var (
|
||||
graphQLLexer = lexer.Must(ebnf.New(`
|
||||
Comment = ("#" | "//") { "\u0000"…"\uffff"-"\n" } .
|
||||
Ident = (alpha | "_") { "_" | alpha | digit } .
|
||||
Number = ("." | digit) {"." | digit} .
|
||||
Whitespace = " " | "\t" | "\n" | "\r" .
|
||||
Punct = "!"…"/" | ":"…"@" | "["…`+"\"`\""+` | "{"…"~" .
|
||||
|
||||
alpha = "a"…"z" | "A"…"Z" .
|
||||
digit = "0"…"9" .
|
||||
`))
|
||||
|
||||
parser = participle.MustBuild(&File{},
|
||||
participle.Lexer(graphQLLexer),
|
||||
participle.Elide("Comment", "Whitespace"),
|
||||
)
|
||||
|
||||
cli struct {
|
||||
Files []string `arg:"" type:"existingfile" required:"" help:"GraphQL schema files to parse."`
|
||||
}
|
||||
)
|
||||
|
||||
func main() {
|
||||
ctx := kong.Parse(&cli)
|
||||
for _, file := range cli.Files {
|
||||
ast := &File{}
|
||||
r, err := os.Open(file)
|
||||
ctx.FatalIfErrorf(err)
|
||||
err = parser.Parse(r, ast)
|
||||
r.Close()
|
||||
repr.Println(ast)
|
||||
ctx.FatalIfErrorf(err)
|
||||
}
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
<a id="markdown-performance" name="performance"></a>
|
||||
## Performance
|
||||
|
||||
One of the included examples is a complete Thrift parser
|
||||
(shell-style comments are not supported). This gives
|
||||
a convenient baseline for comparing to the PEG based
|
||||
[pigeon](https://github.com/PuerkitoBio/pigeon), which is the parser used by
|
||||
[go-thrift](https://github.com/samuel/go-thrift). Additionally, the pigeon
|
||||
parser is utilising a generated parser, while the participle parser is built at
|
||||
run time.
|
||||
|
||||
You can run the benchmarks yourself, but here's the output on my machine:
|
||||
|
||||
BenchmarkParticipleThrift-4 10000 221818 ns/op 48880 B/op 1240 allocs/op
|
||||
BenchmarkGoThriftParser-4 2000 804709 ns/op 170301 B/op 3086 allocs/op
|
||||
|
||||
On a real life codebase of 47K lines of Thrift, Participle takes 200ms and go-
|
||||
thrift takes 630ms, which aligns quite closely with the benchmarks.
|
|
@ -0,0 +1,255 @@
|
|||
# Participle parser tutorial
|
||||
|
||||
<!-- MarkdownTOC -->
|
||||
|
||||
1. [Introduction](#introduction)
|
||||
1. [The complete grammar](#the-complete-grammar)
|
||||
1. [Root of the .ini AST \(structure, fields\)](#root-of-the-ini-ast-structure-fields)
|
||||
1. [.ini properties \(named tokens, capturing, literals\)](#ini-properties-named-tokens-capturing-literals)
|
||||
1. [.ini property values \(alternates, recursive structs, sequences\)](#ini-property-values-alternates-recursive-structs-sequences)
|
||||
1. [Complete, but limited, .ini grammar \(top-level properties only\)](#complete-but-limited-ini-grammar-top-level-properties-only)
|
||||
1. [Extending our grammar to support sections](#extending-our-grammar-to-support-sections)
|
||||
1. [\(Optional\) Source positional information](#optional-source-positional-information)
|
||||
1. [Parsing using our grammar](#parsing-using-our-grammar)
|
||||
|
||||
<!-- /MarkdownTOC -->
|
||||
|
||||
## Introduction
|
||||
|
||||
Writing a parser in Participle typically involves starting from the "root" of
|
||||
the AST, annotating fields with the grammar, then recursively expanding until
|
||||
it is complete. The AST is expressed via Go data types and the grammar is
|
||||
expressed through struct field tags, as a form of EBNF.
|
||||
|
||||
The parser we're going to create for this tutorial parses .ini files
|
||||
like this:
|
||||
|
||||
```ini
|
||||
age = 21
|
||||
name = "Bob Smith"
|
||||
|
||||
[address]
|
||||
city = "Beverly Hills"
|
||||
postal_code = 90210
|
||||
```
|
||||
|
||||
## The complete grammar
|
||||
|
||||
I think it's useful to see the complete grammar first, to see what we're
|
||||
working towards. Read on below for details.
|
||||
|
||||
```go
|
||||
type INI struct {
|
||||
Properties []*Property `@@*`
|
||||
Sections []*Section `@@*`
|
||||
}
|
||||
|
||||
type Section struct {
|
||||
Identifier string `"[" @Ident "]"`
|
||||
Properties []*Property `@@*`
|
||||
}
|
||||
|
||||
type Property struct {
|
||||
Key string `@Ident "="`
|
||||
Value *Value `@@`
|
||||
}
|
||||
|
||||
type Value struct {
|
||||
String *string ` @String`
|
||||
Number *float64 `| @Float`
|
||||
}
|
||||
```
|
||||
|
||||
## Root of the .ini AST (structure, fields)
|
||||
|
||||
The first step is to create a root struct for our grammar. In the case of our
|
||||
.ini parser, this struct will contain a sequence of properties:
|
||||
|
||||
```go
|
||||
type INI struct {
|
||||
Properties []*Property
|
||||
}
|
||||
|
||||
type Property struct {
|
||||
}
|
||||
```
|
||||
|
||||
## .ini properties (named tokens, capturing, literals)
|
||||
|
||||
Each property in an .ini file has an identifier key:
|
||||
|
||||
```go
|
||||
type Property struct {
|
||||
Key string
|
||||
}
|
||||
```
|
||||
|
||||
The default lexer tokenises Go source code, and includes an `Ident` token type
|
||||
that matches identifiers. To match this token we simply use the token type
|
||||
name:
|
||||
|
||||
```go
|
||||
type Property struct {
|
||||
Key string `Ident`
|
||||
}
|
||||
```
|
||||
|
||||
This will *match* identifiers, but not *capture* them into the `Key` field. To
|
||||
capture input tokens into AST fields, prefix any grammar node with `@`:
|
||||
|
||||
```go
|
||||
type Property struct {
|
||||
Key string `@Ident`
|
||||
}
|
||||
```
|
||||
|
||||
In .ini files, each key is separated from its value with a literal `=`. To
|
||||
match a literal, enclose the literal in double quotes:
|
||||
|
||||
```go
|
||||
type Property struct {
|
||||
Key string `@Ident "="`
|
||||
}
|
||||
```
|
||||
|
||||
> Note: literals in the grammar must match tokens from the lexer *exactly*. In
|
||||
> this example if the lexer does not output `=` as a distinct token the
|
||||
> grammar will not match.
|
||||
|
||||
## .ini property values (alternates, recursive structs, sequences)
|
||||
|
||||
For the purposes of our example we are only going to support quoted string
|
||||
and numeric property values. As each value can be *either* a string or a float
|
||||
we'll need something akin to a sum type. Go's type system cannot express this
|
||||
directly, so we'll use the common approach of making each element a pointer.
|
||||
The selected "case" will *not* be nil.
|
||||
|
||||
```go
|
||||
type Value struct {
|
||||
String *string
|
||||
Number *float64
|
||||
}
|
||||
```
|
||||
|
||||
> Note: Participle will hydrate pointers as necessary.
|
||||
|
||||
To express matching a set of alternatives we use the `|` operator:
|
||||
|
||||
```go
|
||||
type Value struct {
|
||||
String *string ` @String`
|
||||
Number *float64 `| @Float`
|
||||
}
|
||||
```
|
||||
|
||||
> Note: the grammar can cross fields.
|
||||
|
||||
Next, we'll match values and capture them into the `Property`. To recursively
|
||||
capture structs use `@@` (capture self):
|
||||
|
||||
```go
|
||||
type Property struct {
|
||||
Key string `@Ident "="`
|
||||
Value *Value `@@`
|
||||
}
|
||||
```
|
||||
|
||||
Now that we can parse a `Property` we need to go back to the root of the
|
||||
grammar. We want to parse 0 or more properties. To do this, we use `<expr>*`.
|
||||
Participle will accumulate each match into the slice until matching fails,
|
||||
then move to the next node in the grammar.
|
||||
|
||||
```go
|
||||
type INI struct {
|
||||
Properties []*Property `@@*`
|
||||
}
|
||||
```
|
||||
|
||||
> Note: tokens can also be accumulated into strings, appending each match.
|
||||
|
||||
## Complete, but limited, .ini grammar (top-level properties only)
|
||||
|
||||
We now have a functional, but limited, .ini parser!
|
||||
|
||||
```go
|
||||
type INI struct {
|
||||
Properties []*Property `@@*`
|
||||
}
|
||||
|
||||
type Property struct {
|
||||
Key string `@Ident "="`
|
||||
Value *Value `@@`
|
||||
}
|
||||
|
||||
type Value struct {
|
||||
String *string ` @String`
|
||||
Number *float64 `| @Float`
|
||||
}
|
||||
```
|
||||
|
||||
## Extending our grammar to support sections
|
||||
|
||||
Adding support for sections is simply a matter of utilising the constructs
|
||||
we've just learnt. A section consists of a header identifier, and a sequence
|
||||
of properties:
|
||||
|
||||
```go
|
||||
type Section struct {
|
||||
Identifier string `"[" @Ident "]"`
|
||||
Properties []*Property `@@*`
|
||||
}
|
||||
```
|
||||
|
||||
Simple!
|
||||
|
||||
Now we just add a sequence of `Section`s to our root node:
|
||||
|
||||
```go
|
||||
type INI struct {
|
||||
Properties []*Property `@@*`
|
||||
Sections []*Section `@@*`
|
||||
}
|
||||
```
|
||||
|
||||
And we're done!
|
||||
|
||||
## (Optional) Source positional information
|
||||
|
||||
If a grammar node includes a field with the name `Pos` and type `lexer.Position`, it will be automatically populated by positional information. eg.
|
||||
|
||||
```go
|
||||
type Value struct {
|
||||
Pos lexer.Position
|
||||
String *string ` @String`
|
||||
Number *float64 `| @Float`
|
||||
}
|
||||
```
|
||||
|
||||
This is useful for error reporting.
|
||||
|
||||
## Parsing using our grammar
|
||||
|
||||
To parse with this grammar we first construct the parser (we'll use the
|
||||
default lexer for now):
|
||||
|
||||
```go
|
||||
parser, err := participle.Build(&INI{})
|
||||
```
|
||||
|
||||
Then create a root node and parse into it with `parser.Parse{,String,Bytes}()`:
|
||||
|
||||
```go
|
||||
ini := &INI{}
|
||||
err = parser.ParseString(`
|
||||
age = 21
|
||||
name = "Bob Smith"
|
||||
|
||||
[address]
|
||||
city = "Beverly Hills"
|
||||
postal_code = 90210
|
||||
`, ini)
|
||||
```
|
||||
|
||||
You can find the full example [here](_examples/ini/main.go), alongside
|
||||
other examples including an SQL `SELECT` parser and a full
|
||||
[Thrift](https://thrift.apache.org/) parser.
|
|
@ -0,0 +1,19 @@
|
|||
package participle
|
||||
|
||||
import (
|
||||
"github.com/alecthomas/participle/lexer"
|
||||
)
|
||||
|
||||
// Capture can be implemented by fields in order to transform captured tokens into field values.
|
||||
type Capture interface {
|
||||
Capture(values []string) error
|
||||
}
|
||||
|
||||
// The Parseable interface can be implemented by any element in the grammar to provide custom parsing.
|
||||
type Parseable interface {
|
||||
// Parse into the receiver.
|
||||
//
|
||||
// Should return NextMatch if no tokens matched and parsing should continue.
|
||||
// Nil should be returned if parsing was successful.
|
||||
Parse(lex lexer.PeekingLexer) error
|
||||
}
|
|
@ -0,0 +1,123 @@
|
|||
package participle
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"github.com/alecthomas/participle/lexer"
|
||||
)
|
||||
|
||||
type contextFieldSet struct {
|
||||
pos lexer.Position
|
||||
strct reflect.Value
|
||||
field structLexerField
|
||||
fieldValue []reflect.Value
|
||||
}
|
||||
|
||||
// Context for a single parse.
|
||||
type parseContext struct {
|
||||
*rewinder
|
||||
lookahead int
|
||||
caseInsensitive map[rune]bool
|
||||
apply []*contextFieldSet
|
||||
}
|
||||
|
||||
func newParseContext(lex lexer.Lexer, lookahead int, caseInsensitive map[rune]bool) (*parseContext, error) {
|
||||
rew, err := newRewinder(lex)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &parseContext{
|
||||
rewinder: rew,
|
||||
caseInsensitive: caseInsensitive,
|
||||
lookahead: lookahead,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Defer adds a function to be applied once a branch has been picked.
|
||||
func (p *parseContext) Defer(pos lexer.Position, strct reflect.Value, field structLexerField, fieldValue []reflect.Value) {
|
||||
p.apply = append(p.apply, &contextFieldSet{pos, strct, field, fieldValue})
|
||||
}
|
||||
|
||||
// Apply deferred functions.
|
||||
func (p *parseContext) Apply() error {
|
||||
for _, apply := range p.apply {
|
||||
if err := setField(apply.pos, apply.strct, apply.field, apply.fieldValue); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
p.apply = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// Branch accepts the branch as the correct branch.
|
||||
func (p *parseContext) Accept(branch *parseContext) {
|
||||
p.apply = append(p.apply, branch.apply...)
|
||||
p.rewinder = branch.rewinder
|
||||
}
|
||||
|
||||
// Branch starts a new lookahead branch.
|
||||
func (p *parseContext) Branch() *parseContext {
|
||||
branch := &parseContext{}
|
||||
*branch = *p
|
||||
branch.apply = nil
|
||||
branch.rewinder = p.rewinder.Lookahead()
|
||||
return branch
|
||||
}
|
||||
|
||||
// Stop returns true if parsing should terminate after the given "branch" failed to match.
|
||||
func (p *parseContext) Stop(branch *parseContext) bool {
|
||||
if branch.cursor > p.cursor+p.lookahead {
|
||||
p.Accept(branch)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type rewinder struct {
|
||||
cursor, limit int
|
||||
tokens []lexer.Token
|
||||
}
|
||||
|
||||
func newRewinder(lex lexer.Lexer) (*rewinder, error) {
|
||||
r := &rewinder{}
|
||||
for {
|
||||
t, err := lex.Next()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if t.EOF() {
|
||||
break
|
||||
}
|
||||
r.tokens = append(r.tokens, t)
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (r *rewinder) Next() (lexer.Token, error) {
|
||||
if r.cursor >= len(r.tokens) {
|
||||
return lexer.EOFToken(lexer.Position{}), nil
|
||||
}
|
||||
r.cursor++
|
||||
return r.tokens[r.cursor-1], nil
|
||||
}
|
||||
|
||||
func (r *rewinder) Peek(n int) (lexer.Token, error) {
|
||||
i := r.cursor + n
|
||||
if i >= len(r.tokens) {
|
||||
return lexer.EOFToken(lexer.Position{}), nil
|
||||
}
|
||||
return r.tokens[i], nil
|
||||
}
|
||||
|
||||
// Lookahead returns a new rewinder usable for lookahead.
|
||||
func (r *rewinder) Lookahead() *rewinder {
|
||||
clone := &rewinder{}
|
||||
*clone = *r
|
||||
clone.limit = clone.cursor
|
||||
return clone
|
||||
}
|
||||
|
||||
// Keep this lookahead rewinder.
|
||||
func (r *rewinder) Keep() {
|
||||
r.limit = 0
|
||||
}
|
|
@ -0,0 +1,73 @@
|
|||
// Package participle constructs parsers from definitions in struct tags and parses directly into
|
||||
// those structs. The approach is philosophically similar to how other marshallers work in Go,
|
||||
// "unmarshalling" an instance of a grammar into a struct.
|
||||
//
|
||||
// The supported annotation syntax is:
|
||||
//
|
||||
// - `@<expr>` Capture expression into the field.
|
||||
// - `@@` Recursively capture using the fields own type.
|
||||
// - `<identifier>` Match named lexer token.
|
||||
// - `( ... )` Group.
|
||||
// - `"..."` Match the literal (note that the lexer must emit tokens matching this literal exactly).
|
||||
// - `"...":<identifier>` Match the literal, specifying the exact lexer token type to match.
|
||||
// - `<expr> <expr> ...` Match expressions.
|
||||
// - `<expr> | <expr>` Match one of the alternatives.
|
||||
//
|
||||
// The following modifiers can be used after any expression:
|
||||
//
|
||||
// - `*` Expression can match zero or more times.
|
||||
// - `+` Expression must match one or more times.
|
||||
// - `?` Expression can match zero or once.
|
||||
// - `!` Require a non-empty match (this is useful with a sequence of optional matches eg. `("a"? "b"? "c"?)!`).
|
||||
//
|
||||
// Supported but deprecated:
|
||||
//
|
||||
// - `{ ... }` Match 0 or more times (**DEPRECATED** - prefer `( ... )*`).
|
||||
// - `[ ... ]` Optional (**DEPRECATED** - prefer `( ... )?`).
|
||||
//
|
||||
// Here's an example of an EBNF grammar.
|
||||
//
|
||||
// type Group struct {
|
||||
// Expression *Expression `"(" @@ ")"`
|
||||
// }
|
||||
//
|
||||
// type Option struct {
|
||||
// Expression *Expression `"[" @@ "]"`
|
||||
// }
|
||||
//
|
||||
// type Repetition struct {
|
||||
// Expression *Expression `"{" @@ "}"`
|
||||
// }
|
||||
//
|
||||
// type Literal struct {
|
||||
// Start string `@String` // lexer.Lexer token "String"
|
||||
// End string `("…" @String)?`
|
||||
// }
|
||||
//
|
||||
// type Term struct {
|
||||
// Name string ` @Ident`
|
||||
// Literal *Literal `| @@`
|
||||
// Group *Group `| @@`
|
||||
// Option *Option `| @@`
|
||||
// Repetition *Repetition `| @@`
|
||||
// }
|
||||
//
|
||||
// type Sequence struct {
|
||||
// Terms []*Term `@@+`
|
||||
// }
|
||||
//
|
||||
// type Expression struct {
|
||||
// Alternatives []*Sequence `@@ ("|" @@)*`
|
||||
// }
|
||||
//
|
||||
// type Expressions []*Expression
|
||||
//
|
||||
// type Production struct {
|
||||
// Name string `@Ident "="`
|
||||
// Expressions Expressions `@@+ "."`
|
||||
// }
|
||||
//
|
||||
// type EBNF struct {
|
||||
// Productions []*Production `@@*`
|
||||
// }
|
||||
package participle
|
|
@ -0,0 +1,7 @@
|
|||
module github.com/alecthomas/participle
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/stretchr/testify v1.2.2
|
||||
)
|
|
@ -0,0 +1,6 @@
|
|||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w=
|
||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
|
@ -0,0 +1,324 @@
|
|||
package participle
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"text/scanner"
|
||||
|
||||
"github.com/alecthomas/participle/lexer"
|
||||
)
|
||||
|
||||
type generatorContext struct {
|
||||
lexer.Definition
|
||||
typeNodes map[reflect.Type]node
|
||||
symbolsToIDs map[rune]string
|
||||
}
|
||||
|
||||
func newGeneratorContext(lex lexer.Definition) *generatorContext {
|
||||
return &generatorContext{
|
||||
Definition: lex,
|
||||
typeNodes: map[reflect.Type]node{},
|
||||
symbolsToIDs: lexer.SymbolsByRune(lex),
|
||||
}
|
||||
}
|
||||
|
||||
// Takes a type and builds a tree of nodes out of it.
|
||||
func (g *generatorContext) parseType(t reflect.Type) (_ node, returnedError error) {
|
||||
rt := t
|
||||
t = indirectType(t)
|
||||
if n, ok := g.typeNodes[t]; ok {
|
||||
return n, nil
|
||||
}
|
||||
if rt.Implements(parseableType) {
|
||||
return &parseable{rt.Elem()}, nil
|
||||
}
|
||||
if reflect.PtrTo(rt).Implements(parseableType) {
|
||||
return &parseable{rt}, nil
|
||||
}
|
||||
switch t.Kind() {
|
||||
case reflect.Slice, reflect.Ptr:
|
||||
t = indirectType(t.Elem())
|
||||
if t.Kind() != reflect.Struct {
|
||||
return nil, fmt.Errorf("expected a struct but got %T", t)
|
||||
}
|
||||
fallthrough
|
||||
|
||||
case reflect.Struct:
|
||||
slexer, err := lexStruct(t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := &strct{typ: t}
|
||||
g.typeNodes[t] = out // Ensure we avoid infinite recursion.
|
||||
if slexer.NumField() == 0 {
|
||||
return nil, fmt.Errorf("can not parse into empty struct %s", t)
|
||||
}
|
||||
defer decorate(&returnedError, func() string { return slexer.Field().Name })
|
||||
e, err := g.parseDisjunction(slexer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if e == nil {
|
||||
return nil, fmt.Errorf("no grammar found in %s", t)
|
||||
}
|
||||
if token, _ := slexer.Peek(); !token.EOF() {
|
||||
return nil, fmt.Errorf("unexpected input %q", token.Value)
|
||||
}
|
||||
out.expr = e
|
||||
return out, nil
|
||||
}
|
||||
return nil, fmt.Errorf("%s should be a struct or should implement the Parseable interface", t)
|
||||
}
|
||||
|
||||
func (g *generatorContext) parseDisjunction(slexer *structLexer) (node, error) {
|
||||
out := &disjunction{}
|
||||
for {
|
||||
n, err := g.parseSequence(slexer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out.nodes = append(out.nodes, n)
|
||||
if token, _ := slexer.Peek(); token.Type != '|' {
|
||||
break
|
||||
}
|
||||
_, err = slexer.Next() // |
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if len(out.nodes) == 1 {
|
||||
return out.nodes[0], nil
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (g *generatorContext) parseSequence(slexer *structLexer) (node, error) {
|
||||
head := &sequence{}
|
||||
cursor := head
|
||||
loop:
|
||||
for {
|
||||
if token, err := slexer.Peek(); err != nil {
|
||||
return nil, err
|
||||
} else if token.Type == lexer.EOF {
|
||||
break loop
|
||||
}
|
||||
term, err := g.parseTerm(slexer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if term == nil {
|
||||
break loop
|
||||
}
|
||||
if cursor.node == nil {
|
||||
cursor.head = true
|
||||
cursor.node = term
|
||||
} else {
|
||||
cursor.next = &sequence{node: term}
|
||||
cursor = cursor.next
|
||||
}
|
||||
}
|
||||
if head.node == nil {
|
||||
return nil, nil
|
||||
}
|
||||
if head.next == nil {
|
||||
return head.node, nil
|
||||
}
|
||||
return head, nil
|
||||
}
|
||||
|
||||
func (g *generatorContext) parseTermNoModifiers(slexer *structLexer) (node, error) {
|
||||
t, err := slexer.Peek()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var out node
|
||||
switch t.Type {
|
||||
case '@':
|
||||
out, err = g.parseCapture(slexer)
|
||||
case scanner.String, scanner.RawString, scanner.Char:
|
||||
out, err = g.parseLiteral(slexer)
|
||||
case '[':
|
||||
return g.parseOptional(slexer)
|
||||
case '{':
|
||||
return g.parseRepetition(slexer)
|
||||
case '(':
|
||||
out, err = g.parseGroup(slexer)
|
||||
case scanner.Ident:
|
||||
out, err = g.parseReference(slexer)
|
||||
case lexer.EOF:
|
||||
_, _ = slexer.Next()
|
||||
return nil, nil
|
||||
default:
|
||||
return nil, nil
|
||||
}
|
||||
return out, err
|
||||
}
|
||||
|
||||
func (g *generatorContext) parseTerm(slexer *structLexer) (node, error) {
|
||||
out, err := g.parseTermNoModifiers(slexer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return g.parseModifier(slexer, out)
|
||||
}
|
||||
|
||||
// Parse modifiers: ?, *, + and/or !
|
||||
func (g *generatorContext) parseModifier(slexer *structLexer, expr node) (node, error) {
|
||||
out := &group{expr: expr}
|
||||
t, err := slexer.Peek()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch t.Type {
|
||||
case '!':
|
||||
out.mode = groupMatchNonEmpty
|
||||
case '+':
|
||||
out.mode = groupMatchOneOrMore
|
||||
case '*':
|
||||
out.mode = groupMatchZeroOrMore
|
||||
case '?':
|
||||
out.mode = groupMatchZeroOrOne
|
||||
default:
|
||||
return expr, nil
|
||||
}
|
||||
_, _ = slexer.Next()
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// @<expression> captures <expression> into the current field.
|
||||
func (g *generatorContext) parseCapture(slexer *structLexer) (node, error) {
|
||||
_, _ = slexer.Next()
|
||||
token, err := slexer.Peek()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
field := slexer.Field()
|
||||
if token.Type == '@' {
|
||||
_, _ = slexer.Next()
|
||||
n, err := g.parseType(field.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &capture{field, n}, nil
|
||||
}
|
||||
if indirectType(field.Type).Kind() == reflect.Struct && !field.Type.Implements(captureType) {
|
||||
return nil, fmt.Errorf("structs can only be parsed with @@ or by implementing the Capture interface")
|
||||
}
|
||||
n, err := g.parseTermNoModifiers(slexer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &capture{field, n}, nil
|
||||
}
|
||||
|
||||
// A reference in the form <identifier> refers to a named token from the lexer.
|
||||
func (g *generatorContext) parseReference(slexer *structLexer) (node, error) { // nolint: interfacer
|
||||
token, err := slexer.Next()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if token.Type != scanner.Ident {
|
||||
return nil, fmt.Errorf("expected identifier but got %q", token)
|
||||
}
|
||||
typ, ok := g.Symbols()[token.Value]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown token type %q", token)
|
||||
}
|
||||
return &reference{typ: typ, identifier: token.Value}, nil
|
||||
}
|
||||
|
||||
// [ <expression> ] optionally matches <expression>.
|
||||
func (g *generatorContext) parseOptional(slexer *structLexer) (node, error) {
|
||||
_, _ = slexer.Next() // [
|
||||
disj, err := g.parseDisjunction(slexer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
n := &group{expr: disj, mode: groupMatchZeroOrOne}
|
||||
next, err := slexer.Next()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if next.Type != ']' {
|
||||
return nil, fmt.Errorf("expected ] but got %q", next)
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// { <expression> } matches 0 or more repititions of <expression>
|
||||
func (g *generatorContext) parseRepetition(slexer *structLexer) (node, error) {
|
||||
_, _ = slexer.Next() // {
|
||||
disj, err := g.parseDisjunction(slexer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
n := &group{expr: disj, mode: groupMatchZeroOrMore}
|
||||
next, err := slexer.Next()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if next.Type != '}' {
|
||||
return nil, fmt.Errorf("expected } but got %q", next)
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// ( <expression> ) groups a sub-expression
|
||||
func (g *generatorContext) parseGroup(slexer *structLexer) (node, error) {
|
||||
_, _ = slexer.Next() // (
|
||||
disj, err := g.parseDisjunction(slexer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
next, err := slexer.Next() // )
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if next.Type != ')' {
|
||||
return nil, fmt.Errorf("expected ) but got %q", next)
|
||||
}
|
||||
return &group{expr: disj}, nil
|
||||
}
|
||||
|
||||
// A literal string.
|
||||
//
|
||||
// Note that for this to match, the tokeniser must be able to produce this string. For example,
|
||||
// if the tokeniser only produces individual characters but the literal is "hello", or vice versa.
|
||||
func (g *generatorContext) parseLiteral(lex *structLexer) (node, error) { // nolint: interfacer
|
||||
token, err := lex.Next()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if token.Type != scanner.String && token.Type != scanner.RawString && token.Type != scanner.Char {
|
||||
return nil, fmt.Errorf("expected quoted string but got %q", token)
|
||||
}
|
||||
s := token.Value
|
||||
t := rune(-1)
|
||||
token, err = lex.Peek()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if token.Value == ":" && (token.Type == scanner.Char || token.Type == ':') {
|
||||
_, _ = lex.Next()
|
||||
token, err = lex.Next()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if token.Type != scanner.Ident {
|
||||
return nil, fmt.Errorf("expected identifier for literal type constraint but got %q", token)
|
||||
}
|
||||
var ok bool
|
||||
t, ok = g.Symbols()[token.Value]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown token type %q in literal type constraint", token)
|
||||
}
|
||||
}
|
||||
return &literal{s: s, t: t, tt: g.symbolsToIDs[t]}, nil
|
||||
}
|
||||
|
||||
func indirectType(t reflect.Type) reflect.Type {
|
||||
if t.Kind() == reflect.Ptr || t.Kind() == reflect.Slice {
|
||||
return indirectType(t.Elem())
|
||||
}
|
||||
return t
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
// Package lexer defines interfaces and implementations used by Participle to perform lexing.
|
||||
//
|
||||
// The primary interfaces are Definition and Lexer. There are three implementations of these
|
||||
// interfaces:
|
||||
//
|
||||
// TextScannerLexer is based on text/scanner. This is the fastest, but least flexible, in that
|
||||
// tokens are restricted to those supported by that package. It can scan about 5M tokens/second on a
|
||||
// late 2013 15" MacBook Pro.
|
||||
//
|
||||
// The second lexer is constructed via the Regexp() function, mapping regexp capture groups
|
||||
// to tokens. The complete input source is read into memory, so it is unsuitable for large inputs.
|
||||
//
|
||||
// The final lexer provided accepts a lexical grammar in EBNF. Each capitalised production is a
|
||||
// lexical token supported by the resulting Lexer. This is very flexible, but a bit slower, scanning
|
||||
// around 730K tokens/second on the same machine, though it is currently completely unoptimised.
|
||||
// This could/should be converted to a table-based lexer.
|
||||
//
|
||||
// Lexer implementations must use Panic/Panicf to report errors.
|
||||
package lexer
|
|
@ -0,0 +1,26 @@
|
|||
package lexer
|
||||
|
||||
import "fmt"
|
||||
|
||||
// Error represents an error while parsing.
|
||||
type Error struct {
|
||||
Message string
|
||||
Pos Position
|
||||
}
|
||||
|
||||
// Errorf creats a new Error at the given position.
|
||||
func Errorf(pos Position, format string, args ...interface{}) *Error {
|
||||
return &Error{
|
||||
Message: fmt.Sprintf(format, args...),
|
||||
Pos: pos,
|
||||
}
|
||||
}
|
||||
|
||||
// Error complies with the error interface and reports the position of an error.
|
||||
func (e *Error) Error() string {
|
||||
filename := e.Pos.Filename
|
||||
if filename == "" {
|
||||
filename = "<source>"
|
||||
}
|
||||
return fmt.Sprintf("%s:%d:%d: %s", filename, e.Pos.Line, e.Pos.Column, e.Message)
|
||||
}
|
|
@ -0,0 +1,150 @@
|
|||
package lexer
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
const (
|
||||
// EOF represents an end of file.
|
||||
EOF rune = -(iota + 1)
|
||||
)
|
||||
|
||||
// EOFToken creates a new EOF token at the given position.
|
||||
func EOFToken(pos Position) Token {
|
||||
return Token{Type: EOF, Pos: pos}
|
||||
}
|
||||
|
||||
// Definition provides the parser with metadata for a lexer.
|
||||
type Definition interface {
|
||||
// Lex an io.Reader.
|
||||
Lex(io.Reader) (Lexer, error)
|
||||
// Symbols returns a map of symbolic names to the corresponding pseudo-runes for those symbols.
|
||||
// This is the same approach as used by text/scanner. For example, "EOF" might have the rune
|
||||
// value of -1, "Ident" might be -2, and so on.
|
||||
Symbols() map[string]rune
|
||||
}
|
||||
|
||||
// A Lexer returns tokens from a source.
|
||||
type Lexer interface {
|
||||
// Next consumes and returns the next token.
|
||||
Next() (Token, error)
|
||||
}
|
||||
|
||||
// A PeekingLexer returns tokens from a source and allows peeking.
|
||||
type PeekingLexer interface {
|
||||
Lexer
|
||||
// Peek at the next token.
|
||||
Peek(n int) (Token, error)
|
||||
}
|
||||
|
||||
// SymbolsByRune returns a map of lexer symbol names keyed by rune.
|
||||
func SymbolsByRune(def Definition) map[rune]string {
|
||||
out := map[rune]string{}
|
||||
for s, r := range def.Symbols() {
|
||||
out[r] = s
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// NameOfReader attempts to retrieve the filename of a reader.
|
||||
func NameOfReader(r interface{}) string {
|
||||
if nr, ok := r.(interface{ Name() string }); ok {
|
||||
return nr.Name()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Must takes the result of a Definition constructor call and returns the definition, but panics if
|
||||
// it errors
|
||||
//
|
||||
// eg.
|
||||
//
|
||||
// lex = lexer.Must(lexer.Build(`Symbol = "symbol" .`))
|
||||
func Must(def Definition, err error) Definition {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
// ConsumeAll reads all tokens from a Lexer.
|
||||
func ConsumeAll(lexer Lexer) ([]Token, error) {
|
||||
tokens := []Token{}
|
||||
for {
|
||||
token, err := lexer.Next()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokens = append(tokens, token)
|
||||
if token.Type == EOF {
|
||||
return tokens, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Position of a token.
|
||||
type Position struct {
|
||||
Filename string
|
||||
Offset int
|
||||
Line int
|
||||
Column int
|
||||
}
|
||||
|
||||
func (p Position) GoString() string {
|
||||
return fmt.Sprintf("Position{Filename: %q, Offset: %d, Line: %d, Column: %d}",
|
||||
p.Filename, p.Offset, p.Line, p.Column)
|
||||
}
|
||||
|
||||
func (p Position) String() string {
|
||||
filename := p.Filename
|
||||
if filename == "" {
|
||||
filename = "<source>"
|
||||
}
|
||||
return fmt.Sprintf("%s:%d:%d", filename, p.Line, p.Column)
|
||||
}
|
||||
|
||||
// A Token returned by a Lexer.
|
||||
type Token struct {
|
||||
// Type of token. This is the value keyed by symbol as returned by Definition.Symbols().
|
||||
Type rune
|
||||
Value string
|
||||
Pos Position
|
||||
}
|
||||
|
||||
// RuneToken represents a rune as a Token.
|
||||
func RuneToken(r rune) Token {
|
||||
return Token{Type: r, Value: string(r)}
|
||||
}
|
||||
|
||||
// EOF returns true if this Token is an EOF token.
|
||||
func (t Token) EOF() bool {
|
||||
return t.Type == EOF
|
||||
}
|
||||
|
||||
func (t Token) String() string {
|
||||
if t.EOF() {
|
||||
return "<EOF>"
|
||||
}
|
||||
return t.Value
|
||||
}
|
||||
|
||||
func (t Token) GoString() string {
|
||||
return fmt.Sprintf("Token{%d, %q}", t.Type, t.Value)
|
||||
}
|
||||
|
||||
// MakeSymbolTable builds a lookup table for checking token ID existence.
|
||||
//
|
||||
// For each symbolic name in "types", the returned map will contain the corresponding token ID as a key.
|
||||
func MakeSymbolTable(def Definition, types ...string) (map[rune]bool, error) {
|
||||
symbols := def.Symbols()
|
||||
table := map[rune]bool{}
|
||||
for _, symbol := range types {
|
||||
rn, ok := symbols[symbol]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("lexer does not support symbol %q", symbol)
|
||||
}
|
||||
table[rn] = true
|
||||
}
|
||||
return table, nil
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
package lexer
|
||||
|
||||
// Upgrade a Lexer to a PeekingLexer with arbitrary lookahead.
|
||||
func Upgrade(lexer Lexer) PeekingLexer {
|
||||
if peeking, ok := lexer.(PeekingLexer); ok {
|
||||
return peeking
|
||||
}
|
||||
return &lookaheadLexer{Lexer: lexer}
|
||||
}
|
||||
|
||||
type lookaheadLexer struct {
|
||||
Lexer
|
||||
peeked []Token
|
||||
}
|
||||
|
||||
func (l *lookaheadLexer) Peek(n int) (Token, error) {
|
||||
for len(l.peeked) <= n {
|
||||
t, err := l.Lexer.Next()
|
||||
if err != nil {
|
||||
return Token{}, err
|
||||
}
|
||||
if t.EOF() {
|
||||
return t, nil
|
||||
}
|
||||
l.peeked = append(l.peeked, t)
|
||||
}
|
||||
return l.peeked[n], nil
|
||||
}
|
||||
|
||||
func (l *lookaheadLexer) Next() (Token, error) {
|
||||
if len(l.peeked) > 0 {
|
||||
t := l.peeked[0]
|
||||
l.peeked = l.peeked[1:]
|
||||
return t, nil
|
||||
}
|
||||
return l.Lexer.Next()
|
||||
}
|
|
@ -0,0 +1,112 @@
|
|||
package lexer
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"regexp"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
var eolBytes = []byte("\n")
|
||||
|
||||
type regexpDefinition struct {
|
||||
re *regexp.Regexp
|
||||
symbols map[string]rune
|
||||
}
|
||||
|
||||
// Regexp creates a lexer definition from a regular expression.
|
||||
//
|
||||
// Each named sub-expression in the regular expression matches a token. Anonymous sub-expressions
|
||||
// will be matched and discarded.
|
||||
//
|
||||
// eg.
|
||||
//
|
||||
// def, err := Regexp(`(?P<Ident>[a-z]+)|(\s+)|(?P<Number>\d+)`)
|
||||
func Regexp(pattern string) (Definition, error) {
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
symbols := map[string]rune{
|
||||
"EOF": EOF,
|
||||
}
|
||||
for i, sym := range re.SubexpNames()[1:] {
|
||||
if sym != "" {
|
||||
symbols[sym] = EOF - 1 - rune(i)
|
||||
}
|
||||
}
|
||||
return ®expDefinition{re: re, symbols: symbols}, nil
|
||||
}
|
||||
|
||||
func (d *regexpDefinition) Lex(r io.Reader) (Lexer, error) {
|
||||
b, err := ioutil.ReadAll(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ®expLexer{
|
||||
pos: Position{
|
||||
Filename: NameOfReader(r),
|
||||
Line: 1,
|
||||
Column: 1,
|
||||
},
|
||||
b: b,
|
||||
re: d.re,
|
||||
names: d.re.SubexpNames(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *regexpDefinition) Symbols() map[string]rune {
|
||||
return d.symbols
|
||||
}
|
||||
|
||||
type regexpLexer struct {
|
||||
pos Position
|
||||
b []byte
|
||||
re *regexp.Regexp
|
||||
names []string
|
||||
}
|
||||
|
||||
func (r *regexpLexer) Next() (Token, error) {
|
||||
nextToken:
|
||||
for len(r.b) != 0 {
|
||||
matches := r.re.FindSubmatchIndex(r.b)
|
||||
if matches == nil || matches[0] != 0 {
|
||||
rn, _ := utf8.DecodeRune(r.b)
|
||||
return Token{}, Errorf(r.pos, "invalid token %q", rn)
|
||||
}
|
||||
match := r.b[:matches[1]]
|
||||
token := Token{
|
||||
Pos: r.pos,
|
||||
Value: string(match),
|
||||
}
|
||||
|
||||
// Update lexer state.
|
||||
r.pos.Offset += matches[1]
|
||||
lines := bytes.Count(match, eolBytes)
|
||||
r.pos.Line += lines
|
||||
// Update column.
|
||||
if lines == 0 {
|
||||
r.pos.Column += utf8.RuneCount(match)
|
||||
} else {
|
||||
r.pos.Column = utf8.RuneCount(match[bytes.LastIndex(match, eolBytes):])
|
||||
}
|
||||
// Move slice along.
|
||||
r.b = r.b[matches[1]:]
|
||||
|
||||
// Finally, assign token type. If it is not a named group, we continue to the next token.
|
||||
for i := 2; i < len(matches); i += 2 {
|
||||
if matches[i] != -1 {
|
||||
if r.names[i/2] == "" {
|
||||
continue nextToken
|
||||
}
|
||||
token.Type = EOF - rune(i/2)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
return EOFToken(r.pos), nil
|
||||
}
|
|
@ -0,0 +1,125 @@
|
|||
package lexer
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
"text/scanner"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// TextScannerLexer is a lexer that uses the text/scanner module.
|
||||
var (
|
||||
TextScannerLexer Definition = &defaultDefinition{}
|
||||
|
||||
// DefaultDefinition defines properties for the default lexer.
|
||||
DefaultDefinition = TextScannerLexer
|
||||
)
|
||||
|
||||
type defaultDefinition struct{}
|
||||
|
||||
func (d *defaultDefinition) Lex(r io.Reader) (Lexer, error) {
|
||||
return Lex(r), nil
|
||||
}
|
||||
|
||||
func (d *defaultDefinition) Symbols() map[string]rune {
|
||||
return map[string]rune{
|
||||
"EOF": scanner.EOF,
|
||||
"Char": scanner.Char,
|
||||
"Ident": scanner.Ident,
|
||||
"Int": scanner.Int,
|
||||
"Float": scanner.Float,
|
||||
"String": scanner.String,
|
||||
"RawString": scanner.RawString,
|
||||
"Comment": scanner.Comment,
|
||||
}
|
||||
}
|
||||
|
||||
// textScannerLexer is a Lexer based on text/scanner.Scanner
|
||||
type textScannerLexer struct {
|
||||
scanner *scanner.Scanner
|
||||
filename string
|
||||
err error
|
||||
}
|
||||
|
||||
// Lex an io.Reader with text/scanner.Scanner.
|
||||
//
|
||||
// This provides very fast lexing of source code compatible with Go tokens.
|
||||
//
|
||||
// Note that this differs from text/scanner.Scanner in that string tokens will be unquoted.
|
||||
func Lex(r io.Reader) Lexer {
|
||||
lexer := lexWithScanner(r, &scanner.Scanner{})
|
||||
lexer.scanner.Error = func(s *scanner.Scanner, msg string) {
|
||||
// This is to support single quoted strings. Hacky.
|
||||
if msg != "illegal char literal" {
|
||||
lexer.err = Errorf(Position(lexer.scanner.Pos()), msg)
|
||||
}
|
||||
}
|
||||
return lexer
|
||||
}
|
||||
|
||||
// LexWithScanner creates a Lexer from a user-provided scanner.Scanner.
|
||||
//
|
||||
// Useful if you need to customise the Scanner.
|
||||
func LexWithScanner(r io.Reader, scan *scanner.Scanner) Lexer {
|
||||
return lexWithScanner(r, scan)
|
||||
}
|
||||
|
||||
func lexWithScanner(r io.Reader, scan *scanner.Scanner) *textScannerLexer {
|
||||
lexer := &textScannerLexer{
|
||||
filename: NameOfReader(r),
|
||||
scanner: scan,
|
||||
}
|
||||
lexer.scanner.Init(r)
|
||||
return lexer
|
||||
}
|
||||
|
||||
// LexBytes returns a new default lexer over bytes.
|
||||
func LexBytes(b []byte) Lexer {
|
||||
return Lex(bytes.NewReader(b))
|
||||
}
|
||||
|
||||
// LexString returns a new default lexer over a string.
|
||||
func LexString(s string) Lexer {
|
||||
return Lex(strings.NewReader(s))
|
||||
}
|
||||
|
||||
func (t *textScannerLexer) Next() (Token, error) {
|
||||
typ := t.scanner.Scan()
|
||||
text := t.scanner.TokenText()
|
||||
pos := Position(t.scanner.Position)
|
||||
pos.Filename = t.filename
|
||||
if t.err != nil {
|
||||
return Token{}, t.err
|
||||
}
|
||||
return textScannerTransform(Token{
|
||||
Type: typ,
|
||||
Value: text,
|
||||
Pos: pos,
|
||||
})
|
||||
}
|
||||
|
||||
func textScannerTransform(token Token) (Token, error) {
|
||||
// Unquote strings.
|
||||
switch token.Type {
|
||||
case scanner.Char:
|
||||
// FIXME(alec): This is pretty hacky...we convert a single quoted char into a double
|
||||
// quoted string in order to support single quoted strings.
|
||||
token.Value = fmt.Sprintf("\"%s\"", token.Value[1:len(token.Value)-1])
|
||||
fallthrough
|
||||
case scanner.String:
|
||||
s, err := strconv.Unquote(token.Value)
|
||||
if err != nil {
|
||||
return Token{}, Errorf(token.Pos, "%s: %q", err.Error(), token.Value)
|
||||
}
|
||||
token.Value = s
|
||||
if token.Type == scanner.Char && utf8.RuneCountInString(s) > 1 {
|
||||
token.Type = scanner.String
|
||||
}
|
||||
case scanner.RawString:
|
||||
token.Value = token.Value[1 : len(token.Value)-1]
|
||||
}
|
||||
return token, nil
|
||||
}
|
|
@ -0,0 +1,118 @@
|
|||
package participle
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/alecthomas/participle/lexer"
|
||||
)
|
||||
|
||||
type mapperByToken struct {
|
||||
symbols []string
|
||||
mapper Mapper
|
||||
}
|
||||
|
||||
// DropToken can be returned by a Mapper to remove a token from the stream.
|
||||
var DropToken = errors.New("drop token") // nolint: golint
|
||||
|
||||
// Mapper function for mutating tokens before being applied to the AST.
|
||||
//
|
||||
// If the Mapper func returns an error of DropToken, the token will be removed from the stream.
|
||||
type Mapper func(token lexer.Token) (lexer.Token, error)
|
||||
|
||||
// Map is an Option that configures the Parser to apply a mapping function to each Token from the lexer.
|
||||
//
|
||||
// This can be useful to eg. upper-case all tokens of a certain type, or dequote strings.
|
||||
//
|
||||
// "symbols" specifies the token symbols that the Mapper will be applied to. If empty, all tokens will be mapped.
|
||||
func Map(mapper Mapper, symbols ...string) Option {
|
||||
return func(p *Parser) error {
|
||||
p.mappers = append(p.mappers, mapperByToken{
|
||||
mapper: mapper,
|
||||
symbols: symbols,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Unquote applies strconv.Unquote() to tokens of the given types.
|
||||
//
|
||||
// Tokens of type "String" will be unquoted if no other types are provided.
|
||||
func Unquote(types ...string) Option {
|
||||
if len(types) == 0 {
|
||||
types = []string{"String"}
|
||||
}
|
||||
return Map(func(t lexer.Token) (lexer.Token, error) {
|
||||
value, err := unquote(t.Value)
|
||||
if err != nil {
|
||||
return t, lexer.Errorf(t.Pos, "invalid quoted string %q: %s", t.Value, err.Error())
|
||||
}
|
||||
t.Value = value
|
||||
return t, nil
|
||||
}, types...)
|
||||
}
|
||||
|
||||
func unquote(s string) (string, error) {
|
||||
quote := s[0]
|
||||
s = s[1 : len(s)-1]
|
||||
out := ""
|
||||
for s != "" {
|
||||
value, _, tail, err := strconv.UnquoteChar(s, quote)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
s = tail
|
||||
out += string(value)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// Upper is an Option that upper-cases all tokens of the given type. Useful for case normalisation.
|
||||
func Upper(types ...string) Option {
|
||||
return Map(func(token lexer.Token) (lexer.Token, error) {
|
||||
token.Value = strings.ToUpper(token.Value)
|
||||
return token, nil
|
||||
}, types...)
|
||||
}
|
||||
|
||||
// Elide drops tokens of the specified types.
|
||||
func Elide(types ...string) Option {
|
||||
return Map(func(token lexer.Token) (lexer.Token, error) {
|
||||
return lexer.Token{}, DropToken
|
||||
}, types...)
|
||||
}
|
||||
|
||||
// Apply a Mapping to all tokens coming out of a Lexer.
|
||||
type mappingLexerDef struct {
|
||||
lexer.Definition
|
||||
mapper Mapper
|
||||
}
|
||||
|
||||
func (m *mappingLexerDef) Lex(r io.Reader) (lexer.Lexer, error) {
|
||||
lexer, err := m.Definition.Lex(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &mappingLexer{lexer, m.mapper}, nil
|
||||
}
|
||||
|
||||
type mappingLexer struct {
|
||||
lexer.Lexer
|
||||
mapper Mapper
|
||||
}
|
||||
|
||||
func (m *mappingLexer) Next() (lexer.Token, error) {
|
||||
for {
|
||||
t, err := m.Lexer.Next()
|
||||
if err != nil {
|
||||
return t, err
|
||||
}
|
||||
t, err = m.mapper(t)
|
||||
if err == DropToken {
|
||||
continue
|
||||
}
|
||||
return t, err
|
||||
}
|
||||
}
|
|
@ -0,0 +1,575 @@
|
|||
package participle
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/alecthomas/participle/lexer"
|
||||
)
|
||||
|
||||
var (
|
||||
// MaxIterations limits the number of elements capturable by {}.
|
||||
MaxIterations = 1000000
|
||||
|
||||
positionType = reflect.TypeOf(lexer.Position{})
|
||||
captureType = reflect.TypeOf((*Capture)(nil)).Elem()
|
||||
parseableType = reflect.TypeOf((*Parseable)(nil)).Elem()
|
||||
|
||||
// NextMatch should be returned by Parseable.Parse() method implementations to indicate
|
||||
// that the node did not match and that other matches should be attempted, if appropriate.
|
||||
NextMatch = errors.New("no match") // nolint: golint
|
||||
)
|
||||
|
||||
// A node in the grammar.
|
||||
type node interface {
|
||||
// Parse from scanner into value.
|
||||
//
|
||||
// Returned slice will be nil if the node does not match.
|
||||
Parse(ctx *parseContext, parent reflect.Value) ([]reflect.Value, error)
|
||||
|
||||
// Return a decent string representation of the Node.
|
||||
String() string
|
||||
}
|
||||
|
||||
func decorate(err *error, name func() string) {
|
||||
if *err == nil {
|
||||
return
|
||||
}
|
||||
switch realError := (*err).(type) {
|
||||
case *lexer.Error:
|
||||
*err = &lexer.Error{Message: name() + ": " + realError.Message, Pos: realError.Pos}
|
||||
default:
|
||||
*err = fmt.Errorf("%s: %s", name(), realError)
|
||||
}
|
||||
}
|
||||
|
||||
// A node that proxies to an implementation that implements the Parseable interface.
|
||||
type parseable struct {
|
||||
t reflect.Type
|
||||
}
|
||||
|
||||
func (p *parseable) String() string { return stringer(p) }
|
||||
|
||||
func (p *parseable) Parse(ctx *parseContext, parent reflect.Value) (out []reflect.Value, err error) {
|
||||
rv := reflect.New(p.t)
|
||||
v := rv.Interface().(Parseable)
|
||||
err = v.Parse(ctx)
|
||||
if err != nil {
|
||||
if err == NextMatch {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return []reflect.Value{rv.Elem()}, nil
|
||||
}
|
||||
|
||||
type strct struct {
|
||||
typ reflect.Type
|
||||
expr node
|
||||
}
|
||||
|
||||
func (s *strct) String() string { return stringer(s) }
|
||||
|
||||
func (s *strct) maybeInjectPos(pos lexer.Position, v reflect.Value) {
|
||||
if f := v.FieldByName("Pos"); f.IsValid() && f.Type() == positionType {
|
||||
f.Set(reflect.ValueOf(pos))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *strct) Parse(ctx *parseContext, parent reflect.Value) (out []reflect.Value, err error) {
|
||||
sv := reflect.New(s.typ).Elem()
|
||||
t, err := ctx.Peek(0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.maybeInjectPos(t.Pos, sv)
|
||||
if out, err = s.expr.Parse(ctx, sv); err != nil {
|
||||
_ = ctx.Apply()
|
||||
return []reflect.Value{sv}, err
|
||||
} else if out == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return []reflect.Value{sv}, ctx.Apply()
|
||||
}
|
||||
|
||||
type groupMatchMode int
|
||||
|
||||
const (
|
||||
groupMatchOnce groupMatchMode = iota
|
||||
groupMatchZeroOrOne = iota
|
||||
groupMatchZeroOrMore = iota
|
||||
groupMatchOneOrMore = iota
|
||||
groupMatchNonEmpty = iota
|
||||
)
|
||||
|
||||
// ( <expr> ) - match once
|
||||
// ( <expr> )* - match zero or more times
|
||||
// ( <expr> )+ - match one or more times
|
||||
// ( <expr> )? - match zero or once
|
||||
// ( <expr> )! - must be a non-empty match
|
||||
//
|
||||
// The additional modifier "!" forces the content of the group to be non-empty if it does match.
|
||||
type group struct {
|
||||
expr node
|
||||
mode groupMatchMode
|
||||
}
|
||||
|
||||
func (g *group) String() string { return stringer(g) }
|
||||
func (g *group) Parse(ctx *parseContext, parent reflect.Value) (out []reflect.Value, err error) {
|
||||
// Configure min/max matches.
|
||||
min := 1
|
||||
max := 1
|
||||
switch g.mode {
|
||||
case groupMatchNonEmpty:
|
||||
out, err = g.expr.Parse(ctx, parent)
|
||||
if err != nil {
|
||||
return out, err
|
||||
}
|
||||
if len(out) == 0 {
|
||||
t, _ := ctx.Peek(0)
|
||||
return out, lexer.Errorf(t.Pos, "sub-expression %s cannot be empty", g)
|
||||
}
|
||||
return out, nil
|
||||
case groupMatchOnce:
|
||||
return g.expr.Parse(ctx, parent)
|
||||
case groupMatchZeroOrOne:
|
||||
min = 0
|
||||
case groupMatchZeroOrMore:
|
||||
min = 0
|
||||
max = MaxIterations
|
||||
case groupMatchOneOrMore:
|
||||
min = 1
|
||||
max = MaxIterations
|
||||
}
|
||||
matches := 0
|
||||
for ; matches < max; matches++ {
|
||||
branch := ctx.Branch()
|
||||
v, err := g.expr.Parse(branch, parent)
|
||||
out = append(out, v...)
|
||||
if err != nil {
|
||||
// Optional part failed to match.
|
||||
if ctx.Stop(branch) {
|
||||
return out, err
|
||||
}
|
||||
break
|
||||
} else {
|
||||
ctx.Accept(branch)
|
||||
}
|
||||
if v == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
// fmt.Printf("%d < %d < %d: out == nil? %v\n", min, matches, max, out == nil)
|
||||
t, _ := ctx.Peek(0)
|
||||
if matches >= MaxIterations {
|
||||
panic(lexer.Errorf(t.Pos, "too many iterations of %s (> %d)", g, MaxIterations))
|
||||
}
|
||||
if matches < min {
|
||||
return out, lexer.Errorf(t.Pos, "sub-expression %s must match at least once", g)
|
||||
}
|
||||
// The idea here is that something like "a"? is a successful match and that parsing should proceed.
|
||||
if min == 0 && out == nil {
|
||||
out = []reflect.Value{}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// <expr> {"|" <expr>}
|
||||
type disjunction struct {
|
||||
nodes []node
|
||||
}
|
||||
|
||||
func (d *disjunction) String() string { return stringer(d) }
|
||||
|
||||
func (d *disjunction) Parse(ctx *parseContext, parent reflect.Value) (out []reflect.Value, err error) {
|
||||
var (
|
||||
deepestError = 0
|
||||
firstError error
|
||||
firstValues []reflect.Value
|
||||
)
|
||||
for _, a := range d.nodes {
|
||||
branch := ctx.Branch()
|
||||
if value, err := a.Parse(branch, parent); err != nil {
|
||||
// If this branch progressed too far and still didn't match, error out.
|
||||
if ctx.Stop(branch) {
|
||||
return value, err
|
||||
}
|
||||
// Show the closest error returned. The idea here is that the further the parser progresses
|
||||
// without error, the more difficult it is to trace the error back to its root.
|
||||
if err != nil && branch.cursor >= deepestError {
|
||||
firstError = err
|
||||
firstValues = value
|
||||
deepestError = branch.cursor
|
||||
}
|
||||
} else if value != nil {
|
||||
ctx.Accept(branch)
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
if firstError != nil {
|
||||
return firstValues, firstError
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// <node> ...
|
||||
type sequence struct {
|
||||
head bool
|
||||
node node
|
||||
next *sequence
|
||||
}
|
||||
|
||||
func (s *sequence) String() string { return stringer(s) }
|
||||
|
||||
func (s *sequence) Parse(ctx *parseContext, parent reflect.Value) (out []reflect.Value, err error) {
|
||||
for n := s; n != nil; n = n.next {
|
||||
child, err := n.node.Parse(ctx, parent)
|
||||
out = append(out, child...)
|
||||
if err != nil {
|
||||
return out, err
|
||||
}
|
||||
if child == nil {
|
||||
// Early exit if first value doesn't match, otherwise all values must match.
|
||||
if n == s {
|
||||
return nil, nil
|
||||
}
|
||||
token, err := ctx.Peek(0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, lexer.Errorf(token.Pos, "unexpected %q (expected %s)", token, n)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// @<expr>
|
||||
type capture struct {
|
||||
field structLexerField
|
||||
node node
|
||||
}
|
||||
|
||||
func (c *capture) String() string { return stringer(c) }
|
||||
|
||||
func (c *capture) Parse(ctx *parseContext, parent reflect.Value) (out []reflect.Value, err error) {
|
||||
token, err := ctx.Peek(0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pos := token.Pos
|
||||
v, err := c.node.Parse(ctx, parent)
|
||||
if err != nil {
|
||||
if v != nil {
|
||||
ctx.Defer(pos, parent, c.field, v)
|
||||
}
|
||||
return []reflect.Value{parent}, err
|
||||
}
|
||||
if v == nil {
|
||||
return nil, nil
|
||||
}
|
||||
ctx.Defer(pos, parent, c.field, v)
|
||||
return []reflect.Value{parent}, nil
|
||||
}
|
||||
|
||||
// <identifier> - named lexer token reference
|
||||
type reference struct {
|
||||
typ rune
|
||||
identifier string // Used for informational purposes.
|
||||
}
|
||||
|
||||
func (r *reference) String() string { return stringer(r) }
|
||||
|
||||
func (r *reference) Parse(ctx *parseContext, parent reflect.Value) (out []reflect.Value, err error) {
|
||||
token, err := ctx.Peek(0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if token.Type != r.typ {
|
||||
return nil, nil
|
||||
}
|
||||
_, _ = ctx.Next()
|
||||
return []reflect.Value{reflect.ValueOf(token.Value)}, nil
|
||||
}
|
||||
|
||||
// [ <expr> ] <sequence>
|
||||
type optional struct {
|
||||
node node
|
||||
}
|
||||
|
||||
func (o *optional) String() string { return stringer(o) }
|
||||
|
||||
func (o *optional) Parse(ctx *parseContext, parent reflect.Value) (out []reflect.Value, err error) {
|
||||
branch := ctx.Branch()
|
||||
out, err = o.node.Parse(branch, parent)
|
||||
if err != nil {
|
||||
// Optional part failed to match.
|
||||
if ctx.Stop(branch) {
|
||||
return out, err
|
||||
}
|
||||
} else {
|
||||
ctx.Accept(branch)
|
||||
}
|
||||
if out == nil {
|
||||
out = []reflect.Value{}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// { <expr> } <sequence>
|
||||
type repetition struct {
|
||||
node node
|
||||
}
|
||||
|
||||
func (r *repetition) String() string { return stringer(r) }
|
||||
|
||||
// Parse a repetition. Once a repetition is encountered it will always match, so grammars
|
||||
// should ensure that branches are differentiated prior to the repetition.
|
||||
func (r *repetition) Parse(ctx *parseContext, parent reflect.Value) (out []reflect.Value, err error) {
|
||||
i := 0
|
||||
for ; i < MaxIterations; i++ {
|
||||
branch := ctx.Branch()
|
||||
v, err := r.node.Parse(branch, parent)
|
||||
out = append(out, v...)
|
||||
if err != nil {
|
||||
// Optional part failed to match.
|
||||
if ctx.Stop(branch) {
|
||||
return out, err
|
||||
}
|
||||
break
|
||||
} else {
|
||||
ctx.Accept(branch)
|
||||
}
|
||||
if v == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if i >= MaxIterations {
|
||||
t, _ := ctx.Peek(0)
|
||||
panic(lexer.Errorf(t.Pos, "too many iterations of %s (> %d)", r, MaxIterations))
|
||||
}
|
||||
if out == nil {
|
||||
out = []reflect.Value{}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// Match a token literal exactly "..."[:<type>].
|
||||
type literal struct {
|
||||
s string
|
||||
t rune
|
||||
tt string // Used for display purposes - symbolic name of t.
|
||||
}
|
||||
|
||||
func (l *literal) String() string { return stringer(l) }
|
||||
|
||||
func (l *literal) Parse(ctx *parseContext, parent reflect.Value) (out []reflect.Value, err error) {
|
||||
token, err := ctx.Peek(0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
equal := false // nolint: ineffassign
|
||||
if ctx.caseInsensitive[token.Type] {
|
||||
equal = strings.EqualFold(token.Value, l.s)
|
||||
} else {
|
||||
equal = token.Value == l.s
|
||||
}
|
||||
if equal && (l.t == -1 || l.t == token.Type) {
|
||||
next, err := ctx.Next()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []reflect.Value{reflect.ValueOf(next.Value)}, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Attempt to transform values to given type.
|
||||
//
|
||||
// This will dereference pointers, and attempt to parse strings into integer values, floats, etc.
|
||||
func conform(t reflect.Type, values []reflect.Value) (out []reflect.Value, err error) {
|
||||
for _, v := range values {
|
||||
for t != v.Type() && t.Kind() == reflect.Ptr && v.Kind() != reflect.Ptr {
|
||||
// This can occur during partial failure.
|
||||
if !v.CanAddr() {
|
||||
return
|
||||
}
|
||||
v = v.Addr()
|
||||
}
|
||||
|
||||
// Already of the right kind, don't bother converting.
|
||||
if v.Kind() == t.Kind() {
|
||||
out = append(out, v)
|
||||
continue
|
||||
}
|
||||
|
||||
kind := t.Kind()
|
||||
switch kind {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
n, err := strconv.ParseInt(v.String(), 0, sizeOfKind(kind))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid integer %q: %s", v.String(), err)
|
||||
}
|
||||
v = reflect.New(t).Elem()
|
||||
v.SetInt(n)
|
||||
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
n, err := strconv.ParseUint(v.String(), 0, sizeOfKind(kind))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid integer %q: %s", v.String(), err)
|
||||
}
|
||||
v = reflect.New(t).Elem()
|
||||
v.SetUint(n)
|
||||
|
||||
case reflect.Bool:
|
||||
v = reflect.ValueOf(true)
|
||||
|
||||
case reflect.Float32, reflect.Float64:
|
||||
n, err := strconv.ParseFloat(v.String(), sizeOfKind(kind))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid integer %q: %s", v.String(), err)
|
||||
}
|
||||
v = reflect.New(t).Elem()
|
||||
v.SetFloat(n)
|
||||
}
|
||||
|
||||
out = append(out, v)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func sizeOfKind(kind reflect.Kind) int {
|
||||
switch kind {
|
||||
case reflect.Int8, reflect.Uint8:
|
||||
return 8
|
||||
case reflect.Int16, reflect.Uint16:
|
||||
return 16
|
||||
case reflect.Int32, reflect.Uint32, reflect.Float32:
|
||||
return 32
|
||||
case reflect.Int64, reflect.Uint64, reflect.Float64:
|
||||
return 64
|
||||
case reflect.Int, reflect.Uint:
|
||||
return strconv.IntSize
|
||||
}
|
||||
panic("unsupported kind " + kind.String())
|
||||
}
|
||||
|
||||
// Set field.
|
||||
//
|
||||
// If field is a pointer the pointer will be set to the value. If field is a string, value will be
|
||||
// appended. If field is a slice, value will be appended to slice.
|
||||
//
|
||||
// For all other types, an attempt will be made to convert the string to the corresponding
|
||||
// type (int, float32, etc.).
|
||||
func setField(pos lexer.Position, strct reflect.Value, field structLexerField, fieldValue []reflect.Value) (err error) { // nolint: gocyclo
|
||||
defer decorate(&err, func() string { return pos.String() + ": " + strct.Type().String() + "." + field.Name })
|
||||
|
||||
f := strct.FieldByIndex(field.Index)
|
||||
switch f.Kind() {
|
||||
case reflect.Slice:
|
||||
fieldValue, err = conform(f.Type().Elem(), fieldValue)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f.Set(reflect.Append(f, fieldValue...))
|
||||
return nil
|
||||
|
||||
case reflect.Ptr:
|
||||
if f.IsNil() {
|
||||
fv := reflect.New(f.Type().Elem()).Elem()
|
||||
f.Set(fv.Addr())
|
||||
f = fv
|
||||
} else {
|
||||
f = f.Elem()
|
||||
}
|
||||
}
|
||||
|
||||
if f.Kind() == reflect.Struct {
|
||||
if pf := f.FieldByName("Pos"); pf.IsValid() && pf.Type() == positionType {
|
||||
pf.Set(reflect.ValueOf(pos))
|
||||
}
|
||||
}
|
||||
|
||||
if f.CanAddr() {
|
||||
if d, ok := f.Addr().Interface().(Capture); ok {
|
||||
ifv := []string{}
|
||||
for _, v := range fieldValue {
|
||||
ifv = append(ifv, v.Interface().(string))
|
||||
}
|
||||
err := d.Capture(ifv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Strings concatenate all captured tokens.
|
||||
if f.Kind() == reflect.String {
|
||||
fieldValue, err = conform(f.Type(), fieldValue)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, v := range fieldValue {
|
||||
f.Set(reflect.ValueOf(f.String() + v.String()).Convert(f.Type()))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Coalesce multiple tokens into one. This allows eg. ["-", "10"] to be captured as separate tokens but
|
||||
// parsed as a single string "-10".
|
||||
if len(fieldValue) > 1 {
|
||||
out := []string{}
|
||||
for _, v := range fieldValue {
|
||||
out = append(out, v.String())
|
||||
}
|
||||
fieldValue = []reflect.Value{reflect.ValueOf(strings.Join(out, ""))}
|
||||
}
|
||||
|
||||
fieldValue, err = conform(f.Type(), fieldValue)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fv := fieldValue[0]
|
||||
|
||||
switch f.Kind() {
|
||||
// Numeric types will increment if the token can not be coerced.
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
if fv.Type() != f.Type() {
|
||||
f.SetInt(f.Int() + 1)
|
||||
} else {
|
||||
f.Set(fv)
|
||||
}
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
if fv.Type() != f.Type() {
|
||||
f.SetUint(f.Uint() + 1)
|
||||
} else {
|
||||
f.Set(fv)
|
||||
}
|
||||
|
||||
case reflect.Float32, reflect.Float64:
|
||||
if fv.Type() != f.Type() {
|
||||
f.SetFloat(f.Float() + 1)
|
||||
} else {
|
||||
f.Set(fv)
|
||||
}
|
||||
|
||||
case reflect.Bool, reflect.Struct:
|
||||
if fv.Type() != f.Type() {
|
||||
return fmt.Errorf("value %q is not correct type %s", fv, f.Type())
|
||||
}
|
||||
f.Set(fv)
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unsupported field type %s for field %s", f.Type(), field.Name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Error is an error returned by the parser internally to differentiate from non-Participle errors.
|
||||
type Error string
|
||||
|
||||
func (e Error) Error() string { return string(e) }
|
|
@ -0,0 +1,39 @@
|
|||
package participle
|
||||
|
||||
import (
|
||||
"github.com/alecthomas/participle/lexer"
|
||||
)
|
||||
|
||||
// An Option to modify the behaviour of the Parser.
|
||||
type Option func(p *Parser) error
|
||||
|
||||
// Lexer is an Option that sets the lexer to use with the given grammar.
|
||||
func Lexer(def lexer.Definition) Option {
|
||||
return func(p *Parser) error {
|
||||
p.lex = def
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// UseLookahead allows branch lookahead up to "n" tokens.
|
||||
//
|
||||
// If parsing cannot be disambiguated before "n" tokens of lookahead, parsing will fail.
|
||||
//
|
||||
// Note that increasing lookahead has a minor performance impact, but also
|
||||
// reduces the accuracy of error reporting.
|
||||
func UseLookahead(n int) Option {
|
||||
return func(p *Parser) error {
|
||||
p.useLookahead = n
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// CaseInsensitive allows the specified token types to be matched case-insensitively.
|
||||
func CaseInsensitive(tokens ...string) Option {
|
||||
return func(p *Parser) error {
|
||||
for _, token := range tokens {
|
||||
p.caseInsensitive[token] = true
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
|
@ -0,0 +1,229 @@
|
|||
package participle
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/alecthomas/participle/lexer"
|
||||
)
|
||||
|
||||
// A Parser for a particular grammar and lexer.
|
||||
type Parser struct {
|
||||
root node
|
||||
lex lexer.Definition
|
||||
typ reflect.Type
|
||||
useLookahead int
|
||||
caseInsensitive map[string]bool
|
||||
mappers []mapperByToken
|
||||
}
|
||||
|
||||
// MustBuild calls Build(grammar, options...) and panics if an error occurs.
|
||||
func MustBuild(grammar interface{}, options ...Option) *Parser {
|
||||
parser, err := Build(grammar, options...)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return parser
|
||||
}
|
||||
|
||||
// Build constructs a parser for the given grammar.
|
||||
//
|
||||
// If "Lexer()" is not provided as an option, a default lexer based on text/scanner will be used. This scans typical Go-
|
||||
// like tokens.
|
||||
//
|
||||
// See documentation for details
|
||||
func Build(grammar interface{}, options ...Option) (parser *Parser, err error) {
|
||||
// Configure Parser struct with defaults + options.
|
||||
p := &Parser{
|
||||
lex: lexer.TextScannerLexer,
|
||||
caseInsensitive: map[string]bool{},
|
||||
useLookahead: 1,
|
||||
}
|
||||
for _, option := range options {
|
||||
if option == nil {
|
||||
return nil, fmt.Errorf("nil Option passed, signature has changed; " +
|
||||
"if you intended to provide a custom Lexer, try participle.Build(grammar, participle.Lexer(lexer))")
|
||||
}
|
||||
if err = option(p); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if len(p.mappers) > 0 {
|
||||
mappers := map[rune][]Mapper{}
|
||||
symbols := p.lex.Symbols()
|
||||
for _, mapper := range p.mappers {
|
||||
if len(mapper.symbols) == 0 {
|
||||
mappers[lexer.EOF] = append(mappers[lexer.EOF], mapper.mapper)
|
||||
} else {
|
||||
for _, symbol := range mapper.symbols {
|
||||
if rn, ok := symbols[symbol]; !ok {
|
||||
return nil, fmt.Errorf("mapper %#v uses unknown token %q", mapper, symbol)
|
||||
} else { // nolint: golint
|
||||
mappers[rn] = append(mappers[rn], mapper.mapper)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
p.lex = &mappingLexerDef{p.lex, func(t lexer.Token) (lexer.Token, error) {
|
||||
combined := make([]Mapper, 0, len(mappers[t.Type])+len(mappers[lexer.EOF]))
|
||||
combined = append(combined, mappers[lexer.EOF]...)
|
||||
combined = append(combined, mappers[t.Type]...)
|
||||
|
||||
var err error
|
||||
for _, m := range combined {
|
||||
t, err = m(t)
|
||||
if err != nil {
|
||||
return t, err
|
||||
}
|
||||
}
|
||||
return t, nil
|
||||
}}
|
||||
}
|
||||
|
||||
context := newGeneratorContext(p.lex)
|
||||
v := reflect.ValueOf(grammar)
|
||||
if v.Kind() == reflect.Interface {
|
||||
v = v.Elem()
|
||||
}
|
||||
p.typ = v.Type()
|
||||
p.root, err = context.parseType(p.typ)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// Lex uses the parser's lexer to tokenise input.
|
||||
func (p *Parser) Lex(r io.Reader) ([]lexer.Token, error) {
|
||||
lex, err := p.lex.Lex(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return lexer.ConsumeAll(lex)
|
||||
}
|
||||
|
||||
// Parse from r into grammar v which must be of the same type as the grammar passed to
|
||||
// participle.Build().
|
||||
func (p *Parser) Parse(r io.Reader, v interface{}) (err error) {
|
||||
rv := reflect.ValueOf(v)
|
||||
if rv.Kind() == reflect.Interface {
|
||||
rv = rv.Elem()
|
||||
}
|
||||
var stream reflect.Value
|
||||
if rv.Kind() == reflect.Chan {
|
||||
stream = rv
|
||||
rt := rv.Type().Elem()
|
||||
rv = reflect.New(rt).Elem()
|
||||
}
|
||||
rt := rv.Type()
|
||||
if rt != p.typ {
|
||||
return fmt.Errorf("must parse into value of type %s not %T", p.typ, v)
|
||||
}
|
||||
baseLexer, err := p.lex.Lex(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
lex := lexer.Upgrade(baseLexer)
|
||||
caseInsensitive := map[rune]bool{}
|
||||
for sym, rn := range p.lex.Symbols() {
|
||||
if p.caseInsensitive[sym] {
|
||||
caseInsensitive[rn] = true
|
||||
}
|
||||
}
|
||||
ctx, err := newParseContext(lex, p.useLookahead, caseInsensitive)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// If the grammar implements Parseable, use it.
|
||||
if parseable, ok := v.(Parseable); ok {
|
||||
return p.rootParseable(ctx, parseable)
|
||||
}
|
||||
if rt.Kind() != reflect.Ptr || rt.Elem().Kind() != reflect.Struct {
|
||||
return fmt.Errorf("target must be a pointer to a struct, not %s", rt)
|
||||
}
|
||||
if stream.IsValid() {
|
||||
return p.parseStreaming(ctx, stream)
|
||||
}
|
||||
return p.parseOne(ctx, rv)
|
||||
}
|
||||
|
||||
func (p *Parser) parseStreaming(ctx *parseContext, rv reflect.Value) error {
|
||||
t := rv.Type().Elem().Elem()
|
||||
for {
|
||||
if token, _ := ctx.Peek(0); token.EOF() {
|
||||
rv.Close()
|
||||
return nil
|
||||
}
|
||||
v := reflect.New(t)
|
||||
if err := p.parseInto(ctx, v); err != nil {
|
||||
return err
|
||||
}
|
||||
rv.Send(v)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Parser) parseOne(ctx *parseContext, rv reflect.Value) error {
|
||||
err := p.parseInto(ctx, rv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
token, err := ctx.Peek(0)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if !token.EOF() {
|
||||
return lexer.Errorf(token.Pos, "unexpected trailing token %q", token)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Parser) parseInto(ctx *parseContext, rv reflect.Value) error {
|
||||
if rv.IsNil() {
|
||||
return fmt.Errorf("target must be a non-nil pointer to a struct, but is a nil %s", rv.Type())
|
||||
}
|
||||
pv, err := p.root.Parse(ctx, rv.Elem())
|
||||
if len(pv) > 0 && pv[0].Type() == rv.Elem().Type() {
|
||||
rv.Elem().Set(reflect.Indirect(pv[0]))
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if pv == nil {
|
||||
token, _ := ctx.Peek(0)
|
||||
return lexer.Errorf(token.Pos, "invalid syntax")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Parser) rootParseable(lex lexer.PeekingLexer, parseable Parseable) error {
|
||||
peek, err := lex.Peek(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = parseable.Parse(lex)
|
||||
if err == NextMatch {
|
||||
return lexer.Errorf(peek.Pos, "invalid syntax")
|
||||
}
|
||||
if err == nil && !peek.EOF() {
|
||||
return lexer.Errorf(peek.Pos, "unexpected token %q", peek)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// ParseString is a convenience around Parse().
|
||||
func (p *Parser) ParseString(s string, v interface{}) error {
|
||||
return p.Parse(strings.NewReader(s), v)
|
||||
}
|
||||
|
||||
// ParseBytes is a convenience around Parse().
|
||||
func (p *Parser) ParseBytes(b []byte, v interface{}) error {
|
||||
return p.Parse(bytes.NewReader(b), v)
|
||||
}
|
||||
|
||||
// String representation of the grammar.
|
||||
func (p *Parser) String() string {
|
||||
return stringern(p.root, 128)
|
||||
}
|
|
@ -0,0 +1,118 @@
|
|||
package participle
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/alecthomas/participle/lexer"
|
||||
)
|
||||
|
||||
type stringerVisitor struct {
|
||||
bytes.Buffer
|
||||
seen map[node]bool
|
||||
}
|
||||
|
||||
func stringern(n node, depth int) string {
|
||||
v := &stringerVisitor{seen: map[node]bool{}}
|
||||
v.visit(n, depth, false)
|
||||
return v.String()
|
||||
}
|
||||
|
||||
func stringer(n node) string {
|
||||
return stringern(n, 1)
|
||||
}
|
||||
|
||||
func (s *stringerVisitor) visit(n node, depth int, disjunctions bool) {
|
||||
if s.seen[n] || depth <= 0 {
|
||||
fmt.Fprintf(s, "...")
|
||||
return
|
||||
}
|
||||
s.seen[n] = true
|
||||
|
||||
switch n := n.(type) {
|
||||
case *disjunction:
|
||||
for i, c := range n.nodes {
|
||||
if i > 0 {
|
||||
fmt.Fprint(s, " | ")
|
||||
}
|
||||
s.visit(c, depth, disjunctions || len(n.nodes) > 1)
|
||||
}
|
||||
|
||||
case *strct:
|
||||
s.visit(n.expr, depth, disjunctions)
|
||||
|
||||
case *sequence:
|
||||
c := n
|
||||
for i := 0; c != nil && depth-i > 0; c, i = c.next, i+1 {
|
||||
if c != n {
|
||||
fmt.Fprint(s, " ")
|
||||
}
|
||||
s.visit(c.node, depth-i, disjunctions)
|
||||
}
|
||||
if c != nil {
|
||||
fmt.Fprint(s, " ...")
|
||||
}
|
||||
|
||||
case *parseable:
|
||||
fmt.Fprintf(s, "<%s>", strings.ToLower(n.t.Name()))
|
||||
|
||||
case *capture:
|
||||
if _, ok := n.node.(*parseable); ok {
|
||||
fmt.Fprintf(s, "<%s>", strings.ToLower(n.field.Name))
|
||||
} else {
|
||||
if n.node == nil {
|
||||
fmt.Fprintf(s, "<%s>", strings.ToLower(n.field.Name))
|
||||
} else {
|
||||
s.visit(n.node, depth, disjunctions)
|
||||
}
|
||||
}
|
||||
|
||||
case *reference:
|
||||
fmt.Fprintf(s, "<%s>", strings.ToLower(n.identifier))
|
||||
|
||||
case *optional:
|
||||
fmt.Fprint(s, "[ ")
|
||||
s.visit(n.node, depth, disjunctions)
|
||||
fmt.Fprint(s, " ]")
|
||||
|
||||
case *repetition:
|
||||
fmt.Fprint(s, "{ ")
|
||||
s.visit(n.node, depth, disjunctions)
|
||||
fmt.Fprint(s, " }")
|
||||
|
||||
case *literal:
|
||||
fmt.Fprintf(s, "%q", n.s)
|
||||
if n.t != lexer.EOF && n.s == "" {
|
||||
fmt.Fprintf(s, ":%s", n.tt)
|
||||
}
|
||||
|
||||
case *group:
|
||||
fmt.Fprint(s, "(")
|
||||
if child, ok := n.expr.(*group); ok && child.mode == groupMatchOnce {
|
||||
s.visit(child.expr, depth, disjunctions)
|
||||
} else if child, ok := n.expr.(*capture); ok {
|
||||
if grandchild, ok := child.node.(*group); ok && grandchild.mode == groupMatchOnce {
|
||||
s.visit(grandchild.expr, depth, disjunctions)
|
||||
} else {
|
||||
s.visit(n.expr, depth, disjunctions)
|
||||
}
|
||||
} else {
|
||||
s.visit(n.expr, depth, disjunctions)
|
||||
}
|
||||
fmt.Fprint(s, ")")
|
||||
switch n.mode {
|
||||
case groupMatchNonEmpty:
|
||||
fmt.Fprintf(s, "!")
|
||||
case groupMatchZeroOrOne:
|
||||
fmt.Fprintf(s, "?")
|
||||
case groupMatchZeroOrMore:
|
||||
fmt.Fprintf(s, "*")
|
||||
case groupMatchOneOrMore:
|
||||
fmt.Fprintf(s, "+")
|
||||
}
|
||||
|
||||
default:
|
||||
panic("unsupported")
|
||||
}
|
||||
}
|
|
@ -0,0 +1,126 @@
|
|||
package participle
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/alecthomas/participle/lexer"
|
||||
)
|
||||
|
||||
// A structLexer lexes over the tags of struct fields while tracking the current field.
|
||||
type structLexer struct {
|
||||
s reflect.Type
|
||||
field int
|
||||
indexes [][]int
|
||||
lexer lexer.PeekingLexer
|
||||
}
|
||||
|
||||
func lexStruct(s reflect.Type) (*structLexer, error) {
|
||||
indexes, err := collectFieldIndexes(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
slex := &structLexer{
|
||||
s: s,
|
||||
indexes: indexes,
|
||||
}
|
||||
if len(slex.indexes) > 0 {
|
||||
tag := fieldLexerTag(slex.Field().StructField)
|
||||
slex.lexer = lexer.Upgrade(lexer.LexString(tag))
|
||||
}
|
||||
return slex, nil
|
||||
}
|
||||
|
||||
// NumField returns the number of fields in the struct associated with this structLexer.
|
||||
func (s *structLexer) NumField() int {
|
||||
return len(s.indexes)
|
||||
}
|
||||
|
||||
type structLexerField struct {
|
||||
reflect.StructField
|
||||
Index []int
|
||||
}
|
||||
|
||||
// Field returns the field associated with the current token.
|
||||
func (s *structLexer) Field() structLexerField {
|
||||
return s.GetField(s.field)
|
||||
}
|
||||
|
||||
func (s *structLexer) GetField(field int) structLexerField {
|
||||
if field >= len(s.indexes) {
|
||||
field = len(s.indexes) - 1
|
||||
}
|
||||
return structLexerField{
|
||||
StructField: s.s.FieldByIndex(s.indexes[field]),
|
||||
Index: s.indexes[field],
|
||||
}
|
||||
}
|
||||
|
||||
func (s *structLexer) Peek() (lexer.Token, error) {
|
||||
field := s.field
|
||||
lex := s.lexer
|
||||
for {
|
||||
token, err := lex.Peek(0)
|
||||
if err != nil {
|
||||
return token, err
|
||||
}
|
||||
if !token.EOF() {
|
||||
token.Pos.Line = field + 1
|
||||
return token, nil
|
||||
}
|
||||
field++
|
||||
if field >= s.NumField() {
|
||||
return lexer.EOFToken(token.Pos), nil
|
||||
}
|
||||
tag := fieldLexerTag(s.GetField(field).StructField)
|
||||
lex = lexer.Upgrade(lexer.LexString(tag))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *structLexer) Next() (lexer.Token, error) {
|
||||
token, err := s.lexer.Next()
|
||||
if err != nil {
|
||||
return token, err
|
||||
}
|
||||
if !token.EOF() {
|
||||
token.Pos.Line = s.field + 1
|
||||
return token, nil
|
||||
}
|
||||
if s.field+1 >= s.NumField() {
|
||||
return lexer.EOFToken(token.Pos), nil
|
||||
}
|
||||
s.field++
|
||||
tag := fieldLexerTag(s.Field().StructField)
|
||||
s.lexer = lexer.Upgrade(lexer.LexString(tag))
|
||||
return s.Next()
|
||||
}
|
||||
|
||||
func fieldLexerTag(field reflect.StructField) string {
|
||||
if tag, ok := field.Tag.Lookup("parser"); ok {
|
||||
return tag
|
||||
}
|
||||
return string(field.Tag)
|
||||
}
|
||||
|
||||
// Recursively collect flattened indices for top-level fields and embedded fields.
|
||||
func collectFieldIndexes(s reflect.Type) (out [][]int, err error) {
|
||||
if s.Kind() != reflect.Struct {
|
||||
return nil, fmt.Errorf("expected a struct but got %q", s)
|
||||
}
|
||||
defer decorate(&err, s.String)
|
||||
for i := 0; i < s.NumField(); i++ {
|
||||
f := s.Field(i)
|
||||
if f.Anonymous {
|
||||
children, err := collectFieldIndexes(f.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, idx := range children {
|
||||
out = append(out, append(f.Index, idx...))
|
||||
}
|
||||
} else if fieldLexerTag(f) != "" {
|
||||
out = append(out, f.Index)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
|
@ -88,6 +88,7 @@ type File struct {
|
|||
rowGroups []*parquet.RowGroup
|
||||
rowGroupIndex int
|
||||
|
||||
nameList []string
|
||||
columnNames set.StringSet
|
||||
columns map[string]*column
|
||||
rowIndex int64
|
||||
|
@ -100,16 +101,23 @@ func Open(getReaderFunc GetReaderFunc, columnNames set.StringSet) (*File, error)
|
|||
return nil, err
|
||||
}
|
||||
|
||||
nameList := []string{}
|
||||
schemaElements := fileMeta.GetSchema()
|
||||
for _, element := range schemaElements {
|
||||
nameList = append(nameList, element.Name)
|
||||
}
|
||||
|
||||
return &File{
|
||||
getReaderFunc: getReaderFunc,
|
||||
rowGroups: fileMeta.GetRowGroups(),
|
||||
schemaElements: fileMeta.GetSchema(),
|
||||
schemaElements: schemaElements,
|
||||
nameList: nameList,
|
||||
columnNames: columnNames,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Read - reads single record.
|
||||
func (file *File) Read() (record map[string]Value, err error) {
|
||||
func (file *File) Read() (record *Record, err error) {
|
||||
if file.rowGroupIndex >= len(file.rowGroups) {
|
||||
return nil, io.EOF
|
||||
}
|
||||
|
@ -134,10 +142,10 @@ func (file *File) Read() (record map[string]Value, err error) {
|
|||
return file.Read()
|
||||
}
|
||||
|
||||
record = make(map[string]Value)
|
||||
record = newRecord(file.nameList)
|
||||
for name := range file.columns {
|
||||
value, valueType := file.columns[name].read()
|
||||
record[name] = Value{value, valueType}
|
||||
record.set(name, Value{value, valueType})
|
||||
}
|
||||
|
||||
file.rowIndex++
|
||||
|
|
|
@ -0,0 +1,70 @@
|
|||
/*
|
||||
* Minio Cloud Storage, (C) 2019 Minio, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package parquet
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Record - ordered parquet record.
|
||||
type Record struct {
|
||||
nameList []string
|
||||
nameValueMap map[string]Value
|
||||
}
|
||||
|
||||
// String - returns string representation of this record.
|
||||
func (r *Record) String() string {
|
||||
values := []string{}
|
||||
r.Range(func(name string, value Value) bool {
|
||||
values = append(values, fmt.Sprintf("%v:%v", name, value))
|
||||
return true
|
||||
})
|
||||
|
||||
return "map[" + strings.Join(values, " ") + "]"
|
||||
}
|
||||
|
||||
func (r *Record) set(name string, value Value) {
|
||||
r.nameValueMap[name] = value
|
||||
}
|
||||
|
||||
// Get - returns Value of name.
|
||||
func (r *Record) Get(name string) (Value, bool) {
|
||||
value, ok := r.nameValueMap[name]
|
||||
return value, ok
|
||||
}
|
||||
|
||||
// Range - calls f sequentially for each name and value present in the record. If f returns false, range stops the iteration.
|
||||
func (r *Record) Range(f func(name string, value Value) bool) {
|
||||
for _, name := range r.nameList {
|
||||
value, ok := r.nameValueMap[name]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if !f(name, value) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func newRecord(nameList []string) *Record {
|
||||
return &Record{
|
||||
nameList: nameList,
|
||||
nameValueMap: make(map[string]Value),
|
||||
}
|
||||
}
|
|
@ -1,9 +0,0 @@
|
|||
This project is originally a fork of [https://github.com/youtube/vitess](https://github.com/youtube/vitess)
|
||||
Copyright Google Inc
|
||||
|
||||
# Contributors
|
||||
Wenbin Xiao 2015
|
||||
Started this project and maintained it.
|
||||
|
||||
Andrew Brampton 2017
|
||||
Merged in multiple upstream fixes/changes.
|
|
@ -1,201 +0,0 @@
|
|||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
|
@ -1,22 +0,0 @@
|
|||
# Copyright 2017 Google Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
MAKEFLAGS = -s
|
||||
|
||||
sql.go: sql.y
|
||||
goyacc -o sql.go sql.y
|
||||
gofmt -w sql.go
|
||||
|
||||
clean:
|
||||
rm -f y.output sql.go
|
|
@ -1,150 +0,0 @@
|
|||
# sqlparser [![Build Status](https://img.shields.io/travis/xwb1989/sqlparser.svg)](https://travis-ci.org/xwb1989/sqlparser) [![Coverage](https://img.shields.io/coveralls/xwb1989/sqlparser.svg)](https://coveralls.io/github/xwb1989/sqlparser) [![Report card](https://goreportcard.com/badge/github.com/xwb1989/sqlparser)](https://goreportcard.com/report/github.com/xwb1989/sqlparser) [![GoDoc](https://godoc.org/github.com/xwb1989/sqlparser?status.svg)](https://godoc.org/github.com/xwb1989/sqlparser)
|
||||
|
||||
Go package for parsing MySQL SQL queries.
|
||||
|
||||
## Notice
|
||||
|
||||
The backbone of this repo is extracted from [vitessio/vitess](https://github.com/vitessio/vitess).
|
||||
|
||||
Inside vitessio/vitess there is a very nicely written sql parser. However as it's not a self-contained application, I created this one.
|
||||
It applies the same LICENSE as vitessio/vitess.
|
||||
|
||||
## Usage
|
||||
|
||||
```go
|
||||
import (
|
||||
"github.com/xwb1989/sqlparser"
|
||||
)
|
||||
```
|
||||
|
||||
Then use:
|
||||
|
||||
```go
|
||||
sql := "SELECT * FROM table WHERE a = 'abc'"
|
||||
stmt, err := sqlparser.Parse(sql)
|
||||
if err != nil {
|
||||
// Do something with the err
|
||||
}
|
||||
|
||||
// Otherwise do something with stmt
|
||||
switch stmt := stmt.(type) {
|
||||
case *sqlparser.Select:
|
||||
_ = stmt
|
||||
case *sqlparser.Insert:
|
||||
}
|
||||
```
|
||||
|
||||
Alternative to read many queries from a io.Reader:
|
||||
|
||||
```go
|
||||
r := strings.NewReader("INSERT INTO table1 VALUES (1, 'a'); INSERT INTO table2 VALUES (3, 4);")
|
||||
|
||||
tokens := sqlparser.NewTokenizer(r)
|
||||
for {
|
||||
stmt, err := sqlparser.ParseNext(tokens)
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
// Do something with stmt or err.
|
||||
}
|
||||
```
|
||||
|
||||
See [parse_test.go](https://github.com/xwb1989/sqlparser/blob/master/parse_test.go) for more examples, or read the [godoc](https://godoc.org/github.com/xwb1989/sqlparser).
|
||||
|
||||
|
||||
## Porting Instructions
|
||||
|
||||
You only need the below if you plan to try and keep this library up to date with [vitessio/vitess](https://github.com/vitessio/vitess).
|
||||
|
||||
### Keeping up to date
|
||||
|
||||
```bash
|
||||
shopt -s nullglob
|
||||
VITESS=${GOPATH?}/src/vitess.io/vitess/go/
|
||||
XWB1989=${GOPATH?}/src/github.com/xwb1989/sqlparser/
|
||||
|
||||
# Create patches for everything that changed
|
||||
LASTIMPORT=1b7879cb91f1dfe1a2dfa06fea96e951e3a7aec5
|
||||
for path in ${VITESS?}/{vt/sqlparser,sqltypes,bytes2,hack}; do
|
||||
cd ${path}
|
||||
git format-patch ${LASTIMPORT?} .
|
||||
done;
|
||||
|
||||
# Apply patches to the dependencies
|
||||
cd ${XWB1989?}
|
||||
git am --directory dependency -p2 ${VITESS?}/{sqltypes,bytes2,hack}/*.patch
|
||||
|
||||
# Apply the main patches to the repo
|
||||
cd ${XWB1989?}
|
||||
git am -p4 ${VITESS?}/vt/sqlparser/*.patch
|
||||
|
||||
# If you encounter diff failures, manually fix them with
|
||||
patch -p4 < .git/rebase-apply/patch
|
||||
...
|
||||
git add name_of_files
|
||||
git am --continue
|
||||
|
||||
# Cleanup
|
||||
rm ${VITESS?}/{sqltypes,bytes2,hack}/*.patch ${VITESS?}/*.patch
|
||||
|
||||
# and Finally update the LASTIMPORT in this README.
|
||||
```
|
||||
|
||||
### Fresh install
|
||||
|
||||
TODO: Change these instructions to use git to copy the files, that'll make later patching easier.
|
||||
|
||||
```bash
|
||||
VITESS=${GOPATH?}/src/vitess.io/vitess/go/
|
||||
XWB1989=${GOPATH?}/src/github.com/xwb1989/sqlparser/
|
||||
|
||||
cd ${XWB1989?}
|
||||
|
||||
# Copy all the code
|
||||
cp -pr ${VITESS?}/vt/sqlparser/ .
|
||||
cp -pr ${VITESS?}/sqltypes dependency
|
||||
cp -pr ${VITESS?}/bytes2 dependency
|
||||
cp -pr ${VITESS?}/hack dependency
|
||||
|
||||
# Delete some code we haven't ported
|
||||
rm dependency/sqltypes/arithmetic.go dependency/sqltypes/arithmetic_test.go dependency/sqltypes/event_token.go dependency/sqltypes/event_token_test.go dependency/sqltypes/proto3.go dependency/sqltypes/proto3_test.go dependency/sqltypes/query_response.go dependency/sqltypes/result.go dependency/sqltypes/result_test.go
|
||||
|
||||
# Some automated fixes
|
||||
|
||||
# Fix imports
|
||||
sed -i '.bak' 's_vitess.io/vitess/go/vt/proto/query_github.com/xwb1989/sqlparser/dependency/querypb_g' *.go dependency/sqltypes/*.go
|
||||
sed -i '.bak' 's_vitess.io/vitess/go/_github.com/xwb1989/sqlparser/dependency/_g' *.go dependency/sqltypes/*.go
|
||||
|
||||
# Copy the proto, but basically drop everything we don't want
|
||||
cp -pr ${VITESS?}/vt/proto/query dependency/querypb
|
||||
|
||||
sed -i '.bak' 's_.*Descriptor.*__g' dependency/querypb/*.go
|
||||
sed -i '.bak' 's_.*ProtoMessage.*__g' dependency/querypb/*.go
|
||||
|
||||
sed -i '.bak' 's/proto.CompactTextString(m)/"TODO"/g' dependency/querypb/*.go
|
||||
sed -i '.bak' 's/proto.EnumName/EnumName/g' dependency/querypb/*.go
|
||||
|
||||
sed -i '.bak' 's/proto.Equal/reflect.DeepEqual/g' dependency/sqltypes/*.go
|
||||
|
||||
# Remove the error library
|
||||
sed -i '.bak' 's/vterrors.Errorf([^,]*, /fmt.Errorf(/g' *.go dependency/sqltypes/*.go
|
||||
sed -i '.bak' 's/vterrors.New([^,]*, /errors.New(/g' *.go dependency/sqltypes/*.go
|
||||
```
|
||||
|
||||
### Testing
|
||||
|
||||
```bash
|
||||
VITESS=${GOPATH?}/src/vitess.io/vitess/go/
|
||||
XWB1989=${GOPATH?}/src/github.com/xwb1989/sqlparser/
|
||||
|
||||
cd ${XWB1989?}
|
||||
|
||||
# Test, fix and repeat
|
||||
go test ./...
|
||||
|
||||
# Finally make some diffs (for later reference)
|
||||
diff -u ${VITESS?}/sqltypes/ ${XWB1989?}/dependency/sqltypes/ > ${XWB1989?}/patches/sqltypes.patch
|
||||
diff -u ${VITESS?}/bytes2/ ${XWB1989?}/dependency/bytes2/ > ${XWB1989?}/patches/bytes2.patch
|
||||
diff -u ${VITESS?}/vt/proto/query/ ${XWB1989?}/dependency/querypb/ > ${XWB1989?}/patches/querypb.patch
|
||||
diff -u ${VITESS?}/vt/sqlparser/ ${XWB1989?}/ > ${XWB1989?}/patches/sqlparser.patch
|
||||
```
|
|
@ -1,343 +0,0 @@
|
|||
/*
|
||||
Copyright 2017 Google Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package sqlparser
|
||||
|
||||
// analyzer.go contains utility analysis functions.
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/xwb1989/sqlparser/dependency/sqltypes"
|
||||
)
|
||||
|
||||
// These constants are used to identify the SQL statement type.
|
||||
const (
|
||||
StmtSelect = iota
|
||||
StmtStream
|
||||
StmtInsert
|
||||
StmtReplace
|
||||
StmtUpdate
|
||||
StmtDelete
|
||||
StmtDDL
|
||||
StmtBegin
|
||||
StmtCommit
|
||||
StmtRollback
|
||||
StmtSet
|
||||
StmtShow
|
||||
StmtUse
|
||||
StmtOther
|
||||
StmtUnknown
|
||||
StmtComment
|
||||
)
|
||||
|
||||
// Preview analyzes the beginning of the query using a simpler and faster
|
||||
// textual comparison to identify the statement type.
|
||||
func Preview(sql string) int {
|
||||
trimmed := StripLeadingComments(sql)
|
||||
|
||||
firstWord := trimmed
|
||||
if end := strings.IndexFunc(trimmed, unicode.IsSpace); end != -1 {
|
||||
firstWord = trimmed[:end]
|
||||
}
|
||||
firstWord = strings.TrimLeftFunc(firstWord, func(r rune) bool { return !unicode.IsLetter(r) })
|
||||
// Comparison is done in order of priority.
|
||||
loweredFirstWord := strings.ToLower(firstWord)
|
||||
switch loweredFirstWord {
|
||||
case "select":
|
||||
return StmtSelect
|
||||
case "stream":
|
||||
return StmtStream
|
||||
case "insert":
|
||||
return StmtInsert
|
||||
case "replace":
|
||||
return StmtReplace
|
||||
case "update":
|
||||
return StmtUpdate
|
||||
case "delete":
|
||||
return StmtDelete
|
||||
}
|
||||
// For the following statements it is not sufficient to rely
|
||||
// on loweredFirstWord. This is because they are not statements
|
||||
// in the grammar and we are relying on Preview to parse them.
|
||||
// For instance, we don't want: "BEGIN JUNK" to be parsed
|
||||
// as StmtBegin.
|
||||
trimmedNoComments, _ := SplitMarginComments(trimmed)
|
||||
switch strings.ToLower(trimmedNoComments) {
|
||||
case "begin", "start transaction":
|
||||
return StmtBegin
|
||||
case "commit":
|
||||
return StmtCommit
|
||||
case "rollback":
|
||||
return StmtRollback
|
||||
}
|
||||
switch loweredFirstWord {
|
||||
case "create", "alter", "rename", "drop", "truncate":
|
||||
return StmtDDL
|
||||
case "set":
|
||||
return StmtSet
|
||||
case "show":
|
||||
return StmtShow
|
||||
case "use":
|
||||
return StmtUse
|
||||
case "analyze", "describe", "desc", "explain", "repair", "optimize":
|
||||
return StmtOther
|
||||
}
|
||||
if strings.Index(trimmed, "/*!") == 0 {
|
||||
return StmtComment
|
||||
}
|
||||
return StmtUnknown
|
||||
}
|
||||
|
||||
// StmtType returns the statement type as a string
|
||||
func StmtType(stmtType int) string {
|
||||
switch stmtType {
|
||||
case StmtSelect:
|
||||
return "SELECT"
|
||||
case StmtStream:
|
||||
return "STREAM"
|
||||
case StmtInsert:
|
||||
return "INSERT"
|
||||
case StmtReplace:
|
||||
return "REPLACE"
|
||||
case StmtUpdate:
|
||||
return "UPDATE"
|
||||
case StmtDelete:
|
||||
return "DELETE"
|
||||
case StmtDDL:
|
||||
return "DDL"
|
||||
case StmtBegin:
|
||||
return "BEGIN"
|
||||
case StmtCommit:
|
||||
return "COMMIT"
|
||||
case StmtRollback:
|
||||
return "ROLLBACK"
|
||||
case StmtSet:
|
||||
return "SET"
|
||||
case StmtShow:
|
||||
return "SHOW"
|
||||
case StmtUse:
|
||||
return "USE"
|
||||
case StmtOther:
|
||||
return "OTHER"
|
||||
default:
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
|
||||
// IsDML returns true if the query is an INSERT, UPDATE or DELETE statement.
|
||||
func IsDML(sql string) bool {
|
||||
switch Preview(sql) {
|
||||
case StmtInsert, StmtReplace, StmtUpdate, StmtDelete:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetTableName returns the table name from the SimpleTableExpr
|
||||
// only if it's a simple expression. Otherwise, it returns "".
|
||||
func GetTableName(node SimpleTableExpr) TableIdent {
|
||||
if n, ok := node.(TableName); ok && n.Qualifier.IsEmpty() {
|
||||
return n.Name
|
||||
}
|
||||
// sub-select or '.' expression
|
||||
return NewTableIdent("")
|
||||
}
|
||||
|
||||
// IsColName returns true if the Expr is a *ColName.
|
||||
func IsColName(node Expr) bool {
|
||||
_, ok := node.(*ColName)
|
||||
return ok
|
||||
}
|
||||
|
||||
// IsValue returns true if the Expr is a string, integral or value arg.
|
||||
// NULL is not considered to be a value.
|
||||
func IsValue(node Expr) bool {
|
||||
switch v := node.(type) {
|
||||
case *SQLVal:
|
||||
switch v.Type {
|
||||
case StrVal, HexVal, IntVal, ValArg:
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsNull returns true if the Expr is SQL NULL
|
||||
func IsNull(node Expr) bool {
|
||||
switch node.(type) {
|
||||
case *NullVal:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsSimpleTuple returns true if the Expr is a ValTuple that
|
||||
// contains simple values or if it's a list arg.
|
||||
func IsSimpleTuple(node Expr) bool {
|
||||
switch vals := node.(type) {
|
||||
case ValTuple:
|
||||
for _, n := range vals {
|
||||
if !IsValue(n) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
case ListArg:
|
||||
return true
|
||||
}
|
||||
// It's a subquery
|
||||
return false
|
||||
}
|
||||
|
||||
// NewPlanValue builds a sqltypes.PlanValue from an Expr.
|
||||
func NewPlanValue(node Expr) (sqltypes.PlanValue, error) {
|
||||
switch node := node.(type) {
|
||||
case *SQLVal:
|
||||
switch node.Type {
|
||||
case ValArg:
|
||||
return sqltypes.PlanValue{Key: string(node.Val[1:])}, nil
|
||||
case IntVal:
|
||||
n, err := sqltypes.NewIntegral(string(node.Val))
|
||||
if err != nil {
|
||||
return sqltypes.PlanValue{}, fmt.Errorf("%v", err)
|
||||
}
|
||||
return sqltypes.PlanValue{Value: n}, nil
|
||||
case StrVal:
|
||||
return sqltypes.PlanValue{Value: sqltypes.MakeTrusted(sqltypes.VarBinary, node.Val)}, nil
|
||||
case HexVal:
|
||||
v, err := node.HexDecode()
|
||||
if err != nil {
|
||||
return sqltypes.PlanValue{}, fmt.Errorf("%v", err)
|
||||
}
|
||||
return sqltypes.PlanValue{Value: sqltypes.MakeTrusted(sqltypes.VarBinary, v)}, nil
|
||||
}
|
||||
case ListArg:
|
||||
return sqltypes.PlanValue{ListKey: string(node[2:])}, nil
|
||||
case ValTuple:
|
||||
pv := sqltypes.PlanValue{
|
||||
Values: make([]sqltypes.PlanValue, 0, len(node)),
|
||||
}
|
||||
for _, val := range node {
|
||||
innerpv, err := NewPlanValue(val)
|
||||
if err != nil {
|
||||
return sqltypes.PlanValue{}, err
|
||||
}
|
||||
if innerpv.ListKey != "" || innerpv.Values != nil {
|
||||
return sqltypes.PlanValue{}, errors.New("unsupported: nested lists")
|
||||
}
|
||||
pv.Values = append(pv.Values, innerpv)
|
||||
}
|
||||
return pv, nil
|
||||
case *NullVal:
|
||||
return sqltypes.PlanValue{}, nil
|
||||
}
|
||||
return sqltypes.PlanValue{}, fmt.Errorf("expression is too complex '%v'", String(node))
|
||||
}
|
||||
|
||||
// StringIn is a convenience function that returns
|
||||
// true if str matches any of the values.
|
||||
func StringIn(str string, values ...string) bool {
|
||||
for _, val := range values {
|
||||
if str == val {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// SetKey is the extracted key from one SetExpr
|
||||
type SetKey struct {
|
||||
Key string
|
||||
Scope string
|
||||
}
|
||||
|
||||
// ExtractSetValues returns a map of key-value pairs
|
||||
// if the query is a SET statement. Values can be bool, int64 or string.
|
||||
// Since set variable names are case insensitive, all keys are returned
|
||||
// as lower case.
|
||||
func ExtractSetValues(sql string) (keyValues map[SetKey]interface{}, scope string, err error) {
|
||||
stmt, err := Parse(sql)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
setStmt, ok := stmt.(*Set)
|
||||
if !ok {
|
||||
return nil, "", fmt.Errorf("ast did not yield *sqlparser.Set: %T", stmt)
|
||||
}
|
||||
result := make(map[SetKey]interface{})
|
||||
for _, expr := range setStmt.Exprs {
|
||||
scope := SessionStr
|
||||
key := expr.Name.Lowered()
|
||||
switch {
|
||||
case strings.HasPrefix(key, "@@global."):
|
||||
scope = GlobalStr
|
||||
key = strings.TrimPrefix(key, "@@global.")
|
||||
case strings.HasPrefix(key, "@@session."):
|
||||
key = strings.TrimPrefix(key, "@@session.")
|
||||
case strings.HasPrefix(key, "@@"):
|
||||
key = strings.TrimPrefix(key, "@@")
|
||||
}
|
||||
|
||||
if strings.HasPrefix(expr.Name.Lowered(), "@@") {
|
||||
if setStmt.Scope != "" && scope != "" {
|
||||
return nil, "", fmt.Errorf("unsupported in set: mixed using of variable scope")
|
||||
}
|
||||
_, out := NewStringTokenizer(key).Scan()
|
||||
key = string(out)
|
||||
}
|
||||
|
||||
setKey := SetKey{
|
||||
Key: key,
|
||||
Scope: scope,
|
||||
}
|
||||
|
||||
switch expr := expr.Expr.(type) {
|
||||
case *SQLVal:
|
||||
switch expr.Type {
|
||||
case StrVal:
|
||||
result[setKey] = strings.ToLower(string(expr.Val))
|
||||
case IntVal:
|
||||
num, err := strconv.ParseInt(string(expr.Val), 0, 64)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
result[setKey] = num
|
||||
default:
|
||||
return nil, "", fmt.Errorf("invalid value type: %v", String(expr))
|
||||
}
|
||||
case BoolVal:
|
||||
var val int64
|
||||
if expr {
|
||||
val = 1
|
||||
}
|
||||
result[setKey] = val
|
||||
case *ColName:
|
||||
result[setKey] = expr.Name.String()
|
||||
case *NullVal:
|
||||
result[setKey] = nil
|
||||
case *Default:
|
||||
result[setKey] = "default"
|
||||
default:
|
||||
return nil, "", fmt.Errorf("invalid syntax: %s", String(expr))
|
||||
}
|
||||
}
|
||||
return result, strings.ToLower(setStmt.Scope), nil
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -1,293 +0,0 @@
|
|||
/*
|
||||
Copyright 2017 Google Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package sqlparser
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
const (
|
||||
// DirectiveMultiShardAutocommit is the query comment directive to allow
|
||||
// single round trip autocommit with a multi-shard statement.
|
||||
DirectiveMultiShardAutocommit = "MULTI_SHARD_AUTOCOMMIT"
|
||||
// DirectiveSkipQueryPlanCache skips query plan cache when set.
|
||||
DirectiveSkipQueryPlanCache = "SKIP_QUERY_PLAN_CACHE"
|
||||
// DirectiveQueryTimeout sets a query timeout in vtgate. Only supported for SELECTS.
|
||||
DirectiveQueryTimeout = "QUERY_TIMEOUT_MS"
|
||||
)
|
||||
|
||||
func isNonSpace(r rune) bool {
|
||||
return !unicode.IsSpace(r)
|
||||
}
|
||||
|
||||
// leadingCommentEnd returns the first index after all leading comments, or
|
||||
// 0 if there are no leading comments.
|
||||
func leadingCommentEnd(text string) (end int) {
|
||||
hasComment := false
|
||||
pos := 0
|
||||
for pos < len(text) {
|
||||
// Eat up any whitespace. Trailing whitespace will be considered part of
|
||||
// the leading comments.
|
||||
nextVisibleOffset := strings.IndexFunc(text[pos:], isNonSpace)
|
||||
if nextVisibleOffset < 0 {
|
||||
break
|
||||
}
|
||||
pos += nextVisibleOffset
|
||||
remainingText := text[pos:]
|
||||
|
||||
// Found visible characters. Look for '/*' at the beginning
|
||||
// and '*/' somewhere after that.
|
||||
if len(remainingText) < 4 || remainingText[:2] != "/*" {
|
||||
break
|
||||
}
|
||||
commentLength := 4 + strings.Index(remainingText[2:], "*/")
|
||||
if commentLength < 4 {
|
||||
// Missing end comment :/
|
||||
break
|
||||
}
|
||||
|
||||
hasComment = true
|
||||
pos += commentLength
|
||||
}
|
||||
|
||||
if hasComment {
|
||||
return pos
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// trailingCommentStart returns the first index of trailing comments.
|
||||
// If there are no trailing comments, returns the length of the input string.
|
||||
func trailingCommentStart(text string) (start int) {
|
||||
hasComment := false
|
||||
reducedLen := len(text)
|
||||
for reducedLen > 0 {
|
||||
// Eat up any whitespace. Leading whitespace will be considered part of
|
||||
// the trailing comments.
|
||||
nextReducedLen := strings.LastIndexFunc(text[:reducedLen], isNonSpace) + 1
|
||||
if nextReducedLen == 0 {
|
||||
break
|
||||
}
|
||||
reducedLen = nextReducedLen
|
||||
if reducedLen < 4 || text[reducedLen-2:reducedLen] != "*/" {
|
||||
break
|
||||
}
|
||||
|
||||
// Find the beginning of the comment
|
||||
startCommentPos := strings.LastIndex(text[:reducedLen-2], "/*")
|
||||
if startCommentPos < 0 {
|
||||
// Badly formatted sql :/
|
||||
break
|
||||
}
|
||||
|
||||
hasComment = true
|
||||
reducedLen = startCommentPos
|
||||
}
|
||||
|
||||
if hasComment {
|
||||
return reducedLen
|
||||
}
|
||||
return len(text)
|
||||
}
|
||||
|
||||
// MarginComments holds the leading and trailing comments that surround a query.
|
||||
type MarginComments struct {
|
||||
Leading string
|
||||
Trailing string
|
||||
}
|
||||
|
||||
// SplitMarginComments pulls out any leading or trailing comments from a raw sql query.
|
||||
// This function also trims leading (if there's a comment) and trailing whitespace.
|
||||
func SplitMarginComments(sql string) (query string, comments MarginComments) {
|
||||
trailingStart := trailingCommentStart(sql)
|
||||
leadingEnd := leadingCommentEnd(sql[:trailingStart])
|
||||
comments = MarginComments{
|
||||
Leading: strings.TrimLeftFunc(sql[:leadingEnd], unicode.IsSpace),
|
||||
Trailing: strings.TrimRightFunc(sql[trailingStart:], unicode.IsSpace),
|
||||
}
|
||||
return strings.TrimFunc(sql[leadingEnd:trailingStart], unicode.IsSpace), comments
|
||||
}
|
||||
|
||||
// StripLeadingComments trims the SQL string and removes any leading comments
|
||||
func StripLeadingComments(sql string) string {
|
||||
sql = strings.TrimFunc(sql, unicode.IsSpace)
|
||||
|
||||
for hasCommentPrefix(sql) {
|
||||
switch sql[0] {
|
||||
case '/':
|
||||
// Multi line comment
|
||||
index := strings.Index(sql, "*/")
|
||||
if index <= 1 {
|
||||
return sql
|
||||
}
|
||||
// don't strip /*! ... */ or /*!50700 ... */
|
||||
if len(sql) > 2 && sql[2] == '!' {
|
||||
return sql
|
||||
}
|
||||
sql = sql[index+2:]
|
||||
case '-':
|
||||
// Single line comment
|
||||
index := strings.Index(sql, "\n")
|
||||
if index == -1 {
|
||||
return sql
|
||||
}
|
||||
sql = sql[index+1:]
|
||||
}
|
||||
|
||||
sql = strings.TrimFunc(sql, unicode.IsSpace)
|
||||
}
|
||||
|
||||
return sql
|
||||
}
|
||||
|
||||
func hasCommentPrefix(sql string) bool {
|
||||
return len(sql) > 1 && ((sql[0] == '/' && sql[1] == '*') || (sql[0] == '-' && sql[1] == '-'))
|
||||
}
|
||||
|
||||
// ExtractMysqlComment extracts the version and SQL from a comment-only query
|
||||
// such as /*!50708 sql here */
|
||||
func ExtractMysqlComment(sql string) (version string, innerSQL string) {
|
||||
sql = sql[3 : len(sql)-2]
|
||||
|
||||
digitCount := 0
|
||||
endOfVersionIndex := strings.IndexFunc(sql, func(c rune) bool {
|
||||
digitCount++
|
||||
return !unicode.IsDigit(c) || digitCount == 6
|
||||
})
|
||||
version = sql[0:endOfVersionIndex]
|
||||
innerSQL = strings.TrimFunc(sql[endOfVersionIndex:], unicode.IsSpace)
|
||||
|
||||
return version, innerSQL
|
||||
}
|
||||
|
||||
const commentDirectivePreamble = "/*vt+"
|
||||
|
||||
// CommentDirectives is the parsed representation for execution directives
|
||||
// conveyed in query comments
|
||||
type CommentDirectives map[string]interface{}
|
||||
|
||||
// ExtractCommentDirectives parses the comment list for any execution directives
|
||||
// of the form:
|
||||
//
|
||||
// /*vt+ OPTION_ONE=1 OPTION_TWO OPTION_THREE=abcd */
|
||||
//
|
||||
// It returns the map of the directive values or nil if there aren't any.
|
||||
func ExtractCommentDirectives(comments Comments) CommentDirectives {
|
||||
if comments == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var vals map[string]interface{}
|
||||
|
||||
for _, comment := range comments {
|
||||
commentStr := string(comment)
|
||||
if commentStr[0:5] != commentDirectivePreamble {
|
||||
continue
|
||||
}
|
||||
|
||||
if vals == nil {
|
||||
vals = make(map[string]interface{})
|
||||
}
|
||||
|
||||
// Split on whitespace and ignore the first and last directive
|
||||
// since they contain the comment start/end
|
||||
directives := strings.Fields(commentStr)
|
||||
for i := 1; i < len(directives)-1; i++ {
|
||||
directive := directives[i]
|
||||
sep := strings.IndexByte(directive, '=')
|
||||
|
||||
// No value is equivalent to a true boolean
|
||||
if sep == -1 {
|
||||
vals[directive] = true
|
||||
continue
|
||||
}
|
||||
|
||||
strVal := directive[sep+1:]
|
||||
directive = directive[:sep]
|
||||
|
||||
intVal, err := strconv.Atoi(strVal)
|
||||
if err == nil {
|
||||
vals[directive] = intVal
|
||||
continue
|
||||
}
|
||||
|
||||
boolVal, err := strconv.ParseBool(strVal)
|
||||
if err == nil {
|
||||
vals[directive] = boolVal
|
||||
continue
|
||||
}
|
||||
|
||||
vals[directive] = strVal
|
||||
}
|
||||
}
|
||||
return vals
|
||||
}
|
||||
|
||||
// IsSet checks the directive map for the named directive and returns
|
||||
// true if the directive is set and has a true/false or 0/1 value
|
||||
func (d CommentDirectives) IsSet(key string) bool {
|
||||
if d == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
val, ok := d[key]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
boolVal, ok := val.(bool)
|
||||
if ok {
|
||||
return boolVal
|
||||
}
|
||||
|
||||
intVal, ok := val.(int)
|
||||
if ok {
|
||||
return intVal == 1
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// SkipQueryPlanCacheDirective returns true if skip query plan cache directive is set to true in query.
|
||||
func SkipQueryPlanCacheDirective(stmt Statement) bool {
|
||||
switch stmt := stmt.(type) {
|
||||
case *Select:
|
||||
directives := ExtractCommentDirectives(stmt.Comments)
|
||||
if directives.IsSet(DirectiveSkipQueryPlanCache) {
|
||||
return true
|
||||
}
|
||||
case *Insert:
|
||||
directives := ExtractCommentDirectives(stmt.Comments)
|
||||
if directives.IsSet(DirectiveSkipQueryPlanCache) {
|
||||
return true
|
||||
}
|
||||
case *Update:
|
||||
directives := ExtractCommentDirectives(stmt.Comments)
|
||||
if directives.IsSet(DirectiveSkipQueryPlanCache) {
|
||||
return true
|
||||
}
|
||||
case *Delete:
|
||||
directives := ExtractCommentDirectives(stmt.Comments)
|
||||
if directives.IsSet(DirectiveSkipQueryPlanCache) {
|
||||
return true
|
||||
}
|
||||
default:
|
||||
return false
|
||||
}
|
||||
return false
|
||||
}
|
|
@ -1,99 +0,0 @@
|
|||
/*
|
||||
Copyright 2017 Google Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package sqlparser
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/xwb1989/sqlparser/dependency/sqltypes"
|
||||
)
|
||||
|
||||
// This file contains types that are 'Encodable'.
|
||||
|
||||
// Encodable defines the interface for types that can
|
||||
// be custom-encoded into SQL.
|
||||
type Encodable interface {
|
||||
EncodeSQL(buf *bytes.Buffer)
|
||||
}
|
||||
|
||||
// InsertValues is a custom SQL encoder for the values of
|
||||
// an insert statement.
|
||||
type InsertValues [][]sqltypes.Value
|
||||
|
||||
// EncodeSQL performs the SQL encoding for InsertValues.
|
||||
func (iv InsertValues) EncodeSQL(buf *bytes.Buffer) {
|
||||
for i, rows := range iv {
|
||||
if i != 0 {
|
||||
buf.WriteString(", ")
|
||||
}
|
||||
buf.WriteByte('(')
|
||||
for j, bv := range rows {
|
||||
if j != 0 {
|
||||
buf.WriteString(", ")
|
||||
}
|
||||
bv.EncodeSQL(buf)
|
||||
}
|
||||
buf.WriteByte(')')
|
||||
}
|
||||
}
|
||||
|
||||
// TupleEqualityList is for generating equality constraints
|
||||
// for tables that have composite primary keys.
|
||||
type TupleEqualityList struct {
|
||||
Columns []ColIdent
|
||||
Rows [][]sqltypes.Value
|
||||
}
|
||||
|
||||
// EncodeSQL generates the where clause constraints for the tuple
|
||||
// equality.
|
||||
func (tpl *TupleEqualityList) EncodeSQL(buf *bytes.Buffer) {
|
||||
if len(tpl.Columns) == 1 {
|
||||
tpl.encodeAsIn(buf)
|
||||
return
|
||||
}
|
||||
tpl.encodeAsEquality(buf)
|
||||
}
|
||||
|
||||
func (tpl *TupleEqualityList) encodeAsIn(buf *bytes.Buffer) {
|
||||
Append(buf, tpl.Columns[0])
|
||||
buf.WriteString(" in (")
|
||||
for i, r := range tpl.Rows {
|
||||
if i != 0 {
|
||||
buf.WriteString(", ")
|
||||
}
|
||||
r[0].EncodeSQL(buf)
|
||||
}
|
||||
buf.WriteByte(')')
|
||||
}
|
||||
|
||||
func (tpl *TupleEqualityList) encodeAsEquality(buf *bytes.Buffer) {
|
||||
for i, r := range tpl.Rows {
|
||||
if i != 0 {
|
||||
buf.WriteString(" or ")
|
||||
}
|
||||
buf.WriteString("(")
|
||||
for j, c := range tpl.Columns {
|
||||
if j != 0 {
|
||||
buf.WriteString(" and ")
|
||||
}
|
||||
Append(buf, c)
|
||||
buf.WriteString(" = ")
|
||||
r[j].EncodeSQL(buf)
|
||||
}
|
||||
buf.WriteByte(')')
|
||||
}
|
||||
}
|
|
@ -1,39 +0,0 @@
|
|||
/*
|
||||
Copyright 2017 Google Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreedto in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package sqlparser
|
||||
|
||||
// FormatImpossibleQuery creates an impossible query in a TrackedBuffer.
|
||||
// An impossible query is a modified version of a query where all selects have where clauses that are
|
||||
// impossible for mysql to resolve. This is used in the vtgate and vttablet:
|
||||
//
|
||||
// - In the vtgate it's used for joins: if the first query returns no result, then vtgate uses the impossible
|
||||
// query just to fetch field info from vttablet
|
||||
// - In the vttablet, it's just an optimization: the field info is fetched once form MySQL, cached and reused
|
||||
// for subsequent queries
|
||||
func FormatImpossibleQuery(buf *TrackedBuffer, node SQLNode) {
|
||||
switch node := node.(type) {
|
||||
case *Select:
|
||||
buf.Myprintf("select %v from %v where 1 != 1", node.SelectExprs, node.From)
|
||||
if node.GroupBy != nil {
|
||||
node.GroupBy.Format(buf)
|
||||
}
|
||||
case *Union:
|
||||
buf.Myprintf("%v %s %v", node.Left, node.Type, node.Right)
|
||||
default:
|
||||
node.Format(buf)
|
||||
}
|
||||
}
|
|
@ -1,224 +0,0 @@
|
|||
/*
|
||||
Copyright 2017 Google Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package sqlparser
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/xwb1989/sqlparser/dependency/sqltypes"
|
||||
|
||||
"github.com/xwb1989/sqlparser/dependency/querypb"
|
||||
)
|
||||
|
||||
// Normalize changes the statement to use bind values, and
|
||||
// updates the bind vars to those values. The supplied prefix
|
||||
// is used to generate the bind var names. The function ensures
|
||||
// that there are no collisions with existing bind vars.
|
||||
// Within Select constructs, bind vars are deduped. This allows
|
||||
// us to identify vindex equality. Otherwise, every value is
|
||||
// treated as distinct.
|
||||
func Normalize(stmt Statement, bindVars map[string]*querypb.BindVariable, prefix string) {
|
||||
nz := newNormalizer(stmt, bindVars, prefix)
|
||||
_ = Walk(nz.WalkStatement, stmt)
|
||||
}
|
||||
|
||||
type normalizer struct {
|
||||
stmt Statement
|
||||
bindVars map[string]*querypb.BindVariable
|
||||
prefix string
|
||||
reserved map[string]struct{}
|
||||
counter int
|
||||
vals map[string]string
|
||||
}
|
||||
|
||||
func newNormalizer(stmt Statement, bindVars map[string]*querypb.BindVariable, prefix string) *normalizer {
|
||||
return &normalizer{
|
||||
stmt: stmt,
|
||||
bindVars: bindVars,
|
||||
prefix: prefix,
|
||||
reserved: GetBindvars(stmt),
|
||||
counter: 1,
|
||||
vals: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
// WalkStatement is the top level walk function.
|
||||
// If it encounters a Select, it switches to a mode
|
||||
// where variables are deduped.
|
||||
func (nz *normalizer) WalkStatement(node SQLNode) (bool, error) {
|
||||
switch node := node.(type) {
|
||||
case *Select:
|
||||
_ = Walk(nz.WalkSelect, node)
|
||||
// Don't continue
|
||||
return false, nil
|
||||
case *SQLVal:
|
||||
nz.convertSQLVal(node)
|
||||
case *ComparisonExpr:
|
||||
nz.convertComparison(node)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// WalkSelect normalizes the AST in Select mode.
|
||||
func (nz *normalizer) WalkSelect(node SQLNode) (bool, error) {
|
||||
switch node := node.(type) {
|
||||
case *SQLVal:
|
||||
nz.convertSQLValDedup(node)
|
||||
case *ComparisonExpr:
|
||||
nz.convertComparison(node)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (nz *normalizer) convertSQLValDedup(node *SQLVal) {
|
||||
// If value is too long, don't dedup.
|
||||
// Such values are most likely not for vindexes.
|
||||
// We save a lot of CPU because we avoid building
|
||||
// the key for them.
|
||||
if len(node.Val) > 256 {
|
||||
nz.convertSQLVal(node)
|
||||
return
|
||||
}
|
||||
|
||||
// Make the bindvar
|
||||
bval := nz.sqlToBindvar(node)
|
||||
if bval == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if there's a bindvar for that value already.
|
||||
var key string
|
||||
if bval.Type == sqltypes.VarBinary {
|
||||
// Prefixing strings with "'" ensures that a string
|
||||
// and number that have the same representation don't
|
||||
// collide.
|
||||
key = "'" + string(node.Val)
|
||||
} else {
|
||||
key = string(node.Val)
|
||||
}
|
||||
bvname, ok := nz.vals[key]
|
||||
if !ok {
|
||||
// If there's no such bindvar, make a new one.
|
||||
bvname = nz.newName()
|
||||
nz.vals[key] = bvname
|
||||
nz.bindVars[bvname] = bval
|
||||
}
|
||||
|
||||
// Modify the AST node to a bindvar.
|
||||
node.Type = ValArg
|
||||
node.Val = append([]byte(":"), bvname...)
|
||||
}
|
||||
|
||||
// convertSQLVal converts an SQLVal without the dedup.
|
||||
func (nz *normalizer) convertSQLVal(node *SQLVal) {
|
||||
bval := nz.sqlToBindvar(node)
|
||||
if bval == nil {
|
||||
return
|
||||
}
|
||||
|
||||
bvname := nz.newName()
|
||||
nz.bindVars[bvname] = bval
|
||||
|
||||
node.Type = ValArg
|
||||
node.Val = append([]byte(":"), bvname...)
|
||||
}
|
||||
|
||||
// convertComparison attempts to convert IN clauses to
|
||||
// use the list bind var construct. If it fails, it returns
|
||||
// with no change made. The walk function will then continue
|
||||
// and iterate on converting each individual value into separate
|
||||
// bind vars.
|
||||
func (nz *normalizer) convertComparison(node *ComparisonExpr) {
|
||||
if node.Operator != InStr && node.Operator != NotInStr {
|
||||
return
|
||||
}
|
||||
tupleVals, ok := node.Right.(ValTuple)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
// The RHS is a tuple of values.
|
||||
// Make a list bindvar.
|
||||
bvals := &querypb.BindVariable{
|
||||
Type: querypb.Type_TUPLE,
|
||||
}
|
||||
for _, val := range tupleVals {
|
||||
bval := nz.sqlToBindvar(val)
|
||||
if bval == nil {
|
||||
return
|
||||
}
|
||||
bvals.Values = append(bvals.Values, &querypb.Value{
|
||||
Type: bval.Type,
|
||||
Value: bval.Value,
|
||||
})
|
||||
}
|
||||
bvname := nz.newName()
|
||||
nz.bindVars[bvname] = bvals
|
||||
// Modify RHS to be a list bindvar.
|
||||
node.Right = ListArg(append([]byte("::"), bvname...))
|
||||
}
|
||||
|
||||
func (nz *normalizer) sqlToBindvar(node SQLNode) *querypb.BindVariable {
|
||||
if node, ok := node.(*SQLVal); ok {
|
||||
var v sqltypes.Value
|
||||
var err error
|
||||
switch node.Type {
|
||||
case StrVal:
|
||||
v, err = sqltypes.NewValue(sqltypes.VarBinary, node.Val)
|
||||
case IntVal:
|
||||
v, err = sqltypes.NewValue(sqltypes.Int64, node.Val)
|
||||
case FloatVal:
|
||||
v, err = sqltypes.NewValue(sqltypes.Float64, node.Val)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return sqltypes.ValueBindVariable(v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (nz *normalizer) newName() string {
|
||||
for {
|
||||
newName := fmt.Sprintf("%s%d", nz.prefix, nz.counter)
|
||||
if _, ok := nz.reserved[newName]; !ok {
|
||||
nz.reserved[newName] = struct{}{}
|
||||
return newName
|
||||
}
|
||||
nz.counter++
|
||||
}
|
||||
}
|
||||
|
||||
// GetBindvars returns a map of the bind vars referenced in the statement.
|
||||
// TODO(sougou); This function gets called again from vtgate/planbuilder.
|
||||
// Ideally, this should be done only once.
|
||||
func GetBindvars(stmt Statement) map[string]struct{} {
|
||||
bindvars := make(map[string]struct{})
|
||||
_ = Walk(func(node SQLNode) (kontinue bool, err error) {
|
||||
switch node := node.(type) {
|
||||
case *SQLVal:
|
||||
if node.Type == ValArg {
|
||||
bindvars[string(node.Val[1:])] = struct{}{}
|
||||
}
|
||||
case ListArg:
|
||||
bindvars[string(node[2:])] = struct{}{}
|
||||
}
|
||||
return true, nil
|
||||
}, stmt)
|
||||
return bindvars
|
||||
}
|
|
@ -1,119 +0,0 @@
|
|||
/*
|
||||
Copyright 2017 Google Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package sqlparser
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/xwb1989/sqlparser/dependency/querypb"
|
||||
"github.com/xwb1989/sqlparser/dependency/sqltypes"
|
||||
)
|
||||
|
||||
// ParsedQuery represents a parsed query where
|
||||
// bind locations are precompued for fast substitutions.
|
||||
type ParsedQuery struct {
|
||||
Query string
|
||||
bindLocations []bindLocation
|
||||
}
|
||||
|
||||
type bindLocation struct {
|
||||
offset, length int
|
||||
}
|
||||
|
||||
// NewParsedQuery returns a ParsedQuery of the ast.
|
||||
func NewParsedQuery(node SQLNode) *ParsedQuery {
|
||||
buf := NewTrackedBuffer(nil)
|
||||
buf.Myprintf("%v", node)
|
||||
return buf.ParsedQuery()
|
||||
}
|
||||
|
||||
// GenerateQuery generates a query by substituting the specified
|
||||
// bindVariables. The extras parameter specifies special parameters
|
||||
// that can perform custom encoding.
|
||||
func (pq *ParsedQuery) GenerateQuery(bindVariables map[string]*querypb.BindVariable, extras map[string]Encodable) ([]byte, error) {
|
||||
if len(pq.bindLocations) == 0 {
|
||||
return []byte(pq.Query), nil
|
||||
}
|
||||
buf := bytes.NewBuffer(make([]byte, 0, len(pq.Query)))
|
||||
current := 0
|
||||
for _, loc := range pq.bindLocations {
|
||||
buf.WriteString(pq.Query[current:loc.offset])
|
||||
name := pq.Query[loc.offset : loc.offset+loc.length]
|
||||
if encodable, ok := extras[name[1:]]; ok {
|
||||
encodable.EncodeSQL(buf)
|
||||
} else {
|
||||
supplied, _, err := FetchBindVar(name, bindVariables)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
EncodeValue(buf, supplied)
|
||||
}
|
||||
current = loc.offset + loc.length
|
||||
}
|
||||
buf.WriteString(pq.Query[current:])
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// EncodeValue encodes one bind variable value into the query.
|
||||
func EncodeValue(buf *bytes.Buffer, value *querypb.BindVariable) {
|
||||
if value.Type != querypb.Type_TUPLE {
|
||||
// Since we already check for TUPLE, we don't expect an error.
|
||||
v, _ := sqltypes.BindVariableToValue(value)
|
||||
v.EncodeSQL(buf)
|
||||
return
|
||||
}
|
||||
|
||||
// It's a TUPLE.
|
||||
buf.WriteByte('(')
|
||||
for i, bv := range value.Values {
|
||||
if i != 0 {
|
||||
buf.WriteString(", ")
|
||||
}
|
||||
sqltypes.ProtoToValue(bv).EncodeSQL(buf)
|
||||
}
|
||||
buf.WriteByte(')')
|
||||
}
|
||||
|
||||
// FetchBindVar resolves the bind variable by fetching it from bindVariables.
|
||||
func FetchBindVar(name string, bindVariables map[string]*querypb.BindVariable) (val *querypb.BindVariable, isList bool, err error) {
|
||||
name = name[1:]
|
||||
if name[0] == ':' {
|
||||
name = name[1:]
|
||||
isList = true
|
||||
}
|
||||
supplied, ok := bindVariables[name]
|
||||
if !ok {
|
||||
return nil, false, fmt.Errorf("missing bind var %s", name)
|
||||
}
|
||||
|
||||
if isList {
|
||||
if supplied.Type != querypb.Type_TUPLE {
|
||||
return nil, false, fmt.Errorf("unexpected list arg type (%v) for key %s", supplied.Type, name)
|
||||
}
|
||||
if len(supplied.Values) == 0 {
|
||||
return nil, false, fmt.Errorf("empty list supplied for %s", name)
|
||||
}
|
||||
return supplied, true, nil
|
||||
}
|
||||
|
||||
if supplied.Type == querypb.Type_TUPLE {
|
||||
return nil, false, fmt.Errorf("unexpected arg type (TUPLE) for non-list key %s", name)
|
||||
}
|
||||
|
||||
return supplied, false, nil
|
||||
}
|
|
@ -1,19 +0,0 @@
|
|||
package sqlparser
|
||||
|
||||
import querypb "github.com/xwb1989/sqlparser/dependency/querypb"
|
||||
|
||||
// RedactSQLQuery returns a sql string with the params stripped out for display
|
||||
func RedactSQLQuery(sql string) (string, error) {
|
||||
bv := map[string]*querypb.BindVariable{}
|
||||
sqlStripped, comments := SplitMarginComments(sql)
|
||||
|
||||
stmt, err := Parse(sqlStripped)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
prefix := "redacted"
|
||||
Normalize(stmt, bv, prefix)
|
||||
|
||||
return comments.Leading + String(stmt) + comments.Trailing, nil
|
||||
}
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -1,950 +0,0 @@
|
|||
/*
|
||||
Copyright 2017 Google Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package sqlparser
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/xwb1989/sqlparser/dependency/bytes2"
|
||||
"github.com/xwb1989/sqlparser/dependency/sqltypes"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultBufSize = 4096
|
||||
eofChar = 0x100
|
||||
)
|
||||
|
||||
// Tokenizer is the struct used to generate SQL
|
||||
// tokens for the parser.
|
||||
type Tokenizer struct {
|
||||
InStream io.Reader
|
||||
AllowComments bool
|
||||
ForceEOF bool
|
||||
lastChar uint16
|
||||
Position int
|
||||
lastToken []byte
|
||||
LastError error
|
||||
posVarIndex int
|
||||
ParseTree Statement
|
||||
partialDDL *DDL
|
||||
nesting int
|
||||
multi bool
|
||||
specialComment *Tokenizer
|
||||
|
||||
buf []byte
|
||||
bufPos int
|
||||
bufSize int
|
||||
}
|
||||
|
||||
// NewStringTokenizer creates a new Tokenizer for the
|
||||
// sql string.
|
||||
func NewStringTokenizer(sql string) *Tokenizer {
|
||||
buf := []byte(sql)
|
||||
return &Tokenizer{
|
||||
buf: buf,
|
||||
bufSize: len(buf),
|
||||
}
|
||||
}
|
||||
|
||||
// NewTokenizer creates a new Tokenizer reading a sql
|
||||
// string from the io.Reader.
|
||||
func NewTokenizer(r io.Reader) *Tokenizer {
|
||||
return &Tokenizer{
|
||||
InStream: r,
|
||||
buf: make([]byte, defaultBufSize),
|
||||
}
|
||||
}
|
||||
|
||||
// keywords is a map of mysql keywords that fall into two categories:
|
||||
// 1) keywords considered reserved by MySQL
|
||||
// 2) keywords for us to handle specially in sql.y
|
||||
//
|
||||
// Those marked as UNUSED are likely reserved keywords. We add them here so that
|
||||
// when rewriting queries we can properly backtick quote them so they don't cause issues
|
||||
//
|
||||
// NOTE: If you add new keywords, add them also to the reserved_keywords or
|
||||
// non_reserved_keywords grammar in sql.y -- this will allow the keyword to be used
|
||||
// in identifiers. See the docs for each grammar to determine which one to put it into.
|
||||
var keywords = map[string]int{
|
||||
"accessible": UNUSED,
|
||||
"add": ADD,
|
||||
"against": AGAINST,
|
||||
"all": ALL,
|
||||
"alter": ALTER,
|
||||
"analyze": ANALYZE,
|
||||
"and": AND,
|
||||
"as": AS,
|
||||
"asc": ASC,
|
||||
"asensitive": UNUSED,
|
||||
"auto_increment": AUTO_INCREMENT,
|
||||
"before": UNUSED,
|
||||
"begin": BEGIN,
|
||||
"between": BETWEEN,
|
||||
"bigint": BIGINT,
|
||||
"binary": BINARY,
|
||||
"_binary": UNDERSCORE_BINARY,
|
||||
"bit": BIT,
|
||||
"blob": BLOB,
|
||||
"bool": BOOL,
|
||||
"boolean": BOOLEAN,
|
||||
"both": UNUSED,
|
||||
"by": BY,
|
||||
"call": UNUSED,
|
||||
"cascade": UNUSED,
|
||||
"case": CASE,
|
||||
"cast": CAST,
|
||||
"change": UNUSED,
|
||||
"char": CHAR,
|
||||
"character": CHARACTER,
|
||||
"charset": CHARSET,
|
||||
"check": UNUSED,
|
||||
"collate": COLLATE,
|
||||
"column": COLUMN,
|
||||
"comment": COMMENT_KEYWORD,
|
||||
"committed": COMMITTED,
|
||||
"commit": COMMIT,
|
||||
"condition": UNUSED,
|
||||
"constraint": CONSTRAINT,
|
||||
"continue": UNUSED,
|
||||
"convert": CONVERT,
|
||||
"substr": SUBSTR,
|
||||
"substring": SUBSTRING,
|
||||
"create": CREATE,
|
||||
"cross": CROSS,
|
||||
"current_date": CURRENT_DATE,
|
||||
"current_time": CURRENT_TIME,
|
||||
"current_timestamp": CURRENT_TIMESTAMP,
|
||||
"current_user": UNUSED,
|
||||
"cursor": UNUSED,
|
||||
"database": DATABASE,
|
||||
"databases": DATABASES,
|
||||
"day_hour": UNUSED,
|
||||
"day_microsecond": UNUSED,
|
||||
"day_minute": UNUSED,
|
||||
"day_second": UNUSED,
|
||||
"date": DATE,
|
||||
"datetime": DATETIME,
|
||||
"dec": UNUSED,
|
||||
"decimal": DECIMAL,
|
||||
"declare": UNUSED,
|
||||
"default": DEFAULT,
|
||||
"delayed": UNUSED,
|
||||
"delete": DELETE,
|
||||
"desc": DESC,
|
||||
"describe": DESCRIBE,
|
||||
"deterministic": UNUSED,
|
||||
"distinct": DISTINCT,
|
||||
"distinctrow": UNUSED,
|
||||
"div": DIV,
|
||||
"double": DOUBLE,
|
||||
"drop": DROP,
|
||||
"duplicate": DUPLICATE,
|
||||
"each": UNUSED,
|
||||
"else": ELSE,
|
||||
"elseif": UNUSED,
|
||||
"enclosed": UNUSED,
|
||||
"end": END,
|
||||
"enum": ENUM,
|
||||
"escape": ESCAPE,
|
||||
"escaped": UNUSED,
|
||||
"exists": EXISTS,
|
||||
"exit": UNUSED,
|
||||
"explain": EXPLAIN,
|
||||
"expansion": EXPANSION,
|
||||
"extended": EXTENDED,
|
||||
"false": FALSE,
|
||||
"fetch": UNUSED,
|
||||
"float": FLOAT_TYPE,
|
||||
"float4": UNUSED,
|
||||
"float8": UNUSED,
|
||||
"for": FOR,
|
||||
"force": FORCE,
|
||||
"foreign": FOREIGN,
|
||||
"from": FROM,
|
||||
"full": FULL,
|
||||
"fulltext": FULLTEXT,
|
||||
"generated": UNUSED,
|
||||
"geometry": GEOMETRY,
|
||||
"geometrycollection": GEOMETRYCOLLECTION,
|
||||
"get": UNUSED,
|
||||
"global": GLOBAL,
|
||||
"grant": UNUSED,
|
||||
"group": GROUP,
|
||||
"group_concat": GROUP_CONCAT,
|
||||
"having": HAVING,
|
||||
"high_priority": UNUSED,
|
||||
"hour_microsecond": UNUSED,
|
||||
"hour_minute": UNUSED,
|
||||
"hour_second": UNUSED,
|
||||
"if": IF,
|
||||
"ignore": IGNORE,
|
||||
"in": IN,
|
||||
"index": INDEX,
|
||||
"infile": UNUSED,
|
||||
"inout": UNUSED,
|
||||
"inner": INNER,
|
||||
"insensitive": UNUSED,
|
||||
"insert": INSERT,
|
||||
"int": INT,
|
||||
"int1": UNUSED,
|
||||
"int2": UNUSED,
|
||||
"int3": UNUSED,
|
||||
"int4": UNUSED,
|
||||
"int8": UNUSED,
|
||||
"integer": INTEGER,
|
||||
"interval": INTERVAL,
|
||||
"into": INTO,
|
||||
"io_after_gtids": UNUSED,
|
||||
"is": IS,
|
||||
"isolation": ISOLATION,
|
||||
"iterate": UNUSED,
|
||||
"join": JOIN,
|
||||
"json": JSON,
|
||||
"key": KEY,
|
||||
"keys": KEYS,
|
||||
"key_block_size": KEY_BLOCK_SIZE,
|
||||
"kill": UNUSED,
|
||||
"language": LANGUAGE,
|
||||
"last_insert_id": LAST_INSERT_ID,
|
||||
"leading": UNUSED,
|
||||
"leave": UNUSED,
|
||||
"left": LEFT,
|
||||
"less": LESS,
|
||||
"level": LEVEL,
|
||||
"like": LIKE,
|
||||
"limit": LIMIT,
|
||||
"linear": UNUSED,
|
||||
"lines": UNUSED,
|
||||
"linestring": LINESTRING,
|
||||
"load": UNUSED,
|
||||
"localtime": LOCALTIME,
|
||||
"localtimestamp": LOCALTIMESTAMP,
|
||||
"lock": LOCK,
|
||||
"long": UNUSED,
|
||||
"longblob": LONGBLOB,
|
||||
"longtext": LONGTEXT,
|
||||
"loop": UNUSED,
|
||||
"low_priority": UNUSED,
|
||||
"master_bind": UNUSED,
|
||||
"match": MATCH,
|
||||
"maxvalue": MAXVALUE,
|
||||
"mediumblob": MEDIUMBLOB,
|
||||
"mediumint": MEDIUMINT,
|
||||
"mediumtext": MEDIUMTEXT,
|
||||
"middleint": UNUSED,
|
||||
"minute_microsecond": UNUSED,
|
||||
"minute_second": UNUSED,
|
||||
"mod": MOD,
|
||||
"mode": MODE,
|
||||
"modifies": UNUSED,
|
||||
"multilinestring": MULTILINESTRING,
|
||||
"multipoint": MULTIPOINT,
|
||||
"multipolygon": MULTIPOLYGON,
|
||||
"names": NAMES,
|
||||
"natural": NATURAL,
|
||||
"nchar": NCHAR,
|
||||
"next": NEXT,
|
||||
"not": NOT,
|
||||
"no_write_to_binlog": UNUSED,
|
||||
"null": NULL,
|
||||
"numeric": NUMERIC,
|
||||
"offset": OFFSET,
|
||||
"on": ON,
|
||||
"only": ONLY,
|
||||
"optimize": OPTIMIZE,
|
||||
"optimizer_costs": UNUSED,
|
||||
"option": UNUSED,
|
||||
"optionally": UNUSED,
|
||||
"or": OR,
|
||||
"order": ORDER,
|
||||
"out": UNUSED,
|
||||
"outer": OUTER,
|
||||
"outfile": UNUSED,
|
||||
"partition": PARTITION,
|
||||
"point": POINT,
|
||||
"polygon": POLYGON,
|
||||
"precision": UNUSED,
|
||||
"primary": PRIMARY,
|
||||
"processlist": PROCESSLIST,
|
||||
"procedure": PROCEDURE,
|
||||
"query": QUERY,
|
||||
"range": UNUSED,
|
||||
"read": READ,
|
||||
"reads": UNUSED,
|
||||
"read_write": UNUSED,
|
||||
"real": REAL,
|
||||
"references": UNUSED,
|
||||
"regexp": REGEXP,
|
||||
"release": UNUSED,
|
||||
"rename": RENAME,
|
||||
"reorganize": REORGANIZE,
|
||||
"repair": REPAIR,
|
||||
"repeat": UNUSED,
|
||||
"repeatable": REPEATABLE,
|
||||
"replace": REPLACE,
|
||||
"require": UNUSED,
|
||||
"resignal": UNUSED,
|
||||
"restrict": UNUSED,
|
||||
"return": UNUSED,
|
||||
"revoke": UNUSED,
|
||||
"right": RIGHT,
|
||||
"rlike": REGEXP,
|
||||
"rollback": ROLLBACK,
|
||||
"schema": SCHEMA,
|
||||
"schemas": UNUSED,
|
||||
"second_microsecond": UNUSED,
|
||||
"select": SELECT,
|
||||
"sensitive": UNUSED,
|
||||
"separator": SEPARATOR,
|
||||
"serializable": SERIALIZABLE,
|
||||
"session": SESSION,
|
||||
"set": SET,
|
||||
"share": SHARE,
|
||||
"show": SHOW,
|
||||
"signal": UNUSED,
|
||||
"signed": SIGNED,
|
||||
"smallint": SMALLINT,
|
||||
"spatial": SPATIAL,
|
||||
"specific": UNUSED,
|
||||
"sql": UNUSED,
|
||||
"sqlexception": UNUSED,
|
||||
"sqlstate": UNUSED,
|
||||
"sqlwarning": UNUSED,
|
||||
"sql_big_result": UNUSED,
|
||||
"sql_cache": SQL_CACHE,
|
||||
"sql_calc_found_rows": UNUSED,
|
||||
"sql_no_cache": SQL_NO_CACHE,
|
||||
"sql_small_result": UNUSED,
|
||||
"ssl": UNUSED,
|
||||
"start": START,
|
||||
"starting": UNUSED,
|
||||
"status": STATUS,
|
||||
"stored": UNUSED,
|
||||
"straight_join": STRAIGHT_JOIN,
|
||||
"stream": STREAM,
|
||||
"table": TABLE,
|
||||
"tables": TABLES,
|
||||
"terminated": UNUSED,
|
||||
"text": TEXT,
|
||||
"than": THAN,
|
||||
"then": THEN,
|
||||
"time": TIME,
|
||||
"timestamp": TIMESTAMP,
|
||||
"tinyblob": TINYBLOB,
|
||||
"tinyint": TINYINT,
|
||||
"tinytext": TINYTEXT,
|
||||
"to": TO,
|
||||
"trailing": UNUSED,
|
||||
"transaction": TRANSACTION,
|
||||
"trigger": TRIGGER,
|
||||
"true": TRUE,
|
||||
"truncate": TRUNCATE,
|
||||
"uncommitted": UNCOMMITTED,
|
||||
"undo": UNUSED,
|
||||
"union": UNION,
|
||||
"unique": UNIQUE,
|
||||
"unlock": UNUSED,
|
||||
"unsigned": UNSIGNED,
|
||||
"update": UPDATE,
|
||||
"usage": UNUSED,
|
||||
"use": USE,
|
||||
"using": USING,
|
||||
"utc_date": UTC_DATE,
|
||||
"utc_time": UTC_TIME,
|
||||
"utc_timestamp": UTC_TIMESTAMP,
|
||||
"values": VALUES,
|
||||
"variables": VARIABLES,
|
||||
"varbinary": VARBINARY,
|
||||
"varchar": VARCHAR,
|
||||
"varcharacter": UNUSED,
|
||||
"varying": UNUSED,
|
||||
"virtual": UNUSED,
|
||||
"vindex": VINDEX,
|
||||
"vindexes": VINDEXES,
|
||||
"view": VIEW,
|
||||
"vitess_keyspaces": VITESS_KEYSPACES,
|
||||
"vitess_shards": VITESS_SHARDS,
|
||||
"vitess_tablets": VITESS_TABLETS,
|
||||
"vschema_tables": VSCHEMA_TABLES,
|
||||
"when": WHEN,
|
||||
"where": WHERE,
|
||||
"while": UNUSED,
|
||||
"with": WITH,
|
||||
"write": WRITE,
|
||||
"xor": UNUSED,
|
||||
"year": YEAR,
|
||||
"year_month": UNUSED,
|
||||
"zerofill": ZEROFILL,
|
||||
}
|
||||
|
||||
// keywordStrings contains the reverse mapping of token to keyword strings
|
||||
var keywordStrings = map[int]string{}
|
||||
|
||||
func init() {
|
||||
for str, id := range keywords {
|
||||
if id == UNUSED {
|
||||
continue
|
||||
}
|
||||
keywordStrings[id] = str
|
||||
}
|
||||
}
|
||||
|
||||
// KeywordString returns the string corresponding to the given keyword
|
||||
func KeywordString(id int) string {
|
||||
str, ok := keywordStrings[id]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
// Lex returns the next token form the Tokenizer.
|
||||
// This function is used by go yacc.
|
||||
func (tkn *Tokenizer) Lex(lval *yySymType) int {
|
||||
typ, val := tkn.Scan()
|
||||
for typ == COMMENT {
|
||||
if tkn.AllowComments {
|
||||
break
|
||||
}
|
||||
typ, val = tkn.Scan()
|
||||
}
|
||||
lval.bytes = val
|
||||
tkn.lastToken = val
|
||||
return typ
|
||||
}
|
||||
|
||||
// Error is called by go yacc if there's a parsing error.
|
||||
func (tkn *Tokenizer) Error(err string) {
|
||||
buf := &bytes2.Buffer{}
|
||||
if tkn.lastToken != nil {
|
||||
fmt.Fprintf(buf, "%s at position %v near '%s'", err, tkn.Position, tkn.lastToken)
|
||||
} else {
|
||||
fmt.Fprintf(buf, "%s at position %v", err, tkn.Position)
|
||||
}
|
||||
tkn.LastError = errors.New(buf.String())
|
||||
|
||||
// Try and re-sync to the next statement
|
||||
if tkn.lastChar != ';' {
|
||||
tkn.skipStatement()
|
||||
}
|
||||
}
|
||||
|
||||
// Scan scans the tokenizer for the next token and returns
|
||||
// the token type and an optional value.
|
||||
func (tkn *Tokenizer) Scan() (int, []byte) {
|
||||
if tkn.specialComment != nil {
|
||||
// Enter specialComment scan mode.
|
||||
// for scanning such kind of comment: /*! MySQL-specific code */
|
||||
specialComment := tkn.specialComment
|
||||
tok, val := specialComment.Scan()
|
||||
if tok != 0 {
|
||||
// return the specialComment scan result as the result
|
||||
return tok, val
|
||||
}
|
||||
// leave specialComment scan mode after all stream consumed.
|
||||
tkn.specialComment = nil
|
||||
}
|
||||
if tkn.lastChar == 0 {
|
||||
tkn.next()
|
||||
}
|
||||
|
||||
if tkn.ForceEOF {
|
||||
tkn.skipStatement()
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
tkn.skipBlank()
|
||||
switch ch := tkn.lastChar; {
|
||||
case isLetter(ch):
|
||||
tkn.next()
|
||||
if ch == 'X' || ch == 'x' {
|
||||
if tkn.lastChar == '\'' {
|
||||
tkn.next()
|
||||
return tkn.scanHex()
|
||||
}
|
||||
}
|
||||
if ch == 'B' || ch == 'b' {
|
||||
if tkn.lastChar == '\'' {
|
||||
tkn.next()
|
||||
return tkn.scanBitLiteral()
|
||||
}
|
||||
}
|
||||
isDbSystemVariable := false
|
||||
if ch == '@' && tkn.lastChar == '@' {
|
||||
isDbSystemVariable = true
|
||||
}
|
||||
return tkn.scanIdentifier(byte(ch), isDbSystemVariable)
|
||||
case isDigit(ch):
|
||||
return tkn.scanNumber(false)
|
||||
case ch == ':':
|
||||
return tkn.scanBindVar()
|
||||
case ch == ';' && tkn.multi:
|
||||
return 0, nil
|
||||
default:
|
||||
tkn.next()
|
||||
switch ch {
|
||||
case eofChar:
|
||||
return 0, nil
|
||||
case '=', ',', ';', '(', ')', '+', '*', '%', '^', '~':
|
||||
return int(ch), nil
|
||||
case '&':
|
||||
if tkn.lastChar == '&' {
|
||||
tkn.next()
|
||||
return AND, nil
|
||||
}
|
||||
return int(ch), nil
|
||||
case '|':
|
||||
if tkn.lastChar == '|' {
|
||||
tkn.next()
|
||||
return OR, nil
|
||||
}
|
||||
return int(ch), nil
|
||||
case '?':
|
||||
tkn.posVarIndex++
|
||||
buf := new(bytes2.Buffer)
|
||||
fmt.Fprintf(buf, ":v%d", tkn.posVarIndex)
|
||||
return VALUE_ARG, buf.Bytes()
|
||||
case '.':
|
||||
if isDigit(tkn.lastChar) {
|
||||
return tkn.scanNumber(true)
|
||||
}
|
||||
return int(ch), nil
|
||||
case '/':
|
||||
switch tkn.lastChar {
|
||||
case '/':
|
||||
tkn.next()
|
||||
return tkn.scanCommentType1("//")
|
||||
case '*':
|
||||
tkn.next()
|
||||
switch tkn.lastChar {
|
||||
case '!':
|
||||
return tkn.scanMySQLSpecificComment()
|
||||
default:
|
||||
return tkn.scanCommentType2()
|
||||
}
|
||||
default:
|
||||
return int(ch), nil
|
||||
}
|
||||
case '#':
|
||||
return tkn.scanCommentType1("#")
|
||||
case '-':
|
||||
switch tkn.lastChar {
|
||||
case '-':
|
||||
tkn.next()
|
||||
return tkn.scanCommentType1("--")
|
||||
case '>':
|
||||
tkn.next()
|
||||
if tkn.lastChar == '>' {
|
||||
tkn.next()
|
||||
return JSON_UNQUOTE_EXTRACT_OP, nil
|
||||
}
|
||||
return JSON_EXTRACT_OP, nil
|
||||
}
|
||||
return int(ch), nil
|
||||
case '<':
|
||||
switch tkn.lastChar {
|
||||
case '>':
|
||||
tkn.next()
|
||||
return NE, nil
|
||||
case '<':
|
||||
tkn.next()
|
||||
return SHIFT_LEFT, nil
|
||||
case '=':
|
||||
tkn.next()
|
||||
switch tkn.lastChar {
|
||||
case '>':
|
||||
tkn.next()
|
||||
return NULL_SAFE_EQUAL, nil
|
||||
default:
|
||||
return LE, nil
|
||||
}
|
||||
default:
|
||||
return int(ch), nil
|
||||
}
|
||||
case '>':
|
||||
switch tkn.lastChar {
|
||||
case '=':
|
||||
tkn.next()
|
||||
return GE, nil
|
||||
case '>':
|
||||
tkn.next()
|
||||
return SHIFT_RIGHT, nil
|
||||
default:
|
||||
return int(ch), nil
|
||||
}
|
||||
case '!':
|
||||
if tkn.lastChar == '=' {
|
||||
tkn.next()
|
||||
return NE, nil
|
||||
}
|
||||
return int(ch), nil
|
||||
case '\'', '"':
|
||||
return tkn.scanString(ch, STRING)
|
||||
case '`':
|
||||
return tkn.scanLiteralIdentifier()
|
||||
default:
|
||||
return LEX_ERROR, []byte{byte(ch)}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// skipStatement scans until the EOF, or end of statement is encountered.
|
||||
func (tkn *Tokenizer) skipStatement() {
|
||||
ch := tkn.lastChar
|
||||
for ch != ';' && ch != eofChar {
|
||||
tkn.next()
|
||||
ch = tkn.lastChar
|
||||
}
|
||||
}
|
||||
|
||||
func (tkn *Tokenizer) skipBlank() {
|
||||
ch := tkn.lastChar
|
||||
for ch == ' ' || ch == '\n' || ch == '\r' || ch == '\t' {
|
||||
tkn.next()
|
||||
ch = tkn.lastChar
|
||||
}
|
||||
}
|
||||
|
||||
func (tkn *Tokenizer) scanIdentifier(firstByte byte, isDbSystemVariable bool) (int, []byte) {
|
||||
buffer := &bytes2.Buffer{}
|
||||
buffer.WriteByte(firstByte)
|
||||
for isLetter(tkn.lastChar) || isDigit(tkn.lastChar) || (isDbSystemVariable && isCarat(tkn.lastChar)) {
|
||||
buffer.WriteByte(byte(tkn.lastChar))
|
||||
tkn.next()
|
||||
}
|
||||
lowered := bytes.ToLower(buffer.Bytes())
|
||||
loweredStr := string(lowered)
|
||||
if keywordID, found := keywords[loweredStr]; found {
|
||||
return keywordID, lowered
|
||||
}
|
||||
// dual must always be case-insensitive
|
||||
if loweredStr == "dual" {
|
||||
return ID, lowered
|
||||
}
|
||||
return ID, buffer.Bytes()
|
||||
}
|
||||
|
||||
func (tkn *Tokenizer) scanHex() (int, []byte) {
|
||||
buffer := &bytes2.Buffer{}
|
||||
tkn.scanMantissa(16, buffer)
|
||||
if tkn.lastChar != '\'' {
|
||||
return LEX_ERROR, buffer.Bytes()
|
||||
}
|
||||
tkn.next()
|
||||
if buffer.Len()%2 != 0 {
|
||||
return LEX_ERROR, buffer.Bytes()
|
||||
}
|
||||
return HEX, buffer.Bytes()
|
||||
}
|
||||
|
||||
func (tkn *Tokenizer) scanBitLiteral() (int, []byte) {
|
||||
buffer := &bytes2.Buffer{}
|
||||
tkn.scanMantissa(2, buffer)
|
||||
if tkn.lastChar != '\'' {
|
||||
return LEX_ERROR, buffer.Bytes()
|
||||
}
|
||||
tkn.next()
|
||||
return BIT_LITERAL, buffer.Bytes()
|
||||
}
|
||||
|
||||
func (tkn *Tokenizer) scanLiteralIdentifier() (int, []byte) {
|
||||
buffer := &bytes2.Buffer{}
|
||||
backTickSeen := false
|
||||
for {
|
||||
if backTickSeen {
|
||||
if tkn.lastChar != '`' {
|
||||
break
|
||||
}
|
||||
backTickSeen = false
|
||||
buffer.WriteByte('`')
|
||||
tkn.next()
|
||||
continue
|
||||
}
|
||||
// The previous char was not a backtick.
|
||||
switch tkn.lastChar {
|
||||
case '`':
|
||||
backTickSeen = true
|
||||
case eofChar:
|
||||
// Premature EOF.
|
||||
return LEX_ERROR, buffer.Bytes()
|
||||
default:
|
||||
buffer.WriteByte(byte(tkn.lastChar))
|
||||
}
|
||||
tkn.next()
|
||||
}
|
||||
if buffer.Len() == 0 {
|
||||
return LEX_ERROR, buffer.Bytes()
|
||||
}
|
||||
return ID, buffer.Bytes()
|
||||
}
|
||||
|
||||
func (tkn *Tokenizer) scanBindVar() (int, []byte) {
|
||||
buffer := &bytes2.Buffer{}
|
||||
buffer.WriteByte(byte(tkn.lastChar))
|
||||
token := VALUE_ARG
|
||||
tkn.next()
|
||||
if tkn.lastChar == ':' {
|
||||
token = LIST_ARG
|
||||
buffer.WriteByte(byte(tkn.lastChar))
|
||||
tkn.next()
|
||||
}
|
||||
if !isLetter(tkn.lastChar) {
|
||||
return LEX_ERROR, buffer.Bytes()
|
||||
}
|
||||
for isLetter(tkn.lastChar) || isDigit(tkn.lastChar) || tkn.lastChar == '.' {
|
||||
buffer.WriteByte(byte(tkn.lastChar))
|
||||
tkn.next()
|
||||
}
|
||||
return token, buffer.Bytes()
|
||||
}
|
||||
|
||||
func (tkn *Tokenizer) scanMantissa(base int, buffer *bytes2.Buffer) {
|
||||
for digitVal(tkn.lastChar) < base {
|
||||
tkn.consumeNext(buffer)
|
||||
}
|
||||
}
|
||||
|
||||
func (tkn *Tokenizer) scanNumber(seenDecimalPoint bool) (int, []byte) {
|
||||
token := INTEGRAL
|
||||
buffer := &bytes2.Buffer{}
|
||||
if seenDecimalPoint {
|
||||
token = FLOAT
|
||||
buffer.WriteByte('.')
|
||||
tkn.scanMantissa(10, buffer)
|
||||
goto exponent
|
||||
}
|
||||
|
||||
// 0x construct.
|
||||
if tkn.lastChar == '0' {
|
||||
tkn.consumeNext(buffer)
|
||||
if tkn.lastChar == 'x' || tkn.lastChar == 'X' {
|
||||
token = HEXNUM
|
||||
tkn.consumeNext(buffer)
|
||||
tkn.scanMantissa(16, buffer)
|
||||
goto exit
|
||||
}
|
||||
}
|
||||
|
||||
tkn.scanMantissa(10, buffer)
|
||||
|
||||
if tkn.lastChar == '.' {
|
||||
token = FLOAT
|
||||
tkn.consumeNext(buffer)
|
||||
tkn.scanMantissa(10, buffer)
|
||||
}
|
||||
|
||||
exponent:
|
||||
if tkn.lastChar == 'e' || tkn.lastChar == 'E' {
|
||||
token = FLOAT
|
||||
tkn.consumeNext(buffer)
|
||||
if tkn.lastChar == '+' || tkn.lastChar == '-' {
|
||||
tkn.consumeNext(buffer)
|
||||
}
|
||||
tkn.scanMantissa(10, buffer)
|
||||
}
|
||||
|
||||
exit:
|
||||
// A letter cannot immediately follow a number.
|
||||
if isLetter(tkn.lastChar) {
|
||||
return LEX_ERROR, buffer.Bytes()
|
||||
}
|
||||
|
||||
return token, buffer.Bytes()
|
||||
}
|
||||
|
||||
func (tkn *Tokenizer) scanString(delim uint16, typ int) (int, []byte) {
|
||||
var buffer bytes2.Buffer
|
||||
for {
|
||||
ch := tkn.lastChar
|
||||
if ch == eofChar {
|
||||
// Unterminated string.
|
||||
return LEX_ERROR, buffer.Bytes()
|
||||
}
|
||||
|
||||
if ch != delim && ch != '\\' {
|
||||
buffer.WriteByte(byte(ch))
|
||||
|
||||
// Scan ahead to the next interesting character.
|
||||
start := tkn.bufPos
|
||||
for ; tkn.bufPos < tkn.bufSize; tkn.bufPos++ {
|
||||
ch = uint16(tkn.buf[tkn.bufPos])
|
||||
if ch == delim || ch == '\\' {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
buffer.Write(tkn.buf[start:tkn.bufPos])
|
||||
tkn.Position += (tkn.bufPos - start)
|
||||
|
||||
if tkn.bufPos >= tkn.bufSize {
|
||||
// Reached the end of the buffer without finding a delim or
|
||||
// escape character.
|
||||
tkn.next()
|
||||
continue
|
||||
}
|
||||
|
||||
tkn.bufPos++
|
||||
tkn.Position++
|
||||
}
|
||||
tkn.next() // Read one past the delim or escape character.
|
||||
|
||||
if ch == '\\' {
|
||||
if tkn.lastChar == eofChar {
|
||||
// String terminates mid escape character.
|
||||
return LEX_ERROR, buffer.Bytes()
|
||||
}
|
||||
if decodedChar := sqltypes.SQLDecodeMap[byte(tkn.lastChar)]; decodedChar == sqltypes.DontEscape {
|
||||
ch = tkn.lastChar
|
||||
} else {
|
||||
ch = uint16(decodedChar)
|
||||
}
|
||||
|
||||
} else if ch == delim && tkn.lastChar != delim {
|
||||
// Correctly terminated string, which is not a double delim.
|
||||
break
|
||||
}
|
||||
|
||||
buffer.WriteByte(byte(ch))
|
||||
tkn.next()
|
||||
}
|
||||
|
||||
return typ, buffer.Bytes()
|
||||
}
|
||||
|
||||
func (tkn *Tokenizer) scanCommentType1(prefix string) (int, []byte) {
|
||||
buffer := &bytes2.Buffer{}
|
||||
buffer.WriteString(prefix)
|
||||
for tkn.lastChar != eofChar {
|
||||
if tkn.lastChar == '\n' {
|
||||
tkn.consumeNext(buffer)
|
||||
break
|
||||
}
|
||||
tkn.consumeNext(buffer)
|
||||
}
|
||||
return COMMENT, buffer.Bytes()
|
||||
}
|
||||
|
||||
func (tkn *Tokenizer) scanCommentType2() (int, []byte) {
|
||||
buffer := &bytes2.Buffer{}
|
||||
buffer.WriteString("/*")
|
||||
for {
|
||||
if tkn.lastChar == '*' {
|
||||
tkn.consumeNext(buffer)
|
||||
if tkn.lastChar == '/' {
|
||||
tkn.consumeNext(buffer)
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
if tkn.lastChar == eofChar {
|
||||
return LEX_ERROR, buffer.Bytes()
|
||||
}
|
||||
tkn.consumeNext(buffer)
|
||||
}
|
||||
return COMMENT, buffer.Bytes()
|
||||
}
|
||||
|
||||
func (tkn *Tokenizer) scanMySQLSpecificComment() (int, []byte) {
|
||||
buffer := &bytes2.Buffer{}
|
||||
buffer.WriteString("/*!")
|
||||
tkn.next()
|
||||
for {
|
||||
if tkn.lastChar == '*' {
|
||||
tkn.consumeNext(buffer)
|
||||
if tkn.lastChar == '/' {
|
||||
tkn.consumeNext(buffer)
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
if tkn.lastChar == eofChar {
|
||||
return LEX_ERROR, buffer.Bytes()
|
||||
}
|
||||
tkn.consumeNext(buffer)
|
||||
}
|
||||
_, sql := ExtractMysqlComment(buffer.String())
|
||||
tkn.specialComment = NewStringTokenizer(sql)
|
||||
return tkn.Scan()
|
||||
}
|
||||
|
||||
func (tkn *Tokenizer) consumeNext(buffer *bytes2.Buffer) {
|
||||
if tkn.lastChar == eofChar {
|
||||
// This should never happen.
|
||||
panic("unexpected EOF")
|
||||
}
|
||||
buffer.WriteByte(byte(tkn.lastChar))
|
||||
tkn.next()
|
||||
}
|
||||
|
||||
func (tkn *Tokenizer) next() {
|
||||
if tkn.bufPos >= tkn.bufSize && tkn.InStream != nil {
|
||||
// Try and refill the buffer
|
||||
var err error
|
||||
tkn.bufPos = 0
|
||||
if tkn.bufSize, err = tkn.InStream.Read(tkn.buf); err != io.EOF && err != nil {
|
||||
tkn.LastError = err
|
||||
}
|
||||
}
|
||||
|
||||
if tkn.bufPos >= tkn.bufSize {
|
||||
if tkn.lastChar != eofChar {
|
||||
tkn.Position++
|
||||
tkn.lastChar = eofChar
|
||||
}
|
||||
} else {
|
||||
tkn.Position++
|
||||
tkn.lastChar = uint16(tkn.buf[tkn.bufPos])
|
||||
tkn.bufPos++
|
||||
}
|
||||
}
|
||||
|
||||
// reset clears any internal state.
|
||||
func (tkn *Tokenizer) reset() {
|
||||
tkn.ParseTree = nil
|
||||
tkn.partialDDL = nil
|
||||
tkn.specialComment = nil
|
||||
tkn.posVarIndex = 0
|
||||
tkn.nesting = 0
|
||||
tkn.ForceEOF = false
|
||||
}
|
||||
|
||||
func isLetter(ch uint16) bool {
|
||||
return 'a' <= ch && ch <= 'z' || 'A' <= ch && ch <= 'Z' || ch == '_' || ch == '@'
|
||||
}
|
||||
|
||||
func isCarat(ch uint16) bool {
|
||||
return ch == '.' || ch == '\'' || ch == '"' || ch == '`'
|
||||
}
|
||||
|
||||
func digitVal(ch uint16) int {
|
||||
switch {
|
||||
case '0' <= ch && ch <= '9':
|
||||
return int(ch) - '0'
|
||||
case 'a' <= ch && ch <= 'f':
|
||||
return int(ch) - 'a' + 10
|
||||
case 'A' <= ch && ch <= 'F':
|
||||
return int(ch) - 'A' + 10
|
||||
}
|
||||
return 16 // larger than any legal digit val
|
||||
}
|
||||
|
||||
func isDigit(ch uint16) bool {
|
||||
return '0' <= ch && ch <= '9'
|
||||
}
|
|
@ -1,140 +0,0 @@
|
|||
/*
|
||||
Copyright 2017 Google Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package sqlparser
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// NodeFormatter defines the signature of a custom node formatter
|
||||
// function that can be given to TrackedBuffer for code generation.
|
||||
type NodeFormatter func(buf *TrackedBuffer, node SQLNode)
|
||||
|
||||
// TrackedBuffer is used to rebuild a query from the ast.
|
||||
// bindLocations keeps track of locations in the buffer that
|
||||
// use bind variables for efficient future substitutions.
|
||||
// nodeFormatter is the formatting function the buffer will
|
||||
// use to format a node. By default(nil), it's FormatNode.
|
||||
// But you can supply a different formatting function if you
|
||||
// want to generate a query that's different from the default.
|
||||
type TrackedBuffer struct {
|
||||
*bytes.Buffer
|
||||
bindLocations []bindLocation
|
||||
nodeFormatter NodeFormatter
|
||||
}
|
||||
|
||||
// NewTrackedBuffer creates a new TrackedBuffer.
|
||||
func NewTrackedBuffer(nodeFormatter NodeFormatter) *TrackedBuffer {
|
||||
return &TrackedBuffer{
|
||||
Buffer: new(bytes.Buffer),
|
||||
nodeFormatter: nodeFormatter,
|
||||
}
|
||||
}
|
||||
|
||||
// WriteNode function, initiates the writing of a single SQLNode tree by passing
|
||||
// through to Myprintf with a default format string
|
||||
func (buf *TrackedBuffer) WriteNode(node SQLNode) *TrackedBuffer {
|
||||
buf.Myprintf("%v", node)
|
||||
return buf
|
||||
}
|
||||
|
||||
// Myprintf mimics fmt.Fprintf(buf, ...), but limited to Node(%v),
|
||||
// Node.Value(%s) and string(%s). It also allows a %a for a value argument, in
|
||||
// which case it adds tracking info for future substitutions.
|
||||
//
|
||||
// The name must be something other than the usual Printf() to avoid "go vet"
|
||||
// warnings due to our custom format specifiers.
|
||||
func (buf *TrackedBuffer) Myprintf(format string, values ...interface{}) {
|
||||
end := len(format)
|
||||
fieldnum := 0
|
||||
for i := 0; i < end; {
|
||||
lasti := i
|
||||
for i < end && format[i] != '%' {
|
||||
i++
|
||||
}
|
||||
if i > lasti {
|
||||
buf.WriteString(format[lasti:i])
|
||||
}
|
||||
if i >= end {
|
||||
break
|
||||
}
|
||||
i++ // '%'
|
||||
switch format[i] {
|
||||
case 'c':
|
||||
switch v := values[fieldnum].(type) {
|
||||
case byte:
|
||||
buf.WriteByte(v)
|
||||
case rune:
|
||||
buf.WriteRune(v)
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected TrackedBuffer type %T", v))
|
||||
}
|
||||
case 's':
|
||||
switch v := values[fieldnum].(type) {
|
||||
case []byte:
|
||||
buf.Write(v)
|
||||
case string:
|
||||
buf.WriteString(v)
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected TrackedBuffer type %T", v))
|
||||
}
|
||||
case 'v':
|
||||
node := values[fieldnum].(SQLNode)
|
||||
if buf.nodeFormatter == nil {
|
||||
node.Format(buf)
|
||||
} else {
|
||||
buf.nodeFormatter(buf, node)
|
||||
}
|
||||
case 'a':
|
||||
buf.WriteArg(values[fieldnum].(string))
|
||||
default:
|
||||
panic("unexpected")
|
||||
}
|
||||
fieldnum++
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
// WriteArg writes a value argument into the buffer along with
|
||||
// tracking information for future substitutions. arg must contain
|
||||
// the ":" or "::" prefix.
|
||||
func (buf *TrackedBuffer) WriteArg(arg string) {
|
||||
buf.bindLocations = append(buf.bindLocations, bindLocation{
|
||||
offset: buf.Len(),
|
||||
length: len(arg),
|
||||
})
|
||||
buf.WriteString(arg)
|
||||
}
|
||||
|
||||
// ParsedQuery returns a ParsedQuery that contains bind
|
||||
// locations for easy substitution.
|
||||
func (buf *TrackedBuffer) ParsedQuery() *ParsedQuery {
|
||||
return &ParsedQuery{Query: buf.String(), bindLocations: buf.bindLocations}
|
||||
}
|
||||
|
||||
// HasBindVars returns true if the parsed query uses bind vars.
|
||||
func (buf *TrackedBuffer) HasBindVars() bool {
|
||||
return len(buf.bindLocations) != 0
|
||||
}
|
||||
|
||||
// BuildParsedQuery builds a ParsedQuery from the input.
|
||||
func BuildParsedQuery(in string, vars ...interface{}) *ParsedQuery {
|
||||
buf := NewTrackedBuffer(nil)
|
||||
buf.Myprintf(in, vars...)
|
||||
return buf.ParsedQuery()
|
||||
}
|
|
@ -97,6 +97,18 @@
|
|||
"revision": "f6be1abbb5abd0517522f850dd785990d373da7e",
|
||||
"revisionTime": "2017-09-13T22:19:17Z"
|
||||
},
|
||||
{
|
||||
"checksumSHA1": "Xmp7mYQyG/1fyIahOyTyN9yZamY=",
|
||||
"path": "github.com/alecthomas/participle",
|
||||
"revision": "bf8340a459bd383e5eb7d44a9a1b3af23b6cf8cd",
|
||||
"revisionTime": "2019-01-03T08:53:15Z"
|
||||
},
|
||||
{
|
||||
"checksumSHA1": "0R8Lqt4DtU8+7Eq1mL7Hd+cjDOI=",
|
||||
"path": "github.com/alecthomas/participle/lexer",
|
||||
"revision": "bf8340a459bd383e5eb7d44a9a1b3af23b6cf8cd",
|
||||
"revisionTime": "2019-01-03T08:53:15Z"
|
||||
},
|
||||
{
|
||||
"checksumSHA1": "tX0Bq1gzqskL98nnB1X2rDqxH18=",
|
||||
"path": "github.com/aliyun/aliyun-oss-go-sdk/oss",
|
||||
|
@ -644,10 +656,10 @@
|
|||
"revisionTime": "2019-01-20T10:05:29Z"
|
||||
},
|
||||
{
|
||||
"checksumSHA1": "pxgHNx36gpRdhSqtaE5fqp7lrAA=",
|
||||
"checksumSHA1": "ik77jlf0oMQTlSndP85DlIVOnOY=",
|
||||
"path": "github.com/minio/parquet-go",
|
||||
"revision": "1014bfb4d0e323e3fbf6683e3519a98b0721f5cc",
|
||||
"revisionTime": "2019-01-14T09:43:57Z"
|
||||
"revision": "7a17a919eeed02c393f3117a9ed1ac6df0da9aa5",
|
||||
"revisionTime": "2019-01-18T04:40:39Z"
|
||||
},
|
||||
{
|
||||
"checksumSHA1": "N4WRPw4p3AN958RH/O53kUsJacQ=",
|
||||
|
@ -888,12 +900,6 @@
|
|||
"revision": "ceec8f93295a060cdb565ec25e4ccf17941dbd55",
|
||||
"revisionTime": "2016-11-14T21:01:44Z"
|
||||
},
|
||||
{
|
||||
"checksumSHA1": "6ksZHYhLc3yOzTbcWKb3bDENhD4=",
|
||||
"path": "github.com/xwb1989/sqlparser",
|
||||
"revision": "120387863bf27d04bc07db8015110a6e96d0146c",
|
||||
"revisionTime": "2018-06-06T15:21:19Z"
|
||||
},
|
||||
{
|
||||
"checksumSHA1": "L/Q8Ylbo+wnj5whDFfMxxwyxmdo=",
|
||||
"path": "github.com/xwb1989/sqlparser/dependency/bytes2",
|
||||
|
|
Loading…
Reference in New Issue